diff --git a/swh/loader/git/dumb.py b/swh/loader/git/dumb.py --- a/swh/loader/git/dumb.py +++ b/swh/loader/git/dumb.py @@ -102,7 +102,7 @@ # commit not already seen in the current load parent not in self.objects[b"commit"] # commit not already archived by a previous load - and parent not in self.base_repo.heads + and parent not in self.base_repo.local_heads ): commit_objects.append(cast(Commit, self._get_git_object(parent))) self.objects[b"commit"].add(parent) diff --git a/swh/loader/git/loader.py b/swh/loader/git/loader.py --- a/swh/loader/git/loader.py +++ b/swh/loader/git/loader.py @@ -60,14 +60,18 @@ else: self.base_snapshots = [] - self.heads: Set[HexBytes] = set() - - def get_parents(self, commit: bytes) -> List[bytes]: - """This method should return the list of known parents""" - return [] + # Cache existing heads + self.local_heads: Set[HexBytes] = set() + heads_logger.debug("Heads known in the archive:") + for base_snapshot in self.base_snapshots: + for branch_name, branch in base_snapshot.branches.items(): + if not branch or branch.target_type == TargetType.ALIAS: + continue + heads_logger.debug(" %r: %s", branch_name, branch.target.hex()) + self.local_heads.add(HexBytes(hashutil.hash_to_bytehex(branch.target))) def graph_walker(self) -> ObjectStoreGraphWalker: - return ObjectStoreGraphWalker(self.heads, self.get_parents) + return ObjectStoreGraphWalker(self.local_heads, get_parents=lambda commit: []) def determine_wants(self, refs: Dict[bytes, HexBytes]) -> List[HexBytes]: """Get the list of bytehex sha1s that the git loader should fetch. @@ -84,18 +88,6 @@ for name, value in refs.items(): heads_logger.debug(" %r: %s", name, value.decode()) - heads_logger.debug("Heads known in the archive:") - # Cache existing heads - local_heads: Set[HexBytes] = set() - for base_snapshot in self.base_snapshots: - for branch_name, branch in base_snapshot.branches.items(): - if not branch or branch.target_type == TargetType.ALIAS: - continue - heads_logger.debug(" %r: %s", branch_name, branch.target.hex()) - local_heads.add(HexBytes(hashutil.hash_to_bytehex(branch.target))) - - self.heads = local_heads - # Get the remote heads that we want to fetch remote_heads: Set[HexBytes] = set() for ref_name, ref_target in refs.items(): @@ -103,9 +95,9 @@ continue remote_heads.add(ref_target) - logger.debug("local_heads_count=%s", len(local_heads)) + logger.debug("local_heads_count=%s", len(self.local_heads)) logger.debug("remote_heads_count=%s", len(remote_heads)) - wanted_refs = list(remote_heads - local_heads) + wanted_refs = list(remote_heads - self.local_heads) logger.debug("wanted_refs_count=%s", len(wanted_refs)) if self.statsd is not None: self.statsd.histogram( @@ -115,7 +107,7 @@ ) self.statsd.histogram( "git_known_refs_percent", - len(local_heads & remote_heads) / len(remote_heads), + len(self.local_heads & remote_heads) / len(remote_heads), tags={}, ) return wanted_refs diff --git a/swh/loader/git/tests/test_loader.py b/swh/loader/git/tests/test_loader.py --- a/swh/loader/git/tests/test_loader.py +++ b/swh/loader/git/tests/test_loader.py @@ -26,7 +26,14 @@ get_stats, prepare_repository_from_archive, ) -from swh.model.model import Origin, OriginVisit, OriginVisitStatus, Snapshot +from swh.model.model import ( + Origin, + OriginVisit, + OriginVisitStatus, + Snapshot, + SnapshotBranch, + TargetType, +) class CommonGitLoaderNotFound: @@ -557,6 +564,56 @@ call("git_known_refs_percent", "h", expected_git_known_refs_percent, {}, 1), ] + def test_load_incremental_negotiation(self): + """Check that the packfile negotiated when running an incremental load only + contains the "new" commits, and not all objects.""" + + snapshot_id = b"\x01" * 20 + now = datetime.datetime.now(tz=datetime.timezone.utc) + + def ovgl(origin_url, allowed_statuses, require_snapshot, type): + if origin_url == f"base://{self.repo_url}": + return OriginVisit(origin=origin_url, visit=42, date=now, type="git") + else: + return None + + self.loader.storage.origin_visit_get_latest.side_effect = ovgl + self.loader.storage.origin_visit_status_get_latest.return_value = ( + OriginVisitStatus( + origin=f"base://{self.repo_url}", + visit=42, + snapshot=snapshot_id, + date=now, + status="full", + ) + ) + self.loader.storage.snapshot_get_branches.return_value = { + "id": snapshot_id, + "branches": { + b"refs/heads/master": SnapshotBranch( + # id of the initial commit in the git repository fixture + target=bytes.fromhex("b6f40292c4e94a8f7e7b4aff50e6c7429ab98e2a"), + target_type=TargetType.REVISION, + ), + }, + "next_branch": None, + } + + res = self.loader.load() + assert res == {"status": "eventful"} + + stats = get_stats(self.loader.storage) + assert stats == { + "content": 3, # instead of 4 for the full repository + "directory": 6, # instead of 7 + "origin": 1, + "origin_visit": 1, + "release": 0, + "revision": 6, # instead of 7 + "skipped_content": 0, + "snapshot": 1, + } + class DumbGitLoaderTestBase(FullGitLoaderTests): """Prepare a git repository to be loaded using the HTTP dumb transfer protocol."""