diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py --- a/swh/storage/cassandra/converters.py +++ b/swh/storage/cassandra/converters.py @@ -14,6 +14,7 @@ from swh.model.model import ( ObjectType, + OriginVisit, OriginVisitStatus, Revision, RevisionType, @@ -81,9 +82,20 @@ return hashes +def row_to_visit(row: ResultSet) -> OriginVisit: + """Format a row representing an origin_visit to an actual OriginVisit. + + """ + return OriginVisit( + origin=row.origin, + visit=row.visit, + date=row.date.replace(tzinfo=datetime.timezone.utc), + type=row.type, + ) + + def row_to_visit_status(row: ResultSet) -> OriginVisitStatus: - """Format a row representing a visit_status to an actual dict representing an - OriginVisitStatus. + """Format a row representing a visit_status to an actual OriginVisitStatus. """ return OriginVisitStatus.from_dict( diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -8,7 +8,7 @@ import json import random import re -from typing import Any, Dict, List, Iterable, Optional, Union +from typing import Any, Dict, List, Iterable, Optional, Tuple, Union import attr @@ -43,6 +43,7 @@ revision_from_db, release_to_db, release_from_db, + row_to_visit, row_to_visit_status, ) from .cql import CqlRunner @@ -857,6 +858,15 @@ "date": visit["date"], } + def _origin_visit_get_latest_status(self, visit: OriginVisit) -> OriginVisitStatus: + """Retrieve the latest visit status information for the origin visit object. + + """ + row = self._cql_runner.origin_visit_status_get_latest(visit.origin, visit.visit) + assert row is not None + visit_status = row_to_visit_status(row) + return attr.evolve(visit_status, origin=visit.origin) + def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]: """Retrieve origin visit and latest origin visit status and merge them into an origin visit. @@ -963,7 +973,9 @@ return None return row_to_visit_status(rows[0]) - def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: + def origin_visit_get_random( + self, type: str + ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back # Random position to start iteration at @@ -972,15 +984,11 @@ # Iterator over all visits, ordered by token(origins) then visit_id rows = self._cql_runner.origin_visit_iter(start_token) for row in rows: - visit = self._format_origin_visit_row(row) - visit_status = self._origin_visit_apply_last_status(visit) - if ( - visit_status["date"] > back_in_the_day - and visit_status["status"] == "full" - ): - return visit_status - else: - return None + visit = row_to_visit(row) + visit_status = self._origin_visit_get_latest_status(visit) + if visit.date > back_in_the_day and visit_status.status == "full": + return visit, visit_status + return None def stat_counters(self): rows = self._cql_runner.stat_counters() diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -862,8 +862,10 @@ for visit_status in visit_statuses: self._origin_visit_status_add_one(visit_status) - def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]: - """Merge origin visit and latest origin visit status + def _origin_visit_status_get_latest( + self, origin: str, visit_id: int + ) -> Tuple[OriginVisit, OriginVisitStatus]: + """Return a tuple of OriginVisit, latest associated OriginVisitStatus. """ assert visit_id >= 1 @@ -872,6 +874,14 @@ visit_key = (origin, visit_id) visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date) + return visit, visit_update + + def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]: + """Merge origin visit and latest origin visit status + + """ + visit, visit_update = self._origin_visit_status_get_latest(origin, visit_id) + assert visit is not None and visit_update is not None return { # default to the values in visit **visit.to_dict(), @@ -993,20 +1003,25 @@ if random_origin_visits[0].type == type: return url - def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: + def origin_visit_get_random( + self, type: str + ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: + url = self._select_random_origin_visit_by_type(type) random_origin_visits = copy.deepcopy(self._origin_visits[url]) random_origin_visits.reverse() back_in_the_day = now() - timedelta(weeks=12) # 3 months back # This should be enough for tests for visit in random_origin_visits: - updated_visit = self._origin_visit_get_updated(url, visit.visit) - assert updated_visit is not None + origin_visit, latest_visit_status = self._origin_visit_status_get_latest( + url, visit.visit + ) + assert latest_visit_status is not None if ( - updated_visit["date"] > back_in_the_day - and updated_visit["status"] == "full" + origin_visit.date > back_in_the_day + and latest_visit_status.status == "full" ): - return updated_visit + return origin_visit, latest_visit_status else: return None diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -5,7 +5,7 @@ import datetime -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from swh.core.api import remote_api_endpoint from swh.model.identifiers import SWHID @@ -913,13 +913,15 @@ ... @remote_api_endpoint("origin/visit/get_random") - def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: + def origin_visit_get_random( + self, type: str + ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: """Randomly select one successful origin visit with made in the last 3 months. Returns: - dict representing an origin visit, in the same format as - :py:meth:`origin_visit_get`. + One random tuple of (OriginVisit, OriginVisitStatus) matching the + selection criteria """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -16,6 +16,7 @@ Iterable, List, Optional, + Tuple, Union, ) @@ -964,11 +965,19 @@ @db_transaction() def origin_visit_get_random( self, type: str, db=None, cur=None - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: row = db.origin_visit_get_random(type, cur) - if row: - visit = dict(zip(db.origin_visit_get_cols, row)) - return self._origin_visit_apply_update(visit, db) + if row is not None: + visit = OriginVisit(origin=row[0], visit=row[1], date=row[2], type=row[3],) + visit_status = OriginVisitStatus( + origin=row[0], + visit=row[1], + date=row[2], + status=row[4], + metadata=row[5], + snapshot=row[6], + ) + return visit, visit_status return None @timed diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1385,10 +1385,11 @@ assert stats["origin"] == len(origins) assert stats["origin_visit"] == len(origins) * len(visits) - random_origin_visit = swh_storage.origin_visit_get_random(visit_type) - assert random_origin_visit - assert random_origin_visit["origin"] is not None - assert random_origin_visit["origin"] in [o.url for o in origins] + random_ov, random_ovs = swh_storage.origin_visit_get_random(visit_type) + assert random_ov and random_ovs + assert random_ov.origin is not None + assert random_ov.origin == random_ovs.origin + assert random_ov.origin in [o.url for o in origins] def test_origin_visit_get_random_nothing_found(self, swh_storage, sample_data): origins = sample_data.origins