diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -581,6 +581,16 @@ return None return line_to_bytes(r[0]) + def origin_visit_exists(self, origin_id, visit_id, cur=None): + """Check whether an origin visit with the given ids exists""" + cur = self._cursor(cur) + + query = "SELECT 1 FROM origin_visit where origin = %s AND visit = %s" + + cur.execute(query, (origin_id, visit_id)) + + return bool(cur.fetchone()) + def origin_visit_get_latest_snapshot(self, origin_id, allowed_statuses=None, cur=None): diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -550,6 +550,9 @@ - **target** (:class:`bytes`): identifier of the target (currently a ``sha1_git`` for all object kinds, or the name of the target branch for aliases) + + Raises: + ValueError: if the origin or visit id does not exist. """ snapshot_id = snapshot['id'] if snapshot_id not in self._snapshots: @@ -561,6 +564,8 @@ '_sorted_branch_names': sorted(snapshot['branches']) } self._objects[snapshot_id].append(('snapshot', snapshot_id)) + if visit not in self._origin_visits: + raise ValueError('Origin %s has no visit %s' % (origin, visit)) self._origin_visits[visit]['snapshot'] = snapshot_id def snapshot_get(self, snapshot_id): diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -763,6 +763,9 @@ - **target** (:class:`bytes`): identifier of the target (currently a ``sha1_git`` for all object kinds, or the name of the target branch for aliases) + + Raises: + ValueError: if the origin or visit id does not exist. """ if not db.snapshot_exists(snapshot['id'], cur): db.mktemp_snapshot_branch(cur) @@ -779,6 +782,9 @@ ['name', 'target', 'target_type'], cur, ) + if not db.origin_visit_exists(origin, visit): + raise ValueError('Not origin visit with ids (%s, %s)' % + (origin, visit)) db.snapshot_add(origin, visit, snapshot['id'], cur) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1369,6 +1369,13 @@ visit_id) self.assertEqual(origin_visit_info['snapshot'], self.snapshot['id']) + def test_snapshot_add_nonexistent_visit(self): + origin_id = self.storage.origin_add_one(self.origin) + visit_id = 54164461156 + + with self.assertRaises(ValueError): + self.storage.snapshot_add(origin_id, visit_id, self.snapshot) + def test_snapshot_add_twice(self): origin_id = self.storage.origin_add_one(self.origin) origin_visit1 = self.storage.origin_visit_add(origin_id,