diff --git a/swh/storage/algos/origin.py b/swh/storage/algos/origin.py --- a/swh/storage/algos/origin.py +++ b/swh/storage/algos/origin.py @@ -6,7 +6,7 @@ from typing import Iterator, List, Optional, Tuple from swh.model.model import Origin, OriginVisit, OriginVisitStatus -from swh.storage.interface import StorageInterface +from swh.storage.interface import ListOrder, StorageInterface def iter_origins( @@ -95,3 +95,35 @@ if visit_status: result = visit, visit_status return result + + +def iter_origin_visits( + storage: StorageInterface, origin: str, order: ListOrder = ListOrder.ASC +) -> Iterator[OriginVisit]: + """Iter over origin visits from an origin + + """ + next_page_token = None + while True: + page = storage.origin_visit_get(origin, order=order, page_token=next_page_token) + next_page_token = page.next_page_token + yield from page.results + if page.next_page_token is None: + break + + +def iter_origin_visit_statuses( + storage: StorageInterface, origin: str, visit: int, order: ListOrder = ListOrder.ASC +) -> Iterator[OriginVisitStatus]: + """Iter over origin visit status from an origin visit + + """ + next_page_token = None + while True: + page = storage.origin_visit_status_get( + origin, visit, order=order, page_token=next_page_token + ) + next_page_token = page.next_page_token + yield from page.results + if next_page_token is None: + break diff --git a/swh/storage/algos/snapshot.py b/swh/storage/algos/snapshot.py --- a/swh/storage/algos/snapshot.py +++ b/swh/storage/algos/snapshot.py @@ -5,9 +5,14 @@ from typing import List, Optional -from swh.model.model import Snapshot +from swh.model.model import Snapshot, TargetType -from swh.storage.algos.origin import origin_get_latest_visit_status +from swh.storage.algos.origin import ( + origin_get_latest_visit_status, + iter_origin_visits, + iter_origin_visit_statuses, +) +from swh.storage.interface import ListOrder, StorageInterface def snapshot_get_all_branches(storage, snapshot_id): @@ -93,3 +98,41 @@ else: snapshot = snapshot_get_all_branches(storage, snapshot_id) return Snapshot.from_dict(snapshot) if snapshot else None + + +def snapshot_id_get_from_revision( + storage: StorageInterface, origin: str, revision_id: bytes +) -> Optional[bytes]: + """Retrieve the most recent snapshot id targeting the revision_id for the given origin. + + *Warning* This is a potentially highly costly operation + + Returns + The snapshot id if found. None otherwise. + + """ + revision = storage.revision_get([revision_id]) + if not revision: + return None + + for visit in iter_origin_visits(storage, origin, order=ListOrder.DESC): + assert visit.visit is not None + for visit_status in iter_origin_visit_statuses( + storage, origin, visit.visit, order=ListOrder.DESC + ): + snapshot_id = visit_status.snapshot + if snapshot_id is None: + continue + + snapshot = snapshot_get_all_branches(storage, snapshot_id) + if not snapshot: + continue + for branch_name, branch in snapshot["branches"].items(): + if ( + branch is not None + and branch["target_type"] == TargetType.REVISION.value + and branch["target"] == revision_id + ): # snapshot found + return snapshot_id + + return None diff --git a/swh/storage/tests/algos/test_origin.py b/swh/storage/tests/algos/test_origin.py --- a/swh/storage/tests/algos/test_origin.py +++ b/swh/storage/tests/algos/test_origin.py @@ -3,13 +3,21 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import copy +import datetime import pytest from unittest.mock import patch from swh.model.model import Origin, OriginVisit, OriginVisitStatus -from swh.storage.algos.origin import iter_origins, origin_get_latest_visit_status +from swh.storage.algos.origin import ( + iter_origins, + origin_get_latest_visit_status, + iter_origin_visits, + iter_origin_visit_statuses, +) +from swh.storage.interface import ListOrder from swh.storage.utils import now from swh.storage.tests.test_storage import round_to_milliseconds @@ -319,3 +327,90 @@ assert actual_ov2.visit == ov2.visit assert actual_ov2.type == ov2.type assert actual_ovs22 == ovs22 + + +def test_iter_origin_visits(swh_storage, sample_data): + """Iter over origin visits for an origin returns all visits""" + origin1, origin2 = sample_data.origins[:2] + swh_storage.origin_add([origin1, origin2]) + + date_past = now() - datetime.timedelta(weeks=20) + + new_visits = [] + for visit_id in range(20, 0, -1): # will trigger pagination + visit = OriginVisit( + origin=origin1.url, + date=date_past - datetime.timedelta(days=visit_id), + type="git", + ) + new_visits.append(visit) + visits = swh_storage.origin_visit_add(new_visits) + reverse_visits = copy.deepcopy(visits) + reverse_visits.reverse() + + # no limit, order asc + actual_visits = list(iter_origin_visits(swh_storage, origin1.url)) + assert actual_visits == visits + + # no limit, order desc + actual_visits = list( + iter_origin_visits(swh_storage, origin1.url, order=ListOrder.DESC) + ) + assert actual_visits == reverse_visits + + # no result + actual_visits = list(iter_origin_visits(swh_storage, origin2.url)) + assert actual_visits == [] + + +def test_iter_origin_visit_status(swh_storage, sample_data): + origin1, origin2 = sample_data.origins[:2] + swh_storage.origin_add([origin1]) + + ov1 = swh_storage.origin_visit_add([sample_data.origin_visit])[0] + assert ov1.origin == origin1.url + + date_past = now() - datetime.timedelta(weeks=20) + + ovs1 = OriginVisitStatus( + origin=origin1.url, + visit=ov1.visit, + date=ov1.date, + status="created", + snapshot=None, + ) + new_visit_statuses = [ovs1] + for i in range(20, 0, -1): # will trigger pagination + status_date = date_past - datetime.timedelta(days=i) + visit_status = OriginVisitStatus( + origin=origin1.url, + visit=ov1.visit, + date=status_date, + status="created", + snapshot=None, + ) + new_visit_statuses.append(visit_status) + + visit_statuses = swh_storage.origin_visit_add(new_visit_statuses) + reverse_visit_statuses = copy.deepcopy(visit_statuses) + reverse_visit_statuses.reverse() + + # order asc + actual_visit_statuses = list( + iter_origin_visit_statuses(swh_storage, ov1.origin, ov1.visit) + ) + assert actual_visit_statuses == visit_statuses + + # order desc + actual_visit_statuses = list( + iter_origin_visit_statuses( + swh_storage, ov1.origin, ov1.visit, order=ListOrder.DESC + ) + ) + assert actual_visit_statuses == reverse_visit_statuses + + # no result + actual_visit_statuses = list( + iter_origin_visit_statuses(swh_storage, origin2.url, ov1.visit) + ) + assert actual_visit_statuses == [] diff --git a/swh/storage/tests/algos/test_snapshot.py b/swh/storage/tests/algos/test_snapshot.py --- a/swh/storage/tests/algos/test_snapshot.py +++ b/swh/storage/tests/algos/test_snapshot.py @@ -10,7 +10,11 @@ from swh.model.hypothesis_strategies import snapshots, branch_names, branch_targets from swh.model.model import OriginVisit, OriginVisitStatus, Snapshot -from swh.storage.algos.snapshot import snapshot_get_all_branches, snapshot_get_latest +from swh.storage.algos.snapshot import ( + snapshot_get_all_branches, + snapshot_get_latest, + snapshot_id_get_from_revision, +) from swh.storage.utils import now @@ -145,3 +149,61 @@ with pytest.raises(ValueError, match="branches_count must be a positive integer"): snapshot_get_latest(swh_storage, origin.url, branches_count="something-wrong") + + +def test_snapshot_get_id_from_revision(swh_storage, sample_data): + origin = sample_data.origin + swh_storage.origin_add([origin]) + + date_visit2 = now() + visit1, visit2 = sample_data.origin_visits[:2] + assert visit1.origin == origin.url + + ov1, ov2 = swh_storage.origin_visit_add([visit1, visit2]) + + revision1, revision2, revision3 = sample_data.revisions[:3] + swh_storage.revision_add([revision1, revision2]) + + empty_snapshot, complete_snapshot = sample_data.snapshots[1:3] + swh_storage.snapshot_add([complete_snapshot]) + + # Add complete_snapshot to visit1 which targets revision1 + ovs1, ovs2 = [ + OriginVisitStatus( + origin=origin.url, + visit=ov1.visit, + date=date_visit2, + status="partial", + snapshot=complete_snapshot.id, + ), + OriginVisitStatus( + origin=origin.url, + visit=ov2.visit, + date=now(), + status="full", + snapshot=empty_snapshot.id, + ), + ] + + swh_storage.origin_visit_status_add([ovs1, ovs2]) + assert ov1.date < ov2.date + assert ov2.date < ovs1.date + assert ovs1.date < ovs2.date + + # revision3 does not exist so result is None + actual_snapshot_id = snapshot_id_get_from_revision( + swh_storage, origin.url, revision3.id + ) + assert actual_snapshot_id is None + + # no snapshot targets revision2 for origin.url so result is None + actual_snapshot_id = snapshot_id_get_from_revision( + swh_storage, origin.url, revision2.id + ) + assert actual_snapshot_id is None + + # complete_snapshot targets at least revision1 + actual_snapshot_id = snapshot_id_get_from_revision( + swh_storage, origin.url, revision1.id + ) + assert actual_snapshot_id == complete_snapshot.id