diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -685,39 +685,17 @@ return results - def origin_get(self, origins): - if isinstance(origins, dict): - # Old API - return_single = True - origins = [origins] - else: - return_single = False - - if any("id" in origin for origin in origins): - raise StorageArgumentException("Origin ids are not supported.") - - results = [self.origin_get_one(origin) for origin in origins] - - if return_single: - assert len(results) == 1 - return results[0] - else: - return results + def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]: + return [self.origin_get_one(origin) for origin in origins] - def origin_get_one(self, origin: Dict[str, Any]) -> Optional[Dict[str, Any]]: - if "id" in origin: - raise StorageArgumentException("Origin ids are not supported.") - if "url" not in origin: - raise StorageArgumentException("Missing origin url") - rows = self._cql_runner.origin_get_by_url(origin["url"]) + def origin_get_one(self, origin_url: str) -> Optional[Origin]: + """Given an origin url, return the origin if it exists, None otherwise - rows = list(rows) + """ + rows = list(self._cql_runner.origin_get_by_url(origin_url)) if rows: assert len(rows) == 1 - result = rows[0]._asdict() - return { - "url": result["url"], - } + return Origin(url=rows[0].url) else: return None @@ -770,12 +748,7 @@ def origin_add(self, origins: Iterable[Origin]) -> Dict[str, int]: origins = list(origins) - known_origins = [ - Origin.from_dict(d) - for d in self.origin_get([origin.to_dict() for origin in origins]) - if d is not None - ] - to_add = [origin for origin in origins if origin not in known_origins] + to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None] self.journal_writer.origin_add(to_add) for origin in to_add: self._cql_runner.origin_add_one(origin) @@ -783,7 +756,7 @@ def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]: for visit in visits: - origin = self.origin_get({"url": visit.origin}) + origin = self.origin_get_one(visit.origin) if not origin: # Cannot add a visit without an origin raise StorageArgumentException("Unknown origin %s", visit.origin) @@ -822,7 +795,7 @@ ) -> None: # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: - origin_url = self.origin_get({"url": visit_status.origin}) + origin_url = self.origin_get_one(visit_status.origin) if not origin_url: raise StorageArgumentException(f"Unknown origin {visit_status.origin}") diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -670,43 +670,11 @@ return t.to_dict() - def origin_get(self, origins): - if isinstance(origins, dict): - # Old API - return_single = True - origins = [origins] - else: - return_single = False + def origin_get_one(self, origin_url: str) -> Optional[Origin]: + return self._origins.get(origin_url) - # Sanity check to be error-compatible with the pgsql backend - if any("id" in origin for origin in origins) and not all( - "id" in origin for origin in origins - ): - raise StorageArgumentException( - 'Either all origins or none at all should have an "id".' - ) - if any("url" in origin for origin in origins) and not all( - "url" in origin for origin in origins - ): - raise StorageArgumentException( - "Either all origins or none at all should have " 'an "url" key.' - ) - - results = [] - for origin in origins: - result = None - if "url" in origin: - if origin["url"] in self._origins: - result = self._origins[origin["url"]] - else: - raise StorageArgumentException("Origin must have an url.") - results.append(self._convert_origin(result)) - - if return_single: - assert len(results) == 1 - return results[0] - else: - return results + def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]: + return [self.origin_get_one(origin_url) for origin_url in origins] def origin_get_by_sha1(self, sha1s): return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s] @@ -803,7 +771,7 @@ def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]: for visit in visits: - origin = self.origin_get({"url": visit.origin}) + origin = self.origin_get_one(visit.origin) if not origin: # Cannot add a visit without an origin raise StorageArgumentException("Unknown origin %s", visit.origin) @@ -855,7 +823,7 @@ ) -> None: # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: - origin_url = self.origin_get({"url": visit_status.origin}) + origin_url = self.origin_get_one(visit_status.origin) if not origin_url: raise StorageArgumentException(f"Unknown origin {visit_status.origin}") diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -944,28 +944,15 @@ ... @remote_api_endpoint("origin/get") - def origin_get(self, origins): - """Return origins, either all identified by their ids or all - identified by tuples (type, url). - - If the url is given and the type is omitted, one of the origins with - that url is returned. + def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]: + """Return origins. Args: - origin: a list of dictionaries representing the individual - origins to find. - These dicts have the key url: - - - url (bytes): the url the origin points to + origin: a list of urls to find Returns: - dict: the origin dictionary with the keys: - - - id: origin's id - - url: origin's url - - Raises: - ValueError: if the url or the id don't exist. + the list of associated existing origin model objects. The unknown origins + will be returned as None at the same index as the input. """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -802,7 +802,7 @@ self, visits: Iterable[OriginVisit], db=None, cur=None ) -> Iterable[OriginVisit]: for visit in visits: - origin = self.origin_get({"url": visit.origin}, db=db, cur=cur) + origin = self.origin_get([visit.origin], db=db, cur=cur)[0] if not origin: # Cannot add a visit without an origin raise StorageArgumentException("Unknown origin %s", visit.origin) @@ -851,7 +851,7 @@ ) -> None: # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: - origin_url = self.origin_get({"url": visit_status.origin}, db=db, cur=cur) + origin_url = self.origin_get([visit_status.origin], db=db, cur=cur)[0] if not origin_url: raise StorageArgumentException(f"Unknown origin {visit_status.origin}") @@ -973,28 +973,17 @@ @timed @db_transaction(statement_timeout=500) - def origin_get(self, origins, db=None, cur=None): - if isinstance(origins, dict): - # Old API - return_single = True - origins = [origins] - elif len(origins) == 0: - return [] - else: - return_single = False - - origin_urls = [origin["url"] for origin in origins] - results = db.origin_get_by_url(origin_urls, cur) - - results = [dict(zip(db.origin_cols, result)) for result in results] - if return_single: - assert len(results) == 1 - if results[0]["url"] is not None: - return results[0] - else: - return None - else: - return [None if res["url"] is None else res for res in results] + def origin_get( + self, origins: Iterable[str], db=None, cur=None + ) -> Iterable[Optional[Origin]]: + origin_urls = list(origins) + rows = db.origin_get_by_url(origin_urls, cur) + result: List[Optional[Origin]] = [] + for row in rows: + origin_d = dict(zip(db.origin_cols, row)) + url = origin_d["url"] + result.append(None if url is None else Origin(url=url)) + return result @timed @db_transaction_generator(statement_timeout=500) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1169,22 +1169,19 @@ } def test_origin_add(self, swh_storage, sample_data): - origin, origin2 = sample_data.origins[:2] - origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]] + origins = list(sample_data.origins[:2]) + origin_urls = [o.url for o in origins] - assert swh_storage.origin_get([origin_dict])[0] is None + assert swh_storage.origin_get(origin_urls) == [None, None] - stats = swh_storage.origin_add([origin, origin2]) + stats = swh_storage.origin_add(origins) assert stats == {"origin:add": 2} - actual_origin = swh_storage.origin_get([origin_dict])[0] - assert actual_origin["url"] == origin.url - - actual_origin2 = swh_storage.origin_get([origin2_dict])[0] - assert actual_origin2["url"] == origin2.url + actual_origins = swh_storage.origin_get(origin_urls) + assert actual_origins == origins assert set(swh_storage.journal_writer.journal.objects) == set( - [("origin", origin), ("origin", origin2),] + [("origin", origins[0]), ("origin", origins[1]),] ) swh_storage.refresh_stat_counters() @@ -1192,7 +1189,6 @@ def test_origin_add_from_generator(self, swh_storage, sample_data): origin, origin2 = sample_data.origins[:2] - origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]] def _ori_gen(): yield origin @@ -1201,11 +1197,8 @@ stats = swh_storage.origin_add(_ori_gen()) assert stats == {"origin:add": 2} - actual_origin = swh_storage.origin_get([origin_dict])[0] - assert actual_origin["url"] == origin.url - - actual_origin2 = swh_storage.origin_get([origin2_dict])[0] - assert actual_origin2["url"] == origin2.url + actual_origins = swh_storage.origin_get([origin.url, origin2.url]) + assert actual_origins == [origin, origin2] assert set(swh_storage.journal_writer.journal.objects) == set( [("origin", origin), ("origin", origin2),] @@ -1216,7 +1209,6 @@ def test_origin_add_twice(self, swh_storage, sample_data): origin, origin2 = sample_data.origins[:2] - origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]] add1 = swh_storage.origin_add([origin, origin2]) assert set(swh_storage.journal_writer.journal.objects) == set( @@ -1230,33 +1222,17 @@ ) assert add2 == {"origin:add": 0} - def test_origin_get_legacy(self, swh_storage, sample_data): - origin, origin2 = sample_data.origins[:2] - origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]] - - assert swh_storage.origin_get(origin_dict) is None - swh_storage.origin_add([origin]) - - actual_origin0 = swh_storage.origin_get(origin_dict) - assert actual_origin0["url"] == origin.url - def test_origin_get(self, swh_storage, sample_data): origin, origin2 = sample_data.origins[:2] - origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]] - assert swh_storage.origin_get(origin_dict) is None - assert swh_storage.origin_get([origin_dict]) == [None] + assert swh_storage.origin_get([origin.url]) == [None] swh_storage.origin_add([origin]) - actual_origins = swh_storage.origin_get([origin_dict]) - assert len(actual_origins) == 1 - - actual_origin0 = swh_storage.origin_get(origin_dict) - assert actual_origin0 == actual_origins[0] - assert actual_origin0["url"] == origin.url + actual_origins = swh_storage.origin_get([origin.url]) + assert actual_origins == [origin] - actual_origins = swh_storage.origin_get([origin_dict, {"url": "not://exists"}]) - assert actual_origins == [origin_dict, None] + actual_origins = swh_storage.origin_get([origin.url, "not://exists"]) + assert actual_origins == [origin, None] def _generate_random_visits(self, nb_visits=100, start=0, end=7): """Generate random visits within the last 2 months (to avoid @@ -1422,7 +1398,7 @@ def test_origin_get_by_sha1(self, swh_storage, sample_data): origin = sample_data.origin - assert swh_storage.origin_get(origin.to_dict()) is None + assert swh_storage.origin_get([origin.url])[0] is None swh_storage.origin_add([origin]) origins = list(swh_storage.origin_get_by_sha1([sha1(origin.url)])) @@ -1431,7 +1407,7 @@ def test_origin_get_by_sha1_not_found(self, swh_storage, sample_data): origin = sample_data.origin - assert swh_storage.origin_get(origin.to_dict()) is None + assert swh_storage.origin_get([origin.url])[0] is None origins = list(swh_storage.origin_get_by_sha1([sha1(origin.url)])) assert len(origins) == 1 assert origins[0] is None