diff --git a/swh/storage/api/serializers.py b/swh/storage/api/serializers.py --- a/swh/storage/api/serializers.py +++ b/swh/storage/api/serializers.py @@ -10,6 +10,8 @@ from swh.model.identifiers import SWHID, parse_swhid import swh.model.model as model +from swh.storage import interface + def _encode_model_object(obj): d = obj.to_dict() @@ -17,23 +19,34 @@ return d -def _encode_model_enum(obj): +def _encode_enum(obj): return { "value": obj.value, "__type__": type(obj).__name__, } +def _decode_model_enum(d): + return getattr(model, d.pop("__type__"))(d["value"]) + + +def _decode_storage_enum(d): + return getattr(interface, d.pop("__type__"))(d["value"]) + + ENCODERS: List[Tuple[type, str, Callable]] = [ (model.BaseModel, "model", _encode_model_object), (SWHID, "swhid", str), - (model.MetadataTargetType, "model_enum", _encode_model_enum), - (model.MetadataAuthorityType, "model_enum", _encode_model_enum), + (model.MetadataTargetType, "model_enum", _encode_enum), + (model.MetadataAuthorityType, "model_enum", _encode_enum), + (interface.ListOrder, "storage_enum", _encode_enum), ] DECODERS: Dict[str, Callable] = { "swhid": parse_swhid, "model": lambda d: getattr(model, d.pop("__type__")).from_dict(d), - "model_enum": lambda d: getattr(model, d.pop("__type__"))(d["value"]), + "model_enum": _decode_model_enum, + "model_enum": _decode_model_enum, + "storage_enum": _decode_storage_enum, } diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -43,6 +43,8 @@ Origin, ) +from swh.storage.interface import ListOrder + from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS @@ -734,11 +736,8 @@ origin_url: str, last_visit: Optional[int], limit: Optional[int], - order: str = "asc", + order: ListOrder, ) -> ResultSet: - order = order.lower() - assert order in ["asc", "desc"] - args: List[Any] = [origin_url] if last_visit is not None: @@ -753,7 +752,7 @@ else: limit_name = "no_limit" - method_name = f"_origin_visit_get_{page_name}_{order}_{limit_name}" + method_name = f"_origin_visit_get_{page_name}_{order.value}_{limit_name}" origin_visit_get_method = getattr(self, method_name) return origin_visit_get_method(*args) diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -32,7 +32,7 @@ MetadataTargetType, RawExtrinsicMetadata, ) -from swh.storage.interface import PagedResult +from swh.storage.interface import ListOrder, PagedResult from swh.storage.objstorage import ObjStorage from swh.storage.writer import JournalWriter from swh.storage.utils import map_optional, now @@ -846,15 +846,11 @@ self, origin: str, page_token: Optional[str] = None, - order: str = "asc", + order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisit]: - order = order.lower() - allowed_orders = ["asc", "desc"] - if order not in allowed_orders: - raise StorageArgumentException( - f"order must be one of {', '.join(allowed_orders)}." - ) + if not isinstance(order, ListOrder): + raise StorageArgumentException("order must be a ListOrder value") if page_token and not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -12,6 +12,7 @@ from swh.core.db.db_utils import stored_procedure, jsonize as _jsonize from swh.core.db.db_utils import execute_values_generator from swh.model.model import OriginVisit, OriginVisitStatus, SHA1_SIZE +from swh.storage.interface import ListOrder def jsonize(d): @@ -620,9 +621,8 @@ yield from cur def origin_visit_get_range( - self, origin: str, visit_from: int, order: str, limit: int, cur=None, + self, origin: str, visit_from: int, order: ListOrder, limit: int, cur=None, ): - assert order in ["asc", "desc"] cur = self._cursor(cur) origin_visit_cols = ["o.url as origin", "ov.visit", "ov.date", "ov.type"] @@ -634,13 +634,13 @@ query_params: List[Any] = [origin] if visit_from > 0: - op_comparison = ">" if order == "asc" else "<" + op_comparison = ">" if order == ListOrder.ASC else "<" query_parts.append(f"and ov.visit {op_comparison} %s") query_params.append(visit_from) - if order == "asc": + if order == ListOrder.ASC: query_parts.append("ORDER BY ov.visit ASC") - elif order == "desc": + elif order == ListOrder.DESC: query_parts.append("ORDER BY ov.visit DESC") query_parts.append("LIMIT %s") diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -51,7 +51,7 @@ RawExtrinsicMetadata, ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex -from swh.storage.interface import PagedResult +from swh.storage.interface import ListOrder, PagedResult from swh.storage.objstorage import ObjStorage from swh.storage.utils import now @@ -864,17 +864,13 @@ self, origin: str, page_token: Optional[str] = None, - order: str = "asc", + order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisit]: next_page_token = None page_token = page_token or "0" - order = order.lower() - allowed_orders = ["asc", "desc"] - if order not in allowed_orders: - raise StorageArgumentException( - f"order must be one of {', '.join(allowed_orders)}." - ) + if not isinstance(order, ListOrder): + raise StorageArgumentException("order must be a ListOrder value") if not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") @@ -884,12 +880,12 @@ visits = sorted( self._origin_visits.get(origin_url, []), key=lambda v: v.visit, - reverse=(order == "desc"), + reverse=(order == ListOrder.DESC), ) - if visit_from > 0 and order == "asc": + if visit_from > 0 and order == ListOrder.ASC: visits = [v for v in visits if v.visit > visit_from] - elif visit_from > 0 and order == "desc": + elif visit_from > 0 and order == ListOrder.DESC: visits = [v for v in visits if v.visit < visit_from] visits = visits[:extra_limit] diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -5,8 +5,10 @@ import datetime +from enum import Enum from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union + from swh.core.api import remote_api_endpoint from swh.core.api.classes import PagedResult as CorePagedResult from swh.model.identifiers import SWHID @@ -28,15 +30,22 @@ ) -def deprecated(f): - f.deprecated_endpoint = True - return f +class ListOrder(Enum): + """Specifies the order for paginated endpoints returning sorted results.""" + + ASC = "asc" + DESC = "desc" TResult = TypeVar("TResult") PagedResult = CorePagedResult[TResult, str] +def deprecated(f): + f.deprecated_endpoint = True + return f + + class StorageInterface: @remote_api_endpoint("check_config") def check_config(self, *, check_write): @@ -799,7 +808,7 @@ self, origin: str, page_token: Optional[str] = None, - order: str = "asc", + order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisit]: """Retrieve page of OriginVisit information. diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -45,7 +45,7 @@ RawExtrinsicMetadata, ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex -from swh.storage.interface import PagedResult +from swh.storage.interface import ListOrder, PagedResult from swh.storage.objstorage import ObjStorage from swh.storage.utils import now @@ -883,18 +883,14 @@ self, origin: str, page_token: Optional[str] = None, - order: str = "asc", + order: ListOrder = ListOrder.ASC, limit: int = 10, db=None, cur=None, ) -> PagedResult[OriginVisit]: page_token = page_token or "0" - order = order.lower() - allowed_orders = ["asc", "desc"] - if order not in allowed_orders: - raise StorageArgumentException( - f"order must be one of {', '.join(allowed_orders)}." - ) + if not isinstance(order, ListOrder): + raise StorageArgumentException("order must be a ListOrder value") if not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") diff --git a/swh/storage/tests/test_serializers.py b/swh/storage/tests/test_serializers.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_serializers.py @@ -0,0 +1,41 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from swh.storage.interface import ListOrder +from swh.model import model + +from swh.storage.api.serializers import ( + _encode_enum, + _decode_model_enum, + _decode_storage_enum, +) + + +def test_model_enum_serialization(sample_data): + result_enum = model.MetadataAuthorityType.DEPOSIT_CLIENT + actual_serialized_enum = _encode_enum(result_enum) + + expected_serialized_enum = { + "value": result_enum.value, + "__type__": type(result_enum).__name__, + } + assert actual_serialized_enum == expected_serialized_enum + + decoded_paged_result = _decode_model_enum(actual_serialized_enum) + assert decoded_paged_result == result_enum + + +def test_storage_enum_serialization(sample_data): + result_enum = ListOrder.ASC + actual_serialized_enum = _encode_enum(result_enum) + + expected_serialized_enum = { + "value": result_enum.value, + "__type__": type(result_enum).__name__, + } + assert actual_serialized_enum == expected_serialized_enum + + decoded_paged_result = _decode_storage_enum(actual_serialized_enum) + assert decoded_paged_result == result_enum diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -41,7 +41,7 @@ from swh.storage import get_storage from swh.storage.converters import origin_url_to_sha1 as sha1 from swh.storage.exc import HashCollision, StorageArgumentException -from swh.storage.interface import StorageInterface, PagedResult # noqa +from swh.storage.interface import ListOrder, PagedResult, StorageInterface from swh.storage.utils import content_hex_hashes, now @@ -1268,7 +1268,7 @@ swh_storage.origin_visit_get(origin.url, page_token=10) # not bytes with pytest.raises( - StorageArgumentException, match="order must be one of asc, desc" + StorageArgumentException, match="order must be a ListOrder value" ): swh_storage.origin_visit_get(origin.url, order="foobar") # wrong order @@ -1337,18 +1337,20 @@ assert actual_page == PagedResult(results=[ov3]) # order desc, no pagination, no limit - actual_page = swh_storage.origin_visit_get(origin.url, order="desc") + actual_page = swh_storage.origin_visit_get(origin.url, order=ListOrder.DESC) assert actual_page.next_page_token is None assert actual_page == PagedResult(results=[ov3, ov2, ov1]) # order desc, no pagination, limit - actual_page = swh_storage.origin_visit_get(origin.url, limit=2, order="desc") + actual_page = swh_storage.origin_visit_get( + origin.url, limit=2, order=ListOrder.DESC + ) next_page_token = actual_page.next_page_token assert next_page_token is not None assert actual_page.results == [ov3, ov2] actual_page = swh_storage.origin_visit_get( - origin.url, page_token=next_page_token, order="desc" + origin.url, page_token=next_page_token, order=ListOrder.DESC ) assert actual_page.next_page_token is None assert actual_page.results == [ov1] @@ -1357,21 +1359,21 @@ # order desc, pagination, no limit next_page_token = str(ov3.visit) actual_page = swh_storage.origin_visit_get( - origin.url, page_token=next_page_token, order="desc" + origin.url, page_token=next_page_token, order=ListOrder.DESC ) assert actual_page.next_page_token is None assert actual_page == PagedResult(results=[ov2, ov1]) # order desc, pagination, limit actual_page = swh_storage.origin_visit_get( - origin.url, page_token=next_page_token, order="desc", limit=1 + origin.url, page_token=next_page_token, order=ListOrder.DESC, limit=1 ) next_page_token = actual_page.next_page_token assert next_page_token is not None assert actual_page.results == [ov2] actual_page = swh_storage.origin_visit_get( - origin.url, page_token=next_page_token, order="desc" + origin.url, page_token=next_page_token, order=ListOrder.DESC ) assert actual_page == PagedResult(results=[ov1])