diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py --- a/swh/storage/cassandra/converters.py +++ b/swh/storage/cassandra/converters.py @@ -10,10 +10,9 @@ from copy import deepcopy from typing import Any, Dict, Tuple -from cassandra.cluster import ResultSet - from swh.model.model import ( ObjectType, + OriginVisit, OriginVisitStatus, Revision, RevisionType, @@ -81,9 +80,20 @@ return hashes -def row_to_visit_status(row: ResultSet) -> OriginVisitStatus: - """Format a row representing a visit_status to an actual dict representing an - OriginVisitStatus. +def row_to_visit(row) -> OriginVisit: + """Format a row representing an origin_visit to an actual OriginVisit. + + """ + return OriginVisit( + origin=row.origin, + visit=row.visit, + date=row.date.replace(tzinfo=datetime.timezone.utc), + type=row.type, + ) + + +def row_to_visit_status(row) -> OriginVisitStatus: + """Format a row representing a visit_status to an actual OriginVisitStatus. """ return OriginVisitStatus.from_dict( diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -8,7 +8,7 @@ import json import random import re -from typing import Any, Dict, List, Iterable, Optional, Union +from typing import Any, Dict, List, Iterable, Optional, Tuple, Union import attr @@ -38,13 +38,7 @@ from ..exc import StorageArgumentException, HashCollision from .common import TOKEN_BEGIN, TOKEN_END -from .converters import ( - revision_to_db, - revision_from_db, - release_to_db, - release_from_db, - row_to_visit_status, -) +from . import converters from .cql import CqlRunner from .schema import HASH_ALGORITHMS @@ -434,7 +428,7 @@ self.journal_writer.revision_add(revisions) for revision in revisions: - revobject = revision_to_db(revision) + revobject = converters.revision_to_db(revision) if revobject: # Add parents first for (rank, parent) in enumerate(revobject["parents"]): @@ -465,7 +459,7 @@ # parent_rank is the clustering key, so results are already # sorted by rank. parents = tuple(row.parent_id for row in parent_rows) - rev = revision_from_db(row, parents=parents) + rev = converters.revision_from_db(row, parents=parents) revs[rev.id] = rev.to_dict() for rev_id in revisions: @@ -501,7 +495,7 @@ if short: yield (row.id, parents) else: - rev = revision_from_db(row, parents=parents) + rev = converters.revision_from_db(row, parents=parents) yield rev.to_dict() yield from self._get_parent_revs(parents, seen, limit, short) @@ -528,7 +522,7 @@ for release in to_add: if release: - self._cql_runner.release_add_one(release_to_db(release)) + self._cql_runner.release_add_one(converters.release_to_db(release)) return {"release:add": len(to_add)} @@ -539,7 +533,7 @@ rows = self._cql_runner.release_get(releases) rels = {} for row in rows: - release = release_from_db(row) + release = converters.release_from_db(row) rels[row.id] = release.to_dict() for rel_id in releases: @@ -844,7 +838,7 @@ visit["origin"], visit["visit"] ) assert row is not None - visit_status = row_to_visit_status(row) + visit_status = converters.row_to_visit_status(row) return { # default to the values in visit **visit, @@ -857,15 +851,14 @@ "date": visit["date"], } - def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]: - """Retrieve origin visit and latest origin visit status and merge them - into an origin visit. + def _origin_visit_get_latest_status(self, visit: OriginVisit) -> OriginVisitStatus: + """Retrieve the latest visit status information for the origin visit object. """ - row_visit = self._cql_runner.origin_visit_get_one(origin, visit_id) - assert row_visit is not None - visit = self._format_origin_visit_row(row_visit) - return self._origin_visit_apply_last_status(visit) + row = self._cql_runner.origin_visit_status_get_latest(visit.origin, visit.visit) + assert row is not None + visit_status = converters.row_to_visit_status(row) + return attr.evolve(visit_status, origin=visit.origin) @staticmethod def _format_origin_visit_row(visit): @@ -961,9 +954,11 @@ rows = [row for row in rows if row.snapshot is not None] if not rows: return None - return row_to_visit_status(rows[0]) + return converters.row_to_visit_status(rows[0]) - def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: + def origin_visit_status_get_random( + self, type: str + ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back # Random position to start iteration at @@ -972,15 +967,11 @@ # Iterator over all visits, ordered by token(origins) then visit_id rows = self._cql_runner.origin_visit_iter(start_token) for row in rows: - visit = self._format_origin_visit_row(row) - visit_status = self._origin_visit_apply_last_status(visit) - if ( - visit_status["date"] > back_in_the_day - and visit_status["status"] == "full" - ): - return visit_status - else: - return None + visit = converters.row_to_visit(row) + visit_status = self._origin_visit_get_latest_status(visit) + if visit.date > back_in_the_day and visit_status.status == "full": + return visit, visit_status + return None def stat_counters(self): rows = self._cql_runner.stat_counters() diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -862,8 +862,10 @@ for visit_status in visit_statuses: self._origin_visit_status_add_one(visit_status) - def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]: - """Merge origin visit and latest origin visit status + def _origin_visit_status_get_latest( + self, origin: str, visit_id: int + ) -> Tuple[OriginVisit, OriginVisitStatus]: + """Return a tuple of OriginVisit, latest associated OriginVisitStatus. """ assert visit_id >= 1 @@ -872,6 +874,14 @@ visit_key = (origin, visit_id) visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date) + return visit, visit_update + + def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]: + """Merge origin visit and latest origin visit status + + """ + visit, visit_update = self._origin_visit_status_get_latest(origin, visit_id) + assert visit is not None and visit_update is not None return { # default to the values in visit **visit.to_dict(), @@ -993,20 +1003,25 @@ if random_origin_visits[0].type == type: return url - def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: + def origin_visit_status_get_random( + self, type: str + ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: + url = self._select_random_origin_visit_by_type(type) random_origin_visits = copy.deepcopy(self._origin_visits[url]) random_origin_visits.reverse() back_in_the_day = now() - timedelta(weeks=12) # 3 months back # This should be enough for tests for visit in random_origin_visits: - updated_visit = self._origin_visit_get_updated(url, visit.visit) - assert updated_visit is not None + origin_visit, latest_visit_status = self._origin_visit_status_get_latest( + url, visit.visit + ) + assert latest_visit_status is not None if ( - updated_visit["date"] > back_in_the_day - and updated_visit["status"] == "full" + origin_visit.date > back_in_the_day + and latest_visit_status.status == "full" ): - return updated_visit + return origin_visit, latest_visit_status else: return None diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -5,7 +5,7 @@ import datetime -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from swh.core.api import remote_api_endpoint from swh.model.identifiers import SWHID @@ -912,14 +912,16 @@ """ ... - @remote_api_endpoint("origin/visit/get_random") - def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: + @remote_api_endpoint("origin/visit_status/get_random") + def origin_visit_status_get_random( + self, type: str + ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: """Randomly select one successful origin visit with made in the last 3 months. Returns: - dict representing an origin visit, in the same format as - :py:meth:`origin_visit_get`. + One random tuple of (OriginVisit, OriginVisitStatus) matching the + selection criteria """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -16,6 +16,7 @@ Iterable, List, Optional, + Tuple, Union, ) @@ -934,12 +935,27 @@ @timed @db_transaction() - def origin_visit_get_random( + def origin_visit_status_get_random( self, type: str, db=None, cur=None - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: row = db.origin_visit_get_random(type, cur) - if row: - return dict(zip(db.origin_visit_get_cols, row)) + if row is not None: + row_d = dict(zip(db.origin_visit_get_cols, row)) + visit = OriginVisit( + origin=row_d["origin"], + visit=row_d["visit"], + date=row_d["date"], + type=row_d["type"], + ) + visit_status = OriginVisitStatus( + origin=row_d["origin"], + visit=row_d["visit"], + date=row_d["date"], + status=row_d["status"], + metadata=row_d["metadata"], + snapshot=row_d["snapshot"], + ) + return visit, visit_status return None @timed diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1353,7 +1353,7 @@ def test_origin_visit_get__unknown_origin(self, swh_storage): assert [] == list(swh_storage.origin_visit_get("foo")) - def test_origin_visit_get_random(self, swh_storage, sample_data): + def test_origin_visit_status_get_random(self, swh_storage, sample_data): origins = sample_data.origins[:2] swh_storage.origin_add(origins) @@ -1385,12 +1385,15 @@ assert stats["origin"] == len(origins) assert stats["origin_visit"] == len(origins) * len(visits) - random_origin_visit = swh_storage.origin_visit_get_random(visit_type) - assert random_origin_visit - assert random_origin_visit["origin"] is not None - assert random_origin_visit["origin"] in [o.url for o in origins] + random_ov, random_ovs = swh_storage.origin_visit_status_get_random(visit_type) + assert random_ov and random_ovs + assert random_ov.origin is not None + assert random_ov.origin == random_ovs.origin + assert random_ov.origin in [o.url for o in origins] - def test_origin_visit_get_random_nothing_found(self, swh_storage, sample_data): + def test_origin_visit_status_get_random_nothing_found( + self, swh_storage, sample_data + ): origins = sample_data.origins swh_storage.origin_add(origins) visit_type = "hg" @@ -1414,7 +1417,7 @@ ] ) - random_origin_visit = swh_storage.origin_visit_get_random(visit_type) + random_origin_visit = swh_storage.origin_visit_status_get_random(visit_type) assert random_origin_visit is None def test_origin_get_by_sha1(self, swh_storage, sample_data):