diff --git a/swh/storage/api/serializers.py b/swh/storage/api/serializers.py --- a/swh/storage/api/serializers.py +++ b/swh/storage/api/serializers.py @@ -10,7 +10,7 @@ from swh.model.identifiers import SWHID, parse_swhid import swh.model.model as model -from swh.storage.interface import PagedResult +from swh.storage import interface def _encode_model_object(obj: model.BaseModel) -> Dict[str, Any]: @@ -23,7 +23,7 @@ return getattr(model, d.pop("__type__")).from_dict(d) -def _encode_paged_result(obj: PagedResult) -> Dict[str, Any]: +def _encode_paged_result(obj: interface.PagedResult) -> Dict[str, Any]: return { "__type__": type(obj).__name__, "results": [_encode_model_object(o) for o in obj.results], @@ -31,26 +31,35 @@ } -def _decode_paged_result(obj: Dict[str, Any]) -> PagedResult: - return PagedResult( +def _decode_paged_result(obj: Dict[str, Any]) -> interface.PagedResult: + return interface.PagedResult( results=[_decode_model_object(d) for d in obj["results"]], next_page_token=obj["next_page_token"], ) -def _encode_model_enum(obj): +def _encode_enum(obj): return { "value": obj.value, "__type__": type(obj).__name__, } +def _decode_model_enum(d): + return getattr(model, d.pop("__type__"))(d["value"]) + + +def _decode_storage_enum(d): + return getattr(interface, d.pop("__type__"))(d["value"]) + + ENCODERS: List[Tuple[type, str, Callable]] = [ (model.BaseModel, "model", _encode_model_object), - (PagedResult, "model_paged_result", _encode_paged_result), + (interface.PagedResult, "model_paged_result", _encode_paged_result), (SWHID, "swhid", str), - (model.MetadataTargetType, "model_enum", _encode_model_enum), - (model.MetadataAuthorityType, "model_enum", _encode_model_enum), + (model.MetadataTargetType, "model_enum", _encode_enum), + (model.MetadataAuthorityType, "model_enum", _encode_enum), + (interface.ListOrder, "storage_enum", _encode_enum), ] @@ -58,5 +67,6 @@ "swhid": parse_swhid, "model": _decode_model_object, "model_paged_result": _decode_paged_result, - "model_enum": lambda d: getattr(model, d.pop("__type__"))(d["value"]), + "model_enum": _decode_model_enum, + "storage_enum": _decode_storage_enum, } diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -32,7 +32,7 @@ MetadataTargetType, RawExtrinsicMetadata, ) -from swh.storage.interface import PagedResult +from swh.storage.interface import ListOrder, PagedResult from swh.storage.objstorage import ObjStorage from swh.storage.writer import JournalWriter from swh.storage.utils import map_optional, now @@ -846,15 +846,11 @@ self, origin: str, page_token: Optional[str] = None, - order: str = "asc", + order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisit]: - order = order.lower() - allowed_orders = ["asc", "desc"] - if order not in allowed_orders: - raise StorageArgumentException( - f"order must be one of {', '.join(allowed_orders)}." - ) + if not isinstance(order, ListOrder): + raise StorageArgumentException("order must be a ListOrder value") if page_token and not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") @@ -863,7 +859,9 @@ visits: List[OriginVisit] = [] extra_limit = limit + 1 - rows = self._cql_runner.origin_visit_get(origin, visit_from, extra_limit, order) + rows = self._cql_runner.origin_visit_get( + origin, visit_from, extra_limit, order.value + ) for row in rows: visits.append(converters.row_to_visit(row)) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -51,7 +51,7 @@ RawExtrinsicMetadata, ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex -from swh.storage.interface import PagedResult +from swh.storage.interface import ListOrder, PagedResult from swh.storage.objstorage import ObjStorage from swh.storage.utils import now @@ -864,17 +864,13 @@ self, origin: str, page_token: Optional[str] = None, - order: str = "asc", + order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisit]: next_page_token = None page_token = page_token or "0" - order = order.lower() - allowed_orders = ["asc", "desc"] - if order not in allowed_orders: - raise StorageArgumentException( - f"order must be one of {', '.join(allowed_orders)}." - ) + if not isinstance(order, ListOrder): + raise StorageArgumentException("order must be a ListOrder value") if not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") @@ -884,12 +880,12 @@ visits = sorted( self._origin_visits.get(origin_url, []), key=lambda v: v.visit, - reverse=(order == "desc"), + reverse=(order == ListOrder.DESC), ) - if visit_from > 0 and order == "asc": + if visit_from > 0 and order == ListOrder.ASC: visits = [v for v in visits if v.visit > visit_from] - elif visit_from > 0 and order == "desc": + elif visit_from > 0 and order == ListOrder.DESC: visits = [v for v in visits if v.visit < visit_from] visits = [v for v in visits if v][:extra_limit] diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -6,6 +6,7 @@ import attr import datetime +from enum import Enum from typing import Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union from attrs_strict import type_validator @@ -30,6 +31,11 @@ ) +class ListOrder(Enum): + ASC = "asc" + DESC = "desc" + + def deprecated(f): f.deprecated_endpoint = True return f @@ -810,7 +816,7 @@ self, origin: str, page_token: Optional[str] = None, - order: str = "asc", + order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisit]: """Retrieve page of OriginVisit information. diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -45,7 +45,7 @@ RawExtrinsicMetadata, ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex -from swh.storage.interface import PagedResult +from swh.storage.interface import ListOrder, PagedResult from swh.storage.objstorage import ObjStorage from swh.storage.utils import now @@ -883,18 +883,14 @@ self, origin: str, page_token: Optional[str] = None, - order: str = "asc", + order: ListOrder = ListOrder.ASC, limit: int = 10, db=None, cur=None, ) -> PagedResult[OriginVisit]: page_token = page_token or "0" - order = order.lower() - allowed_orders = ["asc", "desc"] - if order not in allowed_orders: - raise StorageArgumentException( - f"order must be one of {', '.join(allowed_orders)}." - ) + if not isinstance(order, ListOrder): + raise StorageArgumentException("order must be a ListOrder value") if not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") @@ -903,7 +899,7 @@ visits: List[OriginVisit] = [] extra_limit = limit + 1 for row in db.origin_visit_get_range( - origin, visit_from=visit_from, order=order, limit=extra_limit, cur=cur + origin, visit_from=visit_from, order=order.value, limit=extra_limit, cur=cur ): row_d = dict(zip(db.origin_visit_cols, row)) visits.append( diff --git a/swh/storage/tests/test_serializers.py b/swh/storage/tests/test_serializers.py --- a/swh/storage/tests/test_serializers.py +++ b/swh/storage/tests/test_serializers.py @@ -3,13 +3,17 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.storage.interface import PagedResult +from swh.storage.interface import ListOrder, PagedResult +from swh.model import model from swh.storage.api.serializers import ( _encode_model_object, _decode_model_object, _encode_paged_result, _decode_paged_result, + _encode_enum, + _decode_model_enum, + _decode_storage_enum, ) @@ -46,3 +50,31 @@ decoded_paged_result = _decode_paged_result(actual_paged_result) assert decoded_paged_result == paged_result + + +def test_model_enum_serialization(sample_data): + result_enum = model.MetadataAuthorityType.DEPOSIT_CLIENT + actual_serialized_enum = _encode_enum(result_enum) + + expected_serialized_enum = { + "value": result_enum.value, + "__type__": type(result_enum).__name__, + } + assert actual_serialized_enum == expected_serialized_enum + + decoded_paged_result = _decode_model_enum(actual_serialized_enum) + assert decoded_paged_result == result_enum + + +def test_storage_enum_serialization(sample_data): + result_enum = ListOrder.ASC + actual_serialized_enum = _encode_enum(result_enum) + + expected_serialized_enum = { + "value": result_enum.value, + "__type__": type(result_enum).__name__, + } + assert actual_serialized_enum == expected_serialized_enum + + decoded_paged_result = _decode_storage_enum(actual_serialized_enum) + assert decoded_paged_result == result_enum diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -41,7 +41,7 @@ from swh.storage import get_storage from swh.storage.converters import origin_url_to_sha1 as sha1 from swh.storage.exc import HashCollision, StorageArgumentException -from swh.storage.interface import StorageInterface, PagedResult +from swh.storage.interface import ListOrder, PagedResult, StorageInterface from swh.storage.utils import content_hex_hashes, now @@ -1266,7 +1266,7 @@ swh_storage.origin_visit_get(origin.url, page_token=10) # not bytes with pytest.raises( - StorageArgumentException, match="order must be one of asc, desc" + StorageArgumentException, match="order must be a ListOrder value" ): swh_storage.origin_visit_get(origin.url, order="foobar") # wrong order @@ -1329,24 +1329,26 @@ assert actual_result == PagedResult(results=[ov3]) # order desc, no pagination, no limit - actual_result = swh_storage.origin_visit_get(origin.url, order="desc") + actual_result = swh_storage.origin_visit_get(origin.url, order=ListOrder.DESC) assert actual_result == PagedResult(results=[ov3, ov2, ov1]) # order desc, no pagination, limit - actual_result = swh_storage.origin_visit_get(origin.url, limit=2, order="desc") + actual_result = swh_storage.origin_visit_get( + origin.url, limit=2, order=ListOrder.DESC + ) assert actual_result.next_page_token is not None assert actual_result.results == [ov3, ov2] # order desc, pagination, no limit next_page_token = str(ov3.visit) actual_result = swh_storage.origin_visit_get( - origin.url, page_token=next_page_token, order="desc" + origin.url, page_token=next_page_token, order=ListOrder.DESC ) assert actual_result == PagedResult(results=[ov2, ov1]) # order desc, pagination, limit actual_result = swh_storage.origin_visit_get( - origin.url, page_token=next_page_token, order="desc", limit=1 + origin.url, page_token=next_page_token, order=ListOrder.DESC, limit=1 ) assert actual_result.next_page_token is not None assert actual_result.results == [ov2]