diff --git a/sql/upgrades/134.sql b/sql/upgrades/134.sql new file mode 100644 --- /dev/null +++ b/sql/upgrades/134.sql @@ -0,0 +1,46 @@ +-- SWH DB schema upgrade +-- from_version: 133 +-- to_version: 134 +-- description: Make swh_snapshot_add delete the temporary table at the end. + +insert into dbversion(version, release, description) + values(133, now(), 'Work In Progress'); + +create or replace function swh_snapshot_add(snapshot_id snapshot.id%type) + returns void + language plpgsql +as $$ +declare + snapshot_object_id snapshot.object_id%type; +begin + select object_id from snapshot where id = snapshot_id into snapshot_object_id; + if snapshot_object_id is null then + insert into snapshot (id) values (snapshot_id) returning object_id into snapshot_object_id; + insert into snapshot_branch (name, target_type, target) + select name, target_type, target from tmp_snapshot_branch tmp + where not exists ( + select 1 + from snapshot_branch sb + where sb.name = tmp.name + and sb.target = tmp.target + and sb.target_type = tmp.target_type + ) + on conflict do nothing; + insert into snapshot_branches (snapshot_id, branch_id) + select snapshot_object_id, sb.object_id as branch_id + from tmp_snapshot_branch tmp + join snapshot_branch sb + using (name, target, target_type) + where tmp.target is not null and tmp.target_type is not null + union + select snapshot_object_id, sb.object_id as branch_id + from tmp_snapshot_branch tmp + join snapshot_branch sb + using (name) + where tmp.target is null and tmp.target_type is null + and sb.target is null and sb.target_type is null; + end if; + drop table if exists tmp_snapshot_branch; +end; +$$; + diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -84,23 +84,23 @@ def object_find_by_sha1_git(self, ids): return self.post('object/find_by_sha1_git', {'ids': ids}) - def snapshot_add(self, snapshot, origin=None, visit=None): + def snapshot_add(self, snapshots, origin=None, visit=None): if origin: assert visit - (origin, visit, snapshot) = (snapshot, origin, visit) + (origin, visit, snapshots) = (snapshots, origin, visit) warnings.warn("arguments 'origin' and 'visit' of snapshot_add " "are deprecated since v0.0.131, please use " - "snapshot_add(snapshot) + " + "snapshot_add([snapshot]) + " "origin_visit_update(origin, visit, " "snapshot=snapshot['id']) instead.", DeprecationWarning) return self.post('snapshot/add', { - 'origin': origin, 'visit': visit, 'snapshot': snapshot, + 'origin': origin, 'visit': visit, 'snapshots': snapshots, }) else: assert not visit return self.post('snapshot/add', { - 'snapshot': snapshot, + 'snapshots': snapshots, }) def snapshot_get(self, snapshot_id): diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -556,11 +556,11 @@ for rel_id in releases: yield copy.deepcopy(self._releases.get(rel_id)) - def snapshot_add(self, snapshot, legacy_arg1=None, legacy_arg2=None): - """Add a snapshot for the given origin/visit couple + def snapshot_add(self, snapshots, legacy_arg1=None, legacy_arg2=None): + """Add a snapshot to the storage Args: - snapshot (dict): the snapshot to add to the visit, containing the + snapshot ([dict]): the snapshots to add, containing the following keys: - **id** (:class:`bytes`): id of the snapshot @@ -581,25 +581,28 @@ """ if legacy_arg1: assert legacy_arg2 - (origin, visit, snapshot) = \ - (snapshot, legacy_arg1, legacy_arg2) + (origin, visit, snapshots) = \ + (snapshots, legacy_arg1, [legacy_arg2]) else: origin = visit = None - snapshot_id = snapshot['id'] - if self.journal_writer: - self.journal_writer.write_addition( - 'snapshot', snapshot) - if snapshot_id not in self._snapshots: - self._snapshots[snapshot_id] = { - 'id': snapshot_id, - 'branches': copy.deepcopy(snapshot['branches']), - '_sorted_branch_names': sorted(snapshot['branches']) - } - self._objects[snapshot_id].append(('snapshot', snapshot_id)) + for snapshot in snapshots: + snapshot_id = snapshot['id'] + if self.journal_writer: + self.journal_writer.write_addition( + 'snapshot', snapshot) + if snapshot_id not in self._snapshots: + self._snapshots[snapshot_id] = { + 'id': snapshot_id, + 'branches': copy.deepcopy(snapshot['branches']), + '_sorted_branch_names': sorted(snapshot['branches']) + } + self._objects[snapshot_id].append(('snapshot', snapshot_id)) if origin: - self.origin_visit_update(origin, visit, snapshot=snapshot_id) + # Legacy API, there can be only one snapshot + self.origin_visit_update( + origin, visit, snapshot=snapshots[0]['id']) def snapshot_get(self, snapshot_id): """Get the content, possibly partial, of a snapshot with the given id diff --git a/swh/storage/sql/40-swh-func.sql b/swh/storage/sql/40-swh-func.sql --- a/swh/storage/sql/40-swh-func.sql +++ b/swh/storage/sql/40-swh-func.sql @@ -740,6 +740,7 @@ where tmp.target is null and tmp.target_type is null and sb.target is null and sb.target_type is null; end if; + drop table if exists tmp_snapshot_branch; end; $$; diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -768,12 +768,12 @@ yield data if data['target_type'] else None @db_transaction() - def snapshot_add(self, snapshot, origin=None, visit=None, + def snapshot_add(self, snapshots, origin=None, visit=None, db=None, cur=None): - """Add a snapshot for the given origin/visit couple + """Add snapshots to the storage. Args: - snapshot (dict): the snapshot to add to the visit, containing the + snapshot ([dict]): the snapshots to add, containing the following keys: - **id** (:class:`bytes`): id of the snapshot @@ -799,42 +799,46 @@ raise TypeError( 'snapshot_add expects one argument (or, as a legacy ' 'behavior, three arguments), not two') - if isinstance(snapshot, int): + if isinstance(snapshots, int): # Called by legacy code that uses the new api/client.py - (origin_id, visit_id, snapshot) = \ - (snapshot, origin, visit) + (origin_id, visit_id, snapshots) = \ + (snapshots, origin, [visit]) else: # Called by legacy code that uses the old api/client.py origin_id = origin visit_id = visit + snapshots = [snapshots] else: # Called by new code that uses the new api/client.py origin_id = visit_id = None - if not db.snapshot_exists(snapshot['id'], cur): - db.mktemp_snapshot_branch(cur) - db.copy_to( - ( - { - 'name': name, - 'target': info['target'] if info else None, - 'target_type': info['target_type'] if info else None, - } - for name, info in snapshot['branches'].items() - ), - 'tmp_snapshot_branch', - ['name', 'target', 'target_type'], - cur, - ) + for snapshot in snapshots: + if not db.snapshot_exists(snapshot['id'], cur): + db.mktemp_snapshot_branch(cur) + db.copy_to( + ( + { + 'name': name, + 'target': info['target'] if info else None, + 'target_type': (info['target_type'] + if info else None), + } + for name, info in snapshot['branches'].items() + ), + 'tmp_snapshot_branch', + ['name', 'target', 'target_type'], + cur, + ) - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) + if self.journal_writer: + self.journal_writer.write_addition('snapshot', snapshot) - db.snapshot_add(snapshot['id'], cur) + db.snapshot_add(snapshot['id'], cur) if visit_id: + # Legacy API, there can be only one snapshot self.origin_visit_update( - origin_id, visit_id, snapshot=snapshot['id'], + origin_id, visit_id, snapshot=snapshots[0]['id'], db=db, cur=cur) @db_transaction(statement_timeout=2000) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1312,7 +1312,7 @@ self.assertEqual(actual_origin_visit['snapshot'], self.snapshot['id']) # when - self.storage.snapshot_add(self.snapshot) + self.storage.snapshot_add([self.snapshot]) self.assertEqual(actual_origin_visit['snapshot'], self.snapshot['id']) def test_origin_visit_get_by(self): @@ -1432,7 +1432,7 @@ self.date_visit1) visit_id = origin_visit1['visit'] - self.storage.snapshot_add(self.empty_snapshot) + self.storage.snapshot_add([self.empty_snapshot]) self.storage.origin_visit_update( origin_id, visit_id, snapshot=self.empty_snapshot['id']) @@ -1518,6 +1518,29 @@ by_ov = self.storage.snapshot_get_by_origin_visit(origin_id, visit_id) self.assertEqual(by_ov, self.complete_snapshot) + def test_snapshot_add_many(self): + self.storage.snapshot_add([self.snapshot, self.complete_snapshot]) + + self.assertEqual( + self.complete_snapshot, + self.storage.snapshot_get(self.complete_snapshot['id'])) + + self.assertEqual( + self.snapshot, + self.storage.snapshot_get(self.snapshot['id'])) + + def test_snapshot_add_many_incremental(self): + self.storage.snapshot_add([self.complete_snapshot]) + self.storage.snapshot_add([self.snapshot, self.complete_snapshot]) + + self.assertEqual( + self.complete_snapshot, + self.storage.snapshot_get(self.complete_snapshot['id'])) + + self.assertEqual( + self.snapshot, + self.storage.snapshot_get(self.snapshot['id'])) + def test_snapshot_add_count_branches(self): origin_id = self.storage.origin_add_one(self.origin) origin_visit1 = self.storage.origin_visit_add(origin_id, @@ -1660,7 +1683,7 @@ self.journal_writer.objects[:] = [] - self.storage.snapshot_add(self.snapshot) + self.storage.snapshot_add([self.snapshot]) with self.assertRaises(ValueError): self.storage.origin_visit_update( @@ -1689,7 +1712,7 @@ origin_visit1 = self.storage.origin_visit_add(origin_id, self.date_visit1) visit1_id = origin_visit1['visit'] - self.storage.snapshot_add(self.snapshot) + self.storage.snapshot_add([self.snapshot]) self.storage.origin_visit_update( origin_id, visit1_id, snapshot=self.snapshot['id']) @@ -1701,7 +1724,7 @@ self.date_visit2) visit2_id = origin_visit2['visit'] - self.storage.snapshot_add(self.snapshot) + self.storage.snapshot_add([self.snapshot]) self.storage.origin_visit_update( origin_id, visit2_id, snapshot=self.snapshot['id']) @@ -1846,7 +1869,7 @@ self.assertIsNone(self.storage.snapshot_get_latest(origin_id)) # Add snapshot to visit1, latest snapshot = visit 1 snapshot - self.storage.snapshot_add(self.complete_snapshot) + self.storage.snapshot_add([self.complete_snapshot]) self.storage.origin_visit_update( origin_id, visit1_id, snapshot=self.complete_snapshot['id']) self.assertEqual(self.complete_snapshot, @@ -1868,7 +1891,7 @@ ) # Add snapshot to visit2 and check that the new snapshot is returned - self.storage.snapshot_add(self.empty_snapshot) + self.storage.snapshot_add([self.empty_snapshot]) self.storage.origin_visit_update( origin_id, visit2_id, snapshot=self.empty_snapshot['id']) self.assertEqual(self.empty_snapshot, @@ -1883,7 +1906,7 @@ # Add snapshot to visit3 (same date as visit2) and check that # the new snapshot is returned - self.storage.snapshot_add(self.complete_snapshot) + self.storage.snapshot_add([self.complete_snapshot]) self.storage.origin_visit_update( origin_id, visit3_id, snapshot=self.complete_snapshot['id']) self.assertEqual(self.complete_snapshot, @@ -1921,7 +1944,7 @@ ) # Actually add the snapshot and check status filter again - self.storage.snapshot_add(self.complete_snapshot) + self.storage.snapshot_add([self.complete_snapshot]) self.assertEqual( self.complete_snapshot, self.storage.snapshot_get_latest(origin_id) @@ -1936,7 +1959,7 @@ self.storage.snapshot_get_latest(origin_id)) # Actually add that snapshot and check that the new one is returned - self.storage.snapshot_add(self.empty_snapshot) + self.storage.snapshot_add([self.empty_snapshot]) self.assertEqual( self.empty_snapshot, self.storage.snapshot_get_latest(origin_id)