diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -16,3 +16,6 @@ [mypy-pytest.*] ignore_missing_imports = True + +[mypy-swh.loader.*] +ignore_missing_imports = True diff --git a/requirements-swh.txt b/requirements-swh.txt --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,5 +1,5 @@ swh.core >= 0.0.7 -swh.loader.core >= 0.3.0 +swh.loader.core >= 0.4.0 swh.model >= 0.3.0 swh.scheduler >= 0.0.39 -swh.storage >= 0.3.0 +swh.storage >= 0.6.0 diff --git a/swh/loader/git/from_disk.py b/swh/loader/git/from_disk.py --- a/swh/loader/git/from_disk.py +++ b/swh/loader/git/from_disk.py @@ -40,9 +40,11 @@ self.origin_url = url self.visit_date = visit_date self.directory = directory + self.last_visit = None def prepare_origin_visit(self, *args, **kwargs): self.origin = Origin(url=self.origin_url) + self.last_visit = self.storage.origin_visit_get_latest(self.origin.url) def prepare(self, *args, **kwargs): self.repo = dulwich.repo.Repo(self.directory) @@ -139,13 +141,15 @@ def fetch_data(self): """Fetch the data from the data source""" - previous_visit = self.storage.origin_visit_get_latest( - self.origin.url, require_snapshot=True - ) - if previous_visit: - self.previous_snapshot_id = previous_visit["snapshot"] - else: + if self.last_visit is None: self.previous_snapshot_id = None + else: + visit_id = self.last_visit["visit"] + assert visit_id is not None + visit_status = self.storage.origin_visit_status_get_latest( + self.origin.url, visit_id, require_snapshot=True + ) + self.previous_snapshot_id = visit_status.snapshot if visit_status else None type_to_ids = defaultdict(list) for oid in self.iter_objects(): diff --git a/swh/loader/git/loader.py b/swh/loader/git/loader.py --- a/swh/loader/git/loader.py +++ b/swh/loader/git/loader.py @@ -134,6 +134,7 @@ # state initialized in fetch_data self.remote_refs: Dict[bytes, bytes] = {} self.symbolic_refs: Dict[bytes, bytes] = {} + self.last_visit = None def fetch_pack_from_origin( self, @@ -208,11 +209,18 @@ def prepare_origin_visit(self, *args, **kwargs) -> None: self.visit_date = datetime.datetime.now(tz=datetime.timezone.utc) self.origin = Origin(url=self.origin_url) + self.last_visit = self.storage.origin_visit_get_latest(self.origin.url) def get_full_snapshot(self, origin_url) -> Optional[Snapshot]: - visit = self.storage.origin_visit_get_latest(origin_url, require_snapshot=True) - if visit and visit["snapshot"]: - snapshot = snapshot_get_all_branches(self.storage, visit["snapshot"]) + if self.last_visit is None: + return None + visit_id = self.last_visit["visit"] + assert visit_id is not None + visit_status = self.storage.origin_visit_status_get_latest( + self.origin_url, visit_id, require_snapshot=True + ) + if visit_status and visit_status.snapshot: + snapshot = snapshot_get_all_branches(self.storage, visit_status.snapshot) else: snapshot = None if snapshot is None: diff --git a/swh/loader/git/tests/test_from_disk.py b/swh/loader/git/tests/test_from_disk.py --- a/swh/loader/git/tests/test_from_disk.py +++ b/swh/loader/git/tests/test_from_disk.py @@ -12,6 +12,7 @@ from swh.model.hashutil import hash_to_bytes from swh.loader.core.tests import BaseLoaderTest +from swh.loader.tests.common import assert_last_visit_matches from swh.loader.git.from_disk import GitLoaderFromDisk as OrigGitLoaderFromDisk from swh.loader.git.from_disk import GitLoaderFromArchive as OrigGitLoaderFromArchive @@ -189,9 +190,13 @@ self.assertEqual(self.loader.load_status(), {"status": "eventful"}) self.assertEqual(self.loader.visit_status(), "full") - visit = self.storage.origin_visit_get_latest(self.repo_url) - self.assertEqual(visit["snapshot"], hash_to_bytes(SNAPSHOT1["id"])) - self.assertEqual(visit["status"], "full") + assert_last_visit_matches( + self.storage, + self.repo_url, + status="full", + type="git", + snapshot=hash_to_bytes(SNAPSHOT1["id"]), + ) def test_load_unchanged(self): """Checks loading a repository a second time does not add @@ -199,17 +204,25 @@ res = self.load() self.assertEqual(res["status"], "eventful") - visit = self.storage.origin_visit_get_latest(self.repo_url) - self.assertEqual(visit["snapshot"], hash_to_bytes(SNAPSHOT1["id"])) - self.assertEqual(visit["status"], "full") + assert_last_visit_matches( + self.storage, + self.repo_url, + status="full", + type="git", + snapshot=hash_to_bytes(SNAPSHOT1["id"]), + ) res = self.load() self.assertEqual(res["status"], "uneventful") self.assertCountSnapshots(1) - visit = self.storage.origin_visit_get_latest(self.repo_url) - self.assertEqual(visit["snapshot"], hash_to_bytes(SNAPSHOT1["id"])) - self.assertEqual(visit["status"], "full") + assert_last_visit_matches( + self.storage, + self.repo_url, + status="full", + type="git", + snapshot=hash_to_bytes(SNAPSHOT1["id"]), + ) class DirGitLoaderTest(BaseDirGitLoaderFromDiskTest, GitLoaderFromDiskTests): @@ -252,11 +265,12 @@ self.assertEqual(self.loader.load_status(), {"status": "eventful"}) self.assertEqual(self.loader.visit_status(), "full") - visit = self.storage.origin_visit_get_latest(self.repo_url) - self.assertIsNotNone(visit["snapshot"]) - self.assertEqual(visit["status"], "full") + visit_status = assert_last_visit_matches( + self.storage, self.repo_url, status="full", type="git" + ) + self.assertIsNotNone(visit_status.snapshot) - snapshot_id = visit["snapshot"] + snapshot_id = visit_status.snapshot snapshot = self.storage.snapshot_get(snapshot_id) branches = snapshot["branches"] assert branches[b"HEAD"] == { @@ -304,11 +318,12 @@ self.assertEqual(self.loader.load_status(), {"status": "eventful"}) self.assertEqual(self.loader.visit_status(), "full") - visit = self.storage.origin_visit_get_latest(self.repo_url) - self.assertIsNotNone(visit["snapshot"]) - self.assertEqual(visit["status"], "full") + visit_status = assert_last_visit_matches( + self.storage, self.repo_url, status="full", type="git" + ) + self.assertIsNotNone(visit_status.snapshot) - merge_snapshot_id = visit["snapshot"] + merge_snapshot_id = visit_status.snapshot assert merge_snapshot_id != snapshot_id merge_snapshot = self.storage.snapshot_get(merge_snapshot_id) @@ -363,9 +378,13 @@ assert self.loader.load_status() == {"status": "eventful"} assert self.loader.visit_status() == "full" - visit = self.storage.origin_visit_get_latest(self.repo_url) - assert visit["snapshot"] == expected_snapshot.id - assert visit["status"] == "full" + assert_last_visit_matches( + self.storage, + self.repo_url, + status="full", + type="git", + snapshot=expected_snapshot.id, + ) def test_load_dangling_symref(self): with open(os.path.join(self.destination_path, ".git/HEAD"), "wb") as f: @@ -380,10 +399,11 @@ self.assertCountRevisions(7) self.assertCountSnapshots(1) - visit = self.storage.origin_visit_get_latest(self.repo_url) - snapshot_id = visit["snapshot"] + visit_status = assert_last_visit_matches( + self.storage, self.repo_url, status="full", type="git" + ) + snapshot_id = visit_status.snapshot assert snapshot_id is not None - assert visit["status"] == "full" snapshot = self.storage.snapshot_get(snapshot_id) branches = snapshot["branches"]