diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -84,10 +84,24 @@ def object_find_by_sha1_git(self, ids): return self.post('object/find_by_sha1_git', {'ids': ids}) - def snapshot_add(self, origin, visit, snapshot): - return self.post('snapshot/add', { - 'origin': origin, 'visit': visit, 'snapshot': snapshot, - }) + def snapshot_add(self, snapshot, origin=None, visit=None): + if origin: + assert visit + (origin, visit, snapshot) = (snapshot, origin, visit) + warnings.warn("arguments 'origin' and 'visit' of snapshot_add " + "are deprecated since v0.0.131, please use " + "snapshot_add(snapshot) + " + "origin_visit_update(origin, visit, " + "snapshot_id=snapshot['id']) instead.", + DeprecationWarning) + return self.post('snapshot/add', { + 'origin': origin, 'visit': visit, 'snapshot': snapshot, + }) + else: + assert not visit + return self.post('snapshot/add', { + 'snapshot': snapshot, + }) def snapshot_get(self, snapshot_id): return self.post('snapshot', { @@ -167,11 +181,13 @@ date = ts return self.post('origin/visit/add', {'origin': origin, 'date': date}) - def origin_visit_update(self, origin, visit_id, status, metadata=None): + def origin_visit_update(self, origin, visit_id, status=None, + metadata=None, snapshot_id=None): return self.post('origin/visit/update', {'origin': origin, 'visit_id': visit_id, 'status': status, - 'metadata': metadata}) + 'metadata': metadata, + 'snapshot_id': snapshot_id}) def origin_visit_get(self, origin, last_visit=None, limit=None): return self.post('origin/visit/get', { diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -138,12 +138,11 @@ return bool(cur.fetchone()) - def snapshot_add(self, origin, visit, snapshot_id, cur=None): - """Add a snapshot for origin/visit from the temporary table""" + def snapshot_add(self, snapshot_id, cur=None): + """Add a snapshot from the temporary table""" cur = self._cursor(cur) - cur.execute("""SELECT swh_snapshot_add(%s, %s, %s)""", - (origin, visit, snapshot_id)) + cur.execute("""SELECT swh_snapshot_add(%s)""", (snapshot_id,)) snapshot_count_cols = ['target_type', 'count'] @@ -297,14 +296,35 @@ (origin, ts)) return cur.fetchone()[0] - def origin_visit_update(self, origin, visit_id, status, - metadata, cur=None): + def origin_visit_update(self, origin_id, visit_id, status, metadata, + snapshot_id, cur=None): """Update origin_visit's status.""" cur = self._cursor(cur) + update_cols = [] + values = [] + where = ['origin=%s AND visit=%s'] + where_values = [origin_id, visit_id] + from_ = '' + if status: + update_cols.append('status=%s') + values.append(status) + if metadata: + update_cols.append('metadata=%s') + values.append(jsonize(metadata)) + if snapshot_id: + update_cols.append('snapshot_id=snapshot.object_id') + from_ = 'FROM snapshot' + where.append('snapshot.id=%s') + where_values.append(snapshot_id) update = """UPDATE origin_visit - SET status=%s, metadata=%s - WHERE origin=%s AND visit=%s""" - cur.execute(update, (status, jsonize(metadata), origin, visit_id)) + SET {update_cols} + {from} + WHERE {where}""".format(**{ + 'update_cols': ', '.join(update_cols), + 'from': from_, + 'where': ' AND '.join(where) + }) + cur.execute(update, (*values, *where_values)) origin_visit_get_cols = ['origin', 'visit', 'date', 'status', 'metadata', 'snapshot'] diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -556,12 +556,10 @@ for rel_id in releases: yield copy.deepcopy(self._releases.get(rel_id)) - def snapshot_add(self, origin, visit, snapshot): + def snapshot_add(self, snapshot, legacy_arg1=None, legacy_arg2=None): """Add a snapshot for the given origin/visit couple Args: - origin (int): id of the origin - visit (int): id of the visit snapshot (dict): the snapshot to add to the visit, containing the following keys: @@ -581,32 +579,39 @@ Raises: ValueError: if the origin's or visit's identifier does not exist. """ + if legacy_arg1: + assert legacy_arg2 + (origin, visit, snapshot) = \ + (snapshot, legacy_arg1, legacy_arg2) + else: + origin = visit = None + + if origin and ( + origin > len(self._origin_visits) or + visit > len(self._origin_visits[origin-1])): + raise ValueError('Origin with id %s does not exist or has no visit' + ' with id %s' % (origin, visit)) snapshot_id = snapshot['id'] + if self.journal_writer: + self.journal_writer.write_addition( + 'snapshot', snapshot) if snapshot_id not in self._snapshots: self._snapshots[snapshot_id] = { - 'origin': origin, - 'visit': visit, 'id': snapshot_id, 'branches': copy.deepcopy(snapshot['branches']), '_sorted_branch_names': sorted(snapshot['branches']) } self._objects[snapshot_id].append(('snapshot', snapshot_id)) - if origin <= len(self._origin_visits) and \ - visit <= len(self._origin_visits[origin-1]): + if origin: if self.journal_writer: - self.journal_writer.write_addition( - 'snapshot', snapshot) self.journal_writer.write_update('origin_visit', { **self._origin_visits[origin-1][visit-1], 'origin': self._origins[origin-1], 'snapshot': snapshot_id}) self._origin_visits[origin-1][visit-1]['snapshot'] = snapshot_id - else: - raise ValueError('Origin with id %s does not exist or has no visit' - ' with id %s' % (origin, visit)) def snapshot_get(self, snapshot_id): """Get the content, possibly partial, of a snapshot with the given id @@ -1077,7 +1082,8 @@ return visit_ret - def origin_visit_update(self, origin, visit_id, status, metadata=None): + def origin_visit_update(self, origin, visit_id, status=None, + metadata=None, snapshot_id=None): """Update an origin_visit's status. Args: @@ -1085,6 +1091,8 @@ visit_id (int): visit's identifier status: visit's new status metadata: data associated to the visit + snapshot_id (sha1_git): identifier of the snapshot to add to + the visit Returns: None @@ -1092,18 +1100,27 @@ """ origin_id = origin # TODO: rename the argument + try: + visit = self._origin_visits[origin_id-1][visit_id-1] + except IndexError: + raise ValueError('Invalid origin_id or visit_id') from None if self.journal_writer: origin = self.origin_get([{'id': origin_id}])[0] self.journal_writer.write_update('origin_visit', { - **self._origin_visits[origin_id-1][visit_id-1], 'origin': origin, 'visit': visit_id, - 'status': status, 'metadata': metadata}) + 'status': status or visit['status'], + 'date': visit['date'], + 'metadata': metadata or visit['metadata'], + 'snapshot': snapshot_id or visit['snapshot']}) if origin_id > len(self._origin_visits) or \ visit_id > len(self._origin_visits[origin_id-1]): return - self._origin_visits[origin_id-1][visit_id-1].update({ - 'status': status, - 'metadata': metadata}) + if status: + visit['status'] = status + if metadata: + visit['metadata'] = metadata + if snapshot_id: + visit['snapshot'] = snapshot_id def origin_visit_get(self, origin, last_visit=None, limit=None): """Retrieve all the origin's visit's information. diff --git a/swh/storage/sql/40-swh-func.sql b/swh/storage/sql/40-swh-func.sql --- a/swh/storage/sql/40-swh-func.sql +++ b/swh/storage/sql/40-swh-func.sql @@ -706,7 +706,7 @@ returning visit; $$; -create or replace function swh_snapshot_add(origin bigint, visit bigint, snapshot_id snapshot.id%type) +create or replace function swh_snapshot_add(snapshot_id snapshot.id%type) returns void language plpgsql as $$ @@ -740,9 +740,6 @@ where tmp.target is null and tmp.target_type is null and sb.target is null and sb.target_type is null; end if; - update origin_visit ov - set snapshot_id = snapshot_object_id - where ov.origin=swh_snapshot_add.origin and ov.visit=swh_snapshot_add.visit; end; $$; diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -768,13 +768,11 @@ yield data if data['target_type'] else None @db_transaction() - def snapshot_add(self, origin, visit, snapshot, + def snapshot_add(self, snapshot, origin=None, visit=None, db=None, cur=None): """Add a snapshot for the given origin/visit couple Args: - origin (int): id of the origin - visit (int): id of the visit snapshot (dict): the snapshot to add to the visit, containing the following keys: @@ -790,15 +788,32 @@ - **target** (:class:`bytes`): identifier of the target (currently a ``sha1_git`` for all object kinds, or the name of the target branch for aliases) + origin (int): legacy argument for backward compatibility + visit (int): legacy argument for backward compatibility Raises: ValueError: if the origin or visit id does not exist. """ - origin_id = origin - visit_id = visit + if origin: + if not visit: + raise TypeError( + 'snapshot_add expects one argument (or, as a legacy ' + 'behavior, three arguments), not two') + if isinstance(snapshot, int): + # Called by legacy code that uses the new api/client.py + (origin_id, visit_id, snapshot) = \ + (snapshot, origin, visit) + else: + # Called by legacy code that uses the old api/client.py + origin_id = origin + visit_id = visit + else: + # Called by new code that uses the new api/client.py + # (new code using the old api/client.py would crash before + # sending the request). + origin_id = visit_id = None if not db.snapshot_exists(snapshot['id'], cur): - db.mktemp_snapshot_branch(cur) db.copy_to( ( @@ -814,29 +829,28 @@ cur, ) - if self.journal_writer: - visit = db.origin_visit_get(origin_id, visit_id, cur=cur) - visit_exists = visit is not None - else: - visit_exists = db.origin_visit_exists(origin_id, visit_id) + if visit_id: + # Legacy behavior is to abort before adding the snapshot + # if the visit does not exist + if self.journal_writer: + visit = db.origin_visit_get(origin_id, visit_id, cur=cur) + visit_exists = visit is not None + else: + visit_exists = db.origin_visit_exists(origin_id, visit_id) - if not visit_exists: - raise ValueError('Not origin visit with ids (%s, %s)' % - (origin_id, visit_id)) + if not visit_exists: + raise ValueError('Not origin visit with ids (%s, %s)' % + (origin_id, visit_id)) if self.journal_writer: - # Send the snapshot before the origin: in case of a crash, - # it's better to have an orphan snapshot than have the - # origin_visit have a dangling reference to a snapshot - origin = self.origin_get([{'id': origin_id}], db=db, cur=cur)[0] - visit = dict(zip(db.origin_visit_get_cols, visit)) self.journal_writer.write_addition('snapshot', snapshot) - self.journal_writer.write_update('origin_visit', { - 'origin': origin, 'visit': visit_id, - 'status': visit['status'], 'metadata': visit['metadata'], - 'date': visit['date'], 'snapshot': snapshot['id']}) - db.snapshot_add(origin_id, visit_id, snapshot['id'], cur) + db.snapshot_add(snapshot['id'], cur) + + if visit_id: + self.origin_visit_update( + origin_id, visit_id, snapshot_id=snapshot['id'], + db=db, cur=cur) @db_transaction(statement_timeout=2000) def snapshot_get(self, snapshot_id, db=None, cur=None): @@ -1059,7 +1073,8 @@ } @db_transaction() - def origin_visit_update(self, origin, visit_id, status, metadata=None, + def origin_visit_update(self, origin, visit_id, status=None, + metadata=None, snapshot_id=None, db=None, cur=None): """Update an origin_visit's status. @@ -1068,6 +1083,8 @@ visit_id: Visit's id status: Visit's new status metadata: Data associated to the visit + snapshot_id (sha1_git): identifier of the snapshot to add to + the visit Returns: None @@ -1078,13 +1095,24 @@ if self.journal_writer: origin = self.origin_get([{'id': origin_id}], db=db, cur=cur)[0] visit = db.origin_visit_get(origin_id, visit_id, cur=cur) + if not visit: + raise ValueError('Invalid visit_id for this origin.') visit = dict(zip(db.origin_visit_get_cols, visit)) self.journal_writer.write_update('origin_visit', { 'origin': origin, 'visit': visit_id, - 'status': status, 'metadata': metadata, - 'date': visit['date'], 'snapshot': None}) + 'status': status or visit['status'], + 'metadata': metadata or visit['metadata'], + 'date': visit['date'], + 'snapshot': snapshot_id or visit['snapshot']}) + updates = [] + if status: + updates.append(('status', status)) + if metadata: + updates.append(('metadata', metadata)) + if snapshot_id: + updates.append(('snapshot', snapshot_id)) return db.origin_visit_update( - origin_id, visit_id, status, metadata, cur) + origin_id, visit_id, status, metadata, snapshot_id, cur) @db_transaction_generator(statement_timeout=500) def origin_visit_get(self, origin, last_visit=None, limit=None, db=None, diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1410,6 +1410,46 @@ self.date_visit1) visit_id = origin_visit1['visit'] + self.storage.snapshot_add(self.empty_snapshot) + self.storage.origin_visit_update( + origin_id, visit_id, snapshot_id=self.empty_snapshot['id']) + + by_id = self.storage.snapshot_get(self.empty_snapshot['id']) + self.assertEqual(by_id, self.empty_snapshot) + + by_ov = self.storage.snapshot_get_by_origin_visit(origin_id, visit_id) + self.assertEqual(by_ov, self.empty_snapshot) + + expected_origin = self.origin.copy() + expected_origin['id'] = origin_id + data1 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data2 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': self.empty_snapshot['id'], + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin_visit', data1), + ('snapshot', self.empty_snapshot), + ('origin_visit', data2)]) + + def test_snapshot_add_get_empty__legacy_add(self): + origin_id = self.storage.origin_add_one(self.origin) + origin_visit1 = self.storage.origin_visit_add(origin_id, + self.date_visit1) + visit_id = origin_visit1['visit'] + self.storage.snapshot_add(origin_id, visit_id, self.empty_snapshot) by_id = self.storage.snapshot_get(self.empty_snapshot['id']) @@ -1598,6 +1638,21 @@ self.journal_writer.objects[:] = [] + self.storage.snapshot_add(self.snapshot) + + with self.assertRaises(ValueError): + self.storage.origin_visit_update( + origin_id, visit_id, self.snapshot['id']) + + self.assertEqual(list(self.journal_writer.objects), [ + ('snapshot', self.snapshot)]) + + def test_snapshot_add_nonexistent_visit__legacy_add(self): + origin_id = self.storage.origin_add_one(self.origin) + visit_id = 54164461156 + + self.journal_writer.objects[:] = [] + with self.assertRaises(ValueError): self.storage.snapshot_add(origin_id, visit_id, self.snapshot) @@ -1618,6 +1673,74 @@ origin_visit1 = self.storage.origin_visit_add(origin_id, self.date_visit1) visit1_id = origin_visit1['visit'] + self.storage.snapshot_add(self.snapshot) + self.storage.origin_visit_update( + origin_id, visit1_id, snapshot_id=self.snapshot['id']) + + by_ov1 = self.storage.snapshot_get_by_origin_visit(origin_id, + visit1_id) + self.assertEqual(by_ov1, self.snapshot) + + origin_visit2 = self.storage.origin_visit_add(origin_id, + self.date_visit2) + visit2_id = origin_visit2['visit'] + + self.storage.snapshot_add(self.snapshot) + self.storage.origin_visit_update( + origin_id, visit2_id, snapshot_id=self.snapshot['id']) + + by_ov2 = self.storage.snapshot_get_by_origin_visit(origin_id, + visit2_id) + self.assertEqual(by_ov2, self.snapshot) + + expected_origin = self.origin.copy() + expected_origin['id'] = origin_id + data1 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data2 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': self.snapshot['id'], + } + data3 = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit2['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data4 = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit2['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': self.snapshot['id'], + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin_visit', data1), + ('snapshot', self.snapshot), + ('origin_visit', data2), + ('origin_visit', data3), + ('snapshot', self.snapshot), + ('origin_visit', data4)]) + + def test_snapshot_add_twice__legacy_add(self): + origin_id = self.storage.origin_add_one(self.origin) + origin_visit1 = self.storage.origin_visit_add(origin_id, + self.date_visit1) + visit1_id = origin_visit1['visit'] self.storage.snapshot_add(origin_id, visit1_id, self.snapshot) by_ov1 = self.storage.snapshot_get_by_origin_visit(origin_id, @@ -1707,6 +1830,67 @@ self.assertIsNone(self.storage.snapshot_get_latest(origin_id)) # Add snapshot to visit1, latest snapshot = visit 1 snapshot + self.storage.snapshot_add(self.complete_snapshot) + self.storage.origin_visit_update( + origin_id, visit1_id, snapshot_id=self.complete_snapshot['id']) + self.assertEqual(self.complete_snapshot, + self.storage.snapshot_get_latest(origin_id)) + + # Status filter: both visits are status=ongoing, so no snapshot + # returned + self.assertIsNone( + self.storage.snapshot_get_latest(origin_id, + allowed_statuses=['full']) + ) + + # Mark the first visit as completed and check status filter again + self.storage.origin_visit_update(origin_id, visit1_id, status='full') + self.assertEqual( + self.complete_snapshot, + self.storage.snapshot_get_latest(origin_id, + allowed_statuses=['full']), + ) + + # Add snapshot to visit2 and check that the new snapshot is returned + self.storage.snapshot_add(self.empty_snapshot) + self.storage.origin_visit_update( + origin_id, visit2_id, snapshot_id=self.empty_snapshot['id']) + self.assertEqual(self.empty_snapshot, + self.storage.snapshot_get_latest(origin_id)) + + # Check that the status filter is still working + self.assertEqual( + self.complete_snapshot, + self.storage.snapshot_get_latest(origin_id, + allowed_statuses=['full']), + ) + + # Add snapshot to visit3 (same date as visit2) and check that + # the new snapshot is returned + self.storage.snapshot_add(self.complete_snapshot) + self.storage.origin_visit_update( + origin_id, visit3_id, snapshot_id=self.complete_snapshot['id']) + self.assertEqual(self.complete_snapshot, + self.storage.snapshot_get_latest(origin_id)) + + def test_snapshot_get_latest__legacy_add(self): + origin_id = self.storage.origin_add_one(self.origin) + origin_visit1 = self.storage.origin_visit_add(origin_id, + self.date_visit1) + visit1_id = origin_visit1['visit'] + origin_visit2 = self.storage.origin_visit_add(origin_id, + self.date_visit2) + visit2_id = origin_visit2['visit'] + + # Add a visit with the same date as the previous one + origin_visit3 = self.storage.origin_visit_add(origin_id, + self.date_visit2) + visit3_id = origin_visit3['visit'] + + # Two visits, both with no snapshot: latest snapshot is None + self.assertIsNone(self.storage.snapshot_get_latest(origin_id)) + + # Add snapshot to visit1, latest snapshot = visit 1 snapshot self.storage.snapshot_add(origin_id, visit1_id, self.complete_snapshot) self.assertEqual(self.complete_snapshot, self.storage.snapshot_get_latest(origin_id))