diff --git a/swh/storage/algos/snapshot.py b/swh/storage/algos/snapshot.py --- a/swh/storage/algos/snapshot.py +++ b/swh/storage/algos/snapshot.py @@ -5,9 +5,10 @@ from typing import List, Optional -from swh.model.model import Snapshot +from swh.model.model import Snapshot, TargetType from swh.storage.algos.origin import origin_get_latest_visit_status +from swh.storage.interface import ListOrder, StorageInterface def snapshot_get_all_branches(storage, snapshot_id): @@ -93,3 +94,55 @@ else: snapshot = snapshot_get_all_branches(storage, snapshot_id) return Snapshot.from_dict(snapshot) if snapshot else None + + +def snapshot_get_from_revision( + storage: StorageInterface, origin: str, revision_id: bytes +) -> Optional[bytes]: + """Retrieve the most recent snapshot targeting the revision_id for the given origin. + + Returns + The snapshot id if found. None otherwise. + + """ + revision = storage.revision_get([revision_id]) + if not revision: + return None + + next_page_token_visit = None + while True: + visit_page = storage.origin_visit_get( + origin, order=ListOrder.DESC, page_token=next_page_token_visit + ) + next_page_token_visit = visit_page.next_page_token + for visit in visit_page.results: + next_page_token_visit_status = None + while True: + visit_status_page = storage.origin_visit_status_get( + origin, + visit.visit, + order=ListOrder.DESC, + page_token=next_page_token_visit_status, + ) + next_page_token_visit_status = visit_status_page.next_page_token + for visit_status in visit_status_page.results: + snapshot_id = visit_status.snapshot + if snapshot_id is None: + continue + + detail_snapshot = storage.snapshot_get(snapshot_id) + if not detail_snapshot: + continue + for branch_name, branch in detail_snapshot["branches"].items(): + if ( + branch["target_type"] == TargetType.REVISION.value + and branch["target"] == revision_id + ): # snapshot found + return snapshot_id + if next_page_token_visit_status is None: + break + + if next_page_token_visit is None: + break + + return None diff --git a/swh/storage/tests/algos/test_snapshot.py b/swh/storage/tests/algos/test_snapshot.py --- a/swh/storage/tests/algos/test_snapshot.py +++ b/swh/storage/tests/algos/test_snapshot.py @@ -10,7 +10,11 @@ from swh.model.hypothesis_strategies import snapshots, branch_names, branch_targets from swh.model.model import OriginVisit, OriginVisitStatus, Snapshot -from swh.storage.algos.snapshot import snapshot_get_all_branches, snapshot_get_latest +from swh.storage.algos.snapshot import ( + snapshot_get_all_branches, + snapshot_get_latest, + snapshot_get_from_revision, +) from swh.storage.utils import now @@ -145,3 +149,61 @@ with pytest.raises(ValueError, match="branches_count must be a positive integer"): snapshot_get_latest(swh_storage, origin.url, branches_count="something-wrong") + + +def test_snapshot_get_from_revision(swh_storage, sample_data): + origin = sample_data.origin + swh_storage.origin_add([origin]) + + date_visit2 = now() + visit1, visit2 = sample_data.origin_visits[:2] + assert visit1.origin == origin.url + + ov1, ov2 = swh_storage.origin_visit_add([visit1, visit2]) + + revision1, revision2, revision3 = sample_data.revisions[:3] + swh_storage.revision_add([revision1, revision2]) + + empty_snapshot, complete_snapshot = sample_data.snapshots[1:3] + swh_storage.snapshot_add([complete_snapshot]) + + # Add complete_snapshot to visit1 which targets revision1 + ovs1, ovs2 = [ + OriginVisitStatus( + origin=origin.url, + visit=ov1.visit, + date=date_visit2, + status="partial", + snapshot=complete_snapshot.id, + ), + OriginVisitStatus( + origin=origin.url, + visit=ov2.visit, + date=now(), + status="full", + snapshot=empty_snapshot.id, + ), + ] + + swh_storage.origin_visit_status_add([ovs1, ovs2]) + assert ov1.date < ov2.date + assert ov2.date < ovs1.date + assert ovs1.date < ovs2.date + + # revision3 does not exist so result is None + actual_snapshot_id = snapshot_get_from_revision( + swh_storage, origin.url, revision3.id + ) + assert actual_snapshot_id is None + + # no snapshot targets revision2 for origin.url so result is None + actual_snapshot_id = snapshot_get_from_revision( + swh_storage, origin.url, revision2.id + ) + assert actual_snapshot_id is None + + # complete_snapshot targets at least revision1 + actual_snapshot_id = snapshot_get_from_revision( + swh_storage, origin.url, revision1.id + ) + assert actual_snapshot_id == complete_snapshot.id