diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -841,6 +841,22 @@ """Add an origin visit status""" self._cql_runner.origin_visit_status_add_one(visit_status) + def origin_visit_status_add_one(self, visit_status: OriginVisitStatus) -> None: + origin_url = visit_status.origin + origin = self.origin_get({"url": origin_url}) + if not origin: + raise StorageArgumentException(f"Unknown origin {origin_url}") + + visit_id = visit_status.visit + visit = self.origin_visit_get_by(origin_url, visit_id) + if not visit: + raise StorageArgumentException( + f"Unknown origin visit ({origin_url}, {visit_id})" + ) + + self.journal_writer.origin_visit_status_add([visit_status]) + self._origin_visit_status_add(visit_status) + def origin_visit_update( self, origin: str, diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -827,6 +827,25 @@ # return last visit return visit + def origin_visit_status_add_one(self, visit_status: OriginVisitStatus) -> None: + origin_url = visit_status.origin + origin = self.origin_get({"url": origin_url}) + if not origin: + raise StorageArgumentException(f"Unknown origin {origin_url}") + + visit_id = visit_status.visit + visit = self.origin_visit_get_by(origin_url, visit_id) + if not visit: + raise StorageArgumentException( + f"Unknown origin visit ({origin_url}, {visit_id})" + ) + + visit_key = (origin_url, visit_id) + assert origin_url in self._origins + + self.journal_writer.origin_visit_status_add([visit_status]) + self._origin_visit_statuses[visit_key].append(visit_status) + def origin_visit_update( self, origin: str, @@ -923,7 +942,6 @@ visit_key = (origin, visit_id) visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date) - return OriginVisit.from_dict( { # default to the values in visit @@ -938,6 +956,7 @@ def origin_visit_get( self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None ) -> Iterable[Dict[str, Any]]: + origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] @@ -986,6 +1005,7 @@ if not ori: return None visits = self._origin_visits[ori.url] + visits = [ self._origin_visit_get_updated(visit.origin, visit.visit) for visit in visits diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -13,6 +13,7 @@ Directory, Origin, OriginVisit, + OriginVisitStatus, Revision, Release, Snapshot, @@ -801,6 +802,19 @@ """ ... + @remote_api_endpoint("origin/visit_status/add") + def origin_visit_status_add_one(self, visit_status: OriginVisitStatus) -> None: + """Add origin visit status if sanity checks are ok. Otherwise raise. + + Args: + visit_status: origin visit status to add + + Raises: + StorageArgumentException if the origin or the origin visit is unknown + + """ + ... + @remote_api_endpoint("origin/visit/update") def origin_visit_update( self, diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -885,6 +885,26 @@ "origin_visit_status:add", count=1, method_name="origin_visit_status" ) + @timed + @db_transaction() + def origin_visit_status_add_one( + self, visit_status: OriginVisitStatus, db=None, cur=None, + ) -> None: + origin_url = visit_status.origin + origin = self.origin_get({"url": origin_url}) + if not origin: + raise StorageArgumentException(f"Unknown origin {origin_url}") + + visit_id = visit_status.visit + visit = self.origin_visit_get_by(origin_url, visit_id) + if not visit: + raise StorageArgumentException( + f"Unknown origin visit ({origin_url}, {visit_id})" + ) + + self.journal_writer.origin_visit_status_add([visit_status]) + self._origin_visit_status_add(visit_status, db, cur) + @timed @db_transaction() def origin_visit_update( diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -31,6 +31,7 @@ Directory, Origin, OriginVisit, + OriginVisitStatus, Release, Revision, Snapshot, @@ -1685,6 +1686,55 @@ if type(cm.value) == psycopg2.ProgrammingError: assert cm.value.pgcode == psycopg2.errorcodes.UNDEFINED_FUNCTION + def test_origin_visit_status_add_one_validation(self, swh_storage): + """Wrong origin_visit_status_add input should raise validation error""" + date_visit = now() + visit_status = OriginVisitStatus( + origin="unknown-origin-url", + visit=10, + date=date_visit, + status="full", + snapshot=None, + ) + with pytest.raises(StorageArgumentException, match="Unknown origin"): + swh_storage.origin_visit_status_add_one(visit_status) + + origin_url = swh_storage.origin_add_one(data.origin2) + visit_status = OriginVisitStatus( + origin=origin_url, visit=10, date=date_visit, status="full", snapshot=None + ) + with pytest.raises(StorageArgumentException, match="Unknown origin visit"): + swh_storage.origin_visit_status_add_one(visit_status) + + def test_origin_visit_status_add_one(self, swh_storage): + """Correct origin_visit_status add instruction should add a new visit status + + """ + origin_url = swh_storage.origin_add_one(data.origin2) + date_visit = data.date_visit1 + origin_visit = swh_storage.origin_visit_add( + origin_url, date=date_visit, type=data.type_visit1 + ) + + snapshot_id = data.snapshot["id"] + date_visit_now = now() + + visit_status = OriginVisitStatus( + origin=origin_visit.origin, + visit=origin_visit.visit, + date=date_visit_now, + status="full", + snapshot=snapshot_id, + ) + swh_storage.origin_visit_status_add_one(visit_status) + + origin_visit = swh_storage.origin_visit_get_latest( + origin_url, require_snapshot=True + ) + assert origin_visit + assert origin_visit["status"] == "full" + assert origin_visit["snapshot"] == snapshot_id + def test_origin_visit_update(self, swh_storage): # given origin_url = swh_storage.origin_add_one(data.origin) diff --git a/swh/storage/writer.py b/swh/storage/writer.py --- a/swh/storage/writer.py +++ b/swh/storage/writer.py @@ -10,6 +10,7 @@ from swh.model.model import ( Origin, OriginVisit, + OriginVisitStatus, Snapshot, Directory, Revision, @@ -84,5 +85,10 @@ def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: self.write_additions("origin_visit", visits) + def origin_visit_status_add( + self, visit_statuses: Iterable[OriginVisitStatus] + ) -> None: + self.write_additions("origin_visit_status", visit_statuses) + def origin_add(self, origins: Iterable[Origin]) -> None: self.write_additions("origin", origins)