diff --git a/debian/control b/debian/control --- a/debian/control +++ b/debian/control @@ -30,6 +30,9 @@ python3-swh.objstorage (>= 0.0.17~), ${misc:Depends}, ${python3:Depends} +Breaks: python3-swh.archiver (<< 0.0.3~), + python3-swh.indexer (<< 0.0.48~), + python3-swh.vault (<< 0.0.19~) Description: Software Heritage storage utilities Package: python3-swh.storage.listener diff --git a/swh/storage/common.py b/swh/storage/common.py --- a/swh/storage/common.py +++ b/swh/storage/common.py @@ -16,8 +16,9 @@ if 'cur' in kwargs and kwargs['cur']: return meth(self, *args, **kwargs) else: - with self.db.transaction() as cur: - return meth(self, *args, cur=cur, **kwargs) + db = self.get_db() + with db.transaction() as cur: + return meth(self, *args, db=db, cur=cur, **kwargs) return _meth @@ -33,6 +34,7 @@ if 'cur' in kwargs and kwargs['cur']: yield from meth(self, *args, **kwargs) else: - with self.db.transaction() as cur: - yield from meth(self, *args, cur=cur, **kwargs) + db = self.get_db() + with db.transaction() as cur: + yield from meth(self, *args, db=db, cur=cur, **kwargs) return _meth diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -42,14 +42,17 @@ """ try: if isinstance(db, psycopg2.extensions.connection): - self.db = Db(db) + self._db = Db(db) else: - self.db = Db.connect(db) + self._db = Db.connect(db) except psycopg2.OperationalError as e: raise StorageDBError(e) self.objstorage = get_objstorage(**objstorage) + def get_db(self): + return self._db + def check_config(self, *, check_write): """Check that the storage is configured and ready to go.""" @@ -57,7 +60,7 @@ return False # Check permissions on one of the tables - with self.db.transaction() as cur: + with self.get_db().transaction() as cur: if check_write: check = 'INSERT' else: @@ -94,7 +97,7 @@ content in """ - db = self.db + db = self.get_db() def _unique_key(hash, keys=CONTENT_HASH_KEYS): """Given a hash (tuple or dict), return a unique key from the @@ -152,7 +155,7 @@ db.skipped_content_add_from_temp(cur) @db_transaction - def content_update(self, content, keys=[], cur=None): + def content_update(self, content, keys=[], db=None, cur=None): """Update content blobs to the storage. Does nothing for unknown contents or skipped ones. @@ -172,8 +175,6 @@ new hash column """ - db = self.db - # TODO: Add a check on input keys. How to properly implement # this? We don't know yet the new columns. @@ -215,7 +216,7 @@ yield {'sha1': obj_id, 'data': data} @db_transaction_generator - def content_get_metadata(self, content, cur=None): + def content_get_metadata(self, content, db=None, cur=None): """Retrieve content metadata in bulk Args: @@ -224,15 +225,13 @@ Returns: an iterable with content metadata corresponding to the given ids """ - db = self.db - db.store_tmp_bytea(content, cur) for content_metadata in db.content_get_metadata_from_temp(cur): yield dict(zip(db.content_get_metadata_keys, content_metadata)) @db_transaction_generator - def content_missing(self, content, key_hash='sha1', cur=None): + def content_missing(self, content, key_hash='sha1', db=None, cur=None): """List content missing from storage Args: @@ -254,8 +253,6 @@ TODO: an exception when we get a hash collision. """ - db = self.db - keys = CONTENT_HASH_KEYS if key_hash not in CONTENT_HASH_KEYS: @@ -272,7 +269,7 @@ yield obj[key_hash_idx] @db_transaction_generator - def content_missing_per_sha1(self, contents, cur=None): + def content_missing_per_sha1(self, contents, db=None, cur=None): """List content missing from storage based only on sha1. Args: @@ -285,14 +282,12 @@ TODO: an exception when we get a hash collision. """ - db = self.db - db.store_tmp_bytea(contents, cur) for obj in db.content_missing_per_sha1_from_temp(cur): yield obj[0] @db_transaction_generator - def skipped_content_missing(self, content, cur=None): + def skipped_content_missing(self, content, db=None, cur=None): """List skipped_content missing from storage Args: @@ -305,8 +300,6 @@ """ keys = CONTENT_HASH_KEYS - db = self.db - db.mktemp('skipped_content', cur) db.copy_to(content, 'tmp_skipped_content', keys + ['length', 'reason'], cur) @@ -314,7 +307,7 @@ yield from db.skipped_content_missing_from_temp(cur) @db_transaction - def content_find(self, content, cur=None): + def content_find(self, content, db=None, cur=None): """Find a content hash in db. Args: @@ -331,8 +324,6 @@ nor sha256. """ - db = self.db - if not set(content).intersection(ALGORITHMS): raise ValueError('content keys must contain at least one of: ' 'sha1, sha1_git, sha256, blake2s256') @@ -384,7 +375,7 @@ if not dirs_missing: return - db = self.db + db = self.get_db() with db.transaction() as cur: # Copy directory ids dirs_missing_dict = ({'id': dir} for dir in dirs_missing) @@ -412,7 +403,7 @@ db.directory_add_from_temp(cur) @db_transaction_generator - def directory_missing(self, directories, cur): + def directory_missing(self, directories, db=None, cur=None): """List directories missing from storage Args: @@ -422,8 +413,6 @@ missing directory ids """ - db = self.db - # Create temporary table for metadata injection db.mktemp('directory', cur) @@ -435,9 +424,7 @@ yield obj[0] @db_transaction_generator - def directory_get(self, - directories, - cur=None): + def directory_get(self, directories, db=None, cur=None): """Get information on directories. Args: @@ -447,7 +434,6 @@ List of directories as dict with keys and associated values. """ - db = self.db keys = ('id', 'dir_entries', 'file_entries', 'rev_entries') db.mktemp('directory', cur) @@ -459,7 +445,7 @@ yield dict(zip(keys, line)) @db_transaction_generator - def directory_ls(self, directory, recursive=False, cur=None): + def directory_ls(self, directory, recursive=False, db=None, cur=None): """Get entries for one directory. Args: @@ -470,8 +456,6 @@ List of entries for such directory. """ - db = self.db - if recursive: res_gen = db.directory_walk(directory, cur=cur) else: @@ -481,7 +465,7 @@ yield dict(zip(db.directory_ls_cols, line)) @db_transaction - def directory_entry_get_by_path(self, directory, paths, cur=None): + def directory_entry_get_by_path(self, directory, paths, db=None, cur=None): """Get the directory entry (either file or dir) from directory with path. Args: @@ -493,8 +477,6 @@ The corresponding directory entry if found, None otherwise. """ - db = self.db - res = db.directory_entry_get_by_path(directory, paths, cur) if res: return dict(zip(db.directory_ls_cols, res)) @@ -531,7 +513,7 @@ - parents (list of sha1_git): the parents of this revision """ - db = self.db + db = self.get_db() revisions_missing = set(self.revision_missing( set(revision['id'] for revision in revisions))) @@ -559,7 +541,7 @@ ['id', 'parent_id', 'parent_rank'], cur) @db_transaction_generator - def revision_missing(self, revisions, cur=None): + def revision_missing(self, revisions, db=None, cur=None): """List revisions missing from storage Args: @@ -569,15 +551,13 @@ missing revision ids """ - db = self.db - db.store_tmp_bytea(revisions, cur) for obj in db.revision_missing_from_temp(cur): yield obj[0] @db_transaction_generator - def revision_get(self, revisions, cur): + def revision_get(self, revisions, db=None, cur=None): """Get all revisions from storage Args: @@ -588,12 +568,9 @@ revision doesn't exist) """ - - db = self.db - db.store_tmp_bytea(revisions, cur) - for line in self.db.revision_get_from_temp(cur): + for line in db.revision_get_from_temp(cur): data = converters.db_to_revision( dict(zip(db.revision_get_cols, line)) ) @@ -603,7 +580,7 @@ yield data @db_transaction_generator - def revision_log(self, revisions, limit=None, cur=None): + def revision_log(self, revisions, limit=None, db=None, cur=None): """Fetch revision entry from the given root revisions. Args: @@ -614,8 +591,6 @@ List of revision log from such revisions root. """ - db = self.db - for line in db.revision_log(revisions, limit, cur): data = converters.db_to_revision( dict(zip(db.revision_get_cols, line)) @@ -626,7 +601,7 @@ yield data @db_transaction_generator - def revision_shortlog(self, revisions, limit=None, cur=None): + def revision_shortlog(self, revisions, limit=None, db=None, cur=None): """Fetch the shortlog for the given revisions Args: @@ -638,13 +613,11 @@ """ - db = self.db - yield from db.revision_shortlog(revisions, limit, cur) @db_transaction_generator def revision_log_by(self, origin_id, branch_name=None, timestamp=None, - limit=None, cur=None): + limit=None, db=None, cur=None): """Fetch revision entry from the actual origin_id's latest revision. Args: @@ -662,8 +635,6 @@ None if no revision matching this combination is found. """ - db = self.db - # Retrieve the revision by criterion revisions = list(db.revision_get_by( origin_id, branch_name, timestamp, limit=1, cur=cur)) @@ -673,7 +644,7 @@ revision_id = revisions[0][0] # otherwise, retrieve the revision log from that revision - yield from self.revision_log([revision_id], limit, cur=cur) + yield from self.revision_log([revision_id], limit, db=db, cur=cur) def release_add(self, releases): """Add releases to the storage @@ -695,7 +666,7 @@ - author_email (bytes): the email of the release author """ - db = self.db + db = self.get_db() release_ids = set(release['id'] for release in releases) releases_missing = set(self.release_missing(release_ids)) @@ -717,7 +688,7 @@ db.release_add_from_temp(cur) @db_transaction_generator - def release_missing(self, releases, cur=None): + def release_missing(self, releases, db=None, cur=None): """List releases missing from storage Args: @@ -727,8 +698,6 @@ a list of missing release ids """ - db = self.db - # Create temporary table for metadata injection db.store_tmp_bytea(releases, cur) @@ -736,7 +705,7 @@ yield obj[0] @db_transaction_generator - def release_get(self, releases, cur=None): + def release_get(self, releases, db=None, cur=None): """Given a list of sha1, return the releases's information Args: @@ -755,8 +724,6 @@ ValueError: if the keys does not match (url and type) nor id. """ - db = self.db - # Create temporary table for metadata injection db.store_tmp_bytea(releases, cur) @@ -767,7 +734,7 @@ @db_transaction def snapshot_add(self, origin, visit, snapshot, back_compat=False, - cur=None): + db=None, cur=None): """Add a snapshot for the given origin/visit couple Args: @@ -791,8 +758,6 @@ back_compat (bool): whether to add the occurrences for backwards-compatibility """ - db = self.db - if not db.snapshot_exists(snapshot['id'], cur): db.mktemp_snapshot_branch(cur) db.copy_to( @@ -834,10 +799,10 @@ 'target_type': target_type, }) - self.occurrence_add(occurrences, cur=cur) + self.occurrence_add(occurrences, db=db, cur=cur) @db_transaction - def snapshot_get(self, snapshot_id, cur=None): + def snapshot_get(self, snapshot_id, db=None, cur=None): """Get the snapshot with the given id Args: @@ -848,8 +813,6 @@ branches:: a list of branches contained by the snapshot """ - db = self.db - branches = {} for branch in db.snapshot_get_by_id(snapshot_id, cur): branch = dict(zip(db.snapshot_get_cols, branch)) @@ -869,7 +832,7 @@ return None @db_transaction - def snapshot_get_by_origin_visit(self, origin, visit, cur=None): + def snapshot_get_by_origin_visit(self, origin, visit, db=None, cur=None): """Get the snapshot for the given origin visit Args: @@ -881,16 +844,14 @@ branches:: a dictionary containing the snapshot branch information """ - db = self.db - snapshot_id = db.snapshot_get_by_origin_visit(origin, visit, cur) if snapshot_id: - return self.snapshot_get(snapshot_id, cur=cur) + return self.snapshot_get(snapshot_id, db=db, cur=cur) else: # compatibility code during the snapshot migration origin_visit_info = self.origin_visit_get_by(origin, visit, - cur=cur) + db=db, cur=cur) if origin_visit_info is None: return None ret = {'id': None} @@ -900,7 +861,8 @@ return None @db_transaction - def snapshot_get_latest(self, origin, allowed_statuses=None, cur=None): + def snapshot_get_latest(self, origin, allowed_statuses=None, db=None, + cur=None): """Get the latest snapshot for the given origin, optionally only from visits that have one of the given allowed_statuses. @@ -916,16 +878,14 @@ id:: identifier for the snapshot branches:: a dictionary containing the snapshot branch information """ - db = self.db - origin_visit = db.origin_visit_get_latest_snapshot( origin, allowed_statuses=allowed_statuses, cur=cur) if origin_visit: origin_visit = dict(zip(db.origin_visit_get_cols, origin_visit)) - return self.snapshot_get(origin_visit['snapshot'], cur=cur) + return self.snapshot_get(origin_visit['snapshot'], db=db, cur=cur) @db_transaction - def occurrence_add(self, occurrences, cur=None): + def occurrence_add(self, occurrences, db=None, cur=None): """Add occurrences to the storage Args: @@ -943,8 +903,6 @@ occurrence """ - db = self.db - db.mktemp_occurrence_history(cur) db.copy_to(occurrences, 'tmp_occurrence_history', ['origin', 'branch', 'target', 'target_type', 'visit'], cur) @@ -952,7 +910,7 @@ db.occurrence_history_add_from_temp(cur) @db_transaction_generator - def occurrence_get(self, origin_id, cur=None): + def occurrence_get(self, origin_id, db=None, cur=None): """Retrieve occurrence information per origin_id. Args: @@ -962,7 +920,6 @@ List of occurrences matching criterion. """ - db = self.db for line in db.occurrence_get(origin_id, cur): yield { 'origin': line[0], @@ -972,7 +929,7 @@ } @db_transaction - def origin_visit_add(self, origin, ts, cur=None): + def origin_visit_add(self, origin, ts, db=None, cur=None): """Add an origin_visit for the origin at ts with status 'ongoing'. Args: @@ -992,12 +949,12 @@ return { 'origin': origin, - 'visit': self.db.origin_visit_add(origin, ts, cur) + 'visit': db.origin_visit_add(origin, ts, cur) } @db_transaction def origin_visit_update(self, origin, visit_id, status, metadata=None, - cur=None): + db=None, cur=None): """Update an origin_visit's status. Args: @@ -1010,11 +967,11 @@ None """ - return self.db.origin_visit_update(origin, visit_id, status, metadata, - cur) + return db.origin_visit_update(origin, visit_id, status, metadata, cur) @db_transaction_generator - def origin_visit_get(self, origin, last_visit=None, limit=None, cur=None): + def origin_visit_get(self, origin, last_visit=None, limit=None, db=None, + cur=None): """Retrieve all the origin's visit's information. Args: @@ -1028,14 +985,13 @@ List of visits. """ - db = self.db for line in db.origin_visit_get_all( origin, last_visit=last_visit, limit=limit, cur=cur): - data = dict(zip(self.db.origin_visit_get_cols, line)) + data = dict(zip(db.origin_visit_get_cols, line)) yield data @db_transaction - def origin_visit_get_by(self, origin, visit, cur=None): + def origin_visit_get_by(self, origin, visit, db=None, cur=None): """Retrieve origin visit's information. Args: @@ -1045,17 +1001,15 @@ The information on that particular (origin, visit) """ - db = self.db - ori_visit = db.origin_visit_get(origin, visit, cur) if not ori_visit: return None - ori_visit = dict(zip(self.db.origin_visit_get_cols, ori_visit)) + ori_visit = dict(zip(db.origin_visit_get_cols, ori_visit)) if ori_visit['snapshot']: - ori_visit['occurrences'] = self.snapshot_get(ori_visit['snapshot'], - cur=cur)['branches'] + ori_visit['occurrences'] = self.snapshot_get( + ori_visit['snapshot'], db=db, cur=cur)['branches'] return ori_visit # TODO: remove Backwards compatibility after snapshot migration @@ -1077,6 +1031,7 @@ branch_name=None, timestamp=None, limit=None, + db=None, cur=None): """Given an origin_id, retrieve occurrences' list per given criterions. @@ -1091,13 +1046,10 @@ found. """ - for line in self.db.revision_get_by(origin_id, - branch_name, - timestamp, - limit=limit, - cur=cur): + for line in db.revision_get_by(origin_id, branch_name, timestamp, + limit=limit, cur=cur): data = converters.db_to_revision( - dict(zip(self.db.revision_get_cols, line)) + dict(zip(db.revision_get_cols, line)) ) if not data['type']: yield None @@ -1105,7 +1057,7 @@ yield data @db_transaction_generator - def release_get_by(self, origin_id, limit=None, cur=None): + def release_get_by(self, origin_id, limit=None, db=None, cur=None): """Given an origin id, return all the tag objects pointing to heads of origin_id. @@ -1118,15 +1070,14 @@ found. """ - - for line in self.db.release_get_by(origin_id, limit=limit, cur=cur): + for line in db.release_get_by(origin_id, limit=limit, cur=cur): data = converters.db_to_release( - dict(zip(self.db.release_get_cols, line)) + dict(zip(db.release_get_cols, line)) ) yield data @db_transaction - def object_find_by_sha1_git(self, ids, cur=None): + def object_find_by_sha1_git(self, ids, db=None, cur=None): """Return the objects found with the given ids. Args: @@ -1142,8 +1093,6 @@ - object_id: the numeric id of the object found. """ - db = self.db - ret = {id: [] for id in ids} for retval in db.object_find_by_sha1_git(ids, cur=cur): @@ -1156,7 +1105,7 @@ origin_keys = ['id', 'type', 'url', 'lister', 'project'] @db_transaction - def origin_get(self, origin, cur=None): + def origin_get(self, origin, db=None, cur=None): """Return the origin either identified by its id or its tuple (type, url). @@ -1184,8 +1133,6 @@ ValueError: if the keys does not match (url and type) nor id. """ - db = self.db - origin_id = origin.get('id') if origin_id: # check lookup per id first ori = db.origin_get(origin_id, cur) @@ -1200,7 +1147,7 @@ @db_transaction_generator def origin_search(self, url_pattern, offset=0, limit=50, - regexp=False, cur=None): + regexp=False, db=None, cur=None): """Search for origins whose urls contain a provided string pattern or match a provided regular expression. The search is performed in a case insensitive way. @@ -1216,14 +1163,12 @@ An iterable of dict containing origin information as returned by :meth:`swh.storage.storage.Storage.origin_get`. """ - db = self.db - for origin in db.origin_search(url_pattern, offset, limit, regexp, cur): yield dict(zip(self.origin_keys, origin)) @db_transaction - def _person_add(self, person, cur=None): + def _person_add(self, person, db=None, cur=None): """Add a person in storage. Note: Internal function for now, do not use outside of this module. @@ -1238,12 +1183,10 @@ Id of the new person. """ - db = self.db - return db.person_add(person) @db_transaction_generator - def person_get(self, person, cur=None): + def person_get(self, person, db=None, cur=None): """Return the persons identified by their ids. Args: @@ -1253,13 +1196,11 @@ The array of persons corresponding of the ids. """ - db = self.db - for person in db.person_get(person): yield dict(zip(db.person_get_cols, person)) @db_transaction - def origin_add(self, origins, cur=None): + def origin_add(self, origins, db=None, cur=None): """Add origins to the storage Args: @@ -1276,12 +1217,12 @@ ret = [] for origin in origins: - ret.append(self.origin_add_one(origin, cur=cur)) + ret.append(self.origin_add_one(origin, db=db, cur=cur)) return ret @db_transaction - def origin_add_one(self, origin, cur=None): + def origin_add_one(self, origin, db=None, cur=None): """Add origin to the storage Args: @@ -1296,8 +1237,6 @@ exists. """ - db = self.db - data = db.origin_get_with(origin['type'], origin['url'], cur) if data: return data[0] @@ -1305,7 +1244,7 @@ return db.origin_add(origin['type'], origin['url'], cur) @db_transaction - def fetch_history_start(self, origin_id, cur=None): + def fetch_history_start(self, origin_id, db=None, cur=None): """Add an entry for origin origin_id in fetch_history. Returns the id of the added fetch_history entry """ @@ -1314,15 +1253,15 @@ 'date': datetime.datetime.now(tz=datetime.timezone.utc), } - return self.db.create_fetch_history(fetch_history, cur) + return db.create_fetch_history(fetch_history, cur) @db_transaction - def fetch_history_end(self, fetch_history_id, data, cur=None): + def fetch_history_end(self, fetch_history_id, data, db=None, cur=None): """Close the fetch_history entry with id `fetch_history_id`, replacing its data with `data`. """ now = datetime.datetime.now(tz=datetime.timezone.utc) - fetch_history = self.db.get_fetch_history(fetch_history_id, cur) + fetch_history = db.get_fetch_history(fetch_history_id, cur) if not fetch_history: raise ValueError('No fetch_history with id %d' % fetch_history_id) @@ -1331,16 +1270,16 @@ fetch_history.update(data) - self.db.update_fetch_history(fetch_history, cur) + db.update_fetch_history(fetch_history, cur) @db_transaction - def fetch_history_get(self, fetch_history_id, cur=None): + def fetch_history_get(self, fetch_history_id, db=None, cur=None): """Get the fetch_history entry with id `fetch_history_id`. """ - return self.db.get_fetch_history(fetch_history_id, cur) + return db.get_fetch_history(fetch_history_id, cur) @db_transaction - def entity_add(self, entities, cur=None): + def entity_add(self, entities, db=None, cur=None): """Add the given entitites to the database (in entity_history). Args: @@ -1363,8 +1302,6 @@ listed the entity. """ - db = self.db - cols = list(db.entity_history_cols) cols.remove('id') @@ -1373,7 +1310,7 @@ db.entity_history_add_from_temp() @db_transaction_generator - def entity_get_from_lister_metadata(self, entities, cur=None): + def entity_get_from_lister_metadata(self, entities, db=None, cur=None): """Fetch entities from the database, matching with the lister and associated metadata. @@ -1387,8 +1324,6 @@ """ - db = self.db - db.mktemp_entity_lister(cur) mapped_entities = [] @@ -1418,7 +1353,7 @@ } @db_transaction_generator - def entity_get(self, uuid, cur=None): + def entity_get(self, uuid, db=None, cur=None): """Returns the list of entity per its uuid identifier and also its parent hierarchy. @@ -1430,12 +1365,11 @@ hierarchy from such entity. """ - db = self.db for entity in db.entity_get(uuid, cur): yield dict(zip(db.entity_cols, entity)) @db_transaction - def entity_get_one(self, uuid, cur=None): + def entity_get_one(self, uuid, db=None, cur=None): """Returns one entity using its uuid identifier. Args: @@ -1445,7 +1379,6 @@ the object corresponding to the given entity """ - db = self.db entity = db.entity_get_one(uuid, cur) if entity: return dict(zip(db.entity_cols, entity)) @@ -1453,7 +1386,7 @@ return None @db_transaction - def stat_counters(self, cur=None): + def stat_counters(self, db=None, cur=None): """compute statistics about the number of tuples in various tables Returns: @@ -1461,11 +1394,11 @@ integer values (e.g., the number of tuples in table content) """ - return {k: v for (k, v) in self.db.stat_counters()} + return {k: v for (k, v) in db.stat_counters()} @db_transaction def origin_metadata_add(self, origin_id, ts, provider, tool, metadata, - cur=None): + db=None, cur=None): """ Add an origin_metadata for the origin at ts with provenance and metadata. @@ -1482,11 +1415,12 @@ if isinstance(ts, str): ts = dateutil.parser.parse(ts) - return self.db.origin_metadata_add(origin_id, ts, provider, tool, - metadata, cur) + return db.origin_metadata_add(origin_id, ts, provider, tool, + metadata, cur) @db_transaction_generator - def origin_metadata_get_by(self, origin_id, provider_type=None, cur=None): + def origin_metadata_get_by(self, origin_id, provider_type=None, db=None, + cur=None): """Retrieve list of all origin_metadata entries for the origin_id Args: @@ -1507,12 +1441,11 @@ - provider_url (str) """ - db = self.db for line in db.origin_metadata_get_by(origin_id, provider_type, cur): yield dict(zip(db.origin_metadata_get_cols, line)) @db_transaction_generator - def tool_add(self, tools, cur=None): + def tool_add(self, tools, db=None, cur=None): """Add new tools to the storage. Args: @@ -1530,7 +1463,6 @@ guaranteed to match the order of the initial list. """ - db = self.db db.mktemp_tool(cur) db.copy_to(tools, 'tmp_tool', ['name', 'version', 'configuration'], @@ -1541,7 +1473,7 @@ yield dict(zip(db.tool_cols, line)) @db_transaction - def tool_get(self, tool, cur=None): + def tool_get(self, tool, db=None, cur=None): """Retrieve tool information. Args: @@ -1553,7 +1485,6 @@ None otherwise. """ - db = self.db tool_conf = tool['configuration'] if isinstance(tool_conf, dict): tool_conf = json.dumps(tool_conf) @@ -1563,31 +1494,28 @@ tool_conf) if not idx: return None - return dict(zip(self.db.tool_cols, idx)) + return dict(zip(db.tool_cols, idx)) @db_transaction def metadata_provider_add(self, provider_name, provider_type, provider_url, - metadata, cur=None): - db = self.db + metadata, db=None, cur=None): return db.metadata_provider_add(provider_name, provider_type, provider_url, metadata, cur) @db_transaction - def metadata_provider_get(self, provider_id, cur=None): - db = self.db + def metadata_provider_get(self, provider_id, db=None, cur=None): result = db.metadata_provider_get(provider_id) if not result: return None - return dict(zip(self.db.metadata_provider_cols, result)) + return dict(zip(db.metadata_provider_cols, result)) @db_transaction - def metadata_provider_get_by(self, provider, cur=None): - db = self.db + def metadata_provider_get_by(self, provider, db=None, cur=None): result = db.metadata_provider_get_by(provider['provider_name'], provider['provider_url']) if not result: return None - return dict(zip(self.db.metadata_provider_cols, result)) + return dict(zip(db.metadata_provider_cols, result)) def diff_directories(self, from_dir, to_dir, track_renaming=False): """Compute the list of file changes introduced between two arbitrary diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -2483,7 +2483,7 @@ self.storage.content_update([cont], keys=['sha1_git']) - with self.storage.db.transaction() as cur: + with self.storage.get_db().transaction() as cur: cur.execute('SELECT sha1, sha1_git, sha256, length, status' ' FROM content WHERE sha1 = %s', (cont['sha1'],)) @@ -2497,7 +2497,7 @@ @istest def content_update_with_new_cols(self): - with self.storage.db.transaction() as cur: + with self.storage.get_db().transaction() as cur: cur.execute("""alter table content add column test text default null, add column test2 text default null""") @@ -2508,7 +2508,7 @@ cont['test2'] = 'value-2' self.storage.content_update([cont], keys=['test', 'test2']) - with self.storage.db.transaction() as cur: + with self.storage.get_db().transaction() as cur: cur.execute( 'SELECT sha1, sha1_git, sha256, length, status, test, test2' ' FROM content WHERE sha1 = %s', @@ -2522,6 +2522,6 @@ (cont['sha1'], cont['sha1_git'], cont['sha256'], cont['length'], 'visible', cont['test'], cont['test2'])) - with self.storage.db.transaction() as cur: + with self.storage.get_db().transaction() as cur: cur.execute("""alter table content drop column test, drop column test2""")