diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -452,11 +452,6 @@ 'SELECT %s' % ', '.join(self.origin_visit_get_cols), 'FROM origin_visit'] - if require_snapshot: - # Makes sure the snapshot is known - query_parts.append( - 'INNER JOIN snapshot ON (origin_visit.snapshot=snapshot.id)') - query_parts.append('WHERE origin = %s') if require_snapshot: diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -816,7 +816,11 @@ visit = self.origin_visit_get_latest( origin, allowed_statuses=allowed_statuses, require_snapshot=True) if visit and visit['snapshot']: - return self.snapshot_get(visit['snapshot']) + snapshot = self.snapshot_get(visit['snapshot']) + if not snapshot: + raise ValueError( + 'last origin visit references an unknown snapshot') + return snapshot def snapshot_count_branches(self, snapshot_id, db=None, cur=None): """Count the number of branches in the snapshot with the given id @@ -1356,8 +1360,7 @@ if visit['status'] in allowed_statuses] if require_snapshot: visits = [visit for visit in visits - if visit['snapshot'] - and visit['snapshot'] in self._snapshots] + if visit['snapshot']] return max(visits, key=lambda v: (v['date'], v['visit']), default=None) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1070,7 +1070,13 @@ origin, allowed_statuses=allowed_statuses, require_snapshot=True, db=db, cur=cur) if origin_visit and origin_visit['snapshot']: - return self.snapshot_get(origin_visit['snapshot'], db=db, cur=cur) + snapshot = self.snapshot_get( + origin_visit['snapshot'], db=db, cur=cur) + print(snapshot) + if not snapshot: + raise ValueError( + 'last origin visit references an unknown snapshot') + return snapshot @db_transaction(statement_timeout=2000) def snapshot_count_branches(self, snapshot_id, db=None, cur=None): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -2900,10 +2900,12 @@ # Two visits, both with no snapshot: latest snapshot is None self.assertIsNone(self.storage.snapshot_get_latest(origin_id)) - # Add unknown snapshot to visit1, latest snapshot = None + # Add unknown snapshot to visit1, check that the inconsistency is + # detected self.storage.origin_visit_update( origin_id, visit1_id, snapshot=self.complete_snapshot['id']) - self.assertIsNone(self.storage.snapshot_get_latest(origin_id)) + with self.assertRaises(ValueError): + self.storage.snapshot_get_latest(origin_id) # Status filter: both visits are status=ongoing, so no snapshot # returned @@ -2914,10 +2916,9 @@ # Mark the first visit as completed and check status filter again self.storage.origin_visit_update(origin_id, visit1_id, status='full') - self.assertIsNone( + with self.assertRaises(ValueError): self.storage.snapshot_get_latest(origin_id, allowed_statuses=['full']), - ) # Actually add the snapshot and check status filter again self.storage.snapshot_add([self.complete_snapshot]) @@ -2926,18 +2927,19 @@ self.storage.snapshot_get_latest(origin_id) ) - # Add unknown snapshot to visit2 and check that the old snapshot - # is still returned + # Add unknown snapshot to visit2 and check that the inconsistency + # is detected self.storage.origin_visit_update( - origin_id, visit2_id, snapshot=self.empty_snapshot['id']) - self.assertEqual( - self.complete_snapshot, - self.storage.snapshot_get_latest(origin_id)) + origin_id, visit2_id, snapshot=self.snapshot['id']) + print('---'*20) + with self.assertRaises(ValueError): + ret = self.storage.snapshot_get_latest(origin_id) + print(ret) # Actually add that snapshot and check that the new one is returned - self.storage.snapshot_add([self.empty_snapshot]) + self.storage.snapshot_add([self.snapshot]) self.assertEqual( - self.empty_snapshot, + self.snapshot, self.storage.snapshot_get_latest(origin_id) )