diff --git a/swh/vault/cookers/git_bare.py b/swh/vault/cookers/git_bare.py --- a/swh/vault/cookers/git_bare.py +++ b/swh/vault/cookers/git_bare.py @@ -25,7 +25,7 @@ import subprocess import tarfile import tempfile -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, Iterator, List, NoReturn, Optional, Set, Tuple import zlib from swh.core.api.classes import stream_results_optional @@ -40,6 +40,8 @@ Revision, RevisionType, Sha1Git, + Snapshot, + SnapshotBranch, TargetType, TimestampWithTimezone, ) @@ -63,9 +65,21 @@ SNAPSHOT = "snapshot" +def assert_never(value: NoReturn, msg) -> NoReturn: + """mypy makes sure this function is never called, through exhaustive checking + of ``value`` in the parent function. + + See https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks + for details. + """ + assert False, msg + + class GitBareCooker(BaseVaultCooker): use_fsck = True + obj_type: RootObjectType + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.obj_type = RootObjectType(self.bundle_type.split("_")[0]) @@ -73,15 +87,15 @@ def cache_type_key(self) -> str: return self.bundle_type - def check_exists(self): - if self.obj_type == RootObjectType.REVISION: + def check_exists(self) -> bool: + if self.obj_type is RootObjectType.REVISION: return not list(self.storage.revision_missing([self.obj_id])) - elif self.obj_type == RootObjectType.DIRECTORY: + elif self.obj_type is RootObjectType.DIRECTORY: return not list(self.storage.directory_missing([self.obj_id])) - elif self.obj_type == RootObjectType.SNAPSHOT: + elif self.obj_type is RootObjectType.SNAPSHOT: return not list(self.storage.snapshot_missing([self.obj_id])) else: - assert False, f"Unexpected root object type: {self.obj_type}" + assert_never(self.obj_type, f"Unexpected root object type: {self.obj_type}") def obj_swhid(self) -> identifiers.CoreSWHID: return identifiers.CoreSWHID( @@ -226,7 +240,7 @@ for (branch_name, branch) in branches } else: - assert False, f"Unexpected root object type: {self.obj_type}" + assert_never(self.obj_type, f"Unexpected root object type: {self.obj_type}") for (ref_name, ref_target) in refs.items(): path = os.path.join(self.gitdir.encode(), ref_name) @@ -263,14 +277,14 @@ return True def push_subgraph(self, obj_type: RootObjectType, obj_id) -> None: - if self.obj_type == RootObjectType.REVISION: + if self.obj_type is RootObjectType.REVISION: self.push_revision_subgraph(obj_id) - elif self.obj_type == RootObjectType.DIRECTORY: + elif self.obj_type is RootObjectType.DIRECTORY: self._push(self._dir_stack, [obj_id]) - elif self.obj_type == RootObjectType.SNAPSHOT: + elif self.obj_type is RootObjectType.SNAPSHOT: self.push_snapshot_subgraph(obj_id) else: - assert False, f"Unexpected root object type: {self.obj_type}" + assert_never(self.obj_type, f"Unexpected root object type: {self.obj_type}") def load_objects(self) -> None: while self._rel_stack or self._rev_stack or self._dir_stack or self._cnt_stack: @@ -352,25 +366,27 @@ object_type=identifiers.ObjectType.SNAPSHOT, object_id=obj_id, ) try: - swhids = map( + swhids: Iterable[identifiers.CoreSWHID] = map( identifiers.CoreSWHID.from_string, self.graph.visit_nodes(str(obj_swhid), edges="snp:*,rel:*,rev:rev"), ) for swhid in swhids: - if swhid.object_type == identifiers.ObjectType.REVISION: + if swhid.object_type is identifiers.ObjectType.REVISION: revision_ids.append(swhid.object_id) - elif swhid.object_type == identifiers.ObjectType.RELEASE: + elif swhid.object_type is identifiers.ObjectType.RELEASE: release_ids.append(swhid.object_id) - elif swhid.object_type == identifiers.ObjectType.DIRECTORY: + elif swhid.object_type is identifiers.ObjectType.DIRECTORY: directory_ids.append(swhid.object_id) - elif swhid.object_type == identifiers.ObjectType.CONTENT: + elif swhid.object_type is identifiers.ObjectType.CONTENT: content_ids.append(swhid.object_id) - elif swhid.object_type == identifiers.ObjectType.SNAPSHOT: + elif swhid.object_type is identifiers.ObjectType.SNAPSHOT: assert ( swhid.object_id == obj_id ), f"Snapshot {obj_id.hex()} references a different snapshot" else: - assert False, f"Unexpected SWHID object type: {swhid}" + assert_never( + swhid.object_type, f"Unexpected SWHID object type: {swhid}" + ) except GraphArgumentException as e: logger.info( "Snapshot %s not found in swh-graph, falling back to fetching " @@ -387,26 +403,35 @@ # TODO: when self.graph is available and supports edge labels, use it # directly to get branch names. - snapshot = snapshot_get_all_branches(self.storage, obj_id) + snapshot: Optional[Snapshot] = snapshot_get_all_branches(self.storage, obj_id) assert snapshot, "Unknown snapshot" # should have been caught by check_exists() for branch in snapshot.branches.values(): if not loaded_from_graph: if branch is None: logging.warning("Dangling branch: %r", branch) - elif branch.target_type == TargetType.REVISION: + continue + assert isinstance(branch, SnapshotBranch) # for mypy + if branch.target_type is TargetType.REVISION: self.push_revision_subgraph(branch.target) - elif branch.target_type == TargetType.RELEASE: + elif branch.target_type is TargetType.RELEASE: self.push_releases_subgraphs([branch.target]) - elif branch.target_type == TargetType.ALIAS: + elif branch.target_type is TargetType.ALIAS: # Nothing to do, this for loop also iterates on the target branch # (if it exists) pass - elif branch.target_type == TargetType.DIRECTORY: + elif branch.target_type is TargetType.DIRECTORY: self._push(self._dir_stack, [branch.target]) - elif branch.target_type == TargetType.CONTENT: + elif branch.target_type is TargetType.CONTENT: self._push(self._cnt_stack, [branch.target]) + elif branch.target_type is TargetType.SNAPSHOT: + if swhid.object_id != obj_id: + raise NotImplementedError( + f"{swhid} has a snapshot as a branch." + ) else: - raise NotImplementedError(f"{branch.target_type} branches") + assert_never( + branch.target_type, f"Unexpected target type: {self.obj_type}" + ) self.write_refs(snapshot=snapshot) @@ -445,15 +470,28 @@ """Given a list of release ids, loads these releases and adds their target to the list of objects to visit""" for release in self.load_releases(obj_ids): - if release.target_type == ObjectType.REVISION: + if release.target_type is ObjectType.REVISION: assert release.target, "{release.swhid(}) has no target" self.push_revision_subgraph(release.target) - elif release.target_type == ObjectType.DIRECTORY: + elif release.target_type is ObjectType.DIRECTORY: assert release.target, "{release.swhid(}) has no target" self._push(self._dir_stack, [release.target]) - else: + elif release.target_type is ObjectType.CONTENT: raise NotImplementedError( - f"{release.swhid()} targets {release.target_type}" + f"{release.swhid()} targets a content: {release.target!r}" + ) + elif release.target_type is ObjectType.RELEASE: + raise NotImplementedError( + f"{release.swhid()} targets another release: {release.target!r}" + ) + elif release.target_type is ObjectType.SNAPSHOT: + raise NotImplementedError( + f"{release.swhid()} targets a snapshot: {release.target!r}" + ) + else: + assert_never( + release.target_type, + f"Unexpected release target type: {release.target_type}", ) def write_release_node(self, release: Dict[str, Any]) -> bool: @@ -507,11 +545,12 @@ elif content.status == "hidden": self.write_content(obj_id, HIDDEN_MESSAGE) self._expect_mismatched_object_error(obj_id) + elif content.status == "absent": + assert False, f"content_get returned absent content {content.swhid()}" else: - assert False, ( - f"unexpected status {content.status!r} " - f"for content {hash_to_hex(content.sha1_git)}" - ) + # TODO: When content.status will have type Literal, replace this with + # assert_never + assert False, f"{content.swhid} has status: {content.status!r}" contents_and_data: Iterator[Tuple[Content, Optional[bytes]]] if self.objstorage is None: