diff --git a/swh/loader/git/loader.py b/swh/loader/git/loader.py --- a/swh/loader/git/loader.py +++ b/swh/loader/git/loader.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from collections import defaultdict from dataclasses import dataclass import datetime import logging @@ -140,6 +141,7 @@ self.remote_refs: Dict[bytes, HexBytes] = {} self.symbolic_refs: Dict[bytes, HexBytes] = {} self.ref_object_types: Dict[bytes, Optional[TargetType]] = {} + self.objects: Dict[bytes, Set[ShaFile]] = {} def fetch_pack_from_origin( self, @@ -291,6 +293,9 @@ self.pack_size = fetch_info.pack_size self.remote_refs = fetch_info.remote_refs self.symbolic_refs = fetch_info.symbolic_refs + # Read the pack file once and group objects per type so we can drop the + # reference early + self.objects = self.group_objects_per_type() self.ref_object_types = {sha1: None for sha1 in self.remote_refs.values()} @@ -331,18 +336,30 @@ with open(os.path.join(pack_dir, refs_name), "xb") as f: pickle.dump(self.remote_refs, f) + def group_objects_per_type(self) -> Dict[bytes, Set[ShaFile]]: + """Group objects from the repository packfile representation into a dict of key + object_type, values the object ids for that object type. + + It's an implementation detail to release earlier the packfile reference since we + no longer need it after that grouping is done. + + """ + + objs = defaultdict(set) + for obj in PackInflater.for_pack_data( + PackData.from_file(self.pack_buffer, self.pack_size) + ): + objs[obj.type_name].add(obj) + + return objs + def iter_objects(self, object_type: bytes) -> Iterator[ShaFile]: - """Read all the objects of type `object_type` from the packfile""" + """Read all the objects of type `object_type` from the in-memory packfile + representation.""" if self.dumb: yield from self.dumb_fetcher.iter_objects(object_type) else: - self.pack_buffer.seek(0) - for obj in PackInflater.for_pack_data( - PackData.from_file(self.pack_buffer, self.pack_size) - ): - if obj.type_name != object_type: - continue - yield obj + yield from self.objects[object_type] def get_contents(self) -> Iterable[BaseContent]: """Format the blobs from the git repository as swh contents"""