diff --git a/swh/loader/git/loader.py b/swh/loader/git/loader.py --- a/swh/loader/git/loader.py +++ b/swh/loader/git/loader.py @@ -40,11 +40,23 @@ logger = logging.getLogger(__name__) +DEFAULT_NUMBER_IDS_TO_FETCH = 2000 + + +def do_print_progress(msg: bytes) -> None: + sys.stderr.buffer.write(msg) + sys.stderr.flush() + + class RepoRepresentation: """Repository representation for a Software Heritage origin.""" def __init__( - self, storage, base_snapshot: Optional[Snapshot] = None, ignore_history=False + self, + storage, + base_snapshot: Optional[Snapshot] = None, + ignore_history=False, + limit: int = DEFAULT_NUMBER_IDS_TO_FETCH, ): self.storage = storage self.ignore_history = ignore_history @@ -55,15 +67,26 @@ self.base_snapshot = Snapshot(branches={}) self.heads: Set[HexBytes] = set() + self.wanted_refs: Optional[List[HexBytes]] = None + # Pagination index + self.index: int = 0 + self.limit = limit + self.walker = ObjectStoreGraphWalker(self.heads, self.get_parents) def get_parents(self, commit: bytes) -> List[bytes]: """This method should return the list of known parents""" return [] def graph_walker(self) -> ObjectStoreGraphWalker: - return ObjectStoreGraphWalker(self.heads, self.get_parents) + return self.walker - def determine_wants(self, refs: Dict[bytes, HexBytes]) -> List[HexBytes]: + def wanted_refs_fetched(self) -> bool: + """Did we fetch all wanted refs?""" + return self.wanted_refs is not None and self.index > len(self.wanted_refs) + + def determine_wants( + self, refs: Dict[bytes, HexBytes], depth=None + ) -> List[HexBytes]: """Get the list of bytehex sha1s that the git loader should fetch. This compares the remote refs sent by the server with the base snapshot @@ -73,27 +96,44 @@ if not refs: return [] - # Cache existing heads - local_heads: Set[HexBytes] = set() - for branch_name, branch in self.base_snapshot.branches.items(): - if not branch or branch.target_type == TargetType.ALIAS: - continue - local_heads.add(hashutil.hash_to_hex(branch.target).encode()) + if not self.wanted_refs: + # We'll compute all wanted_refs to ingest but we'll return it by batch of + # limit - self.heads = local_heads + # Cache existing heads + local_heads: Set[HexBytes] = set() + for branch_name, branch in self.base_snapshot.branches.items(): + if not branch or branch.target_type == TargetType.ALIAS: + continue + local_heads.add(hashutil.hash_to_hex(branch.target).encode()) - # Get the remote heads that we want to fetch - remote_heads: Set[HexBytes] = set() - for ref_name, ref_target in refs.items(): - if utils.ignore_branch_name(ref_name): - continue - remote_heads.add(ref_target) + self.heads = local_heads + + # Get the remote heads that we want to fetch + remote_heads: Set[HexBytes] = set() + for ref_name, ref_target in refs.items(): + if utils.ignore_branch_name(ref_name): + continue + remote_heads.add(ref_target) + + logger.debug("local_heads_count=%s", len(local_heads)) + logger.debug("remote_heads_count=%s", len(remote_heads)) + wanted_refs = list(remote_heads - local_heads) + logger.debug("wanted_refs_count=%s", len(wanted_refs)) + self.wanted_refs = wanted_refs + + start = self.index + # wanted_refs_length = len(self.wanted_refs) + # self.index = min(self.index + limit, wanted_refs_length) + self.index += self.limit - logger.debug("local_heads_count=%s", len(local_heads)) - logger.debug("remote_heads_count=%s", len(remote_heads)) - wanted_refs = list(remote_heads - local_heads) - logger.debug("wanted_refs_count=%s", len(wanted_refs)) - return wanted_refs + assert self.wanted_refs is not None + asked_refs = self.wanted_refs[start : self.index] + if start > 0: + # modify walker heads to not resolve again seen refs + self.walker.heads.update(self.wanted_refs[start - self.index : start]) + logger.debug("asked_refs_count=%s", len(asked_refs)) + return asked_refs @dataclass @@ -102,6 +142,7 @@ symbolic_refs: Dict[bytes, HexBytes] pack_buffer: SpooledTemporaryFile pack_size: int + continue_loading: bool class GitLoader(DVCSLoader): @@ -120,6 +161,8 @@ temp_file_cutoff: int = 100 * 1024 * 1024, save_data_path: Optional[str] = None, max_content_size: Optional[int] = None, + # Number of ids per packfile + packfile_chunk_size: int = DEFAULT_NUMBER_IDS_TO_FETCH, ): """Initialize the bulk updater. @@ -144,6 +187,7 @@ self.remote_refs: Dict[bytes, HexBytes] = {} self.symbolic_refs: Dict[bytes, HexBytes] = {} self.ref_object_types: Dict[bytes, Optional[TargetType]] = {} + self.packfile_chunk_size = packfile_chunk_size def fetch_pack_from_origin( self, @@ -221,6 +265,7 @@ symbolic_refs=utils.filter_refs(symbolic_refs), pack_buffer=pack_buffer, pack_size=pack_size, + continue_loading=not self.base_repo.wanted_refs_fetched(), ) def prepare_origin_visit(self) -> None: @@ -248,23 +293,22 @@ else: self.base_snapshot = Snapshot(branches={}) - def fetch_data(self) -> bool: - assert self.origin is not None - - base_repo = self.repo_representation( + self.base_repo = self.repo_representation( storage=self.storage, base_snapshot=self.base_snapshot, ignore_history=self.ignore_history, + limit=self.packfile_chunk_size, ) - def do_progress(msg: bytes) -> None: - sys.stderr.buffer.write(msg) - sys.stderr.flush() + def fetch_data(self) -> bool: + continue_loading = False + assert self.origin is not None try: fetch_info = self.fetch_pack_from_origin( - self.origin.url, base_repo, do_progress + self.origin.url, self.base_repo, do_print_progress ) + continue_loading = fetch_info.continue_loading except NotGitRepository as e: raise NotFound(e) except GitProtocolError as e: @@ -292,17 +336,20 @@ "Protocol used for communication: %s", "dumb" if self.dumb else "smart" ) if self.dumb: - self.dumb_fetcher = dumb.GitObjectsFetcher(self.origin_url, base_repo) + self.dumb_fetcher = dumb.GitObjectsFetcher(self.origin_url, self.base_repo) self.dumb_fetcher.fetch_object_ids() self.remote_refs = utils.filter_refs(self.dumb_fetcher.refs) # type: ignore self.symbolic_refs = self.dumb_fetcher.head else: self.pack_buffer = fetch_info.pack_buffer self.pack_size = fetch_info.pack_size - self.remote_refs = fetch_info.remote_refs - self.symbolic_refs = fetch_info.symbolic_refs + self.remote_refs.update(fetch_info.remote_refs) + self.symbolic_refs.update(fetch_info.symbolic_refs) - self.ref_object_types = {sha1: None for sha1 in self.remote_refs.values()} + for sha1 in self.remote_refs.values(): + if sha1 in self.ref_object_types: + continue + self.ref_object_types[sha1] = None logger.info( "Listed %d refs for repo %s", @@ -315,8 +362,7 @@ }, ) - # No more data to fetch - return False + return continue_loading def save_data(self) -> None: """Store a pack for archival""" @@ -341,6 +387,11 @@ with open(os.path.join(pack_dir, refs_name), "xb") as f: pickle.dump(self.remote_refs, f) + def store_data(self, create_snapshot: bool = False): + super().store_data(create_snapshot) + if not self.dumb: + self.pack_buffer.close() + def iter_objects(self, object_type: bytes) -> Iterator[ShaFile]: """Read all the objects of type `object_type` from the packfile""" if self.dumb: