diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import datetime import functools import json import logging @@ -38,6 +39,7 @@ Content, SkippedContent, OriginVisit, + OriginVisitUpdate, Origin, ) @@ -687,31 +689,63 @@ else: return self._origin_visit_get_limit(origin_url, last_visit, limit) - def origin_visit_update( - self, origin_url: str, visit_id: int, updates: Dict[str, Any] - ) -> None: - set_parts = [] - args: List[Any] = [] - for (column, value) in updates.items(): - set_parts.append(f"{column} = %s") - if column == "metadata": - args.append(json.dumps(value)) - else: - args.append(value) + @_prepared_insert_statement("origin_visit", _origin_visit_keys) + def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None: + self._add_one(statement, "origin_visit", visit, self._origin_visit_keys) - if not set_parts: - return + _origin_visit_update_table_keys = [ + "origin", + "visit", + "date", + "status", + "snapshot", + "metadata", + ] - query = ( - "UPDATE origin_visit SET " - + ", ".join(set_parts) - + " WHERE origin = %s AND visit = %s" + @_prepared_insert_statement("origin_visit_update", _origin_visit_update_table_keys) + def origin_visit_update_add_one( + self, visit_update: OriginVisitUpdate, *, statement + ) -> None: + assert self._origin_visit_update_table_keys[-1] == "metadata" + keys = self._origin_visit_update_table_keys + + metadata = json.dumps(visit_update.metadata) + self._execute_with_retries( + statement, [getattr(visit_update, key) for key in keys[:-1]] + [metadata] ) - self._execute_with_retries(query, args + [origin_url, visit_id]) - @_prepared_insert_statement("origin_visit", _origin_visit_keys) - def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None: - self._add_one(statement, "origin_visit", visit, self._origin_visit_keys) + def _format_origin_visit_update_row( + self, visit_update: ResultSet + ) -> Dict[str, Any]: + """Format a row visit_update into an origin_visit_update dict + + """ + return { + **visit_update._asdict(), + "origin": visit_update.origin, + "date": visit_update.date.replace(tzinfo=datetime.timezone.utc), + "metadata": ( + json.loads(visit_update.metadata) if visit_update.metadata else None + ), + } + + @_prepared_statement( + "SELECT * FROM origin_visit_update " + "WHERE origin = ? AND visit = ? " + "ORDER BY date DESC " + "LIMIT 1" + ) + def origin_visit_update_get_latest( + self, origin: str, visit: int, *, statement + ) -> Optional[Dict[str, Any]]: + """Given an origin visit id, return its latest origin_visit_update + + """ + rows = list(self._execute_with_retries(statement, [origin, visit])) + if rows: + return self._format_origin_visit_update_row(rows[0]) + else: + return None @_prepared_statement( "UPDATE origin_visit SET " @@ -745,31 +779,6 @@ def origin_visit_get_all(self, origin_url: str, *, statement) -> ResultSet: return self._execute_with_retries(statement, [origin_url]) - @_prepared_statement("SELECT * FROM origin_visit WHERE origin = ?") - def origin_visit_get_latest( - self, - origin: str, - allowed_statuses: Optional[Iterable[str]], - require_snapshot: bool, - *, - statement, - ) -> Optional[Row]: - # TODO: do the ordering and filtering in Cassandra - rows = list(self._execute_with_retries(statement, [origin])) - - rows.sort(key=lambda row: (row.date, row.visit), reverse=True) - - for row in rows: - if require_snapshot and row.snapshot is None: - continue - if allowed_statuses is not None and row.status not in allowed_statuses: - continue - if row.snapshot is not None and self.snapshot_missing([row.snapshot]): - raise ValueError("visit references unknown snapshot") - return row - else: - return None - @_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) >= ?") def _origin_visit_iter_from(self, min_token: int, *, statement) -> Iterator[Row]: yield from self._execute_with_retries(statement, [min_token]) diff --git a/swh/storage/cassandra/schema.py b/swh/storage/cassandra/schema.py --- a/swh/storage/cassandra/schema.py +++ b/swh/storage/cassandra/schema.py @@ -154,6 +154,15 @@ PRIMARY KEY ((origin), visit) ); +CREATE TABLE IF NOT EXISTS origin_visit_update ( + origin text, + visit bigint, + date timestamp, + status ascii, + metadata text, + snapshot blob, + PRIMARY KEY ((origin), visit, date) +); CREATE TABLE IF NOT EXISTS origin ( sha1 blob PRIMARY KEY, @@ -211,7 +220,8 @@ TABLES = ( "skipped_content content revision revision_parent release " "directory directory_entry snapshot snapshot_branch " - "origin_visit origin tool_by_uuid tool object_count" + "origin_visit origin tool_by_uuid tool object_count " + "origin_visit_update" ).split() HASH_ALGORITHMS = ["sha1", "sha1_git", "sha256", "blake2s256"] diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -21,6 +21,7 @@ Content, SkippedContent, OriginVisit, + OriginVisitUpdate, Snapshot, Origin, ) @@ -576,11 +577,11 @@ def snapshot_get_by_origin_visit(self, origin, visit): try: - visit = self._cql_runner.origin_visit_get_one(origin, visit) + visit = self.origin_visit_get_by(origin, visit) except IndexError: return None - return self.snapshot_get(visit.snapshot) + return self.snapshot_get(visit["snapshot"]) def snapshot_get_latest(self, origin, allowed_statuses=None): visit = self.origin_visit_get_latest( @@ -806,14 +807,14 @@ raise StorageArgumentException("Unknown origin %s", origin_url) visit_id = self._cql_runner.origin_generate_unique_visit_id(origin_url) - + visit_status = "ongoing" with convert_validation_exceptions(): visit = OriginVisit.from_dict( { "origin": origin_url, "date": date, "type": type, - "status": "ongoing", + "status": visit_status, "snapshot": None, "metadata": None, "visit": visit_id, @@ -822,8 +823,25 @@ self.journal_writer.origin_visit_add(visit) self._cql_runner.origin_visit_add_one(visit) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date, + status=visit_status, + snapshot=None, + metadata=None, + ) + self._origin_visit_update_add(visit_update) + return visit + def _origin_visit_update_add(self, origin_visit_update: OriginVisitUpdate) -> None: + """Add an origin visit update""" + # Inject origin visit update in the schema + self._cql_runner.origin_visit_update_add_one(origin_visit_update) + def origin_visit_update( self, origin: str, @@ -836,16 +854,16 @@ origin_url = origin # TODO: rename the argument # Get the existing data of the visit - row = self._cql_runner.origin_visit_get_one(origin_url, visit_id) - if not row: + visit_ = self.origin_visit_get_by(origin_url, visit_id) + if not visit_: raise StorageArgumentException("This origin visit does not exist.") with convert_validation_exceptions(): - visit = OriginVisit.from_dict(self._format_origin_visit_row(row)) + visit = OriginVisit.from_dict(visit_) updates: Dict[str, Any] = {"status": status} - if metadata: + if metadata and metadata != visit.metadata: updates["metadata"] = metadata - if snapshot: + if snapshot and snapshot != visit.snapshot: updates["snapshot"] = snapshot with convert_validation_exceptions(): @@ -853,7 +871,63 @@ self.journal_writer.origin_visit_update(visit) - self._cql_runner.origin_visit_update(origin_url, visit_id, updates) + last_visit_update = self._origin_visit_get_updated(visit.origin, visit.visit) + assert last_visit_update is not None + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date or now(), + status=status, + snapshot=snapshot or last_visit_update["snapshot"], + metadata=metadata or last_visit_update["metadata"], + ) + self._origin_visit_update_add(visit_update) + + def _origin_visit_merge( + self, visit: Dict[str, Any], visit_update: Dict[str, Any] + ) -> Dict[str, Any]: + """Merge origin_visit and origin_visit_update together. + + """ + return OriginVisit.from_dict( + { + # default to the values in visit + **visit, + # override with the last update + **visit_update, + # visit['origin'] is the URL (via a join), while + # visit_update['origin'] is only an id. + "origin": visit["origin"], + # but keep the date of the creation of the origin visit + "date": visit["date"], + } + ).to_dict() + + def _origin_visit_apply_update(self, visit: Dict[str, Any]) -> Dict[str, Any]: + """Retrieve the latest visit update information for the origin visit. + Then merge it with the visit and return it. + + """ + visit_update = self._cql_runner.origin_visit_update_get_latest( + visit["origin"], visit["visit"] + ) + assert visit_update is not None + return self._origin_visit_merge(visit, visit_update) + + def _origin_visit_get_updated( + self, origin: str, visit_id: int + ) -> Optional[Dict[str, Any]]: + """Retrieve origin visit and latest origin visit update and merge them + into an origin visit. + + """ + row_visit = self._cql_runner.origin_visit_get_one(origin, visit_id) + if row_visit is None: + return None + visit = self._format_origin_visit_row(row_visit) + return self._origin_visit_apply_update(visit) def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: for visit in visits: @@ -862,7 +936,18 @@ self.journal_writer.origin_visit_upsert(visits) for visit in visits: + assert visit.visit is not None self._cql_runner.origin_visit_upsert(visit) + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=visit.origin, + visit=visit.visit, + date=now(), + status=visit.status, + snapshot=visit.snapshot, + metadata=visit.metadata, + ) + self._origin_visit_update_add(visit_update) @staticmethod def _format_origin_visit_row(visit): @@ -877,8 +962,9 @@ self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None ) -> Iterable[Dict[str, Any]]: rows = self._cql_runner.origin_visit_get(origin, last_visit, limit) - - yield from map(self._format_origin_visit_row, rows) + for row in rows: + visit = self._format_origin_visit_row(row) + yield self._origin_visit_apply_update(visit) def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime @@ -886,21 +972,23 @@ # Iterator over all the visits of the origin # This should be ok for now, as there aren't too many visits # per origin. - visits = list(self._cql_runner.origin_visit_get_all(origin)) + rows = list(self._cql_runner.origin_visit_get_all(origin)) def key(visit): dt = visit.date.replace(tzinfo=datetime.timezone.utc) - visit_date return (abs(dt), -visit.visit) - if visits: - visit = min(visits, key=key) - return visit._asdict() + if rows: + row = min(rows, key=key) + visit = self._format_origin_visit_row(row) + return self._origin_visit_apply_update(visit) return None def origin_visit_get_by(self, origin: str, visit: int) -> Optional[Dict[str, Any]]: - visit = self._cql_runner.origin_visit_get_one(origin, visit) - if visit: - return self._format_origin_visit_row(visit) + row = self._cql_runner.origin_visit_get_one(origin, visit) + if row: + visit_ = self._format_origin_visit_row(row) + return self._origin_visit_apply_update(visit_) return None def origin_visit_get_latest( @@ -909,12 +997,27 @@ allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[Dict[str, Any]]: - visit = self._cql_runner.origin_visit_get_latest( - origin, allowed_statuses=allowed_statuses, require_snapshot=require_snapshot - ) - if visit: - return self._format_origin_visit_row(visit) - return None + # TODO: Do not fetch all visits + rows = self._cql_runner.origin_visit_get_all(origin) + latest_visit = None + for row in rows: + visit = self._format_origin_visit_row(row) + updated_visit = self._origin_visit_apply_update(visit) + if allowed_statuses and updated_visit["status"] not in allowed_statuses: + continue + if require_snapshot and updated_visit["snapshot"] is None: + continue + + # updated_visit is a candidate + if latest_visit is not None: + if updated_visit["date"] < latest_visit["date"]: + continue + if updated_visit["visit"] < latest_visit["visit"]: + continue + + latest_visit = updated_visit + + return latest_visit def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back @@ -926,8 +1029,12 @@ rows = self._cql_runner.origin_visit_iter(start_token) for row in rows: visit = self._format_origin_visit_row(row) - if visit["date"] > back_in_the_day and visit["status"] == "full": - return visit + visit_update = self._origin_visit_apply_update(visit) + if ( + visit_update["date"] > back_in_the_day + and visit_update["status"] == "full" + ): + return visit_update else: return None