diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -839,6 +839,7 @@ def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None: """Add an origin visit status""" + self.journal_writer.origin_visit_status_add([visit_status]) self._cql_runner.origin_visit_status_add_one(visit_status) def origin_visit_status_add( @@ -956,7 +957,7 @@ snapshot=visit.snapshot, metadata=visit.metadata, ) - self._origin_visit_status_add(visit_status) + self._cql_runner.origin_visit_status_add_one(visit_status) @staticmethod def _format_origin_visit_row(visit): diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -804,6 +804,7 @@ metadata=None, visit=visit_id, ) + self.journal_writer.origin_visit_add([visit]) self._origin_visits[origin_url].append(visit) assert visit.visit is not None visit_key = (origin_url, visit.visit) @@ -817,15 +818,21 @@ snapshot=None, metadata=None, ) - self._origin_visit_statuses[visit_key] = [visit_update] - + self._origin_visit_status_add_one(visit_update) self._objects[visit_key].append(("origin_visit", None)) - self.journal_writer.origin_visit_add([visit]) - # return last visit return visit + def _origin_visit_status_add_one(self, visit_status: OriginVisitStatus) -> None: + """Add an origin visit status without checks. + + """ + self.journal_writer.origin_visit_status_add([visit_status]) + visit_key = (visit_status.origin, visit_status.visit) + self._origin_visit_statuses.setdefault(visit_key, []) + self._origin_visit_statuses[visit_key].append(visit_status) + def origin_visit_status_add( self, visit_statuses: Iterable[OriginVisitStatus], ) -> None: @@ -835,11 +842,8 @@ if not origin_url: raise StorageArgumentException(f"Unknown origin {visit_status.origin}") - # Insert for visit_status in visit_statuses: - visit_key = (visit_status.origin, visit_status.visit) - self.journal_writer.origin_visit_status_add([visit_status]) - self._origin_visit_statuses[visit_key].append(visit_status) + self._origin_visit_status_add_one(visit_status) def origin_visit_update( self, @@ -876,7 +880,7 @@ snapshot=snapshot or last_visit_update.snapshot, metadata=metadata or last_visit_update.metadata, ) - self._origin_visit_statuses[visit_key].append(visit_update) + self._origin_visit_status_add_one(visit_update) self.journal_writer.origin_visit_update( [self._origin_visit_get_updated(origin_url, visit_id)] diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -858,6 +858,7 @@ "snapshot": None, } ) + self.journal_writer.origin_visit_add([visit]) with convert_validation_exceptions(): visit_status = OriginVisitStatus( @@ -869,18 +870,15 @@ metadata=None, ) self._origin_visit_status_add(visit_status, db=db, cur=cur) - - self.journal_writer.origin_visit_add([visit]) - send_metric("origin_visit:add", count=1, method_name="origin_visit") return visit def _origin_visit_status_add( - self, origin_visit_status: OriginVisitStatus, db, cur + self, visit_status: OriginVisitStatus, db, cur ) -> None: """Add an origin visit status""" - db.origin_visit_status_add(origin_visit_status, cur=cur) - # TODO: write to the journal the origin visit status + self.journal_writer.origin_visit_status_add([visit_status]) + db.origin_visit_status_add(visit_status, cur=cur) send_metric( "origin_visit_status:add", count=1, method_name="origin_visit_status" ) @@ -896,7 +894,6 @@ if not origin_url: raise StorageArgumentException(f"Unknown origin {visit_status.origin}") - self.journal_writer.origin_visit_status_add(visit_statuses) for visit_status in visit_statuses: self._origin_visit_status_add(visit_status, db, cur) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -2152,7 +2152,6 @@ "origin": origin_url, "date": data.date_visit2, "visit": origin_visit1.visit, - "type": data.type_visit1, "status": "ongoing", "metadata": None, "snapshot": None, @@ -2166,9 +2165,14 @@ "metadata": None, "snapshot": None, } - assert list(swh_storage.journal_writer.journal.objects) == [ + actual_written_objects = list(swh_storage.journal_writer.journal.objects) + assert actual_written_objects == [ ("origin", Origin.from_dict(data.origin2)), - ("origin_visit", OriginVisit.from_dict(data1)), + ( + "origin_visit", + OriginVisit.from_dict({**data1, "type": data.type_visit1,}), + ), + ("origin_visit_status", OriginVisitStatus.from_dict(data1)), ("origin_visit", OriginVisit.from_dict(data2)), ]