diff --git a/swh/loader/git/loader.py b/swh/loader/git/loader.py --- a/swh/loader/git/loader.py +++ b/swh/loader/git/loader.py @@ -10,7 +10,7 @@ import os import pickle import sys -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, List, Optional, Union import dulwich.client from dulwich.object_store import ObjectStoreGraphWalker @@ -33,7 +33,7 @@ self.storage = storage self._parents_cache = {} - self._type_cache = {} + self._type_cache: Dict[bytes, TargetType] = {} self.ignore_history = ignore_history @@ -136,7 +136,27 @@ return list(ret) - def get_stored_objects(self, objects): + def _get_stored_objects_batch( + self, query + ) -> Dict[bytes, List[Dict[str, Union[bytes, TargetType]]]]: + results = self.storage.object_find_by_sha1_git( + self._encode_for_storage(query) + ) + ret: Dict[bytes, List[Dict[str, Union[bytes, TargetType]]]] = {} + for (id, objects) in results.items(): + assert id not in ret + ret[id] = [ + { + 'sha1_git': obj['sha1_git'], + 'type': TargetType(obj['type']), + } + for obj in objects + ] + return ret + + def get_stored_objects( + self, objects + ) -> Dict[bytes, List[Dict[str, Union[bytes, TargetType]]]]: """Find which of these objects were stored in the archive. Do the request in packets to avoid a server timeout. @@ -146,23 +166,15 @@ packet_size = 1000 - ret = {} + ret: Dict[bytes, List[Dict[str, Union[bytes, TargetType]]]] = {} query = [] for object in objects: query.append(object) if len(query) >= packet_size: - ret.update( - self.storage.object_find_by_sha1_git( - self._encode_for_storage(query) - ) - ) + ret.update(self._get_stored_objects_batch(query)) query = [] if query: - ret.update( - self.storage.object_find_by_sha1_git( - self._encode_for_storage(query) - ) - ) + ret.update(self._get_stored_objects_batch(query)) return ret def find_remote_ref_types_in_swh( @@ -176,12 +188,14 @@ """ all_objs = set(remote_refs.values()) - set(self._type_cache) - type_by_id = {} + type_by_id: Dict[bytes, TargetType] = {} for id, objs in self.get_stored_objects(all_objs).items(): id = hashutil.hash_to_bytehex(id) if objs: - type_by_id[id] = objs[0]['type'] + type_ = objs[0]['type'] + assert isinstance(type_, TargetType) + type_by_id[id] = type_ self._type_cache.update(type_by_id)