diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py --- a/swh/storage/cassandra/converters.py +++ b/swh/storage/cassandra/converters.py @@ -3,16 +3,20 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from copy import deepcopy +import datetime import json +import attr + +from copy import deepcopy from typing import Any, Dict, Tuple -import attr +from cassandra.cluster import ResultSet from swh.model.model import ( - RevisionType, ObjectType, + OriginVisitStatus, Revision, + RevisionType, Release, Sha1Git, ) @@ -71,3 +75,18 @@ for algo in DEFAULT_ALGORITHMS: hashes[algo] = getattr(row, algo) return hashes + + +def row_to_visit_status(row: ResultSet) -> OriginVisitStatus: + """Format a row representing a visit_status to an actual dict representing an + OriginVisitStatus. + + """ + return OriginVisitStatus.from_dict( + { + **row._asdict(), + "origin": row.origin, + "date": row.date.replace(tzinfo=datetime.timezone.utc), + "metadata": (json.loads(row.metadata) if row.metadata else None), + } + ) diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -3,7 +3,6 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import datetime import functools import json import logging @@ -711,38 +710,31 @@ statement, [getattr(visit_update, key) for key in keys[:-1]] + [metadata] ) - def _format_origin_visit_status_row( - self, visit_status: ResultSet - ) -> Dict[str, Any]: - """Format a row visit_status into an origin_visit_status dict + def origin_visit_status_get_latest(self, origin: str, visit: int,) -> Optional[Row]: + """Given an origin visit id, return its latest origin_visit_status - """ - return { - **visit_status._asdict(), - "origin": visit_status.origin, - "date": visit_status.date.replace(tzinfo=datetime.timezone.utc), - "metadata": ( - json.loads(visit_status.metadata) if visit_status.metadata else None - ), - } + """ + rows = self.origin_visit_status_get(origin, visit) + return rows[0] if rows else None @_prepared_statement( "SELECT * FROM origin_visit_status " "WHERE origin = ? AND visit = ? " - "ORDER BY date DESC " - "LIMIT 1" + "ORDER BY date DESC" ) - def origin_visit_status_get_latest( - self, origin: str, visit: int, *, statement - ) -> Optional[Dict[str, Any]]: - """Given an origin visit id, return its latest origin_visit_status + def origin_visit_status_get( + self, + origin: str, + visit: int, + allowed_statuses: Optional[List[str]] = None, + require_snapshot: bool = False, + *, + statement, + ) -> List[Row]: + """Return all origin visit statuses for a given visit """ - rows = list(self._execute_with_retries(statement, [origin, visit])) - if rows: - return self._format_origin_visit_status_row(rows[0]) - else: - return None + return list(self._execute_with_retries(statement, [origin, visit])) @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ? AND visit = ?") def origin_visit_get_one( diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -37,6 +37,7 @@ revision_from_db, release_to_db, release_from_db, + row_to_visit_status, ) from .cql import CqlRunner from .schema import HASH_ALGORITHMS @@ -838,7 +839,7 @@ self._origin_visit_status_add(visit_status) def _origin_visit_merge( - self, visit: Dict[str, Any], visit_status: Dict[str, Any] + self, visit: Dict[str, Any], visit_status: OriginVisitStatus, ) -> Dict[str, Any]: """Merge origin_visit and visit_status together. @@ -848,7 +849,7 @@ # default to the values in visit **visit, # override with the last update - **visit_status, + **visit_status.to_dict(), # visit['origin'] is the URL (via a join), while # visit_status['origin'] is only an id. "origin": visit["origin"], @@ -862,11 +863,11 @@ Then merge it with the visit and return it. """ - visit_status = self._cql_runner.origin_visit_status_get_latest( + row = self._cql_runner.origin_visit_status_get_latest( visit["origin"], visit["visit"] ) - assert visit_status is not None - return self._origin_visit_merge(visit, visit_status) + assert row is not None + return self._origin_visit_merge(visit, row_to_visit_status(row)) def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]: """Retrieve origin visit and latest origin visit status and merge them @@ -948,6 +949,25 @@ return latest_visit + def origin_visit_status_get_latest( + self, + origin_url: str, + visit: int, + allowed_statuses: Optional[List[str]] = None, + require_snapshot: bool = False, + ) -> Optional[OriginVisitStatus]: + rows = self._cql_runner.origin_visit_status_get( + origin_url, visit, allowed_statuses, require_snapshot + ) + # filtering is done python side as we cannot do it server side + if allowed_statuses: + rows = [row for row in rows if row.status in allowed_statuses] + if require_snapshot: + rows = [row for row in rows if row.snapshot is not None] + if not rows: + return None + return row_to_visit_status(rows[0]) + def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -525,7 +525,18 @@ "ovs.snapshot", ] - def _make_origin_visit_status(self, row: Tuple[Any]) -> Optional[Dict[str, Any]]: + origin_visit_status_select_cols = [ + "o.url AS origin", + "ovs.visit", + "ovs.date", + "ovs.status", + "ovs.snapshot", + "ovs.metadata", + ] + + def _make_origin_visit_status( + self, row: Optional[Tuple[Any]] + ) -> Optional[Dict[str, Any]]: """Make an origin_visit_status dict out of a row """ @@ -534,21 +545,39 @@ return dict(zip(self.origin_visit_status_cols, row)) def origin_visit_status_get_latest( - self, origin: str, visit: int, cur=None + self, + origin_url: str, + visit: int, + allowed_statuses: Optional[List[str]] = None, + require_snapshot: bool = False, + cur=None, ) -> Optional[Dict[str, Any]]: """Given an origin visit id, return its latest origin_visit_status """ - cols = self.origin_visit_status_cols cur = self._cursor(cur) - cur.execute( - f"SELECT {', '.join(cols)} " - f"FROM origin_visit_status ovs " - f"INNER JOIN origin o on o.id=ovs.origin " - f"WHERE o.url=%s AND ovs.visit=%s" - f"ORDER BY ovs.date DESC LIMIT 1", - (origin, visit), - ) + + query_parts = [ + "SELECT %s" % ", ".join(self.origin_visit_status_select_cols), + "FROM origin_visit_status ovs ", + "INNER JOIN origin o ON o.id = ovs.origin", + ] + query_parts.append("WHERE o.url = %s") + query_params: List[Any] = [origin_url] + query_parts.append("AND ovs.visit = %s") + query_params.append(visit) + + if require_snapshot: + query_parts.append("AND ovs.snapshot is not null") + + if allowed_statuses: + query_parts.append("AND ovs.status IN %s") + query_params.append(tuple(allowed_statuses)) + + query_parts.append("ORDER BY ovs.date DESC LIMIT 1") + query = "\n".join(query_parts) + + cur.execute(query, tuple(query_params)) row = cur.fetchone() return self._make_origin_visit_status(row) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -942,6 +942,30 @@ return None return visit.to_dict() + def origin_visit_status_get_latest( + self, + origin_url: str, + visit: int, + allowed_statuses: Optional[List[str]] = None, + require_snapshot: bool = False, + ) -> Optional[OriginVisitStatus]: + ori = self._origins.get(origin_url) + if not ori: + return None + + visit_key = (origin_url, visit) + visits = self._origin_visit_statuses.get(visit_key) + if not visits: + return None + + if allowed_statuses is not None: + visits = [visit for visit in visits if visit.status in allowed_statuses] + if require_snapshot: + visits = [visit for visit in visits if visit.snapshot] + + visit_status = max(visits, key=lambda v: (v.date, v.visit), default=None) + return visit_status + def _select_random_origin_visit_by_type(self, type: str) -> str: while True: url = random.choice(list(self._origin_visits.keys())) diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -871,7 +871,7 @@ ) -> Optional[Dict[str, Any]]: """Get the latest origin visit for the given origin, optionally looking only for those with one of the given allowed_statuses - or for those with a known snapshot. + or for those with a snapshot. Args: origin: origin URL @@ -896,6 +896,34 @@ """ ... + @remote_api_endpoint("origin/visit_status/get_latest") + def origin_visit_status_get_latest( + self, + origin_url: str, + visit: int, + allowed_statuses: Optional[List[str]] = None, + require_snapshot: bool = False, + ) -> Optional[OriginVisitStatus]: + """Get the latest origin visit status for the given origin visit, optionally + looking only for those with one of the given allowed_statuses or with a + snapshot. + + Args: + origin: origin URL + + allowed_statuses: list of visit statuses considered to find the latest + visit. Possible values are {created, ongoing, partial, full}. For + instance, ``allowed_statuses=['full']`` will only consider visits that + have successfully run to completion. + require_snapshot: If True, only a visit with a snapshot + will be returned. + + Returns: + The OriginVisitStatus matching the criteria + + """ + ... + @remote_api_endpoint("origin/visit/get_random") def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: """Randomly select one successful origin visit with diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -878,6 +878,24 @@ for visit_status in visit_statuses: self._origin_visit_status_add(visit_status, db, cur) + @timed + @db_transaction() + def origin_visit_status_get_latest( + self, + origin_url: str, + visit: int, + allowed_statuses: Optional[List[str]] = None, + require_snapshot: bool = False, + db=None, + cur=None, + ) -> Optional[OriginVisitStatus]: + row = db.origin_visit_status_get_latest( + origin_url, visit, allowed_statuses, require_snapshot, cur=cur + ) + if not row: + return None + return OriginVisitStatus.from_dict(row) + def _origin_visit_get_updated( self, origin: str, visit_id: int, db, cur ) -> Optional[Dict[str, Any]]: diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -99,6 +99,14 @@ assert actual_list == expected_list, k +def round_to_milliseconds(date): + """Round datetime to milliseconds before insertion, so equality doesn't fail after a + round-trip through a DB (eg. Cassandra) + + """ + return date.replace(microsecond=round(date.microsecond, -3)) + + class LazyContent(Content): def with_data(self): return Content.from_dict({**self.to_dict(), "data": data.cont["data"]}) @@ -1627,12 +1635,8 @@ date_visit = now() date_visit2 = date_visit + datetime.timedelta(minutes=1) - # Round to milliseconds before insertion, so equality doesn't fail - # after a round-trip through a DB (eg. Cassandra) - date_visit = date_visit.replace(microsecond=round(date_visit.microsecond, -3)) - date_visit2 = date_visit2.replace( - microsecond=round(date_visit2.microsecond, -3) - ) + date_visit = round_to_milliseconds(date_visit) + date_visit2 = round_to_milliseconds(date_visit2) visit1 = OriginVisit( origin=origin1.url, @@ -2136,6 +2140,122 @@ "snapshot": data.complete_snapshot["id"], } == swh_storage.origin_visit_get_latest(origin_url, require_snapshot=True) + def test_origin_visit_status_get_latest(self, swh_storage): + origin1 = Origin.from_dict(data.origin) + swh_storage.origin_add_one(data.origin) + + # to have some reference visits + + ov1, ov2 = swh_storage.origin_visit_add( + [ + OriginVisit( + origin=origin1.url, + date=data.date_visit1, + type=data.type_visit1, + status="ongoing", + snapshot=None, + ), + OriginVisit( + origin=origin1.url, + date=data.date_visit2, + type=data.type_visit2, + status="ongoing", + snapshot=None, + ), + ] + ) + + snapshot = Snapshot.from_dict(data.complete_snapshot) + swh_storage.snapshot_add([snapshot]) + + date_now = now() + date_now = round_to_milliseconds(date_now) + assert data.date_visit1 < data.date_visit2 + assert data.date_visit2 < date_now + + ovs1 = OriginVisitStatus( + origin=origin1.url, + visit=ov1.visit, + date=data.date_visit1, + status="partial", + snapshot=None, + ) + ovs2 = OriginVisitStatus( + origin=origin1.url, + visit=ov1.visit, + date=data.date_visit2, + status="ongoing", + snapshot=None, + ) + ovs3 = OriginVisitStatus( + origin=origin1.url, + visit=ov2.visit, + date=data.date_visit2, + status="ongoing", + snapshot=None, + ) + ovs4 = OriginVisitStatus( + origin=origin1.url, + visit=ov2.visit, + date=date_now, + status="full", + snapshot=snapshot.id, + metadata={"something": "wicked"}, + ) + + swh_storage.origin_visit_status_add([ovs1, ovs2, ovs3, ovs4]) + + # unknown origin so no result + actual_origin_visit = swh_storage.origin_visit_status_get_latest( + "unknown-origin", ov1.visit + ) + assert actual_origin_visit is None + + # unknown visit so no result + actual_origin_visit = swh_storage.origin_visit_status_get_latest( + ov1.origin, ov1.visit + 10 + ) + assert actual_origin_visit is None + + # Two visits, both with no snapshot, take the most recent + actual_origin_visit2 = swh_storage.origin_visit_status_get_latest( + origin1.url, ov1.visit + ) + assert isinstance(actual_origin_visit2, OriginVisitStatus) + assert actual_origin_visit2 == ovs2 + assert ovs2.origin == origin1.url + assert ovs2.visit == ov1.visit + + actual_origin_visit = swh_storage.origin_visit_status_get_latest( + origin1.url, ov1.visit, require_snapshot=True + ) + # there is no visit with snapshot yet for that visit + assert actual_origin_visit is None + + actual_origin_visit2 = swh_storage.origin_visit_status_get_latest( + origin1.url, ov1.visit, allowed_statuses=["partial", "ongoing"] + ) + # visit status with partial status visit elected + assert actual_origin_visit2 == ovs2 + assert actual_origin_visit2.status == "ongoing" + + actual_origin_visit4 = swh_storage.origin_visit_status_get_latest( + origin1.url, ov2.visit, require_snapshot=True + ) + assert actual_origin_visit4 == ovs4 + assert actual_origin_visit4.snapshot == snapshot.id + + actual_origin_visit = swh_storage.origin_visit_status_get_latest( + origin1.url, ov2.visit, require_snapshot=True, allowed_statuses=["ongoing"] + ) + # nothing matches so nothing + assert actual_origin_visit is None # there is no visit with status full + + actual_origin_visit3 = swh_storage.origin_visit_status_get_latest( + origin1.url, ov2.visit, allowed_statuses=["ongoing"] + ) + assert actual_origin_visit3 == ovs3 + def test_person_fullname_unicity(self, swh_storage): # given (person injection through revisions for example) revision = data.revision