diff --git a/sql/upgrades/130.sql b/sql/upgrades/130.sql new file mode 100644 --- /dev/null +++ b/sql/upgrades/130.sql @@ -0,0 +1,45 @@ +-- SWH DB schema upgrade +-- from_version: 129 +-- to_version: 130 +-- description: Remove origin_visit update from snapshot_add + +insert into dbversion(version, release, description) + values(130, now(), 'Work In Progress'); + +create or replace function swh_snapshot_add(snapshot_id snapshot.id%type) + returns void + language plpgsql +as $$ +declare + snapshot_object_id snapshot.object_id%type; +begin + select object_id from snapshot where id = snapshot_id into snapshot_object_id; + if snapshot_object_id is null then + insert into snapshot (id) values (snapshot_id) returning object_id into snapshot_object_id; + insert into snapshot_branch (name, target_type, target) + select name, target_type, target from tmp_snapshot_branch tmp + where not exists ( + select 1 + from snapshot_branch sb + where sb.name = tmp.name + and sb.target = tmp.target + and sb.target_type = tmp.target_type + ) + on conflict do nothing; + insert into snapshot_branches (snapshot_id, branch_id) + select snapshot_object_id, sb.object_id as branch_id + from tmp_snapshot_branch tmp + join snapshot_branch sb + using (name, target, target_type) + where tmp.target is not null and tmp.target_type is not null + union + select snapshot_object_id, sb.object_id as branch_id + from tmp_snapshot_branch tmp + join snapshot_branch sb + using (name) + where tmp.target is null and tmp.target_type is null + and sb.target is null and sb.target_type is null; + end if; +end; +$$; + diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -84,10 +84,24 @@ def object_find_by_sha1_git(self, ids): return self.post('object/find_by_sha1_git', {'ids': ids}) - def snapshot_add(self, origin, visit, snapshot): - return self.post('snapshot/add', { - 'origin': origin, 'visit': visit, 'snapshot': snapshot, - }) + def snapshot_add(self, snapshot, origin=None, visit=None): + if origin: + assert visit + (origin, visit, snapshot) = (snapshot, origin, visit) + warnings.warn("arguments 'origin' and 'visit' of snapshot_add " + "are deprecated since v0.0.131, please use " + "snapshot_add(snapshot) + " + "origin_visit_update(origin, visit, " + "snapshot=snapshot['id']) instead.", + DeprecationWarning) + return self.post('snapshot/add', { + 'origin': origin, 'visit': visit, 'snapshot': snapshot, + }) + else: + assert not visit + return self.post('snapshot/add', { + 'snapshot': snapshot, + }) def snapshot_get(self, snapshot_id): return self.post('snapshot', { @@ -167,11 +181,13 @@ date = ts return self.post('origin/visit/add', {'origin': origin, 'date': date}) - def origin_visit_update(self, origin, visit_id, status, metadata=None): + def origin_visit_update(self, origin, visit_id, status=None, + metadata=None, snapshot=None): return self.post('origin/visit/update', {'origin': origin, 'visit_id': visit_id, 'status': status, - 'metadata': metadata}) + 'metadata': metadata, + 'snapshot': snapshot}) def origin_visit_get(self, origin, last_visit=None, limit=None): return self.post('origin/visit/get', { diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -138,12 +138,11 @@ return bool(cur.fetchone()) - def snapshot_add(self, origin, visit, snapshot_id, cur=None): - """Add a snapshot for origin/visit from the temporary table""" + def snapshot_add(self, snapshot_id, cur=None): + """Add a snapshot from the temporary table""" cur = self._cursor(cur) - cur.execute("""SELECT swh_snapshot_add(%s, %s, %s)""", - (origin, visit, snapshot_id)) + cur.execute("""SELECT swh_snapshot_add(%s)""", (snapshot_id,)) snapshot_count_cols = ['target_type', 'count'] @@ -297,14 +296,36 @@ (origin, ts)) return cur.fetchone()[0] - def origin_visit_update(self, origin, visit_id, status, - metadata, cur=None): + def origin_visit_update(self, origin_id, visit_id, updates, cur=None): """Update origin_visit's status.""" cur = self._cursor(cur) - update = """UPDATE origin_visit - SET status=%s, metadata=%s - WHERE origin=%s AND visit=%s""" - cur.execute(update, (status, jsonize(metadata), origin, visit_id)) + update_cols = [] + values = [] + where = ['origin=%s AND visit=%s'] + where_values = [origin_id, visit_id] + from_ = '' + if 'status' in updates: + update_cols.append('status=%s') + values.append(updates.pop('status')) + if 'metadata' in updates: + update_cols.append('metadata=%s') + values.append(jsonize(updates.pop('metadata'))) + if 'snapshot' in updates: + update_cols.append('snapshot_id=snapshot.object_id') + from_ = 'FROM snapshot' + where.append('snapshot.id=%s') + where_values.append(updates.pop('snapshot')) + assert not updates, 'Unknown fields: %r' % updates + query = """UPDATE origin_visit + SET {update_cols} + {from} + WHERE {where}""".format(**{ + 'update_cols': ', '.join(update_cols), + 'from': from_, + 'where': ' AND '.join(where) + }) + print(query) + cur.execute(query, (*values, *where_values)) origin_visit_get_cols = ['origin', 'visit', 'date', 'status', 'metadata', 'snapshot'] diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -556,12 +556,10 @@ for rel_id in releases: yield copy.deepcopy(self._releases.get(rel_id)) - def snapshot_add(self, origin, visit, snapshot): + def snapshot_add(self, snapshot, legacy_arg1=None, legacy_arg2=None): """Add a snapshot for the given origin/visit couple Args: - origin (int): id of the origin - visit (int): id of the visit snapshot (dict): the snapshot to add to the visit, containing the following keys: @@ -581,32 +579,27 @@ Raises: ValueError: if the origin's or visit's identifier does not exist. """ + if legacy_arg1: + assert legacy_arg2 + (origin, visit, snapshot) = \ + (snapshot, legacy_arg1, legacy_arg2) + else: + origin = visit = None snapshot_id = snapshot['id'] + if self.journal_writer: + self.journal_writer.write_addition( + 'snapshot', snapshot) if snapshot_id not in self._snapshots: self._snapshots[snapshot_id] = { - 'origin': origin, - 'visit': visit, 'id': snapshot_id, 'branches': copy.deepcopy(snapshot['branches']), '_sorted_branch_names': sorted(snapshot['branches']) } self._objects[snapshot_id].append(('snapshot', snapshot_id)) - if origin <= len(self._origin_visits) and \ - visit <= len(self._origin_visits[origin-1]): - if self.journal_writer: - self.journal_writer.write_addition( - 'snapshot', snapshot) - self.journal_writer.write_update('origin_visit', { - **self._origin_visits[origin-1][visit-1], - 'origin': self._origins[origin-1], - 'snapshot': snapshot_id}) - - self._origin_visits[origin-1][visit-1]['snapshot'] = snapshot_id - else: - raise ValueError('Origin with id %s does not exist or has no visit' - ' with id %s' % (origin, visit)) + if origin: + self.origin_visit_update(origin, visit, snapshot=snapshot_id) def snapshot_get(self, snapshot_id): """Get the content, possibly partial, of a snapshot with the given id @@ -1077,7 +1070,8 @@ return visit_ret - def origin_visit_update(self, origin, visit_id, status, metadata=None): + def origin_visit_update(self, origin, visit_id, status=None, + metadata=None, snapshot=None): """Update an origin_visit's status. Args: @@ -1085,6 +1079,8 @@ visit_id (int): visit's identifier status: visit's new status metadata: data associated to the visit + snapshot (sha1_git): identifier of the snapshot to add to + the visit Returns: None @@ -1092,18 +1088,27 @@ """ origin_id = origin # TODO: rename the argument + try: + visit = self._origin_visits[origin_id-1][visit_id-1] + except IndexError: + raise ValueError('Invalid origin_id or visit_id') from None if self.journal_writer: origin = self.origin_get([{'id': origin_id}])[0] self.journal_writer.write_update('origin_visit', { - **self._origin_visits[origin_id-1][visit_id-1], 'origin': origin, 'visit': visit_id, - 'status': status, 'metadata': metadata}) + 'status': status or visit['status'], + 'date': visit['date'], + 'metadata': metadata or visit['metadata'], + 'snapshot': snapshot or visit['snapshot']}) if origin_id > len(self._origin_visits) or \ visit_id > len(self._origin_visits[origin_id-1]): return - self._origin_visits[origin_id-1][visit_id-1].update({ - 'status': status, - 'metadata': metadata}) + if status: + visit['status'] = status + if metadata: + visit['metadata'] = metadata + if snapshot: + visit['snapshot'] = snapshot def origin_visit_get(self, origin, last_visit=None, limit=None): """Retrieve all the origin's visit's information. diff --git a/swh/storage/sql/30-swh-schema.sql b/swh/storage/sql/30-swh-schema.sql --- a/swh/storage/sql/30-swh-schema.sql +++ b/swh/storage/sql/30-swh-schema.sql @@ -12,7 +12,7 @@ -- latest schema version insert into dbversion(version, release, description) - values(129, now(), 'Work In Progress'); + values(130, now(), 'Work In Progress'); -- a SHA1 checksum create domain sha1 as bytea check (length(value) = 20); diff --git a/swh/storage/sql/40-swh-func.sql b/swh/storage/sql/40-swh-func.sql --- a/swh/storage/sql/40-swh-func.sql +++ b/swh/storage/sql/40-swh-func.sql @@ -706,7 +706,7 @@ returning visit; $$; -create or replace function swh_snapshot_add(origin bigint, visit bigint, snapshot_id snapshot.id%type) +create or replace function swh_snapshot_add(snapshot_id snapshot.id%type) returns void language plpgsql as $$ @@ -740,9 +740,6 @@ where tmp.target is null and tmp.target_type is null and sb.target is null and sb.target_type is null; end if; - update origin_visit ov - set snapshot_id = snapshot_object_id - where ov.origin=swh_snapshot_add.origin and ov.visit=swh_snapshot_add.visit; end; $$; diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -768,13 +768,11 @@ yield data if data['target_type'] else None @db_transaction() - def snapshot_add(self, origin, visit, snapshot, + def snapshot_add(self, snapshot, origin=None, visit=None, db=None, cur=None): """Add a snapshot for the given origin/visit couple Args: - origin (int): id of the origin - visit (int): id of the visit snapshot (dict): the snapshot to add to the visit, containing the following keys: @@ -790,15 +788,30 @@ - **target** (:class:`bytes`): identifier of the target (currently a ``sha1_git`` for all object kinds, or the name of the target branch for aliases) + origin (int): legacy argument for backward compatibility + visit (int): legacy argument for backward compatibility Raises: ValueError: if the origin or visit id does not exist. """ - origin_id = origin - visit_id = visit + if origin: + if not visit: + raise TypeError( + 'snapshot_add expects one argument (or, as a legacy ' + 'behavior, three arguments), not two') + if isinstance(snapshot, int): + # Called by legacy code that uses the new api/client.py + (origin_id, visit_id, snapshot) = \ + (snapshot, origin, visit) + else: + # Called by legacy code that uses the old api/client.py + origin_id = origin + visit_id = visit + else: + # Called by new code that uses the new api/client.py + origin_id = visit_id = None if not db.snapshot_exists(snapshot['id'], cur): - db.mktemp_snapshot_branch(cur) db.copy_to( ( @@ -815,28 +828,14 @@ ) if self.journal_writer: - visit = db.origin_visit_get(origin_id, visit_id, cur=cur) - visit_exists = visit is not None - else: - visit_exists = db.origin_visit_exists(origin_id, visit_id) - - if not visit_exists: - raise ValueError('Not origin visit with ids (%s, %s)' % - (origin_id, visit_id)) - - if self.journal_writer: - # Send the snapshot before the origin: in case of a crash, - # it's better to have an orphan snapshot than have the - # origin_visit have a dangling reference to a snapshot - origin = self.origin_get([{'id': origin_id}], db=db, cur=cur)[0] - visit = dict(zip(db.origin_visit_get_cols, visit)) self.journal_writer.write_addition('snapshot', snapshot) - self.journal_writer.write_update('origin_visit', { - 'origin': origin, 'visit': visit_id, - 'status': visit['status'], 'metadata': visit['metadata'], - 'date': visit['date'], 'snapshot': snapshot['id']}) - db.snapshot_add(origin_id, visit_id, snapshot['id'], cur) + db.snapshot_add(snapshot['id'], cur) + + if visit_id: + self.origin_visit_update( + origin_id, visit_id, snapshot=snapshot['id'], + db=db, cur=cur) @db_transaction(statement_timeout=2000) def snapshot_get(self, snapshot_id, db=None, cur=None): @@ -1059,7 +1058,8 @@ } @db_transaction() - def origin_visit_update(self, origin, visit_id, status, metadata=None, + def origin_visit_update(self, origin, visit_id, status=None, + metadata=None, snapshot=None, db=None, cur=None): """Update an origin_visit's status. @@ -1068,6 +1068,8 @@ visit_id: Visit's id status: Visit's new status metadata: Data associated to the visit + snapshot (sha1_git): identifier of the snapshot to add to + the visit Returns: None @@ -1075,16 +1077,28 @@ """ origin_id = origin # TODO: rename the argument + visit = db.origin_visit_get(origin_id, visit_id, cur=cur) + + if not visit: + raise ValueError('Invalid visit_id for this origin.') + + visit = dict(zip(db.origin_visit_get_cols, visit)) + + updates = {} + if status and status != visit['status']: + updates['status'] = status + if metadata and metadata != visit['metadata']: + updates['metadata'] = metadata + if snapshot and snapshot != visit['snapshot']: + updates['snapshot'] = snapshot + if self.journal_writer: origin = self.origin_get([{'id': origin_id}], db=db, cur=cur)[0] - visit = db.origin_visit_get(origin_id, visit_id, cur=cur) - visit = dict(zip(db.origin_visit_get_cols, visit)) self.journal_writer.write_update('origin_visit', { - 'origin': origin, 'visit': visit_id, - 'status': status, 'metadata': metadata, - 'date': visit['date'], 'snapshot': None}) - return db.origin_visit_update( - origin_id, visit_id, status, metadata, cur) + **visit, **updates, 'origin': origin}) + + if updates: + db.origin_visit_update(origin_id, visit_id, updates, cur) @db_transaction_generator(statement_timeout=500) def origin_visit_get(self, origin, last_visit=None, limit=None, db=None, diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1410,6 +1410,46 @@ self.date_visit1) visit_id = origin_visit1['visit'] + self.storage.snapshot_add(self.empty_snapshot) + self.storage.origin_visit_update( + origin_id, visit_id, snapshot=self.empty_snapshot['id']) + + by_id = self.storage.snapshot_get(self.empty_snapshot['id']) + self.assertEqual(by_id, self.empty_snapshot) + + by_ov = self.storage.snapshot_get_by_origin_visit(origin_id, visit_id) + self.assertEqual(by_ov, self.empty_snapshot) + + expected_origin = self.origin.copy() + expected_origin['id'] = origin_id + data1 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data2 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': self.empty_snapshot['id'], + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin_visit', data1), + ('snapshot', self.empty_snapshot), + ('origin_visit', data2)]) + + def test_snapshot_add_get_empty__legacy_add(self): + origin_id = self.storage.origin_add_one(self.origin) + origin_visit1 = self.storage.origin_visit_add(origin_id, + self.date_visit1) + visit_id = origin_visit1['visit'] + self.storage.snapshot_add(origin_id, visit_id, self.empty_snapshot) by_id = self.storage.snapshot_get(self.empty_snapshot['id']) @@ -1598,26 +1638,103 @@ self.journal_writer.objects[:] = [] + self.storage.snapshot_add(self.snapshot) + with self.assertRaises(ValueError): - self.storage.snapshot_add(origin_id, visit_id, self.snapshot) + self.storage.origin_visit_update( + origin_id, visit_id, snapshot=self.snapshot['id']) - self.assertEqual(list(self.journal_writer.objects), []) + self.assertEqual(list(self.journal_writer.objects), [ + ('snapshot', self.snapshot)]) - def test_snapshot_add_nonexistent_visit_no_journal(self): - # Same test as before, but uses a different code path for checking - # the origin visit exists. - self.storage.journal_writer = None + def test_snapshot_add_nonexistent_visit__legacy_add(self): origin_id = self.storage.origin_add_one(self.origin) visit_id = 54164461156 + self.journal_writer.objects[:] = [] + with self.assertRaises(ValueError): self.storage.snapshot_add(origin_id, visit_id, self.snapshot) + # Note: the actual legacy behavior was to abort before adding + # the snapshot; but delaying non-existence checks makes the + # compatibility code simpler + self.assertEqual(list(self.journal_writer.objects), [ + ('snapshot', self.snapshot)]) + def test_snapshot_add_twice(self): origin_id = self.storage.origin_add_one(self.origin) origin_visit1 = self.storage.origin_visit_add(origin_id, self.date_visit1) visit1_id = origin_visit1['visit'] + self.storage.snapshot_add(self.snapshot) + self.storage.origin_visit_update( + origin_id, visit1_id, snapshot=self.snapshot['id']) + + by_ov1 = self.storage.snapshot_get_by_origin_visit(origin_id, + visit1_id) + self.assertEqual(by_ov1, self.snapshot) + + origin_visit2 = self.storage.origin_visit_add(origin_id, + self.date_visit2) + visit2_id = origin_visit2['visit'] + + self.storage.snapshot_add(self.snapshot) + self.storage.origin_visit_update( + origin_id, visit2_id, snapshot=self.snapshot['id']) + + by_ov2 = self.storage.snapshot_get_by_origin_visit(origin_id, + visit2_id) + self.assertEqual(by_ov2, self.snapshot) + + expected_origin = self.origin.copy() + expected_origin['id'] = origin_id + data1 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data2 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': self.snapshot['id'], + } + data3 = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit2['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data4 = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit2['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': self.snapshot['id'], + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin_visit', data1), + ('snapshot', self.snapshot), + ('origin_visit', data2), + ('origin_visit', data3), + ('snapshot', self.snapshot), + ('origin_visit', data4)]) + + def test_snapshot_add_twice__legacy_add(self): + origin_id = self.storage.origin_add_one(self.origin) + origin_visit1 = self.storage.origin_visit_add(origin_id, + self.date_visit1) + visit1_id = origin_visit1['visit'] self.storage.snapshot_add(origin_id, visit1_id, self.snapshot) by_ov1 = self.storage.snapshot_get_by_origin_visit(origin_id, @@ -1707,6 +1824,67 @@ self.assertIsNone(self.storage.snapshot_get_latest(origin_id)) # Add snapshot to visit1, latest snapshot = visit 1 snapshot + self.storage.snapshot_add(self.complete_snapshot) + self.storage.origin_visit_update( + origin_id, visit1_id, snapshot=self.complete_snapshot['id']) + self.assertEqual(self.complete_snapshot, + self.storage.snapshot_get_latest(origin_id)) + + # Status filter: both visits are status=ongoing, so no snapshot + # returned + self.assertIsNone( + self.storage.snapshot_get_latest(origin_id, + allowed_statuses=['full']) + ) + + # Mark the first visit as completed and check status filter again + self.storage.origin_visit_update(origin_id, visit1_id, status='full') + self.assertEqual( + self.complete_snapshot, + self.storage.snapshot_get_latest(origin_id, + allowed_statuses=['full']), + ) + + # Add snapshot to visit2 and check that the new snapshot is returned + self.storage.snapshot_add(self.empty_snapshot) + self.storage.origin_visit_update( + origin_id, visit2_id, snapshot=self.empty_snapshot['id']) + self.assertEqual(self.empty_snapshot, + self.storage.snapshot_get_latest(origin_id)) + + # Check that the status filter is still working + self.assertEqual( + self.complete_snapshot, + self.storage.snapshot_get_latest(origin_id, + allowed_statuses=['full']), + ) + + # Add snapshot to visit3 (same date as visit2) and check that + # the new snapshot is returned + self.storage.snapshot_add(self.complete_snapshot) + self.storage.origin_visit_update( + origin_id, visit3_id, snapshot=self.complete_snapshot['id']) + self.assertEqual(self.complete_snapshot, + self.storage.snapshot_get_latest(origin_id)) + + def test_snapshot_get_latest__legacy_add(self): + origin_id = self.storage.origin_add_one(self.origin) + origin_visit1 = self.storage.origin_visit_add(origin_id, + self.date_visit1) + visit1_id = origin_visit1['visit'] + origin_visit2 = self.storage.origin_visit_add(origin_id, + self.date_visit2) + visit2_id = origin_visit2['visit'] + + # Add a visit with the same date as the previous one + origin_visit3 = self.storage.origin_visit_add(origin_id, + self.date_visit2) + visit3_id = origin_visit3['visit'] + + # Two visits, both with no snapshot: latest snapshot is None + self.assertIsNone(self.storage.snapshot_get_latest(origin_id)) + + # Add snapshot to visit1, latest snapshot = visit 1 snapshot self.storage.snapshot_add(origin_id, visit1_id, self.complete_snapshot) self.assertEqual(self.complete_snapshot, self.storage.snapshot_get_latest(origin_id))