diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -938,12 +938,15 @@ def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None: """Add an origin visit status""" if visit_status.type is None: - origin_row = self._cql_runner.origin_visit_get_one( + visit_row = self._cql_runner.origin_visit_get_one( visit_status.origin, visit_status.visit ) - if origin_row is None: - raise StorageArgumentException(f"Unknown origin {visit_status.origin}") - visit_status = attr.evolve(visit_status, type=origin_row.type) + if visit_row is None: + raise StorageArgumentException( + f"Unknown origin visit {visit_status.visit} " + f"of origin {visit_status.origin}" + ) + visit_status = attr.evolve(visit_status, type=visit_row.type) self.journal_writer.origin_visit_status_add([visit_status]) self._cql_runner.origin_visit_status_add_one( diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -954,7 +954,11 @@ origin_visit = self.origin_visit_get_by( visit_status.origin, visit_status.visit, db=db, cur=cur ) - assert origin_visit is not None + if origin_visit is None: + raise StorageArgumentException( + f"Unknown origin visit {visit_status.visit} " + f"of origin {visit_status.origin}" + ) origin_visit_status = attr.evolve(visit_status, type=origin_visit.type) else: diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -1606,6 +1606,33 @@ assert actual_page.next_page_token is None assert actual_page.results == [] + def test_origin_visit_status_add_unknown_type(self, swh_storage, sample_data): + ov = OriginVisit( + origin=sample_data.origin.url, + date=now(), + type=sample_data.type_visit1, + visit=42, + ) + ovs = OriginVisitStatus( + origin=ov.origin, + visit=ov.visit, + date=now(), + status="created", + snapshot=None, + ) + + with pytest.raises(StorageArgumentException): + swh_storage.origin_visit_status_add([ovs]) + + swh_storage.origin_add([sample_data.origin]) + + with pytest.raises(StorageArgumentException): + swh_storage.origin_visit_status_add([ovs]) + + swh_storage.origin_visit_add([ov]) + + swh_storage.origin_visit_status_add([ovs]) + def test_origin_visit_status_get_all(self, swh_storage, sample_data): origin = sample_data.origin swh_storage.origin_add([origin])