Page MenuHomeSoftware Heritage

D3359.id11926.diff
No OneTemporary

D3359.id11926.diff

diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py
--- a/swh/storage/cassandra/cql.py
+++ b/swh/storage/cassandra/cql.py
@@ -657,30 +657,103 @@
"date",
]
- @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit > ?")
- def _origin_visit_get_no_limit(
+ @_prepared_statement(
+ "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? "
+ "ORDER BY visit ASC"
+ )
+ def _origin_visit_get_pagination_asc_no_limit(
self, origin_url: str, last_visit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, last_visit])
@_prepared_statement(
- "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? LIMIT ?"
+ "SELECT * FROM origin_visit WHERE origin = ? AND visit > ? "
+ "ORDER BY visit ASC "
+ "LIMIT ?"
)
- def _origin_visit_get_limit(
+ def _origin_visit_get_pagination_asc_limit(
self, origin_url: str, last_visit: int, limit: int, *, statement
) -> ResultSet:
return self._execute_with_retries(statement, [origin_url, last_visit, limit])
+ @_prepared_statement(
+ "SELECT * FROM origin_visit WHERE origin = ? AND visit < ? "
+ "ORDER BY visit DESC"
+ )
+ def _origin_visit_get_pagination_desc_no_limit(
+ self, origin_url: str, last_visit: int, *, statement
+ ) -> ResultSet:
+ return self._execute_with_retries(statement, [origin_url, last_visit])
+
+ @_prepared_statement(
+ "SELECT * FROM origin_visit WHERE origin = ? AND visit < ? "
+ "ORDER BY visit DESC "
+ "LIMIT ?"
+ )
+ def _origin_visit_get_pagination_desc_limit(
+ self, origin_url: str, last_visit: int, limit: int, *, statement
+ ) -> ResultSet:
+ return self._execute_with_retries(statement, [origin_url, last_visit, limit])
+
+ @_prepared_statement(
+ "SELECT * FROM origin_visit WHERE origin = ? " "ORDER BY visit ASC " "LIMIT ?"
+ )
+ def _origin_visit_get_no_pagination_asc_limit(
+ self, origin_url: str, limit: int, *, statement
+ ) -> ResultSet:
+ return self._execute_with_retries(statement, [origin_url, limit])
+
+ @_prepared_statement(
+ "SELECT * FROM origin_visit WHERE origin = ? " "ORDER BY visit ASC "
+ )
+ def _origin_visit_get_no_pagination_asc_no_limit(
+ self, origin_url: str, *, statement
+ ) -> ResultSet:
+ return self._execute_with_retries(statement, [origin_url])
+
+ @_prepared_statement(
+ "SELECT * FROM origin_visit WHERE origin = ? " "ORDER BY visit DESC"
+ )
+ def _origin_visit_get_no_pagination_desc_no_limit(
+ self, origin_url: str, *, statement
+ ) -> ResultSet:
+ return self._execute_with_retries(statement, [origin_url])
+
+ @_prepared_statement(
+ "SELECT * FROM origin_visit WHERE origin = ? " "ORDER BY visit DESC " "LIMIT ?"
+ )
+ def _origin_visit_get_no_pagination_desc_limit(
+ self, origin_url: str, limit: int, *, statement
+ ) -> ResultSet:
+ return self._execute_with_retries(statement, [origin_url, limit])
+
def origin_visit_get(
- self, origin_url: str, last_visit: Optional[int], limit: Optional[int]
+ self,
+ origin_url: str,
+ last_visit: Optional[int],
+ limit: Optional[int],
+ order: str = "asc",
) -> ResultSet:
- if last_visit is None:
- last_visit = -1
+ order = order.lower()
+ assert order in ["asc", "desc"]
- if limit is None:
- return self._origin_visit_get_no_limit(origin_url, last_visit)
+ args: List[Any] = [origin_url]
+
+ if last_visit is not None:
+ page_name = "pagination"
+ args.append(last_visit)
+ else:
+ page_name = "no_pagination"
+
+ if limit is not None:
+ limit_name = "limit"
+ args.append(limit)
else:
- return self._origin_visit_get_limit(origin_url, last_visit, limit)
+ limit_name = "no_limit"
+
+ method_name = f"_origin_visit_get_{page_name}_{order}_{limit_name}"
+ origin_visit_get_method = getattr(self, method_name)
+ return origin_visit_get_method(*args)
@_prepared_insert_statement("origin_visit", _origin_visit_keys)
def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None:
diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -888,9 +888,14 @@
}
def origin_visit_get(
- self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None
+ self,
+ origin: str,
+ last_visit: Optional[int] = None,
+ limit: Optional[int] = None,
+ order: str = "asc",
) -> Iterable[Dict[str, Any]]:
- rows = self._cql_runner.origin_visit_get(origin, last_visit, limit)
+ rows = self._cql_runner.origin_visit_get(origin, last_visit, limit, order)
+
for row in rows:
visit = self._format_origin_visit_row(row)
yield self._origin_visit_apply_last_status(visit)
diff --git a/swh/storage/db.py b/swh/storage/db.py
--- a/swh/storage/db.py
+++ b/swh/storage/db.py
@@ -578,7 +578,9 @@
row = cur.fetchone()
return self._make_origin_visit_status(row)
- def origin_visit_get_all(self, origin_id, last_visit=None, limit=None, cur=None):
+ def origin_visit_get_all(
+ self, origin_id, last_visit=None, order="asc", limit=None, cur=None
+ ):
"""Retrieve all visits for origin with id origin_id.
Args:
@@ -589,29 +591,32 @@
"""
cur = self._cursor(cur)
+ assert order.lower() in ["asc", "desc"]
- if last_visit:
- extra_condition = "and ov.visit > %s"
- args = (origin_id, last_visit, limit)
- else:
- extra_condition = ""
- args = (origin_id, limit)
+ query_parts = [
+ "SELECT DISTINCT ON (ov.visit) %s "
+ % ", ".join(self.origin_visit_select_cols),
+ "FROM origin_visit ov",
+ "INNER JOIN origin o ON o.id = ov.origin",
+ "INNER JOIN origin_visit_status ovs",
+ "ON ov.origin = ovs.origin AND ov.visit = ovs.visit",
+ ]
+ query_parts.append("WHERE o.url = %s")
+ query_params: List[Any] = [origin_id]
- query = """\
- SELECT DISTINCT ON (ov.visit) %s
- FROM origin_visit ov
- INNER JOIN origin o ON o.id = ov.origin
- INNER JOIN origin_visit_status ovs
- ON ov.origin = ovs.origin AND ov.visit = ovs.visit
- WHERE o.url=%%s %s
- ORDER BY ov.visit ASC, ovs.date DESC
- LIMIT %%s""" % (
- ", ".join(self.origin_visit_select_cols),
- extra_condition,
- )
+ if last_visit is not None:
+ op_comparison = ">" if order == "asc" else "<"
+ query_parts.append(f"and ov.visit {op_comparison} %s")
+ query_params.append(last_visit)
+
+ query_parts.append(f"ORDER BY ov.visit {order}, ovs.date DESC")
- cur.execute(query, args)
+ if limit is not None:
+ query_parts.append("LIMIT %s")
+ query_params.append(limit)
+ query = "\n".join(query_parts)
+ cur.execute(query, tuple(query_params))
yield from cur
def origin_visit_get(self, origin_id, visit_id, cur=None):
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -865,14 +865,23 @@
)
def origin_visit_get(
- self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None
+ self,
+ origin: str,
+ last_visit: Optional[int] = None,
+ limit: Optional[int] = None,
+ order: str = "asc",
) -> Iterable[Dict[str, Any]]:
-
+ order = order.lower()
+ assert order in ["asc", "desc"]
origin_url = self._get_origin_url(origin)
if origin_url in self._origin_visits:
visits = self._origin_visits[origin_url]
+ visits = sorted(visits, key=lambda v: v.visit, reverse=(order == "desc"))
if last_visit is not None:
- visits = visits[last_visit:]
+ if order == "asc":
+ visits = [v for v in visits if v.visit > last_visit]
+ else:
+ visits = [v for v in visits if v.visit < last_visit]
if limit is not None:
visits = visits[:limit]
for visit in visits:
diff --git a/swh/storage/interface.py b/swh/storage/interface.py
--- a/swh/storage/interface.py
+++ b/swh/storage/interface.py
@@ -781,7 +781,11 @@
@remote_api_endpoint("origin/visit/get")
def origin_visit_get(
- self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None
+ self,
+ origin: str,
+ last_visit: Optional[int] = None,
+ limit: Optional[int] = None,
+ order: str = "asc",
) -> Iterable[Dict[str, Any]]:
"""Retrieve all the origin's visit's information.
@@ -791,6 +795,7 @@
Default to None
limit: Number of results to return from the last visit.
Default to None
+ order: Order on visit id fields to list origin visits (default to asc)
Yields:
List of visits.
diff --git a/swh/storage/storage.py b/swh/storage/storage.py
--- a/swh/storage/storage.py
+++ b/swh/storage/storage.py
@@ -929,11 +929,13 @@
origin: str,
last_visit: Optional[int] = None,
limit: Optional[int] = None,
+ order: str = "asc",
db=None,
cur=None,
) -> Iterable[Dict[str, Any]]:
+ assert order in ["asc", "desc"]
lines = db.origin_visit_get_all(
- origin, last_visit=last_visit, limit=limit, cur=cur
+ origin, last_visit=last_visit, limit=limit, order=order, cur=cur
)
for line in lines:
visit = dict(zip(db.origin_visit_get_cols, line))
diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py
--- a/swh/storage/tests/test_storage.py
+++ b/swh/storage/tests/test_storage.py
@@ -1451,16 +1451,51 @@
for v in visits
]
+ # order asc, no pagination, no limit
all_visits = list(swh_storage.origin_visit_get(origin.url))
assert all_visits == [ov1, ov2, ov3]
+ # order asc, no pagination, limit
all_visits2 = list(swh_storage.origin_visit_get(origin.url, limit=2))
assert all_visits2 == [ov1, ov2]
+ # order asc, pagination, no limit
all_visits3 = list(
- swh_storage.origin_visit_get(origin.url, last_visit=ov1["visit"], limit=1)
+ swh_storage.origin_visit_get(origin.url, last_visit=ov1["visit"])
)
- assert all_visits3 == [ov2]
+ assert all_visits3 == [ov2, ov3]
+
+ # order asc, pagination, limit
+ all_visits4 = list(
+ swh_storage.origin_visit_get(origin.url, last_visit=ov2["visit"], limit=1)
+ )
+ assert all_visits4 == [ov3]
+
+ # order desc, no pagination, no limit
+ all_visits5 = list(swh_storage.origin_visit_get(origin.url, order="desc"))
+ assert all_visits5 == [ov3, ov2, ov1]
+
+ # order desc, no pagination, limit
+ all_visits6 = list(
+ swh_storage.origin_visit_get(origin.url, limit=2, order="desc")
+ )
+ assert all_visits6 == [ov3, ov2]
+
+ # order desc, pagination, no limit
+ all_visits7 = list(
+ swh_storage.origin_visit_get(
+ origin.url, last_visit=ov3["visit"], order="desc"
+ )
+ )
+ assert all_visits7 == [ov2, ov1]
+
+ # order desc, pagination, limit
+ all_visits8 = list(
+ swh_storage.origin_visit_get(
+ origin.url, last_visit=ov3["visit"], order="desc", limit=1
+ )
+ )
+ assert all_visits8 == [ov2]
def test_origin_visit_get__unknown_origin(self, swh_storage):
assert [] == list(swh_storage.origin_visit_get("foo"))

File Metadata

Mime Type
text/plain
Expires
Sun, Aug 24, 4:53 PM (1 w, 2 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3221552

Event Timeline