diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ tenacity cassandra-driver >= 3.19.0, != 3.21.0 deprecated +typing-extensions diff --git a/swh/storage/algos/snapshot.py b/swh/storage/algos/snapshot.py --- a/swh/storage/algos/snapshot.py +++ b/swh/storage/algos/snapshot.py @@ -27,18 +27,18 @@ * **branches**: a dict of branches contained in the snapshot whose keys are the branches' names. """ - ret = storage.snapshot_get(snapshot_id) + ret = storage.snapshot_get_branches(snapshot_id) if not ret: return - next_branch = ret.pop("next_branch", None) + next_branch = ret["next_branch"] while next_branch: data = storage.snapshot_get_branches(snapshot_id, branches_from=next_branch) ret["branches"].update(data["branches"]) - next_branch = data.get("next_branch") + next_branch = data["next_branch"] - return ret + return Snapshot(id=ret["id"], branches=ret["branches"]) def snapshot_get_latest( @@ -95,9 +95,9 @@ if snapshot is None: return None snapshot.pop("next_branch") + return Snapshot(**snapshot) else: - snapshot = snapshot_get_all_branches(storage, snapshot_id) - return Snapshot.from_dict(snapshot) if snapshot else None + return snapshot_get_all_branches(storage, snapshot_id) def snapshot_id_get_from_revision( @@ -127,11 +127,11 @@ snapshot = snapshot_get_all_branches(storage, snapshot_id) if not snapshot: continue - for branch_name, branch in snapshot["branches"].items(): + for branch_name, branch in snapshot.branches.items(): if ( branch is not None - and branch["target_type"] == TargetType.REVISION.value - and branch["target"] == revision_id + and branch.target_type == TargetType.REVISION + and branch.target == revision_id ): # snapshot found return snapshot_id diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -586,7 +586,7 @@ ) def snapshot_branch_get( self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement - ) -> None: + ) -> ResultSet: return self._execute_with_retries(statement, [snapshot_id, from_, limit]) ########################## diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -26,6 +26,8 @@ OriginVisit, OriginVisitStatus, Snapshot, + SnapshotBranch, + TargetType, Origin, MetadataAuthority, MetadataAuthorityType, @@ -34,7 +36,13 @@ RawExtrinsicMetadata, Sha1Git, ) -from swh.storage.interface import ListOrder, PagedResult, Sha1, VISIT_STATUSES +from swh.storage.interface import ( + ListOrder, + PagedResult, + PartialBranches, + Sha1, + VISIT_STATUSES, +) from swh.storage.objstorage import ObjStorage from swh.storage.writer import JournalWriter from swh.storage.utils import map_optional, now @@ -612,7 +620,17 @@ return self._cql_runner.snapshot_missing(snapshots) def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: - return self.snapshot_get_branches(snapshot_id) + d = self.snapshot_get_branches(snapshot_id) + if d is None: + return None + return { + "id": d["id"], + "branches": { + name: branch.to_dict() if branch else None + for (name, branch) in d["branches"].items() + }, + "next_branch": d["next_branch"], + } def snapshot_get_by_origin_visit( self, origin: str, visit: int @@ -643,7 +661,7 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[PartialBranches]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. @@ -682,18 +700,18 @@ else: last_branch = None - branches_d = { - branch.name: {"target": branch.target, "target_type": branch.target_type,} - if branch.target - else None - for branch in branches - } - - return { - "id": snapshot_id, - "branches": branches_d, - "next_branch": last_branch, - } + return PartialBranches( + id=snapshot_id, + branches={ + branch.name: None + if branch.target is None + else SnapshotBranch( + target=branch.target, target_type=TargetType(branch.target_type) + ) + for branch in branches + }, + next_branch=last_branch, + ) def snapshot_get_random(self) -> Sha1Git: return self._cql_runner.snapshot_get_random().id diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -55,7 +55,12 @@ Sha1Git, ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex -from swh.storage.interface import ListOrder, PagedResult, VISIT_STATUSES +from swh.storage.interface import ( + ListOrder, + PagedResult, + PartialBranches, + VISIT_STATUSES, +) from swh.storage.objstorage import ObjStorage from swh.storage.utils import now @@ -593,7 +598,17 @@ yield id def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: - return self.snapshot_get_branches(snapshot_id) + d = self.snapshot_get_branches(snapshot_id) + if d is None: + return None + return { + "id": d["id"], + "branches": { + name: branch.to_dict() if branch else None + for (name, branch) in d["branches"].items() + }, + "next_branch": d["next_branch"], + } def snapshot_get_by_origin_visit( self, origin: str, visit: int @@ -627,7 +642,7 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[PartialBranches]: snapshot = self._snapshots.get(snapshot_id) if snapshot is None: return None @@ -659,16 +674,9 @@ else: next_branch = sorted_branch_names[to_index] - branches = { - name: branch.to_dict() if branch else None - for (name, branch) in branches.items() - } - - return { - "id": snapshot_id, - "branches": branches, - "next_branch": next_branch, - } + return PartialBranches( + id=snapshot_id, branches=branches, next_branch=next_branch, + ) def snapshot_get_random(self) -> Sha1Git: return random.choice(list(self._snapshots)) diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -8,6 +8,7 @@ from enum import Enum from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from typing_extensions import TypedDict from swh.core.api import remote_api_endpoint from swh.core.api.classes import PagedResult as CorePagedResult @@ -22,6 +23,7 @@ Release, Snapshot, SkippedContent, + SnapshotBranch, MetadataAuthority, MetadataAuthorityType, MetadataFetcher, @@ -39,6 +41,19 @@ DESC = "desc" +class PartialBranches(TypedDict): + """Type of the dictionary returned by snapshot_get_branches""" + + id: Sha1Git + """Identifier of the snapshot""" + branches: Dict[bytes, Optional[SnapshotBranch]] + """A dict of branches contained in the snapshot + whose keys are the branches' names""" + next_branch: Optional[bytes] + """The name of the first branch not returned or :const:`None` if + the snapshot has less than the request number of branches.""" + + TResult = TypeVar("TResult") PagedResult = CorePagedResult[TResult, str] @@ -720,7 +735,7 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[PartialBranches]: """Get the content, possibly partial, of a snapshot with the given id The branches of the snapshot are iterated in the lexicographical diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -40,6 +40,8 @@ Sha1, Sha1Git, Snapshot, + SnapshotBranch, + TargetType, SHA1_SIZE, MetadataAuthority, MetadataAuthorityType, @@ -48,7 +50,12 @@ RawExtrinsicMetadata, ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex -from swh.storage.interface import ListOrder, PagedResult, VISIT_STATUSES +from swh.storage.interface import ( + ListOrder, + PagedResult, + PartialBranches, + VISIT_STATUSES, +) from swh.storage.objstorage import ObjStorage from swh.storage.utils import now @@ -736,7 +743,15 @@ def snapshot_get( self, snapshot_id: Sha1Git, db=None, cur=None ) -> Optional[Dict[str, Any]]: - return self.snapshot_get_branches(snapshot_id, db=db, cur=cur) + d = self.snapshot_get_branches(snapshot_id) + return { + "id": d["id"], + "branches": { + name: branch.to_dict() if branch else None + for (name, branch) in d["branches"].items() + }, + "next_branch": d["next_branch"], + } @timed @db_transaction(statement_timeout=2000) @@ -767,13 +782,9 @@ target_types: Optional[List[str]] = None, db=None, cur=None, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[PartialBranches]: if snapshot_id == EMPTY_SNAPSHOT_ID: - return { - "id": snapshot_id, - "branches": {}, - "next_branch": None, - } + return PartialBranches(id=snapshot_id, branches={}, next_branch=None,) branches = {} next_branch = None @@ -787,24 +798,27 @@ cur=cur, ) ) - for branch in fetched_branches[:branches_count]: - branch = dict(zip(db.snapshot_get_cols, branch)) - del branch["snapshot_id"] - name = branch.pop("name") - if branch == {"target": None, "target_type": None}: + for row in fetched_branches[:branches_count]: + branch_d = dict(zip(db.snapshot_get_cols, row)) + del branch_d["snapshot_id"] + name = branch_d.pop("name") + if branch_d["target"] is None and branch_d["target_type"] is None: branch = None + else: + assert branch_d["target_type"] is not None + branch = SnapshotBranch( + target=branch_d["target"], + target_type=TargetType(branch_d["target_type"]), + ) branches[name] = branch if len(fetched_branches) > branches_count: - branch = dict(zip(db.snapshot_get_cols, fetched_branches[-1])) - next_branch = branch["name"] + next_branch = dict(zip(db.snapshot_get_cols, fetched_branches[-1]))["name"] if branches: - return { - "id": snapshot_id, - "branches": branches, - "next_branch": next_branch, - } + return PartialBranches( + id=snapshot_id, branches=branches, next_branch=next_branch, + ) return None diff --git a/swh/storage/tests/algos/test_snapshot.py b/swh/storage/tests/algos/test_snapshot.py --- a/swh/storage/tests/algos/test_snapshot.py +++ b/swh/storage/tests/algos/test_snapshot.py @@ -31,7 +31,7 @@ swh_storage.snapshot_add([snapshot]) returned_snapshot = snapshot_get_all_branches(swh_storage, snapshot.id) - assert snapshot.to_dict() == returned_snapshot + assert snapshot == returned_snapshot @given(branch_name=branch_names(), branch_target=branch_targets(only_objects=True)) @@ -45,7 +45,7 @@ swh_storage.snapshot_add([snapshot]) returned_snapshot = snapshot_get_all_branches(swh_storage, snapshot.id) - assert snapshot.to_dict() == returned_snapshot + assert snapshot == returned_snapshot def test_snapshot_get_latest_none(swh_storage, sample_data): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -37,6 +37,7 @@ Release, Revision, Snapshot, + TargetType, ) from swh.model.hypothesis_strategies import objects from swh.storage import get_storage @@ -2618,7 +2619,7 @@ swh_storage.snapshot_add([complete_snapshot]) snp_id = complete_snapshot.id - branches = complete_snapshot.to_dict()["branches"] + branches = complete_snapshot.branches branch_names = list(sorted(branches)) # Test branch_from @@ -2686,7 +2687,7 @@ ) snp_id = complete_snapshot.id - branches = complete_snapshot.to_dict()["branches"] + branches = complete_snapshot.branches snapshot = swh_storage.snapshot_get_branches( snp_id, target_types=["release", "revision"] @@ -2697,7 +2698,7 @@ "branches": { name: tgt for name, tgt in branches.items() - if tgt and tgt["target_type"] in ["release", "revision"] + if tgt and tgt.target_type in [TargetType.RELEASE, TargetType.REVISION] }, "next_branch": None, } @@ -2711,7 +2712,7 @@ "branches": { name: tgt for name, tgt in branches.items() - if tgt and tgt["target_type"] == "alias" + if tgt and tgt.target_type == TargetType.ALIAS }, "next_branch": None, } @@ -2724,7 +2725,7 @@ swh_storage.snapshot_add([complete_snapshot]) snp_id = complete_snapshot.id - branches = complete_snapshot.to_dict()["branches"] + branches = complete_snapshot.branches branch_names = list(sorted(branches)) # Test branch_from