diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -841,6 +841,30 @@ """Add an origin visit status""" self._cql_runner.origin_visit_status_add_one(visit_status) + def origin_visit_status_add( + self, + visit_statuses: Iterable[OriginVisitStatus], + with_sanity_checks: bool = True, + ) -> None: + # First round to check existence (fail early if any is ko) + if with_sanity_checks: + for visit_status in visit_statuses: + origin_url = visit_status.origin + origin = self.origin_get({"url": origin_url}) + if not origin: + raise StorageArgumentException(f"Unknown origin {origin_url}") + + visit_id = visit_status.visit + visit = self.origin_visit_get_by(origin_url, visit_id) + if not visit: + raise StorageArgumentException( + f"Unknown origin visit ({origin_url}, {visit_id})" + ) + + self.journal_writer.origin_visit_status_add(visit_statuses) + for visit_status in visit_statuses: + self._origin_visit_status_add(visit_status) + def origin_visit_update( self, origin: str, diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -827,6 +827,32 @@ # return last visit return visit + def origin_visit_status_add( + self, + visit_statuses: Iterable[OriginVisitStatus], + with_sanity_checks: bool = True, + ) -> None: + # First round to check existence (fail early if any is ko) + if with_sanity_checks: + for visit_status in visit_statuses: + origin_url = visit_status.origin + origin = self.origin_get({"url": origin_url}) + if not origin: + raise StorageArgumentException(f"Unknown origin {origin_url}") + + visit_id = visit_status.visit + visit = self.origin_visit_get_by(origin_url, visit_id) + if not visit: + raise StorageArgumentException( + f"Unknown origin visit ({origin_url}, {visit_id})" + ) + + # Insert + for visit_status in visit_statuses: + visit_key = (visit_status.origin, visit_status.visit) + self.journal_writer.origin_visit_status_add([visit_status]) + self._origin_visit_statuses[visit_key].append(visit_status) + def origin_visit_update( self, origin: str, @@ -923,7 +949,6 @@ visit_key = (origin, visit_id) visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date) - return OriginVisit.from_dict( { # default to the values in visit @@ -938,6 +963,7 @@ def origin_visit_get( self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None ) -> Iterable[Dict[str, Any]]: + origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] @@ -986,6 +1012,7 @@ if not ori: return None visits = self._origin_visits[ori.url] + visits = [ self._origin_visit_get_updated(visit.origin, visit.visit) for visit in visits diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -13,6 +13,7 @@ Directory, Origin, OriginVisit, + OriginVisitStatus, Revision, Release, Snapshot, @@ -801,6 +802,28 @@ """ ... + @remote_api_endpoint("origin/visit_status/add") + def origin_visit_status_add( + self, + visit_statuses: Iterable[OriginVisitStatus], + with_sanity_checks: bool = True, + ) -> None: + """Add origin visit statuses if sanity checks are ok. Otherwise raise. + + Args: + visit_statuses: origin visit statuses to add + with_sanity_checks: Default to True. When True, checking if origin and + origin visit referenced by the origin visit status exist. If not this + raises. Turn this off to avoid sanity checks (e.g mass replay use for + example). It's then up to the caller to ensure the data sent is consistent + (as failure could still happen, e.g. integrity error, etc...) + + Raises: StorageArgumentException if with_sanity_checks is True and either the + origin or the origin visit is unknown + + """ + ... + @remote_api_endpoint("origin/visit/update") def origin_visit_update( self, diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -885,6 +885,34 @@ "origin_visit_status:add", count=1, method_name="origin_visit_status" ) + @timed + @db_transaction() + def origin_visit_status_add( + self, + visit_statuses: Iterable[OriginVisitStatus], + with_sanity_checks: bool = True, + db=None, + cur=None, + ) -> None: + # First round to check existence (fail early if any is ko) + if with_sanity_checks: + for visit_status in visit_statuses: + origin_url = visit_status.origin + origin = self.origin_get({"url": origin_url}) + if not origin: + raise StorageArgumentException(f"Unknown origin {origin_url}") + + visit_id = visit_status.visit + visit = self.origin_visit_get_by(origin_url, visit_id) + if not visit: + raise StorageArgumentException( + f"Unknown origin visit ({origin_url}, {visit_id})" + ) + + self.journal_writer.origin_visit_status_add(visit_statuses) + for visit_status in visit_statuses: + self._origin_visit_status_add(visit_status, db, cur) + @timed @db_transaction() def origin_visit_update( diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -31,6 +31,7 @@ Directory, Origin, OriginVisit, + OriginVisitStatus, Release, Revision, Snapshot, @@ -1685,6 +1686,75 @@ if type(cm.value) == psycopg2.ProgrammingError: assert cm.value.pgcode == psycopg2.errorcodes.UNDEFINED_FUNCTION + def test_origin_visit_status_add_validation(self, swh_storage): + """Wrong origin_visit_status input should raise storage argument error""" + date_visit = now() + visit_status1 = OriginVisitStatus( + origin="unknown-origin-url", + visit=10, + date=date_visit, + status="full", + snapshot=None, + ) + with pytest.raises(StorageArgumentException, match="Unknown origin"): + swh_storage.origin_visit_status_add([visit_status1]) + + origin_url = swh_storage.origin_add_one(data.origin2) + visit_status2 = OriginVisitStatus( + origin=origin_url, visit=10, date=date_visit, status="full", snapshot=None + ) + with pytest.raises(StorageArgumentException, match="Unknown origin visit"): + swh_storage.origin_visit_status_add([visit_status2]) + + def test_origin_visit_status_add_one(self, swh_storage): + """Correct origin visit statuses should add a new visit status + + """ + origin_url = swh_storage.origin_add_one(data.origin2) + origin_visit1 = swh_storage.origin_visit_add( + origin_url, date=data.date_visit1, type=data.type_visit1 + ) + snapshot_id = data.snapshot["id"] + date_visit_now = now() + visit_status1 = OriginVisitStatus( + origin=origin_visit1.origin, + visit=origin_visit1.visit, + date=date_visit_now, + status="full", + snapshot=snapshot_id, + ) + + origin_url2 = swh_storage.origin_add_one({"url": "new-origin"}) + origin_visit2 = swh_storage.origin_visit_add( + origin_url2, date=data.date_visit2, type=data.type_visit2 + ) + date_visit_now = now() + visit_status2 = OriginVisitStatus( + origin=origin_visit2.origin, + visit=origin_visit2.visit, + date=date_visit_now, + status="ongoing", + snapshot=None, + metadata={"intrinsic": "something"}, + ) + swh_storage.origin_visit_status_add([visit_status1, visit_status2]) + + origin_visit1 = swh_storage.origin_visit_get_latest( + origin_url, require_snapshot=True + ) + assert origin_visit1 + assert origin_visit1["status"] == "full" + assert origin_visit1["snapshot"] == snapshot_id + + origin_visit2 = swh_storage.origin_visit_get_latest( + origin_url2, require_snapshot=False + ) + assert origin_url2 != origin_url + assert origin_visit2 + assert origin_visit2["status"] == "ongoing" + assert origin_visit2["snapshot"] is None + assert origin_visit2["metadata"] == {"intrinsic": "something"} + def test_origin_visit_update(self, swh_storage): # given origin_url = swh_storage.origin_add_one(data.origin) diff --git a/swh/storage/writer.py b/swh/storage/writer.py --- a/swh/storage/writer.py +++ b/swh/storage/writer.py @@ -10,6 +10,7 @@ from swh.model.model import ( Origin, OriginVisit, + OriginVisitStatus, Snapshot, Directory, Revision, @@ -84,5 +85,10 @@ def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: self.write_additions("origin_visit", visits) + def origin_visit_status_add( + self, visit_statuses: Iterable[OriginVisitStatus] + ) -> None: + self.write_additions("origin_visit_status", visit_statuses) + def origin_add(self, origins: Iterable[Origin]) -> None: self.write_additions("origin", origins)