diff --git a/swh/storage/algos/snapshot.py b/swh/storage/algos/snapshot.py --- a/swh/storage/algos/snapshot.py +++ b/swh/storage/algos/snapshot.py @@ -3,10 +3,16 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import List, Optional +from typing import Iterator, List, Optional, Tuple from swh.model.hashutil import hash_to_hex -from swh.model.model import Sha1Git, Snapshot, TargetType +from swh.model.model import ( + OriginVisit, + OriginVisitStatus, + Sha1Git, + Snapshot, + TargetType, +) from swh.storage.algos.origin import ( origin_get_latest_visit_status, @@ -113,10 +119,32 @@ Returns The snapshot id if found. None otherwise. + """ + res = visits_and_snapshots_get_from_revision(storage, origin, revision_id) + + # they are sorted by descending date, so we just need to return the first one, + # if any. + for (visit, status, snapshot) in res: + return snapshot.id + + return None + + +def visits_and_snapshots_get_from_revision( + storage: StorageInterface, origin: str, revision_id: bytes +) -> Iterator[Tuple[OriginVisit, OriginVisitStatus, Snapshot]]: + """Retrieve all visits, visit statuses, and matching snapshot of the given origin, + such that the snapshot targets the revision_id. + + *Warning* This is a potentially highly costly operation + + Yields: + Tuples of (visit, status, snapshot) + """ revision = storage.revision_get([revision_id]) if not revision: - return None + return for visit in iter_origin_visits(storage, origin, order=ListOrder.DESC): assert visit.visit is not None @@ -136,6 +164,4 @@ and branch.target_type == TargetType.REVISION and branch.target == revision_id ): # snapshot found - return snapshot_id - - return None + yield (visit, visit_status, snapshot) diff --git a/swh/storage/tests/algos/test_snapshot.py b/swh/storage/tests/algos/test_snapshot.py --- a/swh/storage/tests/algos/test_snapshot.py +++ b/swh/storage/tests/algos/test_snapshot.py @@ -14,6 +14,7 @@ snapshot_get_all_branches, snapshot_get_latest, snapshot_id_get_from_revision, + visits_and_snapshots_get_from_revision, ) from swh.storage.utils import now @@ -151,7 +152,7 @@ snapshot_get_latest(swh_storage, origin.url, branches_count="something-wrong") -def test_snapshot_get_id_from_revision(swh_storage, sample_data): +def test_snapshot_id_get_from_revision(swh_storage, sample_data): origin = sample_data.origin swh_storage.origin_add([origin]) @@ -207,3 +208,61 @@ swh_storage, origin.url, revision1.id ) assert actual_snapshot_id == complete_snapshot.id + + +def test_visit_and_snapshot_get_from_revision(swh_storage, sample_data): + origin = sample_data.origin + swh_storage.origin_add([origin]) + + date_visit2 = now() + visit1, visit2 = sample_data.origin_visits[:2] + assert visit1.origin == origin.url + + ov1, ov2 = swh_storage.origin_visit_add([visit1, visit2]) + + revision1, revision2, revision3 = sample_data.revisions[:3] + swh_storage.revision_add([revision1, revision2]) + + empty_snapshot, complete_snapshot = sample_data.snapshots[1:3] + swh_storage.snapshot_add([complete_snapshot]) + + # Add complete_snapshot to visit1 which targets revision1 + ovs1, ovs2 = [ + OriginVisitStatus( + origin=origin.url, + visit=ov1.visit, + date=date_visit2, + status="partial", + snapshot=complete_snapshot.id, + ), + OriginVisitStatus( + origin=origin.url, + visit=ov2.visit, + date=now(), + status="full", + snapshot=empty_snapshot.id, + ), + ] + + swh_storage.origin_visit_status_add([ovs1, ovs2]) + assert ov1.date < ov2.date + assert ov2.date < ovs1.date + assert ovs1.date < ovs2.date + + # revision3 does not exist so result is None + actual_snapshot_id = snapshot_id_get_from_revision( + swh_storage, origin.url, revision3.id + ) + assert actual_snapshot_id is None + + # no snapshot targets revision2 for origin.url so result is None + res = list( + visits_and_snapshots_get_from_revision(swh_storage, origin.url, revision2.id) + ) + assert res == [] + + # complete_snapshot targets at least revision1 + res = list( + visits_and_snapshots_get_from_revision(swh_storage, origin.url, revision1.id) + ) + assert res == [(ov1, ovs1, complete_snapshot)]