diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -844,15 +844,37 @@ def origin_visit_get( self, origin: str, - last_visit: Optional[int] = None, - limit: Optional[int] = None, + page_token: Optional[str] = None, order: str = "asc", - ) -> Iterable[Dict[str, Any]]: - rows = self._cql_runner.origin_visit_get(origin, last_visit, limit, order) + limit: int = 10, + ) -> Dict[str, Any]: + result: Dict[str, Any] = {} + order = order.lower() + allowed_orders = ["asc", "desc"] + if order not in allowed_orders: + raise StorageArgumentException( + f"order must be one of {', '.join(allowed_orders)}." + ) + if page_token and not isinstance(page_token, str): + raise StorageArgumentException("page_token must be a string.") + visit_from = page_token and int(page_token) + visits: List[OriginVisit] = [] + extra_limit = limit + 1 + rows = self._cql_runner.origin_visit_get(origin, visit_from, extra_limit, order) for row in rows: - visit = self._format_origin_visit_row(row) - yield self._origin_visit_apply_last_status(visit) + visits.append(converters.row_to_visit(row)) + + assert len(visits) <= extra_limit + if len(visits) == extra_limit: + last_visit = visits[limit] + visits = visits[:limit] + assert last_visit is not None + result["next_page_token"] = str(last_visit.visit) + + if visits: + result["visits"] = visits + return result def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -481,6 +481,8 @@ + [jsonize(visit_status.metadata)], ) + origin_visit_cols = ["origin", "visit", "date", "type"] + def origin_visit_add_with_id(self, origin_visit: OriginVisit, cur=None) -> None: """Insert origin visit when id are already set @@ -488,12 +490,11 @@ ov = origin_visit assert ov.visit is not None cur = self._cursor(cur) - origin_visit_cols = ["origin", "visit", "date", "type"] query = """INSERT INTO origin_visit ({cols}) VALUES ((select id from origin where url=%s), {values}) ON CONFLICT (origin, visit) DO NOTHING""".format( - cols=", ".join(origin_visit_cols), - values=", ".join("%s" for col in origin_visit_cols[1:]), + cols=", ".join(self.origin_visit_cols), + values=", ".join("%s" for col in self.origin_visit_cols[1:]), ) cur.execute(query, (ov.origin, ov.visit, ov.date, ov.type)) @@ -618,6 +619,42 @@ cur.execute(query, tuple(query_params)) yield from cur + def origin_visit_get_range( + self, + origin: str, + visit_from: int = 0, + order: str = "asc", + limit: int = 10, + cur=None, + ): + assert order.lower() in ["asc", "desc"] + cur = self._cursor(cur) + + origin_visit_cols = ["o.url as origin", "ov.visit", "ov.date", "ov.type"] + query_parts = [ + f"SELECT {', '.join(origin_visit_cols)} " "FROM origin_visit ov ", + "INNER JOIN origin o ON o.id = ov.origin ", + ] + query_parts.append("WHERE o.url = %s") + query_params: List[Any] = [origin] + + if visit_from > 0: + op_comparison = ">" if order == "asc" else "<" + query_parts.append(f"and ov.visit {op_comparison} %s") + query_params.append(visit_from) + + if order == "asc": + query_parts.append("ORDER BY ov.visit ASC") + elif order == "desc": + query_parts.append("ORDER BY ov.visit DESC") + + query_parts.append("LIMIT %s") + query_params.append(limit) + + query = "\n".join(query_parts) + cur.execute(query, tuple(query_params)) + yield from cur + def origin_visit_get(self, origin_id, visit_id, cur=None): """Retrieve information on visit visit_id of origin origin_id. diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -862,31 +862,48 @@ def origin_visit_get( self, origin: str, - last_visit: Optional[int] = None, - limit: Optional[int] = None, + page_token: Optional[str] = None, order: str = "asc", - ) -> Iterable[Dict[str, Any]]: + limit: int = 10, + ) -> Dict[str, Any]: + result: Dict[str, Any] = {} + page_token = page_token or "0" order = order.lower() - assert order in ["asc", "desc"] - origin_url = self._get_origin_url(origin) - if origin_url in self._origin_visits: - visits = self._origin_visits[origin_url] - visits = sorted(visits, key=lambda v: v.visit, reverse=(order == "desc")) - if last_visit is not None: - if order == "asc": - visits = [v for v in visits if v.visit > last_visit] - else: - visits = [v for v in visits if v.visit < last_visit] - if limit is not None: - visits = visits[:limit] - for visit in visits: - if not visit: - continue - visit_id = visit.visit + allowed_orders = ["asc", "desc"] + if order not in allowed_orders: + raise StorageArgumentException( + f"order must be one of {', '.join(allowed_orders)}." + ) + if not isinstance(page_token, str): + raise StorageArgumentException("page_token must be a string.") - visit_update = self._origin_visit_get_updated(origin_url, visit_id) - assert visit_update is not None - yield visit_update + visit_from = int(page_token) + origin_url = self._get_origin_url(origin) + extra_limit = limit + 1 + visits = sorted( + self._origin_visits.get(origin_url, []), + key=lambda v: v.visit, + reverse=(order == "desc"), + ) + if not visits: + return result + + if visit_from > 0 and order == "asc": + visits = [v for v in visits if v.visit > visit_from] + elif visit_from > 0 and order == "desc": + visits = [v for v in visits if v.visit < visit_from] + visits = [v for v in visits if v is not None][:extra_limit] + + assert len(visits) <= extra_limit + if len(visits) == extra_limit: + last_visit = visits[limit] + visits = visits[:limit] + assert last_visit is not None + result["next_page_token"] = str(last_visit.visit) + + if visits: + result["visits"] = visits + return result def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -793,26 +793,55 @@ def origin_visit_get( self, origin: str, - last_visit: Optional[int] = None, - limit: Optional[int] = None, + page_token: Optional[str] = None, order: str = "asc", - ) -> Iterable[Dict[str, Any]]: - """Retrieve all the origin's visit's information. + limit: int = 10, + ) -> Dict[str, Any]: + """Retrieve OriginVisit information. Args: origin: The visited origin - last_visit: Starting point from which listing the next visits - Default to None - limit: Number of results to return from the last visit. - Default to None + page_token: opaque string used to get the next results of a search order: Order on visit id fields to list origin visits (default to asc) + limit: Number of visits to return - Yields: - List of visits. + Raises: + StorageArgumentException if wrong order or wrong page_token type + + Returns: + Dict with the following keys: + - **next_page_token** (str, optional): opaque token to be used as + `page_token` for retrieving the next page. if absent, there is + no more pages to gather. + - **visits** (Iterable[OriginVisit]): list of visits """ ... + # @remote_api_endpoint("origin/visit_status/get") + # def origin_visit_status_get( + # self, + # origin: str, + # last_visit: Optional[int] = None, + # limit: Optional[int] = None, + # order: str = "asc", + # ) -> Iterable[Optional[OriginVisitStatus]]: + # """Retrieve OriginVisit information. + + # Args: + # origin: The visited origin + # last_visit: Starting point from which listing the next visits + # Default to None + # limit: Number of results to return from the last visit. + # Default to None + # order: Order on visit id fields to list origin visits (default to asc) + + # Yields: + # List of optional (when not found) visits + + # """ + # ... + @remote_api_endpoint("origin/visit/find_by_date") def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -877,22 +877,54 @@ return OriginVisitStatus.from_dict(row) @timed - @db_transaction_generator(statement_timeout=500) + @db_transaction(statement_timeout=500) def origin_visit_get( self, origin: str, - last_visit: Optional[int] = None, - limit: Optional[int] = None, + page_token: Optional[str] = None, order: str = "asc", + limit: int = 10, db=None, cur=None, - ) -> Iterable[Dict[str, Any]]: - assert order in ["asc", "desc"] - lines = db.origin_visit_get_all( - origin, last_visit=last_visit, limit=limit, order=order, cur=cur - ) - for line in lines: - yield dict(zip(db.origin_visit_get_cols, line)) + ) -> Dict[str, Any]: + result: Dict[str, Any] = {} + page_token = page_token or "0" + order = order.lower() + allowed_orders = ["asc", "desc"] + if order not in allowed_orders: + raise StorageArgumentException( + f"order must be one of {', '.join(allowed_orders)}." + ) + if not isinstance(page_token, str): + raise StorageArgumentException("page_token must be a string.") + + visit_from = int(page_token) + visits: List[OriginVisit] = [] + extra_limit = limit + 1 + for row in db.origin_visit_get_range( + origin, visit_from=visit_from, order=order, limit=extra_limit, cur=cur + ): + row_d = dict(zip(db.origin_visit_cols, row)) + visits.append( + OriginVisit( + origin=row_d["origin"], + visit=int(row_d["visit"]), + date=row_d["date"], + type=row_d["type"], + ) + ) + + assert len(visits) <= extra_limit + + if len(visits) == extra_limit: + last_visit = visits[limit] + visits = visits[:limit] + assert last_visit is not None + result["next_page_token"] = str(last_visit.visit) + + if visits: + result["visits"] = visits + return result @timed @db_transaction(statement_timeout=500) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1253,10 +1253,27 @@ visits.append(date_visit) return visits + def test_origin_visit_get__unknown_origin(self, swh_storage): + actual_visit = swh_storage.origin_visit_get("foo") + assert actual_visit == {} + + def test_origin_visit_get__validation_failure(self, swh_storage, sample_data): + origin = sample_data.origin + swh_storage.origin_add([origin]) + with pytest.raises( + StorageArgumentException, match="page_token must be a string" + ): + swh_storage.origin_visit_get(origin.url, page_token=10) # not a string + + with pytest.raises( + StorageArgumentException, match="order must be one of asc, desc" + ): + swh_storage.origin_visit_get(origin.url, order="foobar") # wrong order + def test_origin_visit_get_all(self, swh_storage, sample_data): origin = sample_data.origin swh_storage.origin_add([origin]) - visits = swh_storage.origin_visit_add( + ov1, ov2, ov3 = swh_storage.origin_visit_add( [ OriginVisit( origin=origin.url, @@ -1275,59 +1292,54 @@ ), ] ) - ov1, ov2, ov3 = [ - {**v.to_dict(), "status": "created", "snapshot": None, "metadata": None,} - for v in visits - ] # order asc, no pagination, no limit - all_visits = list(swh_storage.origin_visit_get(origin.url)) - assert all_visits == [ov1, ov2, ov3] + actual_result = swh_storage.origin_visit_get(origin.url) + assert actual_result.get("next_page_token") is None + assert actual_result["visits"] == [ov1, ov2, ov3] # order asc, no pagination, limit - all_visits2 = list(swh_storage.origin_visit_get(origin.url, limit=2)) - assert all_visits2 == [ov1, ov2] + actual_result = swh_storage.origin_visit_get(origin.url, limit=2) + assert actual_result["next_page_token"] == str(ov3.visit) + assert actual_result["visits"] == [ov1, ov2] # order asc, pagination, no limit - all_visits3 = list( - swh_storage.origin_visit_get(origin.url, last_visit=ov1["visit"]) + actual_result = swh_storage.origin_visit_get( + origin.url, page_token=str(ov1.visit) ) - assert all_visits3 == [ov2, ov3] + assert actual_result.get("next_page_token") is None + assert actual_result["visits"] == [ov2, ov3] # order asc, pagination, limit - all_visits4 = list( - swh_storage.origin_visit_get(origin.url, last_visit=ov2["visit"], limit=1) + actual_result = swh_storage.origin_visit_get( + origin.url, page_token=str(ov2.visit), limit=1 ) - assert all_visits4 == [ov3] + assert actual_result.get("next_page_token") is None + assert actual_result["visits"] == [ov3] # order desc, no pagination, no limit - all_visits5 = list(swh_storage.origin_visit_get(origin.url, order="desc")) - assert all_visits5 == [ov3, ov2, ov1] + actual_result = swh_storage.origin_visit_get(origin.url, order="desc") + assert actual_result.get("next_page_token") is None + assert actual_result["visits"] == [ov3, ov2, ov1] # order desc, no pagination, limit - all_visits6 = list( - swh_storage.origin_visit_get(origin.url, limit=2, order="desc") - ) - assert all_visits6 == [ov3, ov2] + actual_result = swh_storage.origin_visit_get(origin.url, limit=2, order="desc") + assert actual_result["next_page_token"] == str(ov1.visit) + assert actual_result["visits"] == [ov3, ov2] # order desc, pagination, no limit - all_visits7 = list( - swh_storage.origin_visit_get( - origin.url, last_visit=ov3["visit"], order="desc" - ) + actual_result = swh_storage.origin_visit_get( + origin.url, page_token=str(ov3.visit), order="desc" ) - assert all_visits7 == [ov2, ov1] + assert actual_result.get("next_page_token") is None + assert actual_result["visits"] == [ov2, ov1] # order desc, pagination, limit - all_visits8 = list( - swh_storage.origin_visit_get( - origin.url, last_visit=ov3["visit"], order="desc", limit=1 - ) + actual_result = swh_storage.origin_visit_get( + origin.url, page_token=str(ov3.visit), order="desc", limit=1 ) - assert all_visits8 == [ov2] - - def test_origin_visit_get__unknown_origin(self, swh_storage): - assert [] == list(swh_storage.origin_visit_get("foo")) + assert actual_result["next_page_token"] == str(ov1.visit) + assert actual_result["visits"] == [ov2] def test_origin_visit_status_get_random(self, swh_storage, sample_data): origins = sample_data.origins[:2]