diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -657,30 +657,67 @@ "date", ] - @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit > ?") - def _origin_visit_get_no_limit( + @_prepared_statement( + "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " + "ORDER BY visit ASC" + ) + def _origin_visit_get_asc_no_limit( self, origin_url: str, last_visit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit]) @_prepared_statement( - "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? LIMIT ?" + "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " + "ORDER BY visit ASC " + "LIMIT ?" ) - def _origin_visit_get_limit( + def _origin_visit_get_asc_limit( + self, origin_url: str, last_visit: int, limit: int, *, statement + ) -> ResultSet: + return self._execute_with_retries(statement, [origin_url, last_visit, limit]) + + @_prepared_statement( + "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " + "ORDER BY visit DESC" + ) + def _origin_visit_get_desc_no_limit( + self, origin_url: str, last_visit: int, *, statement + ) -> ResultSet: + return self._execute_with_retries(statement, [origin_url, last_visit]) + + @_prepared_statement( + "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " + "ORDER BY visit DESC " + "LIMIT ?" + ) + def _origin_visit_get_desc_limit( self, origin_url: str, last_visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit, limit]) def origin_visit_get( - self, origin_url: str, last_visit: Optional[int], limit: Optional[int] + self, + origin_url: str, + last_visit: Optional[int], + limit: Optional[int], + order: str = "asc", ) -> ResultSet: + order = order.lower() + assert order in ["asc", "desc"] + if last_visit is None: last_visit = -1 if limit is None: - return self._origin_visit_get_no_limit(origin_url, last_visit) + method_name = f"_origin_visit_get_{order}_no_limit" + args = [origin_url, last_visit] else: - return self._origin_visit_get_limit(origin_url, last_visit, limit) + method_name = f"_origin_visit_get_{order}_limit" + args = [origin_url, last_visit, limit] + + method = getattr(self, method_name) + rows = method(*args) + return rows @_prepared_insert_statement("origin_visit", _origin_visit_keys) def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None: diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -888,9 +888,14 @@ } def origin_visit_get( - self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None + self, + origin: str, + last_visit: Optional[int] = None, + limit: Optional[int] = None, + order: str = "asc", ) -> Iterable[Dict[str, Any]]: - rows = self._cql_runner.origin_visit_get(origin, last_visit, limit) + rows = self._cql_runner.origin_visit_get(origin, last_visit, limit, order) + for row in rows: visit = self._format_origin_visit_row(row) yield self._origin_visit_apply_last_status(visit) diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -578,7 +578,9 @@ row = cur.fetchone() return self._make_origin_visit_status(row) - def origin_visit_get_all(self, origin_id, last_visit=None, limit=None, cur=None): + def origin_visit_get_all( + self, origin_id, last_visit=None, order="asc", limit=None, cur=None + ): """Retrieve all visits for origin with id origin_id. Args: @@ -589,29 +591,31 @@ """ cur = self._cursor(cur) + assert order.lower() in ["asc", "desc"] + + query_parts = [ + "SELECT DISTINCT ON (ov.visit) %s " + % ", ".join(self.origin_visit_select_cols), + "FROM origin_visit ov", + "INNER JOIN origin o ON o.id = ov.origin", + "INNER JOIN origin_visit_status ovs", + "ON ov.origin = ovs.origin AND ov.visit = ovs.visit", + ] + query_parts.append("WHERE o.url = %s") + query_params: List[Any] = [origin_id] if last_visit: - extra_condition = "and ov.visit > %s" - args = (origin_id, last_visit, limit) - else: - extra_condition = "" - args = (origin_id, limit) + query_parts.append("and ov.visit > %s") + query_params.append(last_visit) - query = """\ - SELECT DISTINCT ON (ov.visit) %s - FROM origin_visit ov - INNER JOIN origin o ON o.id = ov.origin - INNER JOIN origin_visit_status ovs - ON ov.origin = ovs.origin AND ov.visit = ovs.visit - WHERE o.url=%%s %s - ORDER BY ov.visit ASC, ovs.date DESC - LIMIT %%s""" % ( - ", ".join(self.origin_visit_select_cols), - extra_condition, - ) + query_parts.append("ORDER BY ov.visit " + order + ", ovs.date DESC") - cur.execute(query, args) + if limit is not None: + query_parts.append("LIMIT %s") + query_params.append(limit) + query = "\n".join(query_parts) + cur.execute(query, tuple(query_params)) yield from cur def origin_visit_get(self, origin_id, visit_id, cur=None): diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -865,12 +865,19 @@ ) def origin_visit_get( - self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None + self, + origin: str, + last_visit: Optional[int] = None, + limit: Optional[int] = None, + order: str = "asc", ) -> Iterable[Dict[str, Any]]: - + order = order.lower() + assert order in ["asc", "desc"] origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] + if order == "desc": # only enforce sort in desc case, already ok otherwise + visits = sorted(visits, key=lambda v: v.visit, reverse=True) if last_visit is not None: visits = visits[last_visit:] if limit is not None: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -781,7 +781,11 @@ @remote_api_endpoint("origin/visit/get") def origin_visit_get( - self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None + self, + origin: str, + last_visit: Optional[int] = None, + limit: Optional[int] = None, + order: str = "asc", ) -> Iterable[Dict[str, Any]]: """Retrieve all the origin's visit's information. @@ -791,6 +795,8 @@ Default to None limit: Number of results to return from the last visit. Default to None + order: Default order on visit id fields to list origin visits (default to + asc) Yields: List of visits. diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -929,11 +929,13 @@ origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None, + order: str = "asc", db=None, cur=None, ) -> Iterable[Dict[str, Any]]: + assert order in ["asc", "desc"] lines = db.origin_visit_get_all( - origin, last_visit=last_visit, limit=limit, cur=cur + origin, last_visit=last_visit, limit=limit, order=order, cur=cur ) for line in lines: visit = dict(zip(db.origin_visit_get_cols, line)) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1462,6 +1462,14 @@ ) assert all_visits3 == [ov2] + all_visits3 = list(swh_storage.origin_visit_get(origin.url, order="desc")) + assert all_visits3 == [ov3, ov2, ov1] + + all_visits4 = list( + swh_storage.origin_visit_get(origin.url, order="desc", limit=2) + ) + assert all_visits4 == [ov3, ov2] + def test_origin_visit_get__unknown_origin(self, swh_storage): assert [] == list(swh_storage.origin_visit_get("foo"))