Page MenuHomeSoftware Heritage

D3629.id12808.diff
No OneTemporary

D3629.id12808.diff

diff --git a/swh/storage/api/serializers.py b/swh/storage/api/serializers.py
--- a/swh/storage/api/serializers.py
+++ b/swh/storage/api/serializers.py
@@ -10,6 +10,8 @@
from swh.model.identifiers import SWHID, parse_swhid
import swh.model.model as model
+from swh.storage import interface
+
def _encode_model_object(obj):
d = obj.to_dict()
@@ -17,23 +19,34 @@
return d
-def _encode_model_enum(obj):
+def _encode_enum(obj):
return {
"value": obj.value,
"__type__": type(obj).__name__,
}
+def _decode_model_enum(d):
+ return getattr(model, d.pop("__type__"))(d["value"])
+
+
+def _decode_storage_enum(d):
+ return getattr(interface, d.pop("__type__"))(d["value"])
+
+
ENCODERS: List[Tuple[type, str, Callable]] = [
(model.BaseModel, "model", _encode_model_object),
(SWHID, "swhid", str),
- (model.MetadataTargetType, "model_enum", _encode_model_enum),
- (model.MetadataAuthorityType, "model_enum", _encode_model_enum),
+ (model.MetadataTargetType, "model_enum", _encode_enum),
+ (model.MetadataAuthorityType, "model_enum", _encode_enum),
+ (interface.ListOrder, "storage_enum", _encode_enum),
]
DECODERS: Dict[str, Callable] = {
"swhid": parse_swhid,
"model": lambda d: getattr(model, d.pop("__type__")).from_dict(d),
- "model_enum": lambda d: getattr(model, d.pop("__type__"))(d["value"]),
+ "model_enum": _decode_model_enum,
+ "model_enum": _decode_model_enum,
+ "storage_enum": _decode_storage_enum,
}
diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py
--- a/swh/storage/cassandra/cql.py
+++ b/swh/storage/cassandra/cql.py
@@ -43,6 +43,8 @@
Origin,
)
+from swh.storage.interface import ListOrder
+
from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url
from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS
@@ -734,11 +736,8 @@
origin_url: str,
last_visit: Optional[int],
limit: Optional[int],
- order: str = "asc",
+ order: ListOrder,
) -> ResultSet:
- order = order.lower()
- assert order in ["asc", "desc"]
-
args: List[Any] = [origin_url]
if last_visit is not None:
@@ -753,7 +752,7 @@
else:
limit_name = "no_limit"
- method_name = f"_origin_visit_get_{page_name}_{order}_{limit_name}"
+ method_name = f"_origin_visit_get_{page_name}_{order.value}_{limit_name}"
origin_visit_get_method = getattr(self, method_name)
return origin_visit_get_method(*args)
diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -32,7 +32,7 @@
MetadataTargetType,
RawExtrinsicMetadata,
)
-from swh.storage.interface import PagedResult
+from swh.storage.interface import ListOrder, PagedResult
from swh.storage.objstorage import ObjStorage
from swh.storage.writer import JournalWriter
from swh.storage.utils import map_optional, now
@@ -846,15 +846,11 @@
self,
origin: str,
page_token: Optional[str] = None,
- order: str = "asc",
+ order: ListOrder = ListOrder.ASC,
limit: int = 10,
) -> PagedResult[OriginVisit]:
- order = order.lower()
- allowed_orders = ["asc", "desc"]
- if order not in allowed_orders:
- raise StorageArgumentException(
- f"order must be one of {', '.join(allowed_orders)}."
- )
+ if not isinstance(order, ListOrder):
+ raise StorageArgumentException("order must be a ListOrder value")
if page_token and not isinstance(page_token, str):
raise StorageArgumentException("page_token must be a string.")
diff --git a/swh/storage/db.py b/swh/storage/db.py
--- a/swh/storage/db.py
+++ b/swh/storage/db.py
@@ -12,6 +12,7 @@
from swh.core.db.db_utils import stored_procedure, jsonize as _jsonize
from swh.core.db.db_utils import execute_values_generator
from swh.model.model import OriginVisit, OriginVisitStatus, SHA1_SIZE
+from swh.storage.interface import ListOrder
def jsonize(d):
@@ -620,9 +621,8 @@
yield from cur
def origin_visit_get_range(
- self, origin: str, visit_from: int, order: str, limit: int, cur=None,
+ self, origin: str, visit_from: int, order: ListOrder, limit: int, cur=None,
):
- assert order in ["asc", "desc"]
cur = self._cursor(cur)
origin_visit_cols = ["o.url as origin", "ov.visit", "ov.date", "ov.type"]
@@ -634,13 +634,13 @@
query_params: List[Any] = [origin]
if visit_from > 0:
- op_comparison = ">" if order == "asc" else "<"
+ op_comparison = ">" if order == ListOrder.ASC else "<"
query_parts.append(f"and ov.visit {op_comparison} %s")
query_params.append(visit_from)
- if order == "asc":
+ if order == ListOrder.ASC:
query_parts.append("ORDER BY ov.visit ASC")
- elif order == "desc":
+ elif order == ListOrder.DESC:
query_parts.append("ORDER BY ov.visit DESC")
query_parts.append("LIMIT %s")
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -51,7 +51,7 @@
RawExtrinsicMetadata,
)
from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex
-from swh.storage.interface import PagedResult
+from swh.storage.interface import ListOrder, PagedResult
from swh.storage.objstorage import ObjStorage
from swh.storage.utils import now
@@ -864,17 +864,13 @@
self,
origin: str,
page_token: Optional[str] = None,
- order: str = "asc",
+ order: ListOrder = ListOrder.ASC,
limit: int = 10,
) -> PagedResult[OriginVisit]:
next_page_token = None
page_token = page_token or "0"
- order = order.lower()
- allowed_orders = ["asc", "desc"]
- if order not in allowed_orders:
- raise StorageArgumentException(
- f"order must be one of {', '.join(allowed_orders)}."
- )
+ if not isinstance(order, ListOrder):
+ raise StorageArgumentException("order must be a ListOrder value")
if not isinstance(page_token, str):
raise StorageArgumentException("page_token must be a string.")
@@ -884,12 +880,12 @@
visits = sorted(
self._origin_visits.get(origin_url, []),
key=lambda v: v.visit,
- reverse=(order == "desc"),
+ reverse=(order == ListOrder.DESC),
)
- if visit_from > 0 and order == "asc":
+ if visit_from > 0 and order == ListOrder.ASC:
visits = [v for v in visits if v.visit > visit_from]
- elif visit_from > 0 and order == "desc":
+ elif visit_from > 0 and order == ListOrder.DESC:
visits = [v for v in visits if v.visit < visit_from]
visits = visits[:extra_limit]
diff --git a/swh/storage/interface.py b/swh/storage/interface.py
--- a/swh/storage/interface.py
+++ b/swh/storage/interface.py
@@ -5,8 +5,10 @@
import datetime
+from enum import Enum
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union
+
from swh.core.api import remote_api_endpoint
from swh.core.api.classes import PagedResult as CorePagedResult
from swh.model.identifiers import SWHID
@@ -28,15 +30,22 @@
)
-def deprecated(f):
- f.deprecated_endpoint = True
- return f
+class ListOrder(Enum):
+ """Specifies the order for paginated endpoints returning sorted results."""
+
+ ASC = "asc"
+ DESC = "desc"
TResult = TypeVar("TResult")
PagedResult = CorePagedResult[TResult, str]
+def deprecated(f):
+ f.deprecated_endpoint = True
+ return f
+
+
class StorageInterface:
@remote_api_endpoint("check_config")
def check_config(self, *, check_write):
@@ -799,7 +808,7 @@
self,
origin: str,
page_token: Optional[str] = None,
- order: str = "asc",
+ order: ListOrder = ListOrder.ASC,
limit: int = 10,
) -> PagedResult[OriginVisit]:
"""Retrieve page of OriginVisit information.
diff --git a/swh/storage/storage.py b/swh/storage/storage.py
--- a/swh/storage/storage.py
+++ b/swh/storage/storage.py
@@ -45,7 +45,7 @@
RawExtrinsicMetadata,
)
from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex
-from swh.storage.interface import PagedResult
+from swh.storage.interface import ListOrder, PagedResult
from swh.storage.objstorage import ObjStorage
from swh.storage.utils import now
@@ -883,18 +883,14 @@
self,
origin: str,
page_token: Optional[str] = None,
- order: str = "asc",
+ order: ListOrder = ListOrder.ASC,
limit: int = 10,
db=None,
cur=None,
) -> PagedResult[OriginVisit]:
page_token = page_token or "0"
- order = order.lower()
- allowed_orders = ["asc", "desc"]
- if order not in allowed_orders:
- raise StorageArgumentException(
- f"order must be one of {', '.join(allowed_orders)}."
- )
+ if not isinstance(order, ListOrder):
+ raise StorageArgumentException("order must be a ListOrder value")
if not isinstance(page_token, str):
raise StorageArgumentException("page_token must be a string.")
diff --git a/swh/storage/tests/test_serializers.py b/swh/storage/tests/test_serializers.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/tests/test_serializers.py
@@ -0,0 +1,41 @@
+# Copyright (C) 2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+from swh.storage.interface import ListOrder
+from swh.model import model
+
+from swh.storage.api.serializers import (
+ _encode_enum,
+ _decode_model_enum,
+ _decode_storage_enum,
+)
+
+
+def test_model_enum_serialization(sample_data):
+ result_enum = model.MetadataAuthorityType.DEPOSIT_CLIENT
+ actual_serialized_enum = _encode_enum(result_enum)
+
+ expected_serialized_enum = {
+ "value": result_enum.value,
+ "__type__": type(result_enum).__name__,
+ }
+ assert actual_serialized_enum == expected_serialized_enum
+
+ decoded_paged_result = _decode_model_enum(actual_serialized_enum)
+ assert decoded_paged_result == result_enum
+
+
+def test_storage_enum_serialization(sample_data):
+ result_enum = ListOrder.ASC
+ actual_serialized_enum = _encode_enum(result_enum)
+
+ expected_serialized_enum = {
+ "value": result_enum.value,
+ "__type__": type(result_enum).__name__,
+ }
+ assert actual_serialized_enum == expected_serialized_enum
+
+ decoded_paged_result = _decode_storage_enum(actual_serialized_enum)
+ assert decoded_paged_result == result_enum
diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py
--- a/swh/storage/tests/test_storage.py
+++ b/swh/storage/tests/test_storage.py
@@ -41,7 +41,7 @@
from swh.storage import get_storage
from swh.storage.converters import origin_url_to_sha1 as sha1
from swh.storage.exc import HashCollision, StorageArgumentException
-from swh.storage.interface import StorageInterface, PagedResult # noqa
+from swh.storage.interface import ListOrder, PagedResult, StorageInterface
from swh.storage.utils import content_hex_hashes, now
@@ -1268,7 +1268,7 @@
swh_storage.origin_visit_get(origin.url, page_token=10) # not bytes
with pytest.raises(
- StorageArgumentException, match="order must be one of asc, desc"
+ StorageArgumentException, match="order must be a ListOrder value"
):
swh_storage.origin_visit_get(origin.url, order="foobar") # wrong order
@@ -1337,18 +1337,20 @@
assert actual_page == PagedResult(results=[ov3])
# order desc, no pagination, no limit
- actual_page = swh_storage.origin_visit_get(origin.url, order="desc")
+ actual_page = swh_storage.origin_visit_get(origin.url, order=ListOrder.DESC)
assert actual_page.next_page_token is None
assert actual_page == PagedResult(results=[ov3, ov2, ov1])
# order desc, no pagination, limit
- actual_page = swh_storage.origin_visit_get(origin.url, limit=2, order="desc")
+ actual_page = swh_storage.origin_visit_get(
+ origin.url, limit=2, order=ListOrder.DESC
+ )
next_page_token = actual_page.next_page_token
assert next_page_token is not None
assert actual_page.results == [ov3, ov2]
actual_page = swh_storage.origin_visit_get(
- origin.url, page_token=next_page_token, order="desc"
+ origin.url, page_token=next_page_token, order=ListOrder.DESC
)
assert actual_page.next_page_token is None
assert actual_page.results == [ov1]
@@ -1357,21 +1359,21 @@
# order desc, pagination, no limit
next_page_token = str(ov3.visit)
actual_page = swh_storage.origin_visit_get(
- origin.url, page_token=next_page_token, order="desc"
+ origin.url, page_token=next_page_token, order=ListOrder.DESC
)
assert actual_page.next_page_token is None
assert actual_page == PagedResult(results=[ov2, ov1])
# order desc, pagination, limit
actual_page = swh_storage.origin_visit_get(
- origin.url, page_token=next_page_token, order="desc", limit=1
+ origin.url, page_token=next_page_token, order=ListOrder.DESC, limit=1
)
next_page_token = actual_page.next_page_token
assert next_page_token is not None
assert actual_page.results == [ov2]
actual_page = swh_storage.origin_visit_get(
- origin.url, page_token=next_page_token, order="desc"
+ origin.url, page_token=next_page_token, order=ListOrder.DESC
)
assert actual_page == PagedResult(results=[ov1])

File Metadata

Mime Type
text/plain
Expires
Thu, Jan 23, 1:24 AM (18 h, 46 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3230093

Event Timeline