diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -192,6 +192,9 @@ 'metadata': metadata, 'snapshot': snapshot}) + def origin_visit_upsert(self, visits): + return self.post('origin/visit/upsert', {'visits': visits}) + def origin_visit_get(self, origin, last_visit=None, limit=None): return self.post('origin/visit/get', { 'origin': origin, 'last_visit': last_visit, 'limit': limit}) diff --git a/swh/storage/api/server.py b/swh/storage/api/server.py --- a/swh/storage/api/server.py +++ b/swh/storage/api/server.py @@ -404,6 +404,13 @@ **decode_request(request))) +@app.route('/origin/visit/upsert', methods=['POST']) +@timed +def origin_visit_upsert(): + return encode_data(get_storage().origin_visit_upsert( + **decode_request(request))) + + @app.route('/person', methods=['POST']) @timed def person_get(): diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -327,6 +327,18 @@ }) cur.execute(query, (*values, *where_values)) + def origin_visit_upsert(self, origin, visit, date, status, + metadata, snapshot, cur=None): + cur = self._cursor(cur) + query = """INSERT INTO origin_visit ({cols}) VALUES ({values}) + ON CONFLICT ON CONSTRAINT origin_visit_pkey DO + UPDATE SET {updates}""".format( + cols=', '.join(self.origin_visit_get_cols), + values=', '.join('%s' for col in self.origin_visit_get_cols), + updates=', '.join('{0}=excluded.{0}'.format(col) + for col in self.origin_visit_get_cols)) + cur.execute(query, (origin, visit, date, status, metadata, snapshot)) + origin_visit_get_cols = ['origin', 'visit', 'date', 'status', 'metadata', 'snapshot'] diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1220,6 +1220,38 @@ if snapshot: visit['snapshot'] = snapshot + def origin_visit_upsert(self, visits): + """Add a origin_visits with a specific id and with all its data. + If there is already an origin_visit with the same + `(origin_id, visit_id)`, updates it instead of inserting a new one. + + Args: + visits: iterable of dicts with keys: + + origin: Visited Origin id + visit: origin visit id + date: timestamp of such visit + status: Visit's new status + metadata: Data associated to the visit + snapshot (sha1_git): identifier of the snapshot to add to + the visit + """ + if self.journal_writer: + for visit in visits: + visit = visit.copy() + visit['origin'] = self.origin_get([{'id': visit['origin']}])[0] + del visit['origin']['id'] + self.journal_writer.write_addition('origin_visit', visit) + + for visit in visits: + origin_id = visit['origin'] + visit_id = visit['visit'] + if isinstance(visit['date'], str): + visit['date'] = dateutil.parser.parse(visit['date']) + while len(self._origin_visits[origin_id-1]) < visit_id: + self._origin_visits[origin_id-1].append(None) + visit = self._origin_visits[origin_id-1][visit_id-1] = visit + def origin_visit_get(self, origin, last_visit=None, limit=None): """Retrieve all the origin's visit's information. @@ -1241,6 +1273,8 @@ if limit is not None: visits = visits[:limit] for visit in visits: + if not visit: + continue visit_id = visit['visit'] yield copy.deepcopy(self._origin_visits[origin-1][visit_id-1]) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1248,6 +1248,35 @@ db.origin_visit_update(origin_id, visit_id, updates, cur) + @db_transaction() + def origin_visit_upsert(self, visits, db=None, cur=None): + """Add a origin_visits with a specific id and with all its data. + If there is already an origin_visit with the same + `(origin_id, visit_id)`, overwrites it. + + Args: + visits: iterable of dicts with keys: + + origin: Visited Origin id + visit: origin visit id + date: timestamp of such visit + status: Visit's new status + metadata: Data associated to the visit + snapshot (sha1_git): identifier of the snapshot to add to + the visit + """ + if self.journal_writer: + for visit in visits: + visit = visit.copy() + visit['origin'] = self.origin_get( + [{'id': visit['origin']}], db=db, cur=cur)[0] + del visit['origin']['id'] + self.journal_writer.write_addition('origin_visit', visit) + + for visit in visits: + # TODO: upsert them all in a single query + db.origin_visit_upsert(**visit, cur=cur) + @db_transaction_generator(statement_timeout=500) def origin_visit_get(self, origin, last_visit=None, limit=None, db=None, cur=None): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1511,6 +1511,105 @@ # then self.assertEqual(actual_origin_visit1, expected_origin_visit) + def test_origin_visit_upsert_new(self): + # given + self.assertIsNone(self.storage.origin_get([self.origin2])[0]) + + origin_id = self.storage.origin_add_one(self.origin2) + self.assertIsNotNone(origin_id) + + # when + self.storage.origin_visit_upsert([{ + 'origin': origin_id, + 'date': self.date_visit2, + 'visit': 123, + 'status': 'full', + 'metadata': None, + 'snapshot': None, + }]) + + # then + actual_origin_visits = list(self.storage.origin_visit_get(origin_id)) + self.assertEqual(actual_origin_visits, + [{ + 'origin': origin_id, + 'date': self.date_visit2, + 'visit': 123, + 'status': 'full', + 'metadata': None, + 'snapshot': None, + }]) + + expected_origin = self.origin2.copy() + data = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': 123, + 'status': 'full', + 'metadata': None, + 'snapshot': None, + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin_visit', data)]) + + def test_origin_visit_upsert_existing(self): + # given + self.assertIsNone(self.storage.origin_get([self.origin2])[0]) + + origin_id = self.storage.origin_add_one(self.origin2) + self.assertIsNotNone(origin_id) + + # when + origin_visit1 = self.storage.origin_visit_add( + origin_id, + date=self.date_visit2) + self.storage.origin_visit_upsert([{ + 'origin': origin_id, + 'date': self.date_visit2, + 'visit': origin_visit1['visit'], + 'status': 'full', + 'metadata': None, + 'snapshot': None, + }]) + + # then + self.assertEqual(origin_visit1['origin'], origin_id) + self.assertIsNotNone(origin_visit1['visit']) + + actual_origin_visits = list(self.storage.origin_visit_get(origin_id)) + self.assertEqual(actual_origin_visits, + [{ + 'origin': origin_id, + 'date': self.date_visit2, + 'visit': origin_visit1['visit'], + 'status': 'full', + 'metadata': None, + 'snapshot': None, + }]) + + expected_origin = self.origin2.copy() + data1 = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data2 = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit1['visit'], + 'status': 'full', + 'metadata': None, + 'snapshot': None, + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin_visit', data1), + ('origin_visit', data2)]) + def test_origin_visit_get_by_no_result(self): # No result actual_origin_visit = self.storage.origin_visit_get_by(