diff --git a/swh/loader/git/dumb.py b/swh/loader/git/dumb.py --- a/swh/loader/git/dumb.py +++ b/swh/loader/git/dumb.py @@ -16,6 +16,8 @@ from dulwich.pack import Pack, PackData, PackIndex, load_pack_index_file from urllib3.response import HTTPResponse +from swh.loader.git.utils import HexBytes + if TYPE_CHECKING: from .loader import RepoRepresentation @@ -128,7 +130,7 @@ buffer.seek(0) return buffer - def _get_refs(self) -> Dict[bytes, bytes]: + def _get_refs(self) -> Dict[bytes, HexBytes]: refs = {} refs_resp_bytes = self._http_get("info/refs") for ref_line in refs_resp_bytes.readlines(): @@ -136,7 +138,7 @@ refs[ref_name] = ref_target return refs - def _get_head(self) -> Dict[bytes, bytes]: + def _get_head(self) -> Dict[bytes, HexBytes]: head_resp_bytes = self._http_get("HEAD") _, head_target = head_resp_bytes.readline().replace(b"\n", b"").split(b" ") return {b"HEAD": head_target} diff --git a/swh/loader/git/loader.py b/swh/loader/git/loader.py --- a/swh/loader/git/loader.py +++ b/swh/loader/git/loader.py @@ -35,6 +35,7 @@ from swh.storage.interface import StorageInterface from . import converters, dumb, utils +from .utils import HexBytes logger = logging.getLogger(__name__) @@ -53,7 +54,7 @@ else: self.base_snapshot = Snapshot(branches={}) - self.heads: Set[bytes] = set() + self.heads: Set[HexBytes] = set() def get_parents(self, commit: bytes) -> List[bytes]: """This method should return the list of known parents""" @@ -62,7 +63,7 @@ def graph_walker(self) -> ObjectStoreGraphWalker: return ObjectStoreGraphWalker(self.heads, self.get_parents) - def determine_wants(self, refs: Dict[bytes, bytes]) -> List[bytes]: + def determine_wants(self, refs: Dict[bytes, HexBytes]) -> List[HexBytes]: """Get the list of bytehex sha1s that the git loader should fetch. This compares the remote refs sent by the server with the base snapshot @@ -73,7 +74,7 @@ return [] # Cache existing heads - local_heads: Set[bytes] = set() + local_heads: Set[HexBytes] = set() for branch_name, branch in self.base_snapshot.branches.items(): if not branch or branch.target_type == TargetType.ALIAS: continue @@ -82,7 +83,7 @@ self.heads = local_heads # Get the remote heads that we want to fetch - remote_heads: Set[bytes] = set() + remote_heads: Set[HexBytes] = set() for ref_name, ref_target in refs.items(): if utils.ignore_branch_name(ref_name): continue @@ -93,8 +94,8 @@ @dataclass class FetchPackReturn: - remote_refs: Dict[bytes, bytes] - symbolic_refs: Dict[bytes, bytes] + remote_refs: Dict[bytes, HexBytes] + symbolic_refs: Dict[bytes, HexBytes] pack_buffer: SpooledTemporaryFile pack_size: int @@ -136,8 +137,8 @@ self.pack_size_bytes = pack_size_bytes self.temp_file_cutoff = temp_file_cutoff # state initialized in fetch_data - self.remote_refs: Dict[bytes, bytes] = {} - self.symbolic_refs: Dict[bytes, bytes] = {} + self.remote_refs: Dict[bytes, HexBytes] = {} + self.symbolic_refs: Dict[bytes, HexBytes] = {} self.ref_object_types: Dict[bytes, Optional[TargetType]] = {} def fetch_pack_from_origin( @@ -283,7 +284,7 @@ logger.debug("Fetching objects with HTTP dumb transfer protocol") self.dumb_fetcher = dumb.GitObjectsFetcher(self.origin_url, base_repo) self.dumb_fetcher.fetch_object_ids() - self.remote_refs = utils.filter_refs(self.dumb_fetcher.refs) + self.remote_refs = utils.filter_refs(self.dumb_fetcher.refs) # type: ignore self.symbolic_refs = self.dumb_fetcher.head else: self.pack_buffer = fetch_info.pack_buffer diff --git a/swh/loader/git/utils.py b/swh/loader/git/utils.py --- a/swh/loader/git/utils.py +++ b/swh/loader/git/utils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017-2020 The Software Heritage developers +# Copyright (C) 2017-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -10,11 +10,14 @@ import os import shutil import tempfile -from typing import Dict, Optional +from typing import Dict, NewType, Optional from swh.core import tarball from swh.model.model import SnapshotBranch +# The hexadecimal representation of the hash in bytes +HexBytes = NewType("HexBytes", bytes) + def init_git_repo_from_archive(project_name, archive_path, root_temp_dir="/tmp"): """Given a path to an archive containing a git repository. @@ -90,10 +93,12 @@ return False -def filter_refs(refs: Dict[bytes, bytes]) -> Dict[bytes, bytes]: +def filter_refs(refs: Dict[bytes, bytes]) -> Dict[bytes, HexBytes]: """Filter the refs dictionary using the policy set in `ignore_branch_name`""" return { - name: target for name, target in refs.items() if not ignore_branch_name(name) + name: HexBytes(target) + for name, target in refs.items() + if not ignore_branch_name(name) }