diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -796,7 +796,7 @@ should be used instead. Args: - origin (int): the origin's identifier + origin (Union[str,int]): the origin's URL or identifier allowed_statuses (list of str): list of visit statuses considered to find the latest snapshot for the visit. For instance, ``allowed_statuses=['full']`` will only consider visits that @@ -810,6 +810,8 @@ or :const:`None` if the snapshot has less than 1000 branches. """ + if isinstance(origin, str): + origin = self.origin_get({'url': origin})['id'] visits = self._origin_visits[origin-1] if allowed_statuses is not None: visits = [visit for visit in visits diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1049,7 +1049,7 @@ should be used instead. Args: - origin (int): the origin identifier + origin (Union[str,int]): the origin's URL or identifier allowed_statuses (list of str): list of visit statuses considered to find the latest snapshot for the visit. For instance, ``allowed_statuses=['full']`` will only consider visits that @@ -1063,6 +1063,9 @@ or :const:`None` if the snapshot has less than 1000 branches. """ + if isinstance(origin, str): + origin = self.origin_get({'url': origin})['id'] + origin_visit = db.origin_visit_get_latest_snapshot( origin, allowed_statuses=allowed_statuses, cur=cur) if origin_visit: diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -2456,6 +2456,68 @@ self.assertEqual(self.complete_snapshot, self.storage.snapshot_get_latest(origin_id)) + def test_snapshot_get_latest_from_url(self): + self.storage.origin_add_one(self.origin) + origin_url = self.origin['url'] + origin_visit1 = self.storage.origin_visit_add(origin_url, + self.date_visit1) + visit1_id = origin_visit1['visit'] + origin_visit2 = self.storage.origin_visit_add(origin_url, + self.date_visit2) + visit2_id = origin_visit2['visit'] + + # Add a visit with the same date as the previous one + origin_visit3 = self.storage.origin_visit_add(origin_url, + self.date_visit2) + visit3_id = origin_visit3['visit'] + + # Two visits, both with no snapshot: latest snapshot is None + self.assertIsNone(self.storage.snapshot_get_latest(origin_url)) + + # Add snapshot to visit1, latest snapshot = visit 1 snapshot + self.storage.snapshot_add([self.complete_snapshot]) + self.storage.origin_visit_update( + origin_url, visit1_id, snapshot=self.complete_snapshot['id']) + self.assertEqual(self.complete_snapshot, + self.storage.snapshot_get_latest(origin_url)) + + # Status filter: both visits are status=ongoing, so no snapshot + # returned + self.assertIsNone( + self.storage.snapshot_get_latest(origin_url, + allowed_statuses=['full']) + ) + + # Mark the first visit as completed and check status filter again + self.storage.origin_visit_update(origin_url, visit1_id, status='full') + self.assertEqual( + self.complete_snapshot, + self.storage.snapshot_get_latest(origin_url, + allowed_statuses=['full']), + ) + + # Add snapshot to visit2 and check that the new snapshot is returned + self.storage.snapshot_add([self.empty_snapshot]) + self.storage.origin_visit_update( + origin_url, visit2_id, snapshot=self.empty_snapshot['id']) + self.assertEqual(self.empty_snapshot, + self.storage.snapshot_get_latest(origin_url)) + + # Check that the status filter is still working + self.assertEqual( + self.complete_snapshot, + self.storage.snapshot_get_latest(origin_url, + allowed_statuses=['full']), + ) + + # Add snapshot to visit3 (same date as visit2) and check that + # the new snapshot is returned + self.storage.snapshot_add([self.complete_snapshot]) + self.storage.origin_visit_update( + origin_url, visit3_id, snapshot=self.complete_snapshot['id']) + self.assertEqual(self.complete_snapshot, + self.storage.snapshot_get_latest(origin_url)) + def test_snapshot_get_latest__missing_snapshot(self): origin_id = self.storage.origin_add_one(self.origin) origin_visit1 = self.storage.origin_visit_add(origin_id,