diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -657,30 +657,103 @@ "date", ] - @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit > ?") - def _origin_visit_get_no_limit( + @_prepared_statement( + "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " + "ORDER BY visit ASC" + ) + def _origin_visit_get_pagination_asc_no_limit( self, origin_url: str, last_visit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit]) @_prepared_statement( - "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? LIMIT ?" + "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? " + "ORDER BY visit ASC " + "LIMIT ?" ) - def _origin_visit_get_limit( + def _origin_visit_get_pagination_asc_limit( self, origin_url: str, last_visit: int, limit: int, *, statement ) -> ResultSet: return self._execute_with_retries(statement, [origin_url, last_visit, limit]) + @_prepared_statement( + "SELECT * FROM origin_visit WHERE origin = ? AND visit < ? " + "ORDER BY visit DESC" + ) + def _origin_visit_get_pagination_desc_no_limit( + self, origin_url: str, last_visit: int, *, statement + ) -> ResultSet: + return self._execute_with_retries(statement, [origin_url, last_visit]) + + @_prepared_statement( + "SELECT * FROM origin_visit WHERE origin = ? AND visit < ? " + "ORDER BY visit DESC " + "LIMIT ?" + ) + def _origin_visit_get_pagination_desc_limit( + self, origin_url: str, last_visit: int, limit: int, *, statement + ) -> ResultSet: + return self._execute_with_retries(statement, [origin_url, last_visit, limit]) + + @_prepared_statement( + "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit ASC LIMIT ?" + ) + def _origin_visit_get_no_pagination_asc_limit( + self, origin_url: str, limit: int, *, statement + ) -> ResultSet: + return self._execute_with_retries(statement, [origin_url, limit]) + + @_prepared_statement( + "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit ASC " + ) + def _origin_visit_get_no_pagination_asc_no_limit( + self, origin_url: str, *, statement + ) -> ResultSet: + return self._execute_with_retries(statement, [origin_url]) + + @_prepared_statement( + "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit DESC" + ) + def _origin_visit_get_no_pagination_desc_no_limit( + self, origin_url: str, *, statement + ) -> ResultSet: + return self._execute_with_retries(statement, [origin_url]) + + @_prepared_statement( + "SELECT * FROM origin_visit WHERE origin = ? ORDER BY visit DESC LIMIT ?" + ) + def _origin_visit_get_no_pagination_desc_limit( + self, origin_url: str, limit: int, *, statement + ) -> ResultSet: + return self._execute_with_retries(statement, [origin_url, limit]) + def origin_visit_get( - self, origin_url: str, last_visit: Optional[int], limit: Optional[int] + self, + origin_url: str, + last_visit: Optional[int], + limit: Optional[int], + order: str = "asc", ) -> ResultSet: - if last_visit is None: - last_visit = -1 + order = order.lower() + assert order in ["asc", "desc"] - if limit is None: - return self._origin_visit_get_no_limit(origin_url, last_visit) + args: List[Any] = [origin_url] + + if last_visit is not None: + page_name = "pagination" + args.append(last_visit) + else: + page_name = "no_pagination" + + if limit is not None: + limit_name = "limit" + args.append(limit) else: - return self._origin_visit_get_limit(origin_url, last_visit, limit) + limit_name = "no_limit" + + method_name = f"_origin_visit_get_{page_name}_{order}_{limit_name}" + origin_visit_get_method = getattr(self, method_name) + return origin_visit_get_method(*args) @_prepared_insert_statement("origin_visit", _origin_visit_keys) def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None: diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -888,9 +888,14 @@ } def origin_visit_get( - self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None + self, + origin: str, + last_visit: Optional[int] = None, + limit: Optional[int] = None, + order: str = "asc", ) -> Iterable[Dict[str, Any]]: - rows = self._cql_runner.origin_visit_get(origin, last_visit, limit) + rows = self._cql_runner.origin_visit_get(origin, last_visit, limit, order) + for row in rows: visit = self._format_origin_visit_row(row) yield self._origin_visit_apply_last_status(visit) diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -578,7 +578,9 @@ row = cur.fetchone() return self._make_origin_visit_status(row) - def origin_visit_get_all(self, origin_id, last_visit=None, limit=None, cur=None): + def origin_visit_get_all( + self, origin_id, last_visit=None, order="asc", limit=None, cur=None + ): """Retrieve all visits for origin with id origin_id. Args: @@ -589,29 +591,37 @@ """ cur = self._cursor(cur) + assert order.lower() in ["asc", "desc"] - if last_visit: - extra_condition = "and ov.visit > %s" - args = (origin_id, last_visit, limit) - else: - extra_condition = "" - args = (origin_id, limit) + query_parts = [ + "SELECT DISTINCT ON (ov.visit) %s " + % ", ".join(self.origin_visit_select_cols), + "FROM origin_visit ov", + "INNER JOIN origin o ON o.id = ov.origin", + "INNER JOIN origin_visit_status ovs", + "ON ov.origin = ovs.origin AND ov.visit = ovs.visit", + ] + query_parts.append("WHERE o.url = %s") + query_params: List[Any] = [origin_id] - query = """\ - SELECT DISTINCT ON (ov.visit) %s - FROM origin_visit ov - INNER JOIN origin o ON o.id = ov.origin - INNER JOIN origin_visit_status ovs - ON ov.origin = ovs.origin AND ov.visit = ovs.visit - WHERE o.url=%%s %s - ORDER BY ov.visit ASC, ovs.date DESC - LIMIT %%s""" % ( - ", ".join(self.origin_visit_select_cols), - extra_condition, - ) + if last_visit is not None: + op_comparison = ">" if order == "asc" else "<" + query_parts.append(f"and ov.visit {op_comparison} %s") + query_params.append(last_visit) - cur.execute(query, args) + if order == "asc": + query_parts.append("ORDER BY ov.visit ASC, ovs.date DESC") + elif order == "desc": + query_parts.append("ORDER BY ov.visit DESC, ovs.date DESC") + else: + assert False + if limit is not None: + query_parts.append("LIMIT %s") + query_params.append(limit) + + query = "\n".join(query_parts) + cur.execute(query, tuple(query_params)) yield from cur def origin_visit_get(self, origin_id, visit_id, cur=None): diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -865,14 +865,23 @@ ) def origin_visit_get( - self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None + self, + origin: str, + last_visit: Optional[int] = None, + limit: Optional[int] = None, + order: str = "asc", ) -> Iterable[Dict[str, Any]]: - + order = order.lower() + assert order in ["asc", "desc"] origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] + visits = sorted(visits, key=lambda v: v.visit, reverse=(order == "desc")) if last_visit is not None: - visits = visits[last_visit:] + if order == "asc": + visits = [v for v in visits if v.visit > last_visit] + else: + visits = [v for v in visits if v.visit < last_visit] if limit is not None: visits = visits[:limit] for visit in visits: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -781,7 +781,11 @@ @remote_api_endpoint("origin/visit/get") def origin_visit_get( - self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None + self, + origin: str, + last_visit: Optional[int] = None, + limit: Optional[int] = None, + order: str = "asc", ) -> Iterable[Dict[str, Any]]: """Retrieve all the origin's visit's information. @@ -791,6 +795,7 @@ Default to None limit: Number of results to return from the last visit. Default to None + order: Order on visit id fields to list origin visits (default to asc) Yields: List of visits. diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -929,11 +929,13 @@ origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None, + order: str = "asc", db=None, cur=None, ) -> Iterable[Dict[str, Any]]: + assert order in ["asc", "desc"] lines = db.origin_visit_get_all( - origin, last_visit=last_visit, limit=limit, cur=cur + origin, last_visit=last_visit, limit=limit, order=order, cur=cur ) for line in lines: visit = dict(zip(db.origin_visit_get_cols, line)) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1451,16 +1451,51 @@ for v in visits ] + # order asc, no pagination, no limit all_visits = list(swh_storage.origin_visit_get(origin.url)) assert all_visits == [ov1, ov2, ov3] + # order asc, no pagination, limit all_visits2 = list(swh_storage.origin_visit_get(origin.url, limit=2)) assert all_visits2 == [ov1, ov2] + # order asc, pagination, no limit all_visits3 = list( - swh_storage.origin_visit_get(origin.url, last_visit=ov1["visit"], limit=1) + swh_storage.origin_visit_get(origin.url, last_visit=ov1["visit"]) ) - assert all_visits3 == [ov2] + assert all_visits3 == [ov2, ov3] + + # order asc, pagination, limit + all_visits4 = list( + swh_storage.origin_visit_get(origin.url, last_visit=ov2["visit"], limit=1) + ) + assert all_visits4 == [ov3] + + # order desc, no pagination, no limit + all_visits5 = list(swh_storage.origin_visit_get(origin.url, order="desc")) + assert all_visits5 == [ov3, ov2, ov1] + + # order desc, no pagination, limit + all_visits6 = list( + swh_storage.origin_visit_get(origin.url, limit=2, order="desc") + ) + assert all_visits6 == [ov3, ov2] + + # order desc, pagination, no limit + all_visits7 = list( + swh_storage.origin_visit_get( + origin.url, last_visit=ov3["visit"], order="desc" + ) + ) + assert all_visits7 == [ov2, ov1] + + # order desc, pagination, limit + all_visits8 = list( + swh_storage.origin_visit_get( + origin.url, last_visit=ov3["visit"], order="desc", limit=1 + ) + ) + assert all_visits8 == [ov2] def test_origin_visit_get__unknown_origin(self, swh_storage): assert [] == list(swh_storage.origin_visit_get("foo"))