diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -72,38 +72,40 @@ else: return Db.from_pool(self._pool) - def check_config(self, *, check_write): + def put_db(self, db): + if db is not self._db: + db.put_conn() + + @db_transaction() + def check_config(self, *, check_write, db, cur): """Check that the storage is configured and ready to go.""" if not self.objstorage.check_config(check_write=check_write): return False # Check permissions on one of the tables - with self.get_db().transaction() as cur: - if check_write: - check = 'INSERT' - else: - check = 'SELECT' - - cur.execute( - "select has_table_privilege(current_user, 'content', %s)", - (check,) - ) - return cur.fetchone()[0] + if check_write: + check = 'INSERT' + else: + check = 'SELECT' - return True + cur.execute( + "select has_table_privilege(current_user, 'content', %s)", + (check,) + ) + return cur.fetchone()[0] - def _content_unique_key(self, hash): + def _content_unique_key(self, hash, db): """Given a hash (tuple or dict), return a unique key from the aggregation of keys. """ - keys = self.get_db().content_hash_keys + keys = db.content_hash_keys if isinstance(hash, tuple): return hash return tuple([hash[k] for k in keys]) - def _filter_new_content(self, content): + def _filter_new_content(self, content, db, cur): content_by_status = defaultdict(list) for d in content: if 'status' not in d: @@ -115,17 +117,18 @@ content_with_data = content_by_status['visible'] content_without_data = content_by_status['absent'] - missing_content = set(self.content_missing(content_with_data)) - missing_skipped = set(self._content_unique_key(hashes) for hashes - in self.skipped_content_missing( - content_without_data)) + missing_content = set(self.content_missing(content_with_data, + db=db, cur=cur)) + missing_skipped = set(self._content_unique_key(hashes, db) + for hashes in self.skipped_content_missing( + content_without_data, db=db, cur=cur)) content_with_data = [ cont for cont in content_with_data if cont['sha1'] in missing_content] content_without_data = [ cont for cont in content_without_data - if self._content_unique_key(cont) in missing_skipped] + if self._content_unique_key(cont, db) in missing_skipped] summary = { 'content:add': len(missing_content), @@ -169,7 +172,8 @@ # move metadata in place db.skipped_content_add_from_temp(cur) - def content_add(self, content): + @db_transaction() + def content_add(self, content, db, cur): """Add content blobs to the storage Note: in case of DB errors, objects might have already been added to @@ -219,10 +223,8 @@ del item['data'] self.journal_writer.write_addition('content', item) - db = self.get_db() - (content_with_data, content_without_data, summary) = \ - self._filter_new_content(content) + self._filter_new_content(content, db, cur) def add_to_objstorage(): """Add to objstorage the new missing_content @@ -245,16 +247,15 @@ self.objstorage.add_batch(data) return content_bytes_added - with db.transaction() as cur: - with ThreadPoolExecutor(max_workers=1) as executor: - added_to_objstorage = executor.submit(add_to_objstorage) + with ThreadPoolExecutor(max_workers=1) as executor: + added_to_objstorage = executor.submit(add_to_objstorage) - self._content_add_metadata( - db, cur, content_with_data, content_without_data) + self._content_add_metadata( + db, cur, content_with_data, content_without_data) - # Wait for objstorage addition before returning from the - # transaction, bubbling up any exception - content_bytes_added = added_to_objstorage.result() + # Wait for objstorage addition before returning from the + # transaction, bubbling up any exception + content_bytes_added = added_to_objstorage.result() summary['content:bytes:add'] = content_bytes_added return summary @@ -293,7 +294,8 @@ db.content_update_from_temp(keys_to_update=keys, cur=cur) - def content_add_metadata(self, content): + @db_transaction() + def content_add_metadata(self, content, db, cur): """Add content metadata to the storage (like `content_add`, but without inserting to the objstorage). @@ -323,14 +325,11 @@ assert 'data' not in content self.journal_writer.write_addition('content', item) - db = self.get_db() - (content_with_data, content_without_data, summary) = \ - self._filter_new_content(content) + self._filter_new_content(content, db, cur) - with db.transaction() as cur: - self._content_add_metadata( - db, cur, content_with_data, content_without_data) + self._content_add_metadata( + db, cur, content_with_data, content_without_data) return summary @@ -521,7 +520,8 @@ return dict(zip(db.content_find_cols, c)) return None - def directory_add(self, directories): + @db_transaction() + def directory_add(self, directories, db, cur): """Add directories to the storage Args: @@ -569,33 +569,31 @@ if not dirs_missing: return summary - db = self.get_db() - with db.transaction() as cur: - # Copy directory ids - dirs_missing_dict = ({'id': dir} for dir in dirs_missing) - db.mktemp('directory', cur) - db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) + # Copy directory ids + dirs_missing_dict = ({'id': dir} for dir in dirs_missing) + db.mktemp('directory', cur) + db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) - # Copy entries - for entry_type, entry_list in dir_entries.items(): - entries = itertools.chain.from_iterable( - entries_for_dir - for dir_id, entries_for_dir - in entry_list.items() - if dir_id in dirs_missing) + # Copy entries + for entry_type, entry_list in dir_entries.items(): + entries = itertools.chain.from_iterable( + entries_for_dir + for dir_id, entries_for_dir + in entry_list.items() + if dir_id in dirs_missing) - db.mktemp_dir_entry(entry_type) + db.mktemp_dir_entry(entry_type) - db.copy_to( - entries, - 'tmp_directory_entry_%s' % entry_type, - ['target', 'name', 'perms', 'dir_id'], - cur, - ) + db.copy_to( + entries, + 'tmp_directory_entry_%s' % entry_type, + ['target', 'name', 'perms', 'dir_id'], + cur, + ) - # Do the final copy - db.directory_add_from_temp(cur) - summary['directory:add'] = len(dirs_missing) + # Do the final copy + db.directory_add_from_temp(cur) + summary['directory:add'] = len(dirs_missing) return summary @@ -653,7 +651,8 @@ if res: return dict(zip(db.directory_ls_cols, res)) - def revision_add(self, revisions): + @db_transaction() + def revision_add(self, revisions, db, cur): """Add revisions to the storage Args: @@ -695,32 +694,29 @@ if self.journal_writer: self.journal_writer.write_additions('revision', revisions) - db = self.get_db() - revisions_missing = set(self.revision_missing( set(revision['id'] for revision in revisions))) if not revisions_missing: return summary - with db.transaction() as cur: - db.mktemp_revision(cur) + db.mktemp_revision(cur) - revisions_filtered = ( - converters.revision_to_db(revision) for revision in revisions - if revision['id'] in revisions_missing) + revisions_filtered = ( + converters.revision_to_db(revision) for revision in revisions + if revision['id'] in revisions_missing) - parents_filtered = [] + parents_filtered = [] - db.copy_to( - revisions_filtered, 'tmp_revision', db.revision_add_cols, - cur, - lambda rev: parents_filtered.extend(rev['parents'])) + db.copy_to( + revisions_filtered, 'tmp_revision', db.revision_add_cols, + cur, + lambda rev: parents_filtered.extend(rev['parents'])) - db.revision_add_from_temp(cur) + db.revision_add_from_temp(cur) - db.copy_to(parents_filtered, 'revision_history', - ['id', 'parent_id', 'parent_rank'], cur) + db.copy_to(parents_filtered, 'revision_history', + ['id', 'parent_id', 'parent_rank'], cur) return {'revision:add': len(revisions_missing)} @@ -798,7 +794,8 @@ yield from db.revision_shortlog(revisions, limit, cur) - def release_add(self, releases): + @db_transaction() + def release_add(self, releases, db, cur): """Add releases to the storage Args: @@ -829,26 +826,26 @@ if self.journal_writer: self.journal_writer.write_additions('release', releases) - db = self.get_db() - release_ids = set(release['id'] for release in releases) - releases_missing = set(self.release_missing(release_ids)) + releases_missing = set(self.release_missing(release_ids, + db=db, cur=cur)) if not releases_missing: return summary - with db.transaction() as cur: - db.mktemp_release(cur) + db.mktemp_release(cur) - releases_filtered = ( - converters.release_to_db(release) for release in releases - if release['id'] in releases_missing - ) + releases_missing = list(releases_missing) + + releases_filtered = ( + converters.release_to_db(release) for release in releases + if release['id'] in releases_missing + ) - db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, - cur) + db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, + cur) - db.release_add_from_temp(cur) + db.release_add_from_temp(cur) return {'release:add': len(releases_missing)}