diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import datetime import functools import json import logging @@ -24,7 +25,7 @@ from swh.model.model import ( Sha1Git, TimestampWithTimezone, Timestamp, Person, Content, - SkippedContent, OriginVisit, Origin + SkippedContent, OriginVisit, OriginVisitUpdate, Origin ) from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url @@ -607,32 +608,58 @@ else: return self._origin_visit_get_limit(origin_url, last_visit, limit) - def origin_visit_update( - self, origin_url: str, visit_id: int, updates: Dict[str, Any] - ) -> None: - set_parts = [] - args: List[Any] = [] - for (column, value) in updates.items(): - set_parts.append(f'{column} = %s') - if column == 'metadata': - args.append(json.dumps(value)) - else: - args.append(value) - - if not set_parts: - return - - query = ('UPDATE origin_visit SET ' + ', '.join(set_parts) + - ' WHERE origin = %s AND visit = %s') - self._execute_with_retries( - query, args + [origin_url, visit_id]) - @_prepared_insert_statement('origin_visit', _origin_visit_keys) def origin_visit_add_one( self, visit: OriginVisit, *, statement) -> None: self._add_one(statement, 'origin_visit', visit, self._origin_visit_keys) + _origin_visit_update_table_keys = [ + 'origin', 'visit', 'date', 'status', 'snapshot', 'metadata' + ] + + @_prepared_insert_statement( + 'origin_visit_update', _origin_visit_update_table_keys) + def origin_visit_update_add_one( + self, visit_update: OriginVisitUpdate, *, statement) -> None: + assert self._origin_visit_update_table_keys[-1] == 'metadata' + keys = self._origin_visit_update_table_keys + + metadata = json.dumps(visit_update.metadata) + self._execute_with_retries( + statement, + [getattr(visit_update, key) for key in keys[:-1]] + [metadata]) + + def _format_origin_visit_update_row( + self, visit_update: ResultSet) -> Dict[str, Any]: + """Format a row visit_update into an origin_visit_update dict + + """ + return { + **visit_update._asdict(), + 'origin': visit_update.origin, + 'date': visit_update.date.replace(tzinfo=datetime.timezone.utc), + 'metadata': (json.loads(visit_update.metadata) + if visit_update.metadata else None), + } + + @_prepared_statement('SELECT * FROM origin_visit_update ' + 'WHERE origin = ? AND visit = ? ' + 'ORDER BY date DESC ' + 'LIMIT 1') + def origin_visit_update_get_latest( + self, origin: str, visit: int, + *, statement) -> Optional[Dict[str, Any]]: + """Given an origin visit id, return its latest origin_visit_update + + """ + rows = list(self._execute_with_retries( + statement, [origin, visit])) + if rows: + return self._format_origin_visit_update_row(rows[0]) + else: + return None + @_prepared_statement( 'UPDATE origin_visit SET ' + ', '.join('%s = ?' % key for key in _origin_visit_update_keys) + @@ -669,28 +696,6 @@ def origin_visit_get_all(self, origin_url: str, *, statement) -> ResultSet: return self._execute_with_retries(statement, [origin_url]) - @_prepared_statement('SELECT * FROM origin_visit WHERE origin = ?') - def origin_visit_get_latest( - self, origin: str, allowed_statuses: Optional[Iterable[str]], - require_snapshot: bool, *, statement) -> Optional[Row]: - # TODO: do the ordering and filtering in Cassandra - rows = list(self._execute_with_retries(statement, [origin])) - - rows.sort(key=lambda row: (row.date, row.visit), reverse=True) - - for row in rows: - if require_snapshot and row.snapshot is None: - continue - if allowed_statuses is not None \ - and row.status not in allowed_statuses: - continue - if row.snapshot is not None and \ - self.snapshot_missing([row.snapshot]): - raise ValueError('visit references unknown snapshot') - return row - else: - return None - @_prepared_statement('SELECT * FROM origin_visit WHERE token(origin) >= ?') def _origin_visit_iter_from( self, min_token: int, *, statement) -> Iterator[Row]: diff --git a/swh/storage/cassandra/schema.py b/swh/storage/cassandra/schema.py --- a/swh/storage/cassandra/schema.py +++ b/swh/storage/cassandra/schema.py @@ -154,6 +154,15 @@ PRIMARY KEY ((origin), visit) ); +CREATE TABLE IF NOT EXISTS origin_visit_update ( + origin text, + visit bigint, + date timestamp, + status ascii, + metadata text, + snapshot blob, + PRIMARY KEY ((origin), visit, date) +); CREATE TABLE IF NOT EXISTS origin ( sha1 blob PRIMARY KEY, @@ -208,7 +217,8 @@ TABLES = ('skipped_content content revision revision_parent release ' 'directory directory_entry snapshot snapshot_branch ' - 'origin_visit origin tool_by_uuid tool object_count').split() + 'origin_visit origin tool_by_uuid tool object_count ' + 'origin_visit_update').split() HASH_ALGORITHMS = ['sha1', 'sha1_git', 'sha256', 'blake2s256'] diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -15,7 +15,7 @@ from swh.model.model import ( Revision, Release, Directory, DirectoryEntry, Content, - SkippedContent, OriginVisit, Snapshot, Origin + SkippedContent, OriginVisit, OriginVisitUpdate, Snapshot, Origin ) from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.storage.objstorage import ObjStorage @@ -561,11 +561,11 @@ def snapshot_get_by_origin_visit(self, origin, visit): try: - visit = self._cql_runner.origin_visit_get_one(origin, visit) + visit = self.origin_visit_get_by(origin, visit) except IndexError: return None - return self.snapshot_get(visit.snapshot) + return self.snapshot_get(visit['snapshot']) def snapshot_get_latest(self, origin, allowed_statuses=None): visit = self.origin_visit_get_latest( @@ -798,13 +798,13 @@ 'Unknown origin %s', origin_url) visit_id = self._cql_runner.origin_generate_unique_visit_id(origin_url) - + visit_status = 'ongoing' with convert_validation_exceptions(): visit = OriginVisit.from_dict({ 'origin': origin_url, 'date': date, 'type': type, - 'status': 'ongoing', + 'status': visit_status, 'snapshot': None, 'metadata': None, 'visit': visit_id @@ -812,8 +812,27 @@ self.journal_writer.origin_visit_add(visit) self._cql_runner.origin_visit_add_one(visit) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date, + status=visit_status, + snapshot=None, + metadata=None, + ) + self._origin_visit_update_add(visit_update) + return visit + def _origin_visit_update_add( + self, origin_visit_update: OriginVisitUpdate) -> None: + """Add an origin visit update""" + # Inject origin visit update in the schema + self._cql_runner.origin_visit_update_add_one( + origin_visit_update) + def origin_visit_update( self, origin: str, visit_id: int, status: str, metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None, @@ -821,18 +840,18 @@ origin_url = origin # TODO: rename the argument # Get the existing data of the visit - row = self._cql_runner.origin_visit_get_one(origin_url, visit_id) - if not row: + visit_ = self.origin_visit_get_by(origin_url, visit_id) + if not visit_: raise StorageArgumentException('This origin visit does not exist.') with convert_validation_exceptions(): - visit = OriginVisit.from_dict(self._format_origin_visit_row(row)) + visit = OriginVisit.from_dict(visit_) updates: Dict[str, Any] = { 'status': status } - if metadata: + if metadata and metadata != visit.metadata: updates['metadata'] = metadata - if snapshot: + if snapshot and snapshot != visit.snapshot: updates['snapshot'] = snapshot with convert_validation_exceptions(): @@ -840,7 +859,61 @@ self.journal_writer.origin_visit_update(visit) - self._cql_runner.origin_visit_update(origin_url, visit_id, updates) + last_visit_update = self._origin_visit_get_updated( + visit.origin, visit.visit) + assert last_visit_update is not None + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date or now(), + status=status, + snapshot=snapshot or last_visit_update['snapshot'], + metadata=metadata or last_visit_update['metadata'], + ) + self._origin_visit_update_add(visit_update) + + def _origin_visit_merge( + self, visit: Dict[str, Any], + visit_update: Dict[str, Any]) -> Dict[str, Any]: + """Merge origin_visit and origin_visit_update together. + + """ + return OriginVisit.from_dict({ + # default to the values in visit + **visit, + # override with the last update + **visit_update, + # visit['origin'] is the URL (via a join), while + # visit_update['origin'] is only an id. + 'origin': visit['origin'], + # but keep the date of the creation of the origin visit + 'date': visit['date'] + }).to_dict() + + def _origin_visit_apply_update( + self, visit: Dict[str, Any]) -> Dict[str, Any]: + """Retrieve the latest visit update information for the origin visit. + Then merge it with the visit and return it. + + """ + visit_update = self._cql_runner.origin_visit_update_get_latest( + visit['origin'], visit['visit']) + assert visit_update is not None + return self._origin_visit_merge(visit, visit_update) + + def _origin_visit_get_updated( + self, origin: str, visit_id: int) -> Optional[Dict[str, Any]]: + """Retrieve origin visit and latest origin visit update and merge them + into an origin visit. + + """ + row_visit = self._cql_runner.origin_visit_get_one(origin, visit_id) + if row_visit is None: + return None + visit = self._format_origin_visit_row(row_visit) + return self._origin_visit_apply_update(visit) def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: for visit in visits: @@ -850,7 +923,18 @@ self.journal_writer.origin_visit_upsert(visits) for visit in visits: + assert visit.visit is not None self._cql_runner.origin_visit_upsert(visit) + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=visit.origin, + visit=visit.visit, + date=now(), + status=visit.status, + snapshot=visit.snapshot, + metadata=visit.metadata, + ) + self._origin_visit_update_add(visit_update) @staticmethod def _format_origin_visit_row(visit): @@ -866,8 +950,9 @@ self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None) -> Iterable[Dict[str, Any]]: rows = self._cql_runner.origin_visit_get(origin, last_visit, limit) - - yield from map(self._format_origin_visit_row, rows) + for row in rows: + visit = self._format_origin_visit_row(row) + yield self._origin_visit_apply_update(visit) def origin_visit_find_by_date( self, origin: str, @@ -875,34 +960,51 @@ # Iterator over all the visits of the origin # This should be ok for now, as there aren't too many visits # per origin. - visits = list(self._cql_runner.origin_visit_get_all(origin)) + rows = list(self._cql_runner.origin_visit_get_all(origin)) def key(visit): dt = visit.date.replace(tzinfo=datetime.timezone.utc) - visit_date return (abs(dt), -visit.visit) - if visits: - visit = min(visits, key=key) - return visit._asdict() + if rows: + row = min(rows, key=key) + visit = self._format_origin_visit_row(row) + return self._origin_visit_apply_update(visit) return None def origin_visit_get_by( self, origin: str, visit: int) -> Optional[Dict[str, Any]]: - visit = self._cql_runner.origin_visit_get_one(origin, visit) - if visit: - return self._format_origin_visit_row(visit) + row = self._cql_runner.origin_visit_get_one(origin, visit) + if row: + visit_ = self._format_origin_visit_row(row) + return self._origin_visit_apply_update(visit_) return None def origin_visit_get_latest( self, origin: str, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False) -> Optional[Dict[str, Any]]: - visit = self._cql_runner.origin_visit_get_latest( - origin, - allowed_statuses=allowed_statuses, - require_snapshot=require_snapshot) - if visit: - return self._format_origin_visit_row(visit) - return None + # TODO: Do not fetch all visits + rows = self._cql_runner.origin_visit_get_all(origin) + latest_visit = None + for row in rows: + visit = self._format_origin_visit_row(row) + updated_visit = self._origin_visit_apply_update(visit) + if allowed_statuses and \ + updated_visit['status'] not in allowed_statuses: + continue + if require_snapshot and updated_visit['snapshot'] is None: + continue + + # updated_visit is a candidate + if latest_visit is not None: + if updated_visit['date'] < latest_visit['date']: + continue + if updated_visit['visit'] < latest_visit['visit']: + continue + + latest_visit = updated_visit + + return latest_visit def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back @@ -914,9 +1016,10 @@ rows = self._cql_runner.origin_visit_iter(start_token) for row in rows: visit = self._format_origin_visit_row(row) - if visit['date'] > back_in_the_day \ - and visit['status'] == 'full': - return visit + visit_update = self._origin_visit_apply_update(visit) + if visit_update['date'] > back_in_the_day \ + and visit_update['status'] == 'full': + return visit_update else: return None