diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -297,14 +297,35 @@ (origin, ts)) return cur.fetchone()[0] - def origin_visit_update(self, origin, visit_id, status, - metadata, cur=None): + def origin_visit_update(self, origin_id, visit_id, status, metadata, + snapshot_id, cur=None): """Update origin_visit's status.""" cur = self._cursor(cur) + update_cols = [] + values = [] + where = ['origin=%s AND visit=%s'] + where_values = [origin_id, visit_id] + from_ = '' + if status: + update_cols.append('status=%s') + values.append(status) + if metadata: + update_cols.append('metadata=%s') + values.append(jsonize(metadata)) + if snapshot_id: + update_cols.append('snapshot_id=snapshot.object_id') + from_ = 'FROM snapshot' + where.append('snapshot.id=%s') + where_values.append(snapshot_id) update = """UPDATE origin_visit - SET status=%s, metadata=%s - WHERE origin=%s AND visit=%s""" - cur.execute(update, (status, jsonize(metadata), origin, visit_id)) + SET {update_cols} + {from} + WHERE {where}""".format(**{ + 'update_cols': ', '.join(update_cols), + 'from': from_, + 'where': ' AND '.join(where) + }) + cur.execute(update, (*values, *where_values)) origin_visit_get_cols = ['origin', 'visit', 'date', 'status', 'metadata', 'snapshot'] diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1077,7 +1077,8 @@ return visit_ret - def origin_visit_update(self, origin, visit_id, status, metadata=None): + def origin_visit_update(self, origin, visit_id, status=None, + metadata=None, snapshot_id=None): """Update an origin_visit's status. Args: @@ -1085,6 +1086,8 @@ visit_id (int): visit's identifier status: visit's new status metadata: data associated to the visit + snapshot_id (sha1_git): identifier of the snapshot to add to + the visit Returns: None @@ -1092,18 +1095,27 @@ """ origin_id = origin # TODO: rename the argument + try: + visit = self._origin_visits[origin_id-1][visit_id-1] + except IndexError: + raise ValueError('Invalid origin_id or visit_id') from None if self.journal_writer: origin = self.origin_get([{'id': origin_id}])[0] self.journal_writer.write_update('origin_visit', { - **self._origin_visits[origin_id-1][visit_id-1], 'origin': origin, 'visit': visit_id, - 'status': status, 'metadata': metadata}) + 'status': status or visit['status'], + 'date': visit['date'], + 'metadata': metadata or visit['metadata'], + 'snapshot': snapshot_id or visit['snapshot']}) if origin_id > len(self._origin_visits) or \ visit_id > len(self._origin_visits[origin_id-1]): return - self._origin_visits[origin_id-1][visit_id-1].update({ - 'status': status, - 'metadata': metadata}) + if status: + visit['status'] = status + if metadata: + visit['metadata'] = metadata + if snapshot_id: + visit['snapshot'] = snapshot_id def origin_visit_get(self, origin, last_visit=None, limit=None): """Retrieve all the origin's visit's information. diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1059,7 +1059,8 @@ } @db_transaction() - def origin_visit_update(self, origin, visit_id, status, metadata=None, + def origin_visit_update(self, origin, visit_id, status=None, + metadata=None, snapshot_id=None, db=None, cur=None): """Update an origin_visit's status. @@ -1068,6 +1069,8 @@ visit_id: Visit's id status: Visit's new status metadata: Data associated to the visit + snapshot_id (sha1_git): identifier of the snapshot to add to + the visit Returns: None @@ -1078,13 +1081,24 @@ if self.journal_writer: origin = self.origin_get([{'id': origin_id}], db=db, cur=cur)[0] visit = db.origin_visit_get(origin_id, visit_id, cur=cur) + if not visit: + raise ValueError('Invalid visit_id for this origin.') visit = dict(zip(db.origin_visit_get_cols, visit)) self.journal_writer.write_update('origin_visit', { 'origin': origin, 'visit': visit_id, - 'status': status, 'metadata': metadata, - 'date': visit['date'], 'snapshot': None}) + 'status': status or visit['status'], + 'metadata': metadata or visit['metadata'], + 'date': visit['date'], + 'snapshot': snapshot_id or visit['snapshot']}) + updates = [] + if status: + updates.append(('status', status)) + if metadata: + updates.append(('metadata', metadata)) + if snapshot_id: + updates.append(('snapshot', snapshot_id)) return db.origin_visit_update( - origin_id, visit_id, status, metadata, cur) + origin_id, visit_id, status, metadata, snapshot_id, cur) @db_transaction_generator(statement_timeout=500) def origin_visit_get(self, origin, last_visit=None, limit=None, db=None,