diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -58,3 +58,6 @@ [mypy-pytest_postgresql.*] ignore_missing_imports = True + +[mypy-swh.core.*] +ignore_missing_imports = True diff --git a/requirements-swh.txt b/requirements-swh.txt --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,3 +1,3 @@ -swh.core[db,http] >= 0.1.0 +swh.core[db,http] >= 0.2.0 swh.model >= 0.4.0 swh.objstorage >= 0.0.40 diff --git a/swh/storage/api/serializers.py b/swh/storage/api/serializers.py --- a/swh/storage/api/serializers.py +++ b/swh/storage/api/serializers.py @@ -5,18 +5,22 @@ """Decoder and encoders for swh-model objects.""" -from typing import Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Tuple from swh.model.identifiers import SWHID, parse_swhid import swh.model.model as model -def _encode_model_object(obj): +def _encode_model_object(obj: model.BaseModel) -> Dict[str, Any]: d = obj.to_dict() d["__type__"] = type(obj).__name__ return d +def _decode_model_object(d: Dict[str, Any]) -> model.BaseModel: + return getattr(model, d.pop("__type__")).from_dict(d) + + def _encode_model_enum(obj): return { "value": obj.value, @@ -34,6 +38,6 @@ DECODERS: Dict[str, Callable] = { "swhid": parse_swhid, - "model": lambda d: getattr(model, d.pop("__type__")).from_dict(d), + "model": _decode_model_object, "model_enum": lambda d: getattr(model, d.pop("__type__"))(d["value"]), } diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -12,6 +12,7 @@ import attr +from swh.core.api.model import PagedResult from swh.core.api.serializers import msgpack_loads, msgpack_dumps from swh.model.identifiers import parse_swhid, SWHID from swh.model.hashutil import DEFAULT_ALGORITHMS @@ -844,15 +845,36 @@ def origin_visit_get( self, origin: str, - last_visit: Optional[int] = None, - limit: Optional[int] = None, + page_token: Optional[str] = None, order: str = "asc", - ) -> Iterable[Dict[str, Any]]: - rows = self._cql_runner.origin_visit_get(origin, last_visit, limit, order) + limit: int = 10, + ) -> PagedResult[OriginVisit]: + order = order.lower() + allowed_orders = ["asc", "desc"] + if order not in allowed_orders: + raise StorageArgumentException( + f"order must be one of {', '.join(allowed_orders)}." + ) + if page_token and not isinstance(page_token, str): + raise StorageArgumentException("page_token must be a string.") + + next_page_token = None + visit_from = page_token and int(page_token) + visits: List[OriginVisit] = [] + extra_limit = limit + 1 + rows = self._cql_runner.origin_visit_get(origin, visit_from, extra_limit, order) for row in rows: - visit = self._format_origin_visit_row(row) - yield self._origin_visit_apply_last_status(visit) + visits.append(converters.row_to_visit(row)) + + assert len(visits) <= extra_limit + if len(visits) == extra_limit: + last_visit = visits[limit] + visits = visits[:limit] + assert last_visit is not None + next_page_token = str(last_visit.visit) + + return PagedResult(results=visits, next_page_token=next_page_token) def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -481,6 +481,8 @@ + [jsonize(visit_status.metadata)], ) + origin_visit_cols = ["origin", "visit", "date", "type"] + def origin_visit_add_with_id(self, origin_visit: OriginVisit, cur=None) -> None: """Insert origin visit when id are already set @@ -488,12 +490,11 @@ ov = origin_visit assert ov.visit is not None cur = self._cursor(cur) - origin_visit_cols = ["origin", "visit", "date", "type"] query = """INSERT INTO origin_visit ({cols}) VALUES ((select id from origin where url=%s), {values}) ON CONFLICT (origin, visit) DO NOTHING""".format( - cols=", ".join(origin_visit_cols), - values=", ".join("%s" for col in origin_visit_cols[1:]), + cols=", ".join(self.origin_visit_cols), + values=", ".join("%s" for col in self.origin_visit_cols[1:]), ) cur.execute(query, (ov.origin, ov.visit, ov.date, ov.type)) @@ -618,6 +619,37 @@ cur.execute(query, tuple(query_params)) yield from cur + def origin_visit_get_range( + self, origin: str, visit_from: int, order: str, limit: int, cur=None, + ): + assert order in ["asc", "desc"] + cur = self._cursor(cur) + + origin_visit_cols = ["o.url as origin", "ov.visit", "ov.date", "ov.type"] + query_parts = [ + f"SELECT {', '.join(origin_visit_cols)} FROM origin_visit ov ", + "INNER JOIN origin o ON o.id = ov.origin ", + ] + query_parts.append("WHERE o.url = %s") + query_params: List[Any] = [origin] + + if visit_from > 0: + op_comparison = ">" if order == "asc" else "<" + query_parts.append(f"and ov.visit {op_comparison} %s") + query_params.append(visit_from) + + if order == "asc": + query_parts.append("ORDER BY ov.visit ASC") + elif order == "desc": + query_parts.append("ORDER BY ov.visit DESC") + + query_parts.append("LIMIT %s") + query_params.append(limit) + + query = "\n".join(query_parts) + cur.execute(query, tuple(query_params)) + yield from cur + def origin_visit_get(self, origin_id, visit_id, cur=None): """Retrieve information on visit visit_id of origin origin_id. diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -30,6 +30,7 @@ import attr +from swh.core.api.model import PagedResult from swh.core.api.serializers import msgpack_loads, msgpack_dumps from swh.model.identifiers import SWHID from swh.model.model import ( @@ -862,31 +863,44 @@ def origin_visit_get( self, origin: str, - last_visit: Optional[int] = None, - limit: Optional[int] = None, + page_token: Optional[str] = None, order: str = "asc", - ) -> Iterable[Dict[str, Any]]: + limit: int = 10, + ) -> PagedResult[OriginVisit]: + next_page_token = None + page_token = page_token or "0" order = order.lower() - assert order in ["asc", "desc"] + allowed_orders = ["asc", "desc"] + if order not in allowed_orders: + raise StorageArgumentException( + f"order must be one of {', '.join(allowed_orders)}." + ) + if not isinstance(page_token, str): + raise StorageArgumentException("page_token must be a string.") + + visit_from = int(page_token) origin_url = self._get_origin_url(origin) - if origin_url in self._origin_visits: - visits = self._origin_visits[origin_url] - visits = sorted(visits, key=lambda v: v.visit, reverse=(order == "desc")) - if last_visit is not None: - if order == "asc": - visits = [v for v in visits if v.visit > last_visit] - else: - visits = [v for v in visits if v.visit < last_visit] - if limit is not None: - visits = visits[:limit] - for visit in visits: - if not visit: - continue - visit_id = visit.visit + extra_limit = limit + 1 + visits = sorted( + self._origin_visits.get(origin_url, []), + key=lambda v: v.visit, + reverse=(order == "desc"), + ) + + if visit_from > 0 and order == "asc": + visits = [v for v in visits if v.visit > visit_from] + elif visit_from > 0 and order == "desc": + visits = [v for v in visits if v.visit < visit_from] + visits = [v for v in visits if v][:extra_limit] + + assert len(visits) <= extra_limit + if len(visits) == extra_limit: + last_visit = visits[limit] + visits = visits[:limit] + assert last_visit is not None + next_page_token = str(last_visit.visit) - visit_update = self._origin_visit_get_updated(origin_url, visit_id) - assert visit_update is not None - yield visit_update + return PagedResult(results=visits, next_page_token=next_page_token) def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -5,9 +5,10 @@ import datetime -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union from swh.core.api import remote_api_endpoint +from swh.core.api.model import PagedResult from swh.model.identifiers import SWHID from swh.model.model import ( Content, @@ -793,22 +794,24 @@ def origin_visit_get( self, origin: str, - last_visit: Optional[int] = None, - limit: Optional[int] = None, + page_token: Optional[str] = None, order: str = "asc", - ) -> Iterable[Dict[str, Any]]: - """Retrieve all the origin's visit's information. + limit: int = 10, + ) -> PagedResult[OriginVisit]: + """Retrieve page of OriginVisit information. Args: origin: The visited origin - last_visit: Starting point from which listing the next visits - Default to None - limit: Number of results to return from the last visit. - Default to None + page_token: opaque string used to get the next results of a search order: Order on visit id fields to list origin visits (default to asc) + limit: Number of visits to return - Yields: - List of visits. + Raises: + StorageArgumentException if the order is wrong or the page_token type is + mistyped. + + Returns: Page of OriginVisit data model objects. if next_page_token is None, + there is no longer data to retrieve. """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -25,6 +25,7 @@ import psycopg2.pool import psycopg2.errors +from swh.core.api.model import PagedResult from swh.core.api.serializers import msgpack_loads, msgpack_dumps from swh.model.identifiers import parse_swhid, SWHID from swh.model.model import ( @@ -877,22 +878,52 @@ return OriginVisitStatus.from_dict(row) @timed - @db_transaction_generator(statement_timeout=500) + @db_transaction(statement_timeout=500) def origin_visit_get( self, origin: str, - last_visit: Optional[int] = None, - limit: Optional[int] = None, + page_token: Optional[str] = None, order: str = "asc", + limit: int = 10, db=None, cur=None, - ) -> Iterable[Dict[str, Any]]: - assert order in ["asc", "desc"] - lines = db.origin_visit_get_all( - origin, last_visit=last_visit, limit=limit, order=order, cur=cur - ) - for line in lines: - yield dict(zip(db.origin_visit_get_cols, line)) + ) -> PagedResult[OriginVisit]: + page_token = page_token or "0" + order = order.lower() + allowed_orders = ["asc", "desc"] + if order not in allowed_orders: + raise StorageArgumentException( + f"order must be one of {', '.join(allowed_orders)}." + ) + if not isinstance(page_token, str): + raise StorageArgumentException("page_token must be a string.") + + next_page_token = None + visit_from = int(page_token) + visits: List[OriginVisit] = [] + extra_limit = limit + 1 + for row in db.origin_visit_get_range( + origin, visit_from=visit_from, order=order, limit=extra_limit, cur=cur + ): + row_d = dict(zip(db.origin_visit_cols, row)) + visits.append( + OriginVisit( + origin=row_d["origin"], + visit=row_d["visit"], + date=row_d["date"], + type=row_d["type"], + ) + ) + + assert len(visits) <= extra_limit + + if len(visits) == extra_limit: + last_visit = visits[limit] + visits = visits[:limit] + assert last_visit is not None + next_page_token = str(last_visit.visit) + + return PagedResult(results=visits, next_page_token=next_page_token) @timed @db_transaction(statement_timeout=500) diff --git a/swh/storage/tests/test_retry.py b/swh/storage/tests/test_retry.py --- a/swh/storage/tests/test_retry.py +++ b/swh/storage/tests/test_retry.py @@ -269,16 +269,15 @@ swh_storage.origin_add([origin]) - origins = list(swh_storage.origin_visit_get(origin.url)) + origins = swh_storage.origin_visit_get(origin.url).results assert not origins origin_visit = swh_storage.origin_visit_add([visit])[0] assert origin_visit.origin == origin.url assert isinstance(origin_visit.visit, int) - origin_visit = next(swh_storage.origin_visit_get(origin.url)) - assert origin_visit["origin"] == origin.url - assert isinstance(origin_visit["visit"], int) + actual_visit = swh_storage.origin_visit_get(origin.url).results[0] + assert actual_visit == visit def test_retrying_proxy_swh_storage_origin_visit_add_retry( @@ -303,7 +302,7 @@ [visit], ] - origins = list(swh_storage.origin_visit_get(origin.url)) + origins = swh_storage.origin_visit_get(origin.url).results assert not origins r = swh_storage.origin_visit_add([visit]) @@ -327,7 +326,7 @@ visit = sample_data.origin_visit assert visit.origin == origin.url - origins = list(swh_storage.origin_visit_get(origin.url)) + origins = swh_storage.origin_visit_get(origin.url).results assert not origins with pytest.raises(StorageArgumentException, match="Refuse to add"): diff --git a/swh/storage/tests/test_serializers.py b/swh/storage/tests/test_serializers.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_serializers.py @@ -0,0 +1,24 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from swh.storage.api.serializers import ( + _encode_model_object, + _decode_model_object, +) + + +def test_model_object_serialization(sample_data): + content = sample_data.content + + actual_content_dict = _encode_model_object(content) + + expected_content_dict = content.to_dict() + expected_content_dict["__type__"] = type(content).__name__ + + assert actual_content_dict == expected_content_dict + + decoded_content = _decode_model_object(actual_content_dict) + + assert decoded_content == content diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -41,7 +41,7 @@ from swh.storage import get_storage from swh.storage.converters import origin_url_to_sha1 as sha1 from swh.storage.exc import HashCollision, StorageArgumentException -from swh.storage.interface import StorageInterface +from swh.storage.interface import StorageInterface, PagedResult from swh.storage.utils import content_hex_hashes, now @@ -1253,10 +1253,27 @@ visits.append(date_visit) return visits + def test_origin_visit_get__unknown_origin(self, swh_storage): + actual_visit = swh_storage.origin_visit_get("foo") + assert actual_visit == PagedResult() + + def test_origin_visit_get__validation_failure(self, swh_storage, sample_data): + origin = sample_data.origin + swh_storage.origin_add([origin]) + with pytest.raises( + StorageArgumentException, match="page_token must be a string" + ): + swh_storage.origin_visit_get(origin.url, page_token=10) # not bytes + + with pytest.raises( + StorageArgumentException, match="order must be one of asc, desc" + ): + swh_storage.origin_visit_get(origin.url, order="foobar") # wrong order + def test_origin_visit_get_all(self, swh_storage, sample_data): origin = sample_data.origin swh_storage.origin_add([origin]) - visits = swh_storage.origin_visit_add( + ov1, ov2, ov3 = swh_storage.origin_visit_add( [ OriginVisit( origin=origin.url, @@ -1275,59 +1292,64 @@ ), ] ) - ov1, ov2, ov3 = [ - {**v.to_dict(), "status": "created", "snapshot": None, "metadata": None,} - for v in visits - ] # order asc, no pagination, no limit - all_visits = list(swh_storage.origin_visit_get(origin.url)) - assert all_visits == [ov1, ov2, ov3] + actual_result = swh_storage.origin_visit_get(origin.url) + assert actual_result == PagedResult( + results=[ov1, ov2, ov3], next_page_token=None + ) # order asc, no pagination, limit - all_visits2 = list(swh_storage.origin_visit_get(origin.url, limit=2)) - assert all_visits2 == [ov1, ov2] + actual_result = swh_storage.origin_visit_get(origin.url, limit=2) + next_page_token = actual_result.next_page_token + assert next_page_token is not None + assert actual_result.results == [ov1, ov2] # order asc, pagination, no limit - all_visits3 = list( - swh_storage.origin_visit_get(origin.url, last_visit=ov1["visit"]) + actual_result = swh_storage.origin_visit_get( + origin.url, page_token=next_page_token + ) + assert actual_result == PagedResult() + + actual_result = swh_storage.origin_visit_get( + origin.url, page_token=str(ov1.visit) ) - assert all_visits3 == [ov2, ov3] + assert actual_result == PagedResult(results=[ov2, ov3]) # order asc, pagination, limit - all_visits4 = list( - swh_storage.origin_visit_get(origin.url, last_visit=ov2["visit"], limit=1) + actual_result = swh_storage.origin_visit_get( + origin.url, page_token=next_page_token, limit=1 + ) + assert actual_result == PagedResult() + + next_page_token = str(ov2.visit) + actual_result = swh_storage.origin_visit_get( + origin.url, page_token=next_page_token, limit=1 ) - assert all_visits4 == [ov3] + assert actual_result == PagedResult(results=[ov3]) # order desc, no pagination, no limit - all_visits5 = list(swh_storage.origin_visit_get(origin.url, order="desc")) - assert all_visits5 == [ov3, ov2, ov1] + actual_result = swh_storage.origin_visit_get(origin.url, order="desc") + assert actual_result == PagedResult(results=[ov3, ov2, ov1]) # order desc, no pagination, limit - all_visits6 = list( - swh_storage.origin_visit_get(origin.url, limit=2, order="desc") - ) - assert all_visits6 == [ov3, ov2] + actual_result = swh_storage.origin_visit_get(origin.url, limit=2, order="desc") + assert actual_result.next_page_token is not None + assert actual_result.results == [ov3, ov2] # order desc, pagination, no limit - all_visits7 = list( - swh_storage.origin_visit_get( - origin.url, last_visit=ov3["visit"], order="desc" - ) + next_page_token = str(ov3.visit) + actual_result = swh_storage.origin_visit_get( + origin.url, page_token=next_page_token, order="desc" ) - assert all_visits7 == [ov2, ov1] + assert actual_result == PagedResult(results=[ov2, ov1]) # order desc, pagination, limit - all_visits8 = list( - swh_storage.origin_visit_get( - origin.url, last_visit=ov3["visit"], order="desc", limit=1 - ) + actual_result = swh_storage.origin_visit_get( + origin.url, page_token=next_page_token, order="desc", limit=1 ) - assert all_visits8 == [ov2] - - def test_origin_visit_get__unknown_origin(self, swh_storage): - assert [] == list(swh_storage.origin_visit_get("foo")) + assert actual_result.next_page_token is not None + assert actual_result.results == [ov2] def test_origin_visit_status_get_random(self, swh_storage, sample_data): origins = sample_data.origins[:2] @@ -1562,21 +1584,16 @@ snapshot=None, ) - actual_origin_visits = list(swh_storage.origin_visit_get(origin1.url)) - expected_visits = [ - {**ovs1.to_dict(), "type": ov1.type}, - {**ovs2.to_dict(), "type": ov2.type}, - ] - - assert len(expected_visits) == len(actual_origin_visits) - + actual_visits = swh_storage.origin_visit_get(origin1.url).results + expected_visits = [ov1, ov2] + assert len(expected_visits) == len(actual_visits) for visit in expected_visits: - assert visit in actual_origin_visits + assert visit in actual_visits actual_objects = list(swh_storage.journal_writer.journal.objects) expected_objects = list( [("origin", origin1)] - + [("origin_visit", visit) for visit in [ov1, ov2]] * 2 + + [("origin_visit", visit) for visit in expected_visits] * 2 + [("origin_visit_status", ovs) for ovs in [ovs1, ovs2]] ) @@ -1719,7 +1736,7 @@ status="created", snapshot=None, ) - date_visit_now = now() + date_visit_now = round_to_milliseconds(now()) visit_status1 = OriginVisitStatus( origin=ov1.origin, visit=ov1.visit, @@ -1732,13 +1749,8 @@ # second call will ignore existing entries (will send to storage though) swh_storage.origin_visit_status_add([visit_status1]) - origin_visits = list(swh_storage.origin_visit_get(ov1.origin)) - - assert len(origin_visits) == 1 - origin_visit1 = origin_visits[0] - assert origin_visit1 - assert origin_visit1["status"] == "full" - assert origin_visit1["snapshot"] == snapshot.id + visit_status = swh_storage.origin_visit_status_get_latest(ov1.origin, ov1.visit) + assert visit_status == visit_status1 actual_objects = list(swh_storage.journal_writer.journal.objects)