diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -763,12 +763,17 @@ return [{"url": orig.url,} for orig in origins[offset : offset + limit]] - def origin_add(self, origins: Iterable[Origin]) -> List[Dict]: - results = [] - for origin in origins: + def origin_add(self, origins: Iterable[Origin]) -> Dict[str, int]: + known_origins = [ + Origin.from_dict(d) + for d in self.origin_get([origin.to_dict() for origin in origins]) + if d is not None + ] + to_add = [origin for origin in origins if origin not in known_origins] + self.journal_writer.origin_add(to_add) + for origin in to_add: self.origin_add_one(origin) - results.append(origin.to_dict()) - return results + return {"origin:add": len(to_add)} def origin_add_one(self, origin: Origin) -> str: known_origin = self.origin_get_one(origin.to_dict()) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -753,11 +753,15 @@ ) ) - def origin_add(self, origins: Iterable[Origin]) -> List[Dict]: - origins = copy.deepcopy(list(origins)) + def origin_add(self, origins: Iterable[Origin]) -> Dict[str, int]: + origins = list(origins) + added = 0 for origin in origins: - self.origin_add_one(origin) - return [origin.to_dict() for origin in origins] + if origin.url not in self._origins: + self.origin_add_one(origin) + added += 1 + + return {"origin:add": added} def origin_add_one(self, origin: Origin) -> str: if origin.url not in self._origins: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -1047,7 +1047,7 @@ ... @remote_api_endpoint("origin/add_multi") - def origin_add(self, origins: Iterable[Origin]) -> List[Dict]: + def origin_add(self, origins: Iterable[Origin]) -> Dict[str, int]: """Add origins to the storage Args: @@ -1058,7 +1058,9 @@ - url (bytes): the url the origin points to Returns: - list: given origins as dict updated with their id + Summary dict of keys with associated count as values + + origin:add: Count of object actually stored in db """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1093,26 +1093,28 @@ @timed @db_transaction() - def origin_add(self, origins: Iterable[Origin], db=None, cur=None) -> List[Dict]: - origins = list(origins) - for origin in origins: - self.origin_add_one(origin, db=db, cur=cur) - - return [o.to_dict() for o in origins] + def origin_add( + self, origins: Iterable[Origin], db=None, cur=None + ) -> Dict[str, int]: + urls = [o.url for o in origins] + known_origins = set(url for (url,) in db.origin_get_by_url(urls, cur)) + # use lists here to keep origins sorted; some tests depend on this + to_add = [url for url in urls if url not in known_origins] + + self.journal_writer.origin_add([Origin(url=url) for url in to_add]) + added = 0 + for url in to_add: + if db.origin_add(url, cur): + added += 1 + return {"origin:add": added} @timed @db_transaction() def origin_add_one(self, origin: Origin, db=None, cur=None) -> str: - origin_row = list(db.origin_get_by_url([origin.url], cur))[0] - origin_url = dict(zip(db.origin_cols, origin_row))["url"] - if origin_url: - return origin_url - - self.journal_writer.origin_add([origin]) - - url = db.origin_add(origin.url, cur) - send_metric("origin:add", count=1, method_name="origin_add_one") - return url + stats = self.origin_add([origin]) + if stats.get("origin:add", 0): + send_metric("origin:add", count=1, method_name="origin_add_one") + return origin.url @db_transaction(statement_timeout=500) def stat_counters(self, db=None, cur=None): diff --git a/swh/storage/tests/algos/test_origin.py b/swh/storage/tests/algos/test_origin.py --- a/swh/storage/tests/algos/test_origin.py +++ b/swh/storage/tests/algos/test_origin.py @@ -26,7 +26,12 @@ def test_iter_origins(swh_storage): - origins = swh_storage.origin_add([{"url": "bar"}, {"url": "qux"}, {"url": "quuz"},]) + origins = [ + {"url": "bar"}, + {"url": "qux"}, + {"url": "quuz"}, + ] + assert swh_storage.origin_add(origins) == {"origin:add": 3} assert_list_eq(iter_origins(swh_storage), origins) assert_list_eq(iter_origins(swh_storage, batch_size=1), origins) assert_list_eq(iter_origins(swh_storage, batch_size=2), origins) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1314,22 +1314,21 @@ origin0 = swh_storage.origin_get([data.origin])[0] assert origin0 is None - origin1, origin2 = swh_storage.origin_add([data.origin, data.origin2]) + stats = swh_storage.origin_add([data.origin, data.origin2]) + assert stats == {"origin:add": 2} actual_origin = swh_storage.origin_get([{"url": data.origin["url"],}])[0] - assert actual_origin["url"] == origin1["url"] + assert actual_origin["url"] == data.origin["url"] actual_origin2 = swh_storage.origin_get([{"url": data.origin2["url"],}])[0] - assert actual_origin2["url"] == origin2["url"] - - if "id" in actual_origin: - del actual_origin["id"] - del actual_origin2["id"] + assert actual_origin2["url"] == data.origin2["url"] - assert list(swh_storage.journal_writer.journal.objects) == [ - ("origin", Origin.from_dict(actual_origin)), - ("origin", Origin.from_dict(actual_origin2)), - ] + assert set(swh_storage.journal_writer.journal.objects) == set( + [ + ("origin", Origin.from_dict(actual_origin)), + ("origin", Origin.from_dict(actual_origin2)), + ] + ) swh_storage.refresh_stat_counters() assert swh_storage.stat_counters()["origin"] == 2 @@ -1339,40 +1338,47 @@ yield data.origin yield data.origin2 - origin1, origin2 = swh_storage.origin_add(_ori_gen()) + stats = swh_storage.origin_add(_ori_gen()) + assert stats == {"origin:add": 2} actual_origin = swh_storage.origin_get([{"url": data.origin["url"],}])[0] - assert actual_origin["url"] == origin1["url"] + assert actual_origin["url"] == data.origin["url"] actual_origin2 = swh_storage.origin_get([{"url": data.origin2["url"],}])[0] - assert actual_origin2["url"] == origin2["url"] + assert actual_origin2["url"] == data.origin2["url"] if "id" in actual_origin: del actual_origin["id"] del actual_origin2["id"] - assert list(swh_storage.journal_writer.journal.objects) == [ - ("origin", Origin.from_dict(actual_origin)), - ("origin", Origin.from_dict(actual_origin2)), - ] + assert set(swh_storage.journal_writer.journal.objects) == set( + [ + ("origin", Origin.from_dict(actual_origin)), + ("origin", Origin.from_dict(actual_origin2)), + ] + ) swh_storage.refresh_stat_counters() assert swh_storage.stat_counters()["origin"] == 2 def test_origin_add_twice(self, swh_storage): add1 = swh_storage.origin_add([data.origin, data.origin2]) - assert list(swh_storage.journal_writer.journal.objects) == [ - ("origin", Origin.from_dict(data.origin)), - ("origin", Origin.from_dict(data.origin2)), - ] + assert set(swh_storage.journal_writer.journal.objects) == set( + [ + ("origin", Origin.from_dict(data.origin)), + ("origin", Origin.from_dict(data.origin2)), + ] + ) + assert add1 == {"origin:add": 2} add2 = swh_storage.origin_add([data.origin, data.origin2]) - assert list(swh_storage.journal_writer.journal.objects) == [ - ("origin", Origin.from_dict(data.origin)), - ("origin", Origin.from_dict(data.origin2)), - ] - - assert add1 == add2 + assert set(swh_storage.journal_writer.journal.objects) == set( + [ + ("origin", Origin.from_dict(data.origin)), + ("origin", Origin.from_dict(data.origin2)), + ] + ) + assert add2 == {"origin:add": 0} def test_origin_add_validation(self, swh_storage): """Incorrect formatted origin should fail the validation @@ -1394,12 +1400,18 @@ def test_origin_get(self, swh_storage): assert swh_storage.origin_get(data.origin) is None + assert swh_storage.origin_get([data.origin]) == [None] swh_storage.origin_add_one(data.origin) actual_origin0 = swh_storage.origin_get([{"url": data.origin["url"]}]) assert len(actual_origin0) == 1 assert actual_origin0[0]["url"] == data.origin["url"] + actual_origins = swh_storage.origin_get( + [{"url": data.origin["url"]}, {"url": "not://exists"}] + ) + assert actual_origins == [{"url": data.origin["url"]}, None] + def _generate_random_visits(self, nb_visits=100, start=0, end=7): """Generate random visits within the last 2 months (to avoid computations) @@ -1860,15 +1872,13 @@ visit_status.pop("type") expected_visit_statuses.append(OriginVisitStatus.from_dict(visit_status)) - # write twice in the journal - expected_visit_statuses += [visit_status1] * 2 + expected_visit_statuses += [visit_status1] expected_objects = ( [("origin", o) for o in expected_origins] + [("origin_visit", v) for v in expected_visits] + [("origin_visit_status", ovs) for ovs in expected_visit_statuses] ) - assert len(actual_objects) == len(expected_objects) for obj in expected_objects: assert obj in actual_objects @@ -3234,7 +3244,7 @@ origin = data.origin fetcher = data.metadata_fetcher authority = data.metadata_authority - swh_storage.origin_add([origin])[0] + assert swh_storage.origin_add([origin]) == {"origin:add": 1} swh_storage.metadata_fetcher_add(**fetcher) swh_storage.metadata_authority_add(**authority) @@ -3253,7 +3263,7 @@ origin = data.origin fetcher = data.metadata_fetcher authority = data.metadata_authority - swh_storage.origin_add([origin])[0] + assert swh_storage.origin_add([origin]) == {"origin:add": 1} new_origin_metadata2 = { **data.origin_metadata2, @@ -3278,7 +3288,7 @@ origin = data.origin fetcher = data.metadata_fetcher authority = data.metadata_authority - swh_storage.origin_add([origin])[0] + assert swh_storage.origin_add([origin]) == {"origin:add": 1} swh_storage.metadata_fetcher_add(**fetcher) swh_storage.metadata_authority_add(**authority) @@ -3296,8 +3306,7 @@ fetcher2 = data.metadata_fetcher2 origin_url1 = data.origin["url"] origin_url2 = data.origin2["url"] - swh_storage.origin_add([data.origin]) - swh_storage.origin_add([data.origin2]) + assert swh_storage.origin_add([data.origin, data.origin2]) == {"origin:add": 2} origin1_metadata1 = data.origin_metadata origin1_metadata2 = data.origin_metadata2 @@ -3334,7 +3343,7 @@ origin = data.origin fetcher = data.metadata_fetcher authority = data.metadata_authority - swh_storage.origin_add([origin])[0] + assert swh_storage.origin_add([origin]) == {"origin:add": 1} swh_storage.metadata_fetcher_add(**fetcher) swh_storage.metadata_authority_add(**authority) @@ -3368,7 +3377,7 @@ origin = data.origin fetcher = data.metadata_fetcher authority = data.metadata_authority - swh_storage.origin_add([origin])[0] + assert swh_storage.origin_add([origin]) == {"origin:add": 1} swh_storage.metadata_fetcher_add(**fetcher) swh_storage.metadata_authority_add(**authority) @@ -3393,7 +3402,7 @@ fetcher1 = data.metadata_fetcher fetcher2 = data.metadata_fetcher2 authority = data.metadata_authority - swh_storage.origin_add([origin])[0] + assert swh_storage.origin_add([origin]) == {"origin:add": 1} swh_storage.metadata_fetcher_add(**fetcher1) swh_storage.metadata_fetcher_add(**fetcher2) diff --git a/swh/storage/validate.py b/swh/storage/validate.py --- a/swh/storage/validate.py +++ b/swh/storage/validate.py @@ -4,7 +4,7 @@ # See top-level LICENSE file for more information import contextlib -from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union +from typing import Dict, Iterable, Iterator, Optional, Tuple, Type, TypeVar, Union from swh.model.model import ( SkippedContent, @@ -138,7 +138,7 @@ def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]: return self.storage.origin_visit_add(visits) - def origin_add(self, origins: Iterable[Union[Dict, Origin]]) -> List: + def origin_add(self, origins: Iterable[Union[Dict, Origin]]) -> Dict[str, int]: return self.storage.origin_add([dict_converter(Origin, o) for o in origins]) def origin_add_one(self, origin: Union[Dict, Origin]) -> int: