diff --git a/swh/loader/git/loader.py b/swh/loader/git/loader.py --- a/swh/loader/git/loader.py +++ b/swh/loader/git/loader.py @@ -11,7 +11,7 @@ import os import pickle import sys -from typing import Any, Dict, Iterable, List, Optional, Set +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type import dulwich.client from dulwich.object_store import ObjectStoreGraphWalker @@ -127,11 +127,11 @@ def __init__( self, - url, - base_url=None, - ignore_history=False, - repo_representation=RepoRepresentation, - config=None, + url: str, + base_url: Optional[str] = None, + ignore_history: bool = False, + repo_representation: Type[RepoRepresentation] = RepoRepresentation, + config: Optional[Dict[str, Any]] = None, ): """Initialize the bulk updater. @@ -141,6 +141,9 @@ data. """ + if config is None: + config = {} + super().__init__(logging_class="swh.loader.git.BulkLoader", config=config) self.origin_url = url self.base_url = base_url @@ -148,10 +151,15 @@ self.repo_representation = repo_representation # state initialized in fetch_data - self.remote_refs = [] - self.symbolic_refs = {} + self.remote_refs: Dict[bytes, bytes] = {} + self.symbolic_refs: Dict[bytes, bytes] = {} - def fetch_pack_from_origin(self, origin_url, base_snapshot, do_activity): + def fetch_pack_from_origin( + self, + origin_url: str, + base_snapshot: Optional[Snapshot], + do_activity: Callable[[bytes], None], + ) -> FetchPackReturn: """Fetch a pack from the origin""" pack_buffer = BytesIO() @@ -167,7 +175,7 @@ size_limit = self.config["pack_size_bytes"] - def do_pack(data): + def do_pack(data: bytes) -> None: cur_size = pack_buffer.tell() would_write = len(data) if cur_size + would_write > size_limit: @@ -201,7 +209,9 @@ pack_size=pack_size, ) - def list_pack(self, pack_data, pack_size): + def list_pack( + self, pack_data, pack_size + ) -> Tuple[Dict[bytes, bytes], Dict[bytes, Set[bytes]]]: id_to_type = {} type_to_ids = defaultdict(set) @@ -214,7 +224,7 @@ return id_to_type, type_to_ids - def prepare_origin_visit(self, *args, **kwargs): + def prepare_origin_visit(self, *args, **kwargs) -> None: self.visit_date = datetime.datetime.now(tz=datetime.timezone.utc) self.origin = Origin(url=self.origin_url) @@ -228,10 +238,12 @@ return None return Snapshot.from_dict(snapshot) - def prepare(self, *args, **kwargs): + def prepare(self, *args, **kwargs) -> None: + assert self.origin is not None + base_origin_url = origin_url = self.origin.url - prev_snapshot = None + prev_snapshot: Optional[Snapshot] = None if not self.ignore_history: prev_snapshot = self.get_full_snapshot(origin_url) @@ -248,8 +260,10 @@ else: self.base_snapshot = Snapshot(branches={}) - def fetch_data(self): - def do_progress(msg): + def fetch_data(self) -> bool: + assert self.origin is not None + + def do_progress(msg: bytes) -> None: sys.stderr.buffer.write(msg) sys.stderr.flush() @@ -263,13 +277,11 @@ self.remote_refs = fetch_info.remote_refs self.symbolic_refs = fetch_info.symbolic_refs - origin_url = self.origin.url - self.log.info( - "Listed %d refs for repo %s" % (len(self.remote_refs), origin_url), + "Listed %d refs for repo %s" % (len(self.remote_refs), self.origin.url), extra={ "swh_type": "git_repo_list_refs", - "swh_repo": origin_url, + "swh_repo": self.origin.url, "swh_num_refs": len(self.remote_refs), }, ) @@ -283,8 +295,9 @@ # No more data to fetch return False - def save_data(self): + def save_data(self) -> None: """Store a pack for archival""" + assert isinstance(self.visit_date, datetime.datetime) write_size = 8192 pack_dir = self.get_save_data_path() @@ -305,14 +318,14 @@ with open(os.path.join(pack_dir, refs_name), "xb") as f: pickle.dump(self.remote_refs, f) - def get_inflater(self): + def get_inflater(self) -> PackInflater: """Reset the pack buffer and get an object inflater from it""" self.pack_buffer.seek(0) return PackInflater.for_pack_data( PackData.from_file(self.pack_buffer, self.pack_size) ) - def has_contents(self): + def has_contents(self) -> bool: return bool(self.type_to_ids[b"blob"]) def get_content_ids(self) -> Iterable[Dict[str, Any]]: @@ -470,9 +483,9 @@ the one we retrieved at the beginning of the run""" eventful = False - if self.base_snapshot: + if self.base_snapshot and self.snapshot: eventful = self.snapshot.id != self.base_snapshot.id - else: + elif self.snapshot: eventful = bool(self.snapshot.branches) return {"status": ("eventful" if eventful else "uneventful")} @@ -493,7 +506,7 @@ help="Ignore the repository history", default=False, ) - def main(origin_url, base_url, ignore_history): + def main(origin_url: str, base_url: str, ignore_history: bool) -> Dict[str, Any]: loader = GitLoader( origin_url, base_url=base_url, ignore_history=ignore_history, )