diff --git a/requirements-swh.txt b/requirements-swh.txt --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -2,4 +2,4 @@ swh.loader.core >= 0.3.0 swh.model >= 0.3.0 swh.scheduler >= 0.0.39 -swh.storage >= 0.3.0 +swh.storage >= 0.5.0 diff --git a/swh/loader/git/from_disk.py b/swh/loader/git/from_disk.py --- a/swh/loader/git/from_disk.py +++ b/swh/loader/git/from_disk.py @@ -40,9 +40,11 @@ self.origin_url = url self.visit_date = visit_date self.directory = directory + self.last_visit = None def prepare_origin_visit(self, *args, **kwargs): self.origin = Origin(url=self.origin_url) + self.last_visit = self.storage.origin_visit_get_latest(self.origin.url) def prepare(self, *args, **kwargs): self.repo = dulwich.repo.Repo(self.directory) @@ -139,13 +141,15 @@ def fetch_data(self): """Fetch the data from the data source""" - previous_visit = self.storage.origin_visit_get_latest( - self.origin.url, require_snapshot=True - ) - if previous_visit: - self.previous_snapshot_id = previous_visit["snapshot"] - else: + if self.last_visit is None: self.previous_snapshot_id = None + else: + visit_id = self.last_visit["visit"] + assert visit_id is not None + visit_status = self.storage.origin_visit_status_get_latest( + self.origin.url, visit_id, require_snapshot=True + ) + self.previous_snapshot_id = visit_status.snapshot if visit_status else None type_to_ids = defaultdict(list) for oid in self.iter_objects(): diff --git a/swh/loader/git/loader.py b/swh/loader/git/loader.py --- a/swh/loader/git/loader.py +++ b/swh/loader/git/loader.py @@ -134,6 +134,7 @@ # state initialized in fetch_data self.remote_refs: Dict[bytes, bytes] = {} self.symbolic_refs: Dict[bytes, bytes] = {} + self.last_visit = None def fetch_pack_from_origin( self, @@ -208,11 +209,18 @@ def prepare_origin_visit(self, *args, **kwargs) -> None: self.visit_date = datetime.datetime.now(tz=datetime.timezone.utc) self.origin = Origin(url=self.origin_url) + self.last_visit = self.storage.origin_visit_get_latest(self.origin.url) def get_full_snapshot(self, origin_url) -> Optional[Snapshot]: - visit = self.storage.origin_visit_get_latest(origin_url, require_snapshot=True) - if visit and visit["snapshot"]: - snapshot = snapshot_get_all_branches(self.storage, visit["snapshot"]) + if self.last_visit is None: + return None + visit_id = self.last_visit["visit"] + assert visit_id is not None + visit_status = self.storage.origin_visit_status_get_latest( + self.origin_url, visit_id, require_snapshot=True + ) + if visit_status and visit_status.snapshot: + snapshot = snapshot_get_all_branches(self.storage, visit_status.snapshot) else: snapshot = None if snapshot is None: diff --git a/swh/loader/git/tests/test_from_disk.py b/swh/loader/git/tests/test_from_disk.py --- a/swh/loader/git/tests/test_from_disk.py +++ b/swh/loader/git/tests/test_from_disk.py @@ -8,7 +8,9 @@ import dulwich.repo -from swh.model.model import Snapshot, SnapshotBranch, TargetType +from typing import Optional + +from swh.model.model import OriginVisitStatus, Snapshot, SnapshotBranch, TargetType from swh.model.hashutil import hash_to_bytes from swh.loader.core.tests import BaseLoaderTest @@ -19,6 +21,47 @@ from . import TEST_LOADER_CONFIG +def assert_last_visit_ok( + storage, + url: str, + status: str, + type: Optional[str] = None, + snapshot: Optional[bytes] = None, +) -> OriginVisitStatus: + """Ensure a given visit/visit-status is expectedly with status, (optional) type and + (optional) snapshot. + + This returns the last visit_status for that given origin + + Args: + url: Origin url + status: expected status + type: expected_type + + Raises: + AssertionError in case visit or visit status is not found + + Returns: + the visit status for further check during the remaining part of the test. + + """ + visit = storage.origin_visit_get_latest(url) + assert visit is not None, f"Visit should exist for origin {url}" + if type: + assert visit["type"] == type + + visit_id = visit["visit"] + visit_status = storage.origin_visit_status_get_latest(url, visit_id) + assert ( + visit_status is not None + ), f"Visit status should exist for origin {url}, visit {visit_id}" + assert visit_status.status == status + if snapshot: + assert visit_status.snapshot == snapshot + + return visit_status + + class GitLoaderFromArchive(OrigGitLoaderFromArchive): def project_name_from_archive(self, archive_path): # We don't want the project name to be 'resources'. @@ -189,9 +232,13 @@ self.assertEqual(self.loader.load_status(), {"status": "eventful"}) self.assertEqual(self.loader.visit_status(), "full") - visit = self.storage.origin_visit_get_latest(self.repo_url) - self.assertEqual(visit["snapshot"], hash_to_bytes(SNAPSHOT1["id"])) - self.assertEqual(visit["status"], "full") + assert_last_visit_ok( + self.storage, + self.repo_url, + status="full", + type="git", + snapshot=hash_to_bytes(SNAPSHOT1["id"]), + ) def test_load_unchanged(self): """Checks loading a repository a second time does not add @@ -199,17 +246,25 @@ res = self.load() self.assertEqual(res["status"], "eventful") - visit = self.storage.origin_visit_get_latest(self.repo_url) - self.assertEqual(visit["snapshot"], hash_to_bytes(SNAPSHOT1["id"])) - self.assertEqual(visit["status"], "full") + assert_last_visit_ok( + self.storage, + self.repo_url, + status="full", + type="git", + snapshot=hash_to_bytes(SNAPSHOT1["id"]), + ) res = self.load() self.assertEqual(res["status"], "uneventful") self.assertCountSnapshots(1) - visit = self.storage.origin_visit_get_latest(self.repo_url) - self.assertEqual(visit["snapshot"], hash_to_bytes(SNAPSHOT1["id"])) - self.assertEqual(visit["status"], "full") + assert_last_visit_ok( + self.storage, + self.repo_url, + status="full", + type="git", + snapshot=hash_to_bytes(SNAPSHOT1["id"]), + ) class DirGitLoaderTest(BaseDirGitLoaderFromDiskTest, GitLoaderFromDiskTests): @@ -252,11 +307,12 @@ self.assertEqual(self.loader.load_status(), {"status": "eventful"}) self.assertEqual(self.loader.visit_status(), "full") - visit = self.storage.origin_visit_get_latest(self.repo_url) - self.assertIsNotNone(visit["snapshot"]) - self.assertEqual(visit["status"], "full") + visit_status = assert_last_visit_ok( + self.storage, self.repo_url, status="full", type="git" + ) + self.assertIsNotNone(visit_status.snapshot) - snapshot_id = visit["snapshot"] + snapshot_id = visit_status.snapshot snapshot = self.storage.snapshot_get(snapshot_id) branches = snapshot["branches"] assert branches[b"HEAD"] == { @@ -304,11 +360,12 @@ self.assertEqual(self.loader.load_status(), {"status": "eventful"}) self.assertEqual(self.loader.visit_status(), "full") - visit = self.storage.origin_visit_get_latest(self.repo_url) - self.assertIsNotNone(visit["snapshot"]) - self.assertEqual(visit["status"], "full") + visit_status = assert_last_visit_ok( + self.storage, self.repo_url, status="full", type="git" + ) + self.assertIsNotNone(visit_status.snapshot) - merge_snapshot_id = visit["snapshot"] + merge_snapshot_id = visit_status.snapshot assert merge_snapshot_id != snapshot_id merge_snapshot = self.storage.snapshot_get(merge_snapshot_id) @@ -363,9 +420,13 @@ assert self.loader.load_status() == {"status": "eventful"} assert self.loader.visit_status() == "full" - visit = self.storage.origin_visit_get_latest(self.repo_url) - assert visit["snapshot"] == expected_snapshot.id - assert visit["status"] == "full" + assert_last_visit_ok( + self.storage, + self.repo_url, + status="full", + type="git", + snapshot=expected_snapshot.id, + ) def test_load_dangling_symref(self): with open(os.path.join(self.destination_path, ".git/HEAD"), "wb") as f: @@ -380,10 +441,11 @@ self.assertCountRevisions(7) self.assertCountSnapshots(1) - visit = self.storage.origin_visit_get_latest(self.repo_url) - snapshot_id = visit["snapshot"] + visit_status = assert_last_visit_ok( + self.storage, self.repo_url, status="full", type="git" + ) + snapshot_id = visit_status.snapshot assert snapshot_id is not None - assert visit["status"] == "full" snapshot = self.storage.snapshot_get(snapshot_id) branches = snapshot["branches"]