diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -958,7 +958,7 @@ snapshot=visit.snapshot, metadata=visit.metadata, ) - self._cql_runner.origin_visit_status_add_one(visit_status) + self._origin_visit_status_add(visit_status) @staticmethod def _format_origin_visit_row(visit): diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -939,7 +939,7 @@ visit_key = (origin_url, visit.visit) with convert_validation_exceptions(): - visit_update = OriginVisitStatus( + visit_status = OriginVisitStatus( origin=origin_url, visit=visit.visit, date=date, @@ -948,12 +948,11 @@ metadata=visit.metadata, ) - self._origin_visit_statuses.setdefault(visit_key, []) while len(self._origin_visits[origin_url]) < visit.visit: self._origin_visits[origin_url].append(None) self._origin_visits[origin_url][visit.visit - 1] = visit - self._origin_visit_statuses[visit_key].append(visit_update) + self._origin_visit_status_add_one(visit_status) self._objects[visit_key].append(("origin_visit", None)) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1006,26 +1006,35 @@ def origin_visit_upsert( self, visits: Iterable[OriginVisit], db=None, cur=None ) -> None: + visit_statuses = [] + nb_visits = 0 for visit in visits: + nb_visits += 1 if visit.visit is None: raise StorageArgumentException(f"Missing visit id for visit {visit}") + with convert_validation_exceptions(): + visit_statuses.append( + OriginVisitStatus( + origin=visit.origin, + visit=visit.visit, + date=now(), + status=visit.status, + snapshot=visit.snapshot, + metadata=visit.metadata, + ) + ) + + assert len(visit_statuses) == nb_visits + # write in journal first self.journal_writer.origin_visit_upsert(visits) + self.journal_writer.origin_visit_status_add(visit_statuses) - for visit in visits: - # TODO: upsert them all in a single query + # then sync to db + for i, visit in enumerate(visits): assert visit.visit is not None db.origin_visit_upsert(visit, cur=cur) - with convert_validation_exceptions(): - visit_status = OriginVisitStatus( - origin=visit.origin, - visit=visit.visit, - date=now(), - status=visit.status, - snapshot=visit.snapshot, - metadata=visit.metadata, - ) - db.origin_visit_status_add(visit_status, cur=cur) + db.origin_visit_status_add(visit_statuses[i], cur=cur) @timed @db_transaction_generator(statement_timeout=500) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -22,11 +22,12 @@ from hypothesis import given, strategies, settings, HealthCheck -from typing import ClassVar, Optional +from typing import ClassVar, Iterable, Optional, Tuple from swh.model import from_disk, identifiers from swh.model.hashutil import hash_to_bytes from swh.model.model import ( + BaseModel, Content, Directory, Origin, @@ -2023,7 +2024,54 @@ def test_origin_visit_get_by__unknown_origin(self, swh_storage): assert swh_storage.origin_visit_get_by("foo", 10) is None - def test_origin_visit_upsert_new(self, swh_storage): + def assert_upsert_written_objects( + self, + actual_written_objects: Iterable[Tuple[str, BaseModel]], + expected_written_objects: Iterable[Tuple[str, BaseModel]], + ): + """Helper utility to ensure written upsert objects are as expected. + + OriginVisitStatus from the origin_visit_upsert call point of view need special + so we can compare actual and expected values. + + """ + written_objects_by = defaultdict(list) + for obj_type, obj in actual_written_objects: + written_objects_by[obj_type].append(obj) + + expected_objects_by = defaultdict(list) + for obj_type, obj in expected_written_objects: + expected_objects_by[obj_type].append(obj) + + # straightforward comparison for those (order does not matter) + for obj_type in ["origin", "origin_visit"]: + assert set(written_objects_by[obj_type]) == set( + expected_objects_by[obj_type] + ) + + # origin-visit-status is specific though, origin_visit_upsert writes now() date. + # We cannot mock as we use multiple implementations (fully qualified name is + # thus different), we cannot open a parameter field date (not part of the + # signature), so we overwrite both actual and expected visit_status to use the + # same date so the comparison works... + + obj_type = "origin_visit_status" + expected_visit_statuses = expected_objects_by[obj_type] + for i, actual_visit_status in enumerate(written_objects_by[obj_type]): + expected_visit_status = expected_visit_statuses[i] + + test_date = now() + actual_new = OriginVisitStatus.from_dict( + {**actual_visit_status.to_dict(), "date": test_date} + ) + + expected_new = OriginVisitStatus.from_dict( + {**expected_visit_status.to_dict(), "date": test_date} + ) + + assert actual_new == expected_new + + def test_origin_visit_upsert_new(self, swh_storage, mocker): # given origin_url = swh_storage.origin_add_one(data.origin2) @@ -2082,7 +2130,6 @@ "origin": origin_url, "date": data.date_visit2, "visit": 123, - "type": data.type_visit2, "status": "full", "metadata": None, "snapshot": None, @@ -2091,16 +2138,29 @@ "origin": origin_url, "date": data.date_visit3, "visit": 1234, - "type": data.type_visit2, "status": "full", "metadata": None, "snapshot": None, } - assert list(swh_storage.journal_writer.journal.objects) == [ - ("origin", Origin.from_dict(data.origin2)), - ("origin_visit", OriginVisit.from_dict(data1)), - ("origin_visit", OriginVisit.from_dict(data2)), - ] + actual_written_objects = list(swh_storage.journal_writer.journal.objects) + + # Ensure we have those written to journal + self.assert_upsert_written_objects( + actual_written_objects, + [ + ("origin", Origin.from_dict(data.origin2)), + ( + "origin_visit", + OriginVisit.from_dict({**data1, "type": data.type_visit2}), + ), + ( + "origin_visit", + OriginVisit.from_dict({**data2, "type": data.type_visit2,}), + ), + ("origin_visit_status", OriginVisitStatus.from_dict(data1)), + ("origin_visit_status", OriginVisitStatus.from_dict(data2)), + ], + ) def test_origin_visit_upsert_existing(self, swh_storage): # given @@ -2156,21 +2216,27 @@ "origin": origin_url, "date": data.date_visit2, "visit": origin_visit1.visit, - "type": data.type_visit1, "status": "full", "metadata": None, "snapshot": None, } actual_written_objects = list(swh_storage.journal_writer.journal.objects) - assert actual_written_objects == [ - ("origin", Origin.from_dict(data.origin2)), - ( - "origin_visit", - OriginVisit.from_dict({**data1, "type": data.type_visit1,}), - ), - ("origin_visit_status", OriginVisitStatus.from_dict(data1)), - ("origin_visit", OriginVisit.from_dict(data2)), - ] + self.assert_upsert_written_objects( + actual_written_objects, + [ + ("origin", Origin.from_dict(data.origin2)), + ( + "origin_visit", + OriginVisit.from_dict({**data1, "type": data.type_visit1,}), + ), + ("origin_visit_status", OriginVisitStatus.from_dict(data1)), + ( + "origin_visit", + OriginVisit.from_dict({**data2, "type": data.type_visit1}), + ), + ("origin_visit_status", OriginVisitStatus.from_dict(data2)), + ], + ) def test_origin_visit_upsert_missing_visit_id(self, swh_storage): # given