diff --git a/swh/storage/algos/origin.py b/swh/storage/algos/origin.py --- a/swh/storage/algos/origin.py +++ b/swh/storage/algos/origin.py @@ -44,6 +44,7 @@ type: Optional[str] = None, allowed_statuses: Optional[Iterable[str]] = None, require_snapshot: bool = False, + limit: Optional[int] = 1000, ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: """Get the latest origin visit (and status) of an origin. Optionally, a combination of criteria can be provided, origin type, allowed statuses or if a visit has a @@ -69,9 +70,7 @@ status exist (and match the search criteria), None otherwise. """ - # visits order are from older visit to most recent. - visits = list(storage.origin_visit_get(origin_url)) - visits.reverse() + visits = list(storage.origin_visit_get(origin_url, order="desc", limit=limit)) if not visits: return None visit_status: Optional[OriginVisitStatus] = None diff --git a/swh/storage/algos/snapshot.py b/swh/storage/algos/snapshot.py --- a/swh/storage/algos/snapshot.py +++ b/swh/storage/algos/snapshot.py @@ -41,6 +41,7 @@ origin: str, allowed_statuses: Optional[Iterable[str]] = None, branches_count: Optional[int] = None, + limit: Optional[int] = 1000, ) -> Optional[Snapshot]: """Get the latest snapshot for the given origin, optionally only from visits that have one of the given allowed_statuses. @@ -58,6 +59,7 @@ branches_count: Optional parameter to retrieve snapshot with all branches (default behavior when None) or not. If set to positive number, the snapshot will be partial with only that number of branches. + limit: Bound the search to a given limit Raises: ValueError if branches_count is not a positive value @@ -67,7 +69,11 @@ """ visit_and_status = origin_get_latest_visit_status( - storage, origin, allowed_statuses=allowed_statuses, require_snapshot=True + storage, + origin, + allowed_statuses=allowed_statuses, + limit=limit, + require_snapshot=True, ) if not visit_and_status: diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -660,30 +660,36 @@ "snapshot", ] - @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit > ?") - def _origin_visit_get_no_limit( - self, origin_url: str, last_visit: int, *, statement + def origin_visit_get( + self, + origin_url: str, + last_visit: Optional[int], + limit: Optional[int], + order: str = "asc", ) -> ResultSet: - return self._execute_with_retries(statement, [origin_url, last_visit]) + order = order.lower() + assert order in ["asc", "desc"] - @_prepared_statement( - "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? LIMIT ?" - ) - def _origin_visit_get_limit( - self, origin_url: str, last_visit: int, limit: int, *, statement - ) -> ResultSet: - return self._execute_with_retries(statement, [origin_url, last_visit, limit]) + query_parts = [ + "SELECT * from origin_visit", + ] - def origin_visit_get( - self, origin_url: str, last_visit: Optional[int], limit: Optional[int] - ) -> ResultSet: - if last_visit is None: - last_visit = -1 + query_parts.append("WHERE origin = %s") + query_params: List[Any] = [origin_url] - if limit is None: - return self._origin_visit_get_no_limit(origin_url, last_visit) - else: - return self._origin_visit_get_limit(origin_url, last_visit, limit) + if last_visit is not None: + query_parts.append("AND visit > %s") + query_params.append(last_visit) + + if order == "desc": + query_parts.append("ORDER BY visit DESC") + + if limit is not None and limit > 0: + query_parts.append("LIMIT %s") + query_params.append(limit) + + query = " ".join(query_parts) + return self._execute_with_retries(query, tuple(query_params)) @_prepared_insert_statement("origin_visit", _origin_visit_keys) def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None: diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -885,9 +885,14 @@ } def origin_visit_get( - self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None + self, + origin: str, + last_visit: Optional[int] = None, + limit: Optional[int] = None, + order: str = "asc", ) -> Iterable[Dict[str, Any]]: - rows = self._cql_runner.origin_visit_get(origin, last_visit, limit) + rows = self._cql_runner.origin_visit_get(origin, last_visit, limit, order) + for row in rows: visit = self._format_origin_visit_row(row) yield self._origin_visit_apply_last_status(visit) diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -581,7 +581,9 @@ row = cur.fetchone() return self._make_origin_visit_status(row) - def origin_visit_get_all(self, origin_id, last_visit=None, limit=None, cur=None): + def origin_visit_get_all( + self, origin_id, last_visit=None, order="asc", limit=None, cur=None + ): """Retrieve all visits for origin with id origin_id. Args: @@ -592,29 +594,31 @@ """ cur = self._cursor(cur) + assert order.lower() in ["asc", "desc"] + + query_parts = [ + "SELECT DISTINCT ON (ov.visit) %s " + % ", ".join(self.origin_visit_select_cols), + "FROM origin_visit ov", + "INNER JOIN origin o ON o.id = ov.origin", + "INNER JOIN origin_visit_status ovs", + "ON ov.origin = ovs.origin AND ov.visit = ovs.visit", + ] + query_parts.append("WHERE o.url = %s") + query_params: List[Any] = [origin_id] if last_visit: - extra_condition = "and ov.visit > %s" - args = (origin_id, last_visit, limit) - else: - extra_condition = "" - args = (origin_id, limit) + query_parts.append("and ov.visit > %s") + query_params.append(last_visit) - query = """\ - SELECT DISTINCT ON (ov.visit) %s - FROM origin_visit ov - INNER JOIN origin o ON o.id = ov.origin - INNER JOIN origin_visit_status ovs - ON ov.origin = ovs.origin AND ov.visit = ovs.visit - WHERE o.url=%%s %s - ORDER BY ov.visit ASC, ovs.date DESC - LIMIT %%s""" % ( - ", ".join(self.origin_visit_select_cols), - extra_condition, - ) + query_parts.append("ORDER BY ov.visit " + order + ", ovs.date DESC") - cur.execute(query, args) + if limit is not None: + query_parts.append("LIMIT %s") + query_params.append(limit) + query = "\n".join(query_parts) + cur.execute(query, tuple(query_params)) yield from cur def origin_visit_get(self, origin_id, visit_id, cur=None): diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -865,12 +865,19 @@ ) def origin_visit_get( - self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None + self, + origin: str, + last_visit: Optional[int] = None, + limit: Optional[int] = None, + order: str = "asc", ) -> Iterable[Dict[str, Any]]: - + order = order.lower() + assert order in ["asc", "desc"] origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] + if order == "desc": # only enforce sort in desc case, already ok otherwise + visits = sorted(visits, key=lambda v: v.visit, reverse=True) if last_visit is not None: visits = visits[last_visit:] if limit is not None: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -781,7 +781,11 @@ @remote_api_endpoint("origin/visit/get") def origin_visit_get( - self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None + self, + origin: str, + last_visit: Optional[int] = None, + limit: Optional[int] = None, + order: str = "asc", ) -> Iterable[Dict[str, Any]]: """Retrieve all the origin's visit's information. @@ -791,6 +795,8 @@ Default to None limit: Number of results to return from the last visit. Default to None + order: Default order on visit id fields to list origin visits (default to + asc) Yields: List of visits. diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -925,11 +925,13 @@ origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None, + order: str = "asc", db=None, cur=None, ) -> Iterable[Dict[str, Any]]: + assert order in ["asc", "desc"] lines = db.origin_visit_get_all( - origin, last_visit=last_visit, limit=limit, cur=cur + origin, last_visit=last_visit, limit=limit, order=order, cur=cur ) for line in lines: visit = dict(zip(db.origin_visit_get_cols, line)) diff --git a/swh/storage/tests/algos/test_origin.py b/swh/storage/tests/algos/test_origin.py --- a/swh/storage/tests/algos/test_origin.py +++ b/swh/storage/tests/algos/test_origin.py @@ -333,3 +333,11 @@ assert actual_ov2.visit == ov2.visit assert actual_ov2.type == ov2.type assert actual_ovs22 == ovs22 + + ov2, ovs22 = origin_get_latest_visit_status( + swh_storage, origin2.url, require_snapshot=True, limit=1 + ) + assert actual_ov2.origin == ov2.origin + assert actual_ov2.visit == ov2.visit + assert actual_ov2.type == ov2.type + assert actual_ovs22 == ovs22 diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1477,6 +1477,18 @@ assert all_visits3 == expected_visits3 + all_visits3 = list(swh_storage.origin_visit_get(origin.url, order="desc")) + expected_visits3 = [v.to_dict() for v in [ov3, ov2, ov1]] + + assert all_visits3 == expected_visits3 + + all_visits4 = list( + swh_storage.origin_visit_get(origin.url, order="desc", limit=2) + ) + expected_visits4 = [v.to_dict() for v in [ov3, ov2]] + + assert all_visits4 == expected_visits4 + def test_origin_visit_get__unknown_origin(self, swh_storage): assert [] == list(swh_storage.origin_visit_get("foo"))