diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -33,7 +33,7 @@ MetadataTargetType, RawExtrinsicMetadata, ) -from swh.storage.interface import ListOrder, PagedResult +from swh.storage.interface import ListOrder, PagedResult, VISIT_STATUSES from swh.storage.objstorage import ObjStorage from swh.storage.writer import JournalWriter from swh.storage.utils import map_optional, now @@ -941,6 +941,11 @@ allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisit]: + if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): + raise StorageArgumentException( + f"Unknown allowed statuses {','.join(allowed_statuses)}, only " + f"{','.join(VISIT_STATUSES)} authorized" + ) # TODO: Do not fetch all visits rows = self._cql_runner.origin_visit_get_all(origin) latest_visit = None @@ -979,6 +984,11 @@ allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisitStatus]: + if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): + raise StorageArgumentException( + f"Unknown allowed statuses {','.join(allowed_statuses)}, only " + f"{','.join(VISIT_STATUSES)} authorized" + ) rows = self._cql_runner.origin_visit_status_get( origin_url, visit, allowed_statuses, require_snapshot ) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -52,7 +52,7 @@ RawExtrinsicMetadata, ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex -from swh.storage.interface import ListOrder, PagedResult +from swh.storage.interface import ListOrder, PagedResult, VISIT_STATUSES from swh.storage.objstorage import ObjStorage from swh.storage.utils import now @@ -939,6 +939,12 @@ allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisit]: + if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): + raise StorageArgumentException( + f"Unknown allowed statuses {','.join(allowed_statuses)}, only " + f"{','.join(VISIT_STATUSES)} authorized" + ) + ori = self._origins.get(origin) if not ori: return None @@ -1009,6 +1015,12 @@ allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisitStatus]: + if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): + raise StorageArgumentException( + f"Unknown allowed statuses {','.join(allowed_statuses)}, only " + f"{','.join(VISIT_STATUSES)} authorized" + ) + ori = self._origins.get(origin_url) if not ori: return None diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -41,6 +41,10 @@ PagedResult = CorePagedResult[TResult, str] +# TODO: Make it an enum (too much impact) +VISIT_STATUSES = ["created", "ongoing", "full", "partial"] + + def deprecated(f): f.deprecated_endpoint = True return f @@ -883,6 +887,10 @@ require_snapshot: If True, only a visit with a snapshot will be returned. + Raises: + StorageArgumentException if values for the allowed_statuses parameters + are unknown + Returns: OriginVisit matching the criteria if found, None otherwise. Note that as OriginVisit no longer held reference on the visit status or snapshot, you @@ -937,6 +945,10 @@ require_snapshot: If True, only a visit with a snapshot will be returned. + Raises: + StorageArgumentException if values for the allowed_statuses parameters + are unknown + Returns: The OriginVisitStatus matching the criteria diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -45,7 +45,7 @@ RawExtrinsicMetadata, ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex -from swh.storage.interface import ListOrder, PagedResult +from swh.storage.interface import ListOrder, PagedResult, VISIT_STATUSES from swh.storage.objstorage import ObjStorage from swh.storage.utils import now @@ -862,6 +862,12 @@ db=None, cur=None, ) -> Optional[OriginVisitStatus]: + if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): + raise StorageArgumentException( + f"Unknown allowed statuses {','.join(allowed_statuses)}, only " + f"{','.join(VISIT_STATUSES)} authorized" + ) + row = db.origin_visit_status_get_latest( origin_url, visit, allowed_statuses, require_snapshot, cur=cur ) @@ -953,6 +959,12 @@ db=None, cur=None, ) -> Optional[OriginVisit]: + if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): + raise StorageArgumentException( + f"Unknown allowed statuses {','.join(allowed_statuses)}, only " + f"{','.join(VISIT_STATUSES)} authorized" + ) + row = db.origin_visit_get_latest( origin, type=type, diff --git a/swh/storage/tests/algos/test_origin.py b/swh/storage/tests/algos/test_origin.py --- a/swh/storage/tests/algos/test_origin.py +++ b/swh/storage/tests/algos/test_origin.py @@ -4,9 +4,6 @@ # See top-level LICENSE file for more information import datetime -import pytest - -from unittest.mock import patch from swh.model.model import Origin, OriginVisit, OriginVisitStatus @@ -26,13 +23,6 @@ assert list(left) == list(right), msg -@pytest.fixture -def swh_storage_backend_config(): - yield { - "cls": "memory", - } - - def test_iter_origins(swh_storage): origins = [ Origin(url="bar"), @@ -79,17 +69,6 @@ ) -@patch("swh.storage.in_memory.InMemoryStorage.origin_get_range") -def test_iter_origins_batch_size(mock_origin_get_range, swh_storage): - mock_origin_get_range.return_value = [] - - list(iter_origins(swh_storage)) - mock_origin_get_range.assert_called_with(origin_from=1, origin_count=10000) - - list(iter_origins(swh_storage, batch_size=42)) - mock_origin_get_range.assert_called_with(origin_from=1, origin_count=42) - - def test_origin_get_latest_visit_status_none(swh_storage, sample_data): """Looking up unknown objects should return nothing @@ -115,11 +94,6 @@ ) assert actual_origin_visit is None - actual_origin_visit = origin_get_latest_visit_status( - swh_storage, origin.url, allowed_statuses=["unknown"] - ) - assert actual_origin_visit is None - def init_storage_with_origin_visits(swh_storage, sample_data): """Initialize storage with origin/origin-visit/origin-visit-status @@ -155,7 +129,7 @@ ovs11 = OriginVisitStatus( origin=origin1.url, visit=ov1.visit, - date=sample_data.date_visit1, + date=ov1.date + datetime.timedelta(seconds=10), # so it's not ignored status="partial", snapshot=None, ) @@ -171,7 +145,7 @@ ovs21 = OriginVisitStatus( origin=origin2.url, visit=ov2.visit, - date=sample_data.date_visit2, + date=ov2.date + datetime.timedelta(seconds=10), # so it's not ignored status="ongoing", snapshot=None, ) @@ -247,10 +221,10 @@ ov1, ov2 = objects["origin_visit"] ovs11, ovs12, _, ovs22 = objects["origin_visit_status"] - # no failed status for that visit + # no partial status for that origin visit assert ( origin_get_latest_visit_status( - swh_storage, origin2.url, allowed_statuses=["failed"] + swh_storage, origin2.url, allowed_statuses=["partial"] ) is None ) @@ -393,14 +367,14 @@ ) ) - visit_statuses = swh_storage.origin_visit_add(new_visit_statuses) - reversed_visit_statuses = list(reversed(visit_statuses)) + swh_storage.origin_visit_status_add(new_visit_statuses) + reversed_visit_statuses = list(reversed(new_visit_statuses)) # order asc actual_visit_statuses = list( iter_origin_visit_statuses(swh_storage, ov1.origin, ov1.visit) ) - assert actual_visit_statuses == visit_statuses + assert actual_visit_statuses == new_visit_statuses # order desc actual_visit_statuses = list( diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -2028,18 +2028,21 @@ actual_visit = swh_storage.origin_visit_get_by(origin.url, 999) # unknown visit assert actual_visit is None - def test_origin_visit_get_latest_none(self, swh_storage, sample_data): - """Origin visit get latest on unknown objects should return nothing - - """ + def test_origin_visit_get_latest_edge_cases(self, swh_storage, sample_data): # unknown origin so no result assert swh_storage.origin_visit_get_latest("unknown-origin") is None - # unknown type + # unknown type so no result origin = sample_data.origin swh_storage.origin_add([origin]) assert swh_storage.origin_visit_get_latest(origin.url, type="unknown") is None + # unknown allowed statuses should raise + with pytest.raises(StorageArgumentException, match="Unknown allowed statuses"): + swh_storage.origin_visit_get_latest( + origin.url, allowed_statuses=["unknown"] + ) + def test_origin_visit_get_latest_filter_type(self, swh_storage, sample_data): """Filtering origin visit get latest with filter type should be ok @@ -2265,6 +2268,19 @@ actual_visit = swh_storage.origin_visit_get_latest(origin.url) assert actual_visit == ov2 + def test_origin_visit_status_get_latest__validation(self, swh_storage, sample_data): + origin = sample_data.origin + swh_storage.origin_add([origin]) + visit1 = OriginVisit( + origin=origin.url, date=sample_data.date_visit1, type="git", + ) + + # unknown allowed statuses should raise + with pytest.raises(StorageArgumentException, match="Unknown allowed statuses"): + swh_storage.origin_visit_status_get_latest( + origin.url, visit1.visit, allowed_statuses=["unknown"] + ) + def test_origin_visit_status_get_latest(self, swh_storage, sample_data): snapshot = sample_data.snapshots[2] origin1 = sample_data.origin