diff --git a/swh/indexer/origin_head.py b/swh/indexer/origin_head.py --- a/swh/indexer/origin_head.py +++ b/swh/indexer/origin_head.py @@ -4,15 +4,16 @@ # See top-level LICENSE file for more information import re -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union -from swh.model.model import SnapshotBranch, TargetType +from swh.model.model import Snapshot, SnapshotBranch, TargetType from swh.model.swhids import CoreSWHID, ObjectType from swh.storage.algos.origin import origin_get_latest_visit_status from swh.storage.algos.snapshot import snapshot_get_all_branches +from swh.storage.interface import PartialBranches, StorageInterface -def get_head_swhid(storage, origin_url: str) -> Optional[CoreSWHID]: +def get_head_swhid(storage: StorageInterface, origin_url: str) -> Optional[CoreSWHID]: """Returns the SWHID of the head revision or release of an origin""" visit_status = origin_get_latest_visit_status( storage, origin_url, allowed_statuses=["full"], require_snapshot=True @@ -20,14 +21,24 @@ if not visit_status: return None assert visit_status.snapshot is not None - snapshot = snapshot_get_all_branches(storage, visit_status.snapshot) - if snapshot is None: - return None if visit_status.type == "ftp": - return _try_get_ftp_head(dict(snapshot.branches)) + # We need to fetch all branches in order to find the largest one + snapshot = snapshot_get_all_branches(storage, visit_status.snapshot) + if snapshot is None: + return None + return _try_get_ftp_head(storage, snapshot) else: - return _try_get_head_generic(dict(snapshot.branches)) + # Peak into the snapshot, without fetching too many refs. + # If the snapshot is small, this gets all of it in a single request. + # If the snapshot is large, we will query specific branches as we need them. + partial_branches = storage.snapshot_get_branches( + visit_status.snapshot, branches_count=100 + ) + if partial_branches is None: + # Snapshot does not exist + return None + return _try_get_head_generic(storage, partial_branches) _archive_filename_re = re.compile( @@ -78,31 +89,56 @@ def _try_get_ftp_head( - branches: Dict[bytes, Optional[SnapshotBranch]] + storage: StorageInterface, snapshot: Snapshot ) -> Optional[CoreSWHID]: - archive_names = list(branches) + archive_names = list(snapshot.branches) max_archive_name = max(archive_names, key=_parse_version) - return _try_resolve_target(branches, max_archive_name) + return _try_resolve_target( + storage, + {"id": snapshot.id, "branches": dict(snapshot.branches), "next_branch": None}, + branch_name=max_archive_name, + ) def _try_get_head_generic( - branches: Dict[bytes, Optional[SnapshotBranch]] + storage: StorageInterface, partial_branches: PartialBranches ) -> Optional[CoreSWHID]: # Works on 'deposit', 'pypi', and VCSs. - return _try_resolve_target(branches, b"HEAD") or _try_resolve_target( - branches, b"master" - ) + return _try_resolve_target( + storage, partial_branches, branch_name=b"HEAD" + ) or _try_resolve_target(storage, partial_branches, branch_name=b"master") + + +def _get_branch( + storage: StorageInterface, partial_branches: PartialBranches, branch_name: bytes +) -> Optional[SnapshotBranch]: + """Given a ``branch_name``, gets it from ``partial_branches`` if present, + and fetches it from the storage otherwise.""" + if branch_name in partial_branches["branches"]: + return partial_branches["branches"][branch_name] + elif partial_branches["next_branch"] is not None: + # Branch is not in `partial_branches`, and `partial_branches` indeed partial + res = storage.snapshot_get_branches( + partial_branches["id"], branches_from=branch_name, branches_count=1 + ) + assert res is not None, "Snapshot does not exist anymore" + return res["branches"].get(branch_name) + else: + # Branch is not in `partial_branches`, but `partial_branches` is the full + # list of branches, which means it is a dangling reference. + return None def _try_resolve_target( - branches: Dict[bytes, Optional[SnapshotBranch]], branch_name: bytes + storage: StorageInterface, partial_branches: PartialBranches, branch_name: bytes ) -> Optional[CoreSWHID]: try: - branch = branches[branch_name] + branch = _get_branch(storage, partial_branches, branch_name) if branch is None: return None + while branch.target_type == TargetType.ALIAS: - branch = branches[branch.target] + branch = _get_branch(storage, partial_branches, branch.target) if branch is None: return None diff --git a/swh/indexer/tests/test_origin_head.py b/swh/indexer/tests/test_origin_head.py --- a/swh/indexer/tests/test_origin_head.py +++ b/swh/indexer/tests/test_origin_head.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information from datetime import datetime, timezone +import itertools import pytest @@ -20,6 +21,13 @@ from swh.model.swhids import CoreSWHID from swh.storage.utils import now + +@pytest.fixture +def swh_storage_backend_config(): + """In-memory storage, to make tests go faster.""" + return {"cls": "memory"} + + SAMPLE_SNAPSHOT = Snapshot( branches={ b"foo": None, @@ -31,6 +39,28 @@ ) +def _add_snapshot_to_origin(storage, origin_url, visit_type, snapshot): + storage.origin_add([Origin(url=origin_url)]) + visit = storage.origin_visit_add( + [ + OriginVisit( + origin=origin_url, + date=datetime(2019, 2, 27, tzinfo=timezone.utc), + type="pypi", + ) + ] + )[0] + storage.snapshot_add([snapshot]) + visit_status = OriginVisitStatus( + origin=origin_url, + visit=visit.visit, + date=now(), + status="full", + snapshot=snapshot.id, + ) + storage.origin_visit_status_add([visit_status]) + + @pytest.fixture def storage(swh_storage): fill_storage(swh_storage) @@ -77,31 +107,115 @@ def test_pypi_missing_branch(storage): origin_url = "https://pypi.org/project/abcdef/" - storage.origin_add( - [ - Origin( - url=origin_url, - ) - ] + _add_snapshot_to_origin(storage, origin_url, "pypi", SAMPLE_SNAPSHOT) + assert get_head_swhid(storage, origin_url) is None + + +@pytest.mark.parametrize( + "branches_start,branches_middle,branches_end", + itertools.product([0, 40, 99, 100, 200], [0, 40, 99, 100, 200], [0, 40, 200]), +) +def test_large_snapshot(storage, branches_start, branches_middle, branches_end): + rev_id = "8ea98e2fea7d9f6546f49ffdeecc1ab4608c8b79" + snapshot = Snapshot( + branches=dict( + [(f"AAAA{i}".encode(), None) for i in range(branches_start)] + + [ + ( + b"HEAD", + SnapshotBranch( + target_type=TargetType.ALIAS, target=b"refs/heads/foo" + ), + ) + ] + + [(f"aaaa{i}".encode(), None) for i in range(branches_middle)] + + [ + ( + b"refs/heads/foo", + SnapshotBranch( + target_type=TargetType.REVISION, + target=bytes.fromhex(rev_id), + ), + ) + ] + + [(f"zzzz{i}".encode(), None) for i in range(branches_end)] + ) ) - visit = storage.origin_visit_add( - [ - OriginVisit( - origin=origin_url, - date=datetime(2019, 2, 27, tzinfo=timezone.utc), - type="pypi", - ) - ] - )[0] - storage.snapshot_add([SAMPLE_SNAPSHOT]) - visit_status = OriginVisitStatus( - origin=origin_url, - visit=visit.visit, - date=now(), - status="full", - snapshot=SAMPLE_SNAPSHOT.id, + + origin_url = "https://example.org/repo.git" + _add_snapshot_to_origin(storage, origin_url, "git", snapshot) + + assert get_head_swhid(storage, origin_url) == CoreSWHID.from_string( + "swh:1:rev:8ea98e2fea7d9f6546f49ffdeecc1ab4608c8b79" ) - storage.origin_visit_status_add([visit_status]) + + +def test_large_snapshot_chained_aliases(storage): + rev_id = "8ea98e2fea7d9f6546f49ffdeecc1ab4608c8b79" + snapshot = Snapshot( + branches=dict( + [(f"AAAA{i}".encode(), None) for i in range(200)] + + [ + ( + b"HEAD", + SnapshotBranch( + target_type=TargetType.ALIAS, target=b"refs/heads/alias2" + ), + ) + ] + + [(f"aaaa{i}".encode(), None) for i in range(200)] + + [ + ( + b"refs/heads/alias2", + SnapshotBranch( + target_type=TargetType.ALIAS, target=b"refs/heads/branch" + ), + ) + ] + + [(f"refs/heads/bbbb{i}".encode(), None) for i in range(200)] + + [ + ( + b"refs/heads/branch", + SnapshotBranch( + target_type=TargetType.REVISION, + target=bytes.fromhex(rev_id), + ), + ) + ] + ) + ) + + origin_url = "https://example.org/repo.git" + _add_snapshot_to_origin(storage, origin_url, "git", snapshot) + + assert get_head_swhid(storage, origin_url) == CoreSWHID.from_string( + "swh:1:rev:8ea98e2fea7d9f6546f49ffdeecc1ab4608c8b79" + ) + + +@pytest.mark.parametrize( + "branches_start,branches_end", + itertools.product([0, 40, 99, 100, 200], [0, 40, 200]), +) +def test_large_snapshot_dangling_alias(storage, branches_start, branches_end): + snapshot = Snapshot( + branches=dict( + [(f"AAAA{i}".encode(), None) for i in range(branches_start)] + + [ + ( + b"HEAD", + SnapshotBranch( + target_type=TargetType.ALIAS, target=b"refs/heads/foo" + ), + ) + ] + + [(f"zzzz{i}".encode(), None) for i in range(branches_end)] + ) + ) + + origin_url = "https://example.org/repo.git" + _add_snapshot_to_origin(storage, origin_url, "git", snapshot) + assert get_head_swhid(storage, origin_url) is None