diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -570,11 +570,16 @@ ', '.join('%s = ?' % key for key in _origin_visit_update_keys) + ' WHERE origin = ? AND visit = ?') def origin_visit_upsert( - self, visit: Dict[str, Any], *, statement) -> None: + self, visit: OriginVisit, *, statement) -> None: + args: List[Any] = [] + for column in self._origin_visit_update_keys: + if column == 'metadata': + args.append(json.dumps(visit.metadata)) + else: + args.append(getattr(visit, column)) + self._execute_with_retries( - statement, - [visit.get(key) for key in self._origin_visit_update_keys] - + [visit['origin'], visit['visit']]) + statement, args + [visit.origin, visit.visit]) # TODO: check if there is already one self._increment_counter('origin_visit', 1) diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -823,18 +823,9 @@ self._cql_runner.origin_visit_update(origin_url, visit_id, updates) - def origin_visit_upsert(self, visits): - visits = [visit.copy() for visit in visits] - for visit in visits: - if isinstance(visit['date'], str): - visit['date'] = dateutil.parser.parse(visit['date']) - + def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: self.journal_writer.origin_visit_upsert(visits) - for visit in visits: - visit = visit.copy() - if visit.get('metadata'): - visit['metadata'] = json.dumps(visit['metadata']) self._cql_runner.origin_visit_upsert(visit) @staticmethod diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -9,7 +9,7 @@ from swh.core.db import BaseDb from swh.core.db.db_utils import stored_procedure, jsonize from swh.core.db.db_utils import execute_values_generator -from swh.model.model import SHA1_SIZE +from swh.model.model import OriginVisit, SHA1_SIZE class Db(BaseDb): @@ -379,11 +379,11 @@ }) cur.execute(query, (*values, *where_values)) - def origin_visit_upsert(self, origin, visit, date, type, status, - metadata, snapshot, cur=None): + def origin_visit_upsert(self, origin_visit: OriginVisit, cur=None) -> None: # doing an extra query like this is way simpler than trying to join # the origin id in the query below - origin_id = next(self.origin_id_get_by_url([origin])) + ov = origin_visit + origin_id = next(self.origin_id_get_by_url([ov.origin])) cur = self._cursor(cur) query = """INSERT INTO origin_visit ({cols}) VALUES ({values}) @@ -394,7 +394,8 @@ updates=', '.join('{0}=excluded.{0}'.format(col) for col in self.origin_visit_get_cols)) cur.execute( - query, (origin_id, visit, date, type, status, metadata, snapshot)) + query, (origin_id, ov.visit, ov.date, ov.type, ov.status, + ov.metadata, ov.snapshot)) origin_visit_get_cols = [ 'origin', 'visit', 'date', 'type', diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -773,16 +773,7 @@ self._origin_visits[origin_url][visit_id-1] = visit - def origin_visit_upsert(self, visits): - for visit in visits: - if not isinstance(visit['origin'], str): - raise TypeError("visit['origin'] must be a string, not %r" - % (visit['origin'],)) - try: - visits = [OriginVisit.from_dict(d) for d in visits] - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) - + def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: self.journal_writer.origin_visit_upsert(visits) for visit in visits: @@ -797,10 +788,11 @@ self._objects[(origin_url, visit_id)].append( ('origin_visit', None)) - while len(self._origin_visits[origin_url]) <= visit_id: - self._origin_visits[origin_url].append(None) + if visit_id: + while len(self._origin_visits[origin_url]) <= visit_id: + self._origin_visits[origin_url].append(None) - self._origin_visits[origin_url][visit_id-1] = visit + self._origin_visits[origin_url][visit_id-1] = visit def _convert_visit(self, visit): if visit is None: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -7,8 +7,8 @@ from swh.core.api import remote_api_endpoint from swh.model.model import ( - SkippedContent, Content, Directory, Revision, Release, - Snapshot, Origin + Content, Directory, Origin, OriginVisit, Revision, Release, + Snapshot, SkippedContent ) @@ -806,7 +806,7 @@ ... @remote_api_endpoint('origin/visit/upsert') - def origin_visit_upsert(self, visits): + def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: """Add a origin_visits with a specific id and with all its data. If there is already an origin_visit with the same `(origin_id, visit_id)`, overwrites it. diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -4,7 +4,6 @@ # See top-level LICENSE file for more information import contextlib -import copy import datetime import itertools import json @@ -21,8 +20,8 @@ import psycopg2.errors from swh.model.model import ( - SkippedContent, Content, Directory, Revision, Release, - Snapshot, Origin, SHA1_SIZE + Content, Directory, Origin, OriginVisit, + Revision, Release, SkippedContent, Snapshot, SHA1_SIZE ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex from swh.storage.objstorage import ObjStorage @@ -870,21 +869,13 @@ @timed @db_transaction() - def origin_visit_upsert(self, visits, db=None, cur=None): - visits = copy.deepcopy(visits) - for visit in visits: - if isinstance(visit['date'], str): - visit['date'] = dateutil.parser.parse(visit['date']) - if not isinstance(visit['origin'], str): - raise StorageArgumentException( - "visit['origin'] must be a string, not %r" - % (visit['origin'],)) - + def origin_visit_upsert(self, visits: Iterable[OriginVisit], + db=None, cur=None) -> None: self.journal_writer.origin_visit_upsert(visits) for visit in visits: # TODO: upsert them all in a single query - db.origin_visit_upsert(**visit, cur=cur) + db.origin_visit_upsert(visit, cur=cur) @timed @db_transaction_generator(statement_timeout=500) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -26,7 +26,7 @@ from swh.model import from_disk, identifiers from swh.model.hashutil import hash_to_bytes -from swh.model.model import Content, Release, Revision +from swh.model.model import Content, OriginVisit, Release, Revision from swh.model.hypothesis_strategies import objects from swh.storage import HashCollision, get_storage from swh.storage.converters import origin_url_to_sha1 as sha1 @@ -1986,24 +1986,24 @@ # when swh_storage.origin_visit_upsert([ - { - 'origin': origin_url, - 'date': data.date_visit2, - 'visit': 123, - 'type': data.type_visit2, - 'status': 'full', - 'metadata': None, - 'snapshot': None, - }, - { - 'origin': origin_url, - 'date': '2018-01-01 23:00:00+00', - 'visit': 1234, - 'type': data.type_visit2, - 'status': 'full', - 'metadata': None, - 'snapshot': None, - }, + OriginVisit.from_dict({ + 'origin': origin_url, + 'date': data.date_visit2, + 'visit': 123, + 'type': data.type_visit2, + 'status': 'full', + 'metadata': None, + 'snapshot': None, + }), + OriginVisit.from_dict({ + 'origin': origin_url, + 'date': '2018-01-01 23:00:00+00', + 'visit': 1234, + 'type': data.type_visit2, + 'status': 'full', + 'metadata': None, + 'snapshot': None, + }), ]) # then @@ -2064,15 +2064,15 @@ date=data.date_visit2, type=data.type_visit1, ) - swh_storage.origin_visit_upsert([{ - 'origin': origin_url, - 'date': data.date_visit2, - 'visit': origin_visit1['visit'], - 'type': data.type_visit1, - 'status': 'full', - 'metadata': None, - 'snapshot': None, - }]) + swh_storage.origin_visit_upsert([OriginVisit.from_dict({ + 'origin': origin_url, + 'date': data.date_visit2, + 'visit': origin_visit1['visit'], + 'type': data.type_visit1, + 'status': 'full', + 'metadata': None, + 'snapshot': None, + })]) # then assert origin_visit1['origin'] == origin_url