diff --git a/sql/upgrades/147.sql b/sql/upgrades/147.sql new file mode 100644 --- /dev/null +++ b/sql/upgrades/147.sql @@ -0,0 +1,58 @@ +-- SWH DB schema upgrade +-- from_version: 146 +-- to_version: 147 +-- description: Add origin_visit_update table and migrate origin_visit +-- to origin_visit_update + +-- latest schema version +insert into dbversion(version, release, description) + values(147, now(), 'Work In Progress'); + +-- schema change + +-- Crawling history of software origin visits by Software Heritage. Each +-- visit see its history change through new origin visit updates +create table origin_visit_update +( + id bigserial not null, -- TODO: Decide if we keep that or not + origin bigint not null, + visit bigint not null, + date timestamptz not null, + status origin_visit_status not null, + metadata jsonb, + snapshot sha1_git +); + +comment on column origin_visit_update.id is 'visit update id'; +comment on column origin_visit_update.origin is 'origin concerned by the visit update'; +comment on column origin_visit_update.visit is 'visit concerned by the visit update'; +comment on column origin_visit_update.date is 'Visit update timestamp'; +comment on column origin_visit_update.status is 'Visit update status'; +comment on column origin_visit_update.metadata is 'Origin metadata at visit update time'; +comment on column origin_visit_update.snapshot is 'Origin snapshot at visit update time'; + +-- origin_visit_update + +create unique index concurrently origin_visit_update_pkey on origin_visit_update(origin, visit, date); +alter table origin_visit_update add primary key using index origin_visit_update_pkey; + +alter table origin_visit_update + add constraint origin_visit_update_origin_visit_fkey + foreign key (origin, visit) + references origin_visit(origin, visit) not valid; +alter table origin_visit_update validate constraint origin_visit_update_origin_visit_fkey; + + +-- data change + +-- best approximation of the visit update date is the origin_visit's date +insert into origin_visit_update (origin, visit, date, status, metadata, snapshot) +select origin, visit, date, status, metadata, snaspshot +from origin_visit +on conflict origin_visit_update(origin, visit, date) +do +-- what policy? +nothing; +-- update set status = excluded.status, +-- metadata = excluded.metadata, +-- snapshot = excluded.snapshot; diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -3,12 +3,14 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import datetime import functools import json import logging import random from typing import ( - Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar + Any, Callable, Dict, Iterable, Iterator, List, Optional, + Tuple, TypeVar ) from cassandra import CoordinationFailure @@ -23,7 +25,7 @@ from swh.model.model import ( Sha1Git, TimestampWithTimezone, Timestamp, Person, Content, - SkippedContent, OriginVisit, Origin + SkippedContent, OriginVisit, OriginVisitUpdate, Origin ) from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url @@ -632,6 +634,52 @@ self._add_one(statement, 'origin_visit', visit, self._origin_visit_keys) + _origin_visit_update_table_keys = [ + 'origin', 'visit', 'date', 'status', 'snapshot', 'metadata' + ] + + @_prepared_insert_statement( + 'origin_visit_update', _origin_visit_update_table_keys) + def origin_visit_update_add_one( + self, visit_update: OriginVisitUpdate, *, statement) -> None: + assert self._origin_visit_update_table_keys[-1] == 'metadata' + keys = self._origin_visit_update_table_keys + + metadata = json.dumps(visit_update.metadata) + self._execute_with_retries( + statement, + [getattr(visit_update, key) for key in keys[:-1]] + [metadata]) + + def _format_origin_visit_update_row( + self, visit_update: ResultSet) -> Dict[str, Any]: + """Format a row visit_update into an origin_visit_update dict + + """ + return { + **visit_update._asdict(), + 'origin': visit_update.origin, + 'date': visit_update.date.replace(tzinfo=datetime.timezone.utc), + 'metadata': (json.loads(visit_update.metadata) + if visit_update.metadata else None), + } + + @_prepared_statement('SELECT * FROM origin_visit_update ' + 'WHERE origin = ? AND visit = ? ' + 'ORDER BY date DESC ' + 'LIMIT 1') + def origin_visit_update_get_latest( + self, origin: str, visit: int, + *, statement) -> Optional[Dict[str, Any]]: + """Given an origin visit id, return its latest origin_visit_update + + """ + rows = list(self._execute_with_retries( + statement, [origin, visit])) + if rows: + return self._format_origin_visit_update_row(rows[0]) + else: + return None + @_prepared_statement( 'UPDATE origin_visit SET ' + ', '.join('%s = ?' % key for key in _origin_visit_update_keys) + @@ -666,43 +714,19 @@ @_prepared_statement('SELECT * FROM origin_visit ' 'WHERE origin = ?') def origin_visit_get_all(self, origin_url: str, *, statement) -> ResultSet: - return self._execute_with_retries( - statement, [origin_url]) - - @_prepared_statement('SELECT * FROM origin_visit WHERE origin = ?') - def origin_visit_get_latest( - self, origin: str, allowed_statuses: Optional[Iterable[str]], - require_snapshot: bool, *, statement) -> Optional[Row]: - # TODO: do the ordering and filtering in Cassandra - rows = list(self._execute_with_retries(statement, [origin])) - - rows.sort(key=lambda row: (row.date, row.visit), reverse=True) - - for row in rows: - if require_snapshot and row.snapshot is None: - continue - if allowed_statuses is not None \ - and row.status not in allowed_statuses: - continue - if row.snapshot is not None and \ - self.snapshot_missing([row.snapshot]): - raise ValueError('visit references unknown snapshot') - return row - else: - return None + return self._execute_with_retries(statement, [origin_url]) @_prepared_statement('SELECT * FROM origin_visit WHERE token(origin) >= ?') def _origin_visit_iter_from( - self, min_token: int, *, statement) -> Generator[Row, None, None]: + self, min_token: int, *, statement) -> Iterator[Row]: yield from self._execute_with_retries(statement, [min_token]) @_prepared_statement('SELECT * FROM origin_visit WHERE token(origin) < ?') def _origin_visit_iter_to( - self, max_token: int, *, statement) -> Generator[Row, None, None]: + self, max_token: int, *, statement) -> Iterator[Row]: yield from self._execute_with_retries(statement, [max_token]) - def origin_visit_iter( - self, start_token: int) -> Generator[Row, None, None]: + def origin_visit_iter(self, start_token: int) -> Iterator[Row]: """Returns all origin visits in order from this token, and wraps around the token space.""" yield from self._origin_visit_iter_from(start_token) diff --git a/swh/storage/cassandra/schema.py b/swh/storage/cassandra/schema.py --- a/swh/storage/cassandra/schema.py +++ b/swh/storage/cassandra/schema.py @@ -154,6 +154,15 @@ PRIMARY KEY ((origin), visit) ); +CREATE TABLE IF NOT EXISTS origin_visit_update ( + origin text, + visit bigint, + date timestamp, + status ascii, + metadata text, + snapshot blob, + PRIMARY KEY ((origin), visit, date) +); CREATE TABLE IF NOT EXISTS origin ( sha1 blob PRIMARY KEY, @@ -208,7 +217,8 @@ TABLES = ('skipped_content content revision revision_parent release ' 'directory directory_entry snapshot snapshot_branch ' - 'origin_visit origin tool_by_uuid tool object_count').split() + 'origin_visit origin tool_by_uuid tool object_count ' + 'origin_visit_update').split() HASH_ALGORITHMS = ['sha1', 'sha1_git', 'sha256', 'blake2s256'] diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -15,7 +15,7 @@ from swh.model.model import ( Revision, Release, Directory, DirectoryEntry, Content, - SkippedContent, OriginVisit, Snapshot, Origin + SkippedContent, OriginVisit, OriginVisitUpdate, Snapshot, Origin ) from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.storage.objstorage import ObjStorage @@ -28,6 +28,7 @@ ) from .cql import CqlRunner from .schema import HASH_ALGORITHMS +from ..validate import convert_validation_exceptions # Max block size of contents to return @@ -571,11 +572,11 @@ def snapshot_get_by_origin_visit(self, origin, visit): try: - visit = self._cql_runner.origin_visit_get_one(origin, visit) + visit = self.origin_visit_get_by(origin, visit) except IndexError: return None - return self.snapshot_get(visit.snapshot) + return self.snapshot_get(visit['snapshot']) def snapshot_get_latest(self, origin, allowed_statuses=None): visit = self.origin_visit_get_latest( @@ -808,58 +809,145 @@ 'Unknown origin %s', origin_url) visit_id = self._cql_runner.origin_generate_unique_visit_id(origin_url) - + visit_status = 'ongoing' try: visit = OriginVisit.from_dict({ 'origin': origin_url, 'date': date, 'type': type, - 'status': 'ongoing', + 'status': visit_status, 'snapshot': None, 'metadata': None, 'visit': visit_id }) except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) + self.journal_writer.origin_visit_add(visit) self._cql_runner.origin_visit_add_one(visit) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date, + status=visit_status, + snapshot=None, + metadata=None, + ) + self._origin_visit_update_add(visit_update) + return visit + def _origin_visit_update_add( + self, origin_visit_update: OriginVisitUpdate) -> None: + """Add an origin visit update""" + # Inject origin visit update in the schema + self._cql_runner.origin_visit_update_add_one( + origin_visit_update) + def origin_visit_update( self, origin: str, visit_id: int, status: str, - metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None): + metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None, + date: Optional[datetime.datetime] = None): origin_url = origin # TODO: rename the argument # Get the existing data of the visit - row = self._cql_runner.origin_visit_get_one(origin_url, visit_id) - if not row: + visit_ = self.origin_visit_get_by(origin_url, visit_id) + if not visit_: raise StorageArgumentException('This origin visit does not exist.') - try: - visit = OriginVisit.from_dict(self._format_origin_visit_row(row)) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + with convert_validation_exceptions(): + visit = OriginVisit.from_dict(visit_) updates: Dict[str, Any] = { 'status': status } - if metadata: + if metadata and metadata != visit.metadata: updates['metadata'] = metadata - if snapshot: + if snapshot and snapshot != visit.snapshot: updates['snapshot'] = snapshot - try: + with convert_validation_exceptions(): visit = attr.evolve(visit, **updates) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) self.journal_writer.origin_visit_update(visit) - self._cql_runner.origin_visit_update(origin_url, visit_id, updates) + last_visit_update = self._origin_visit_get_updated( + visit.origin, visit.visit) + assert last_visit_update is not None + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date or now(), + status=status, + snapshot=snapshot or last_visit_update['snapshot'], + metadata=metadata or last_visit_update['metadata'], + ) + self._origin_visit_update_add(visit_update) + + def _origin_visit_merge( + self, visit: Dict[str, Any], + visit_update: Dict[str, Any]) -> Dict[str, Any]: + """Merge origin_visit and origin_visit_update together. + + """ + return OriginVisit.from_dict({ + # default to the values in visit + **visit, + # override with the last update + **visit_update, + # visit['origin'] is the URL (via a join), while + # visit_update['origin'] is only an id. + 'origin': visit['origin'], + # but keep the date of the creation of the origin visit + 'date': visit['date'] + }).to_dict() + + def _origin_visit_apply_update( + self, visit: Dict[str, Any]) -> Dict[str, Any]: + """Retrieve the latest visit update information for the origin visit. + Then merge it with the visit and return it. + + """ + visit_update = self._cql_runner.origin_visit_update_get_latest( + visit['origin'], visit['visit']) + assert visit_update is not None + return self._origin_visit_merge(visit, visit_update) + + def _origin_visit_get_updated( + self, origin: str, visit_id: int) -> Optional[Dict[str, Any]]: + """Retrieve origin visit and latest origin visit update and merge them + into an origin visit. + + """ + row_visit = self._cql_runner.origin_visit_get_one(origin, visit_id) + if row_visit is None: + return None + visit = self._format_origin_visit_row(row_visit) + return self._origin_visit_apply_update(visit) def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: + for visit in visits: + if visit.visit is None: + raise StorageArgumentException( + f'Missing visit id for visit {visit}') + self.journal_writer.origin_visit_upsert(visits) for visit in visits: + assert visit.visit is not None self._cql_runner.origin_visit_upsert(visit) + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=visit.origin, + visit=visit.visit, + date=now(), + status=visit.status, + snapshot=visit.snapshot, + metadata=visit.metadata, + ) + self._origin_visit_update_add(visit_update) @staticmethod def _format_origin_visit_row(visit): @@ -871,38 +959,65 @@ if visit.metadata else None), } - def origin_visit_get(self, origin, last_visit=None, limit=None): + def origin_visit_get( + self, origin: str, last_visit: Optional[int] = None, + limit: Optional[int] = None) -> Iterable[Dict[str, Any]]: rows = self._cql_runner.origin_visit_get(origin, last_visit, limit) + for row in rows: + visit = self._format_origin_visit_row(row) + yield self._origin_visit_apply_update(visit) - yield from map(self._format_origin_visit_row, rows) - - def origin_visit_find_by_date(self, origin, visit_date): + def origin_visit_find_by_date( + self, origin: str, + visit_date: datetime.datetime) -> Optional[Dict[str, Any]]: # Iterator over all the visits of the origin # This should be ok for now, as there aren't too many visits # per origin. - visits = list(self._cql_runner.origin_visit_get_all(origin)) + rows = list(self._cql_runner.origin_visit_get_all(origin)) def key(visit): dt = visit.date.replace(tzinfo=datetime.timezone.utc) - visit_date return (abs(dt), -visit.visit) - if visits: - visit = min(visits, key=key) - return visit._asdict() + if rows: + row = min(rows, key=key) + visit = self._format_origin_visit_row(row) + return self._origin_visit_apply_update(visit) + return None - def origin_visit_get_by(self, origin, visit): - visit = self._cql_runner.origin_visit_get_one(origin, visit) - if visit: - return self._format_origin_visit_row(visit) + def origin_visit_get_by( + self, origin: str, visit: int) -> Optional[Dict[str, Any]]: + row = self._cql_runner.origin_visit_get_one(origin, visit) + if row: + visit_ = self._format_origin_visit_row(row) + return self._origin_visit_apply_update(visit_) + return None def origin_visit_get_latest( - self, origin, allowed_statuses=None, require_snapshot=False): - visit = self._cql_runner.origin_visit_get_latest( - origin, - allowed_statuses=allowed_statuses, - require_snapshot=require_snapshot) - if visit: - return self._format_origin_visit_row(visit) + self, origin: str, allowed_statuses: Optional[List[str]] = None, + require_snapshot: bool = False) -> Optional[Dict[str, Any]]: + # TODO: Do not fetch all visits + rows = self._cql_runner.origin_visit_get_all(origin) + latest_visit = None + for row in rows: + visit = self._format_origin_visit_row(row) + updated_visit = self._origin_visit_apply_update(visit) + if allowed_statuses and \ + updated_visit['status'] not in allowed_statuses: + continue + if require_snapshot and updated_visit['snapshot'] is None: + continue + + # updated_visit is a candidate + if latest_visit is not None: + if updated_visit['date'] < latest_visit['date']: + continue + if updated_visit['visit'] < latest_visit['visit']: + continue + + latest_visit = updated_visit + + return latest_visit def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back @@ -914,9 +1029,10 @@ rows = self._cql_runner.origin_visit_iter(start_token) for row in rows: visit = self._format_origin_visit_row(row) - if visit['date'] > back_in_the_day \ - and visit['status'] == 'full': - return visit + visit_update = self._origin_visit_apply_update(visit) + if visit_update['date'] > back_in_the_day \ + and visit_update['status'] == 'full': + return visit_update else: return None diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -6,10 +6,12 @@ import random import select +from typing import Any, Dict, Optional, Tuple + from swh.core.db import BaseDb from swh.core.db.db_utils import stored_procedure, jsonize from swh.core.db.db_utils import execute_values_generator -from swh.model.model import OriginVisit, SHA1_SIZE +from swh.model.model import OriginVisit, OriginVisitUpdate, SHA1_SIZE class Db(BaseDb): @@ -209,9 +211,13 @@ def snapshot_get_by_origin_visit(self, origin_url, visit_id, cur=None): cur = self._cursor(cur) query = """\ - SELECT snapshot FROM origin_visit - INNER JOIN origin ON origin.id = origin_visit.origin - WHERE origin.url=%s AND origin_visit.visit=%s; + SELECT ovu.snapshot + FROM origin_visit ov + INNER JOIN origin o ON o.id = ov.origin + INNER JOIN origin_visit_update ovu + ON ov.origin = ovu.origin AND ov.visit = ovu.visit + WHERE o.url=%s AND ov.visit=%s + ORDER BY ovu.date DESC LIMIT 1 """ cur.execute(query, (origin_url, visit_id)) @@ -351,6 +357,26 @@ (origin, ts, type)) return cur.fetchone()[0] + origin_visit_update_cols = [ + 'origin', 'visit', 'date', 'status', 'snapshot', 'metadata'] + + def origin_visit_update_add( + self, visit_update: OriginVisitUpdate, cur=None) -> None: + assert self.origin_visit_update_cols[0] == 'origin' + assert self.origin_visit_update_cols[-1] == 'metadata' + cols = self.origin_visit_update_cols[1:-1] + cur = self._cursor(cur) + cur.execute( + f"WITH origin_id as (select id from origin where url=%s) " + f"INSERT INTO origin_visit_update " + f"(origin, {', '.join(cols)}, metadata) " + f"VALUES ((select id from origin_id), " + f"{', '.join(['%s']*len(cols))}, %s) " + f"ON CONFLICT (origin, visit, date) do nothing", + [visit_update.origin] + + [getattr(visit_update, key) for key in cols] + + [jsonize(visit_update.metadata)]) + def origin_visit_update(self, origin_id, visit_id, updates, cur=None): """Update origin_visit's status.""" cur = self._cursor(cur) @@ -401,8 +427,34 @@ 'origin', 'visit', 'date', 'type', 'status', 'metadata', 'snapshot'] origin_visit_select_cols = [ - 'origin.url AS origin', 'visit', 'date', 'origin_visit.type AS type', - 'status', 'metadata', 'snapshot'] + 'o.url AS origin', 'ov.visit', 'ov.date', 'ov.type AS type', + 'ovu.status', 'ovu.metadata', 'ovu.snapshot'] + + def _make_origin_visit_update( + self, row: Tuple[Any]) -> Optional[Dict[str, Any]]: + """Make an origin_visit_update dict out of a row + + """ + if not row: + return None + return dict(zip(self.origin_visit_update_cols, row)) + + def origin_visit_update_get_latest( + self, origin: str, visit: int, + cur=None) -> Optional[Dict[str, Any]]: + """Given an origin visit id, return its latest origin_visit_update + + """ + cols = self.origin_visit_update_cols + cur = self._cursor(cur) + cur.execute(f"SELECT {', '.join(cols)} " + f"FROM origin_visit_update ovu " + f"INNER JOIN origin o on o.id=ovu.origin " + f"WHERE o.url=%s AND ovu.visit=%s" + f"ORDER BY ovu.date DESC LIMIT 1", + (origin, visit)) + row = cur.fetchone() + return self._make_origin_visit_update(row) def origin_visit_get_all(self, origin_id, last_visit=None, limit=None, cur=None): @@ -412,25 +464,27 @@ origin_id: The occurrence's origin Yields: - The occurrence's history visits + The visits for that origin """ cur = self._cursor(cur) if last_visit: - extra_condition = 'and visit > %s' + extra_condition = 'and ov.visit > %s' args = (origin_id, last_visit, limit) else: extra_condition = '' args = (origin_id, limit) query = """\ - SELECT %s - FROM origin_visit - INNER JOIN origin ON origin.id = origin_visit.origin - WHERE origin.url=%%s %s - order by visit asc - limit %%s""" % ( + SELECT DISTINCT ON (ov.visit) %s + FROM origin_visit ov + INNER JOIN origin o ON o.id = ov.origin + INNER JOIN origin_visit_update ovu + ON ov.origin = ovu.origin AND ov.visit = ovu.visit + WHERE o.url=%%s %s + ORDER BY ov.visit ASC, ovu.date DESC + LIMIT %%s""" % ( ', '.join(self.origin_visit_select_cols), extra_condition ) @@ -453,9 +507,13 @@ query = """\ SELECT %s - FROM origin_visit - INNER JOIN origin ON origin.id = origin_visit.origin - WHERE origin.url = %%s AND visit = %%s + FROM origin_visit ov + INNER JOIN origin o ON o.id = ov.origin + INNER JOIN origin_visit_update ovu + ON ov.origin = ovu.origin AND ov.visit = ovu.visit + WHERE o.url = %%s AND ov.visit = %%s + ORDER BY ovu.date DESC + LIMIT 1 """ % (', '.join(self.origin_visit_select_cols)) cur.execute(query, (origin_id, visit_id)) @@ -469,9 +527,11 @@ cur.execute( 'SELECT * FROM swh_visit_find_by_date(%s, %s)', (origin, visit_date)) - r = cur.fetchall() - if r: - return r[0] + rows = cur.fetchall() + if rows: + visit = dict(zip(self.origin_visit_get_cols, rows[0])) + visit['origin'] = origin + return visit def origin_visit_exists(self, origin_id, visit_id, cur=None): """Check whether an origin visit with the given ids exists""" @@ -502,20 +562,24 @@ query_parts = [ 'SELECT %s' % ', '.join(self.origin_visit_select_cols), - 'FROM origin_visit', - 'INNER JOIN origin ON origin.id = origin_visit.origin'] + 'FROM origin_visit ov ', + 'INNER JOIN origin o ON o.id = ov.origin', + 'INNER JOIN origin_visit_update ovu ', + 'ON o.id = ovu.origin AND ov.visit = ovu.visit ', + ] - query_parts.append('WHERE origin.url = %s') + query_parts.append('WHERE o.url = %s') if require_snapshot: - query_parts.append('AND snapshot is not null') + query_parts.append('AND ovu.snapshot is not null') if allowed_statuses: query_parts.append( - cur.mogrify('AND status IN %s', + cur.mogrify('AND ovu.status IN %s', (tuple(allowed_statuses),)).decode()) - query_parts.append('ORDER BY date DESC, visit DESC LIMIT 1') + query_parts.append( + 'ORDER BY ov.date DESC, ov.visit DESC, ovu.date DESC LIMIT 1') query = '\n'.join(query_parts) @@ -532,18 +596,15 @@ """ cur = self._cursor(cur) columns = ','.join(self.origin_visit_select_cols) - query = f"""with visits as ( - select * - from origin_visit - where origin_visit.status='full' and - origin_visit.type=%s and - origin_visit.date > now() - '3 months'::interval - ) - select {columns} - from visits as origin_visit - inner join origin - on origin_visit.origin=origin.id - where random() < 0.1 + query = f"""select {columns} + from origin_visit ov + inner join origin o on ov.origin=o.id + inner join origin_visit_update ovu + on ov.origin = ovu.origin and ov.visit = ovu.visit + where ovu.status='full' + and ov.type=%s + and ov.date > now() - '3 months'::interval + and random() < 0.1 limit 1 """ cur.execute(query, (type, )) @@ -753,15 +814,17 @@ origin_cols = ','.join(self.origin_cols) query = """SELECT %s - FROM origin + FROM origin o WHERE """ if with_visit: query += """ EXISTS ( SELECT 1 - FROM origin_visit - INNER JOIN snapshot ON snapshot=snapshot.id - WHERE origin=origin.id + FROM origin_visit ov + INNER JOIN origin_visit_update ovu + ON ov.origin = ovu.origin AND ov.visit = ovu.visit + INNER JOIN snapshot ON ovu.snapshot=snapshot.id + WHERE ov.origin=o.id ) AND """ query += 'url %s %%s ' diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -14,13 +14,13 @@ from collections import defaultdict from datetime import timedelta -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import attr from swh.model.model import ( BaseContent, Content, SkippedContent, Directory, Revision, - Release, Snapshot, OriginVisit, Origin, SHA1_SIZE + Release, Snapshot, OriginVisit, OriginVisitUpdate, Origin, SHA1_SIZE ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex from swh.storage.objstorage import ObjStorage @@ -29,6 +29,7 @@ from .converters import origin_url_to_sha1 from .utils import get_partition_bounds_bytes +from .validate import convert_validation_exceptions from .writer import JournalWriter # Max block size of contents to return @@ -58,6 +59,8 @@ self._origins_by_id = [] self._origins_by_sha1 = {} self._origin_visits = {} + self._origin_visit_updates: Dict[ + Tuple[str, int], List[OriginVisitUpdate]] = {} self._persons = [] self._origin_metadata = defaultdict(list) self._tools = {} @@ -500,7 +503,9 @@ if origin_url not in self._origins or \ visit > len(self._origin_visits[origin_url]): return None - snapshot_id = self._origin_visits[origin_url][visit-1].snapshot + + visit = self._origin_visit_get_updated(origin_url, visit) + snapshot_id = visit.snapshot if snapshot_id: return self.snapshot_get(snapshot_id) else: @@ -665,15 +670,18 @@ else: origins = [orig for orig in origins if url_pattern in orig['url']] if with_visit: - origins = [ - orig for orig in origins - if len(self._origin_visits[orig['url']]) > 0 and - set(ov.snapshot - for ov in self._origin_visits[orig['url']] - if ov.snapshot) & - set(self._snapshots)] + filtered_origins = [] + for orig in origins: + visits = (self._origin_visit_get_updated(ov.origin, ov.visit) + for ov in self._origin_visits[orig['url']]) + for ov in visits: + if ov.snapshot and ov.snapshot in self._snapshots: + filtered_origins.append(orig) + break + else: + filtered_origins = origins - return origins[offset:offset+limit] + return filtered_origins[offset:offset+limit] def origin_count(self, url_pattern, regexp=False, with_visit=False): return len(self.origin_search(url_pattern, regexp=regexp, @@ -722,19 +730,33 @@ # visit ids are in the range [1, +inf[ visit_id = len(self._origin_visits[origin_url]) + 1 status = 'ongoing' - visit = OriginVisit( - origin=origin_url, - date=date, - type=type, - status=status, - snapshot=None, - metadata=None, - visit=visit_id, - ) + with convert_validation_exceptions(): + visit = OriginVisit( + origin=origin_url, + date=date, + type=type, + # TODO: Remove when we remove those fields from the model + status=status, + snapshot=None, + metadata=None, + visit=visit_id, + ) self._origin_visits[origin_url].append(visit) - visit = visit - - self._objects[(origin_url, visit.visit)].append( + assert visit.visit is not None + visit_key = (origin_url, visit.visit) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date, + status=status, + snapshot=None, + metadata=None, + ) + self._origin_visit_updates[visit_key] = [visit_update] + + self._objects[visit_key].append( ('origin_visit', None)) self.journal_writer.origin_visit_add(visit) @@ -744,7 +766,8 @@ def origin_visit_update( self, origin: str, visit_id: int, status: str, - metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None): + metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None, + date: Optional[datetime.datetime] = None): origin_url = self._get_origin_url(origin) if origin_url is None: raise StorageArgumentException('Unknown origin.') @@ -755,43 +778,72 @@ raise StorageArgumentException( 'Unknown visit_id for this origin') from None - updates: Dict[str, Any] = { - 'status': status - } - if metadata: - updates['metadata'] = metadata - if snapshot: - updates['snapshot'] = snapshot + # Retrieve the previous visit update + assert visit.visit is not None + visit_key = (origin_url, visit.visit) - try: - visit = attr.evolve(visit, **updates) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + last_visit_update = max( + self._origin_visit_updates[visit_key], key=lambda v: v.date) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date or now(), + status=status, + snapshot=snapshot or last_visit_update.snapshot, + metadata=metadata or last_visit_update.metadata, + ) + self._origin_visit_updates[visit_key].append(visit_update) - self.journal_writer.origin_visit_update(visit) + self.journal_writer.origin_visit_update( + self._origin_visit_get_updated(origin_url, visit_id)) self._origin_visits[origin_url][visit_id-1] = visit def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: + for visit in visits: + if visit.visit is None: + raise StorageArgumentException( + f'Missing visit id for visit {visit}') + self.journal_writer.origin_visit_upsert(visits) + date = now() + for visit in visits: - visit_id = visit.visit + assert visit.visit is not None origin_url = visit.origin + origin = self.origin_get({'url': origin_url}) - try: - visit = attr.evolve(visit, origin=origin_url) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) - - self._objects[(origin_url, visit_id)].append( - ('origin_visit', None)) - - if visit_id: - while len(self._origin_visits[origin_url]) <= visit_id: + if not origin: # Cannot add a visit without an origin + raise StorageArgumentException( + 'Unknown origin %s', origin_url) + + if origin_url in self._origins: + origin = self._origins[origin_url] + # visit ids are in the range [1, +inf[ + visit_key = (origin_url, visit.visit) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit.visit, + date=date, + status=visit.status, + snapshot=visit.snapshot, + metadata=visit.metadata, + ) + + self._origin_visit_updates.setdefault(visit_key, []) + while len(self._origin_visits[origin_url]) <= visit.visit: self._origin_visits[origin_url].append(None) - self._origin_visits[origin_url][visit_id-1] = visit + self._origin_visits[origin_url][visit.visit-1] = visit + self._origin_visit_updates[visit_key].append(visit_update) + + self._objects[visit_key].append( + ('origin_visit', None)) def _convert_visit(self, visit): if visit is None: @@ -801,7 +853,32 @@ return visit - def origin_visit_get(self, origin, last_visit=None, limit=None): + def _origin_visit_get_updated( + self, origin: str, visit_id: int) -> Optional[OriginVisit]: + """Merge origin visit and latest origin visit update + + """ + assert visit_id >= 1 + visit = self._origin_visits[origin][visit_id-1] + if visit is None: + return None + visit_key = (origin, visit_id) + + visit_update = max( + self._origin_visit_updates[visit_key], key=lambda v: v.date) + + return OriginVisit.from_dict({ + # default to the values in visit + **visit.to_dict(), + # override with the last update + **visit_update.to_dict(), + # but keep the date of the creation of the origin visit + 'date': visit.date + }) + + def origin_visit_get( + self, origin: str, last_visit: Optional[int] = None, + limit: Optional[int] = None) -> Iterable[Dict[str, Any]]: origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] @@ -814,31 +891,47 @@ continue visit_id = visit.visit - yield self._convert_visit( - self._origin_visits[origin_url][visit_id-1]) + visit_update = self._origin_visit_get_updated( + origin_url, visit_id) + assert visit_update is not None + yield visit_update.to_dict() - def origin_visit_find_by_date(self, origin, visit_date): + def origin_visit_find_by_date( + self, origin: str, + visit_date: datetime.datetime) -> Optional[Dict[str, Any]]: origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] visit = min( visits, key=lambda v: (abs(v.date - visit_date), -v.visit)) - return self._convert_visit(visit) + visit_update = self._origin_visit_get_updated( + origin, visit.visit) + assert visit_update is not None + return visit_update.to_dict() + return None - def origin_visit_get_by(self, origin, visit): + def origin_visit_get_by( + self, origin: str, visit: int) -> Optional[Dict[str, Any]]: origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits and \ visit <= len(self._origin_visits[origin_url]): - return self._convert_visit( - self._origin_visits[origin_url][visit-1]) + visit_update = self._origin_visit_get_updated( + origin_url, visit) + assert visit_update is not None + return visit_update.to_dict() + return None def origin_visit_get_latest( - self, origin, allowed_statuses=None, require_snapshot=False): - origin = self._origins.get(origin) - if not origin: - return - visits = self._origin_visits[origin.url] + self, origin: str, allowed_statuses: Optional[List[str]] = None, + require_snapshot: bool = False) -> Optional[Dict[str, Any]]: + ori = self._origins.get(origin) + if not ori: + return None + visits = self._origin_visits[ori.url] + visits = [self._origin_visit_get_updated(visit.origin, visit.visit) + for visit in visits + if visit is not None] if allowed_statuses is not None: visits = [visit for visit in visits if visit.status in allowed_statuses] @@ -848,7 +941,9 @@ visit = max( visits, key=lambda v: (v.date, v.visit), default=None) - return self._convert_visit(visit) + if visit is None: + return None + return visit.to_dict() def _select_random_origin_visit_by_type(self, type: str) -> str: while True: @@ -864,8 +959,12 @@ back_in_the_day = now() - timedelta(weeks=12) # 3 months back # This should be enough for tests for visit in random_origin_visits: - if visit.date > back_in_the_day and visit.status == 'full': - return visit.to_dict() + updated_visit = self._origin_visit_get_updated( + url, visit.visit) + assert updated_visit is not None + if updated_visit.date > back_in_the_day \ + and updated_visit.status == 'full': + return updated_visit.to_dict() else: return None diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -795,7 +795,8 @@ @remote_api_endpoint('origin/visit/update') def origin_visit_update( self, origin: str, visit_id: int, status: str, - metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None): + metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None, + date: Optional[datetime.datetime] = None): """Update an origin_visit's status. Args: @@ -805,6 +806,7 @@ metadata: Data associated to the visit snapshot (sha1_git): identifier of the snapshot to add to the visit + date: Update date Returns: None @@ -832,14 +834,16 @@ ... @remote_api_endpoint('origin/visit/get') - def origin_visit_get(self, origin, last_visit=None, limit=None): + def origin_visit_get( + self, origin: str, last_visit: Optional[int] = None, + limit: Optional[int] = None) -> Iterable[Dict[str, Any]]: """Retrieve all the origin's visit's information. Args: - origin (str): The visited origin + origin: The visited origin last_visit: Starting point from which listing the next visits Default to None - limit (int): Number of results to return from the last visit. + limit: Number of results to return from the last visit. Default to None Yields: @@ -849,27 +853,31 @@ ... @remote_api_endpoint('origin/visit/find_by_date') - def origin_visit_find_by_date(self, origin, visit_date): + def origin_visit_find_by_date( + self, origin: str, + visit_date: datetime.datetime) -> Optional[Dict[str, Any]]: """Retrieves the origin visit whose date is closest to the provided timestamp. In case of a tie, the visit with largest id is selected. Args: - origin (str): The occurrence's origin (URL). - target (datetime): target timestamp + origin: origin (URL) + visit_date: expected visit date Returns: - A visit. + A visit """ ... @remote_api_endpoint('origin/visit/getby') - def origin_visit_get_by(self, origin, visit): + def origin_visit_get_by( + self, origin: str, visit: int) -> Optional[Dict[str, Any]]: """Retrieve origin visit's information. Args: - origin: The occurrence's origin (identifier). + origin: origin (URL) + visit: visit id Returns: The information on that particular (origin, visit) or None if @@ -880,18 +888,19 @@ @remote_api_endpoint('origin/visit/get_latest') def origin_visit_get_latest( - self, origin, allowed_statuses=None, require_snapshot=False): + self, origin: str, allowed_statuses: Optional[List[str]] = None, + require_snapshot: bool = False) -> Optional[Dict[str, Any]]: """Get the latest origin visit for the given origin, optionally looking only for those with one of the given allowed_statuses or for those with a known snapshot. Args: - origin (str): the origin's URL - allowed_statuses (list of str): list of visit statuses considered + origin: origin URL + allowed_statuses: list of visit statuses considered to find the latest visit. For instance, ``allowed_statuses=['full']`` will only consider visits that have successfully run to completion. - require_snapshot (bool): If True, only a visit with a snapshot + require_snapshot: If True, only a visit with a snapshot will be returned. Returns: diff --git a/swh/storage/sql/30-swh-schema.sql b/swh/storage/sql/30-swh-schema.sql --- a/swh/storage/sql/30-swh-schema.sql +++ b/swh/storage/sql/30-swh-schema.sql @@ -17,7 +17,7 @@ -- latest schema version insert into dbversion(version, release, description) - values(146, now(), 'Work In Progress'); + values(147, now(), 'Work In Progress'); -- a SHA1 checksum create domain sha1 as bytea check (length(value) = 20); @@ -282,6 +282,7 @@ visit bigint not null, date timestamptz not null, type text not null, + -- remove those when done migrating the schema status origin_visit_status not null, metadata jsonb, snapshot sha1_git @@ -296,6 +297,28 @@ comment on column origin_visit.snapshot is 'Origin snapshot at visit time'; +-- Crawling history of software origin visits by Software Heritage. Each +-- visit see its history change through new origin visit updates +create table origin_visit_update +( + id bigserial not null, -- TODO: Decide if we keep that or not + origin bigint not null, + visit bigint not null, + date timestamptz not null, + status origin_visit_status not null, + metadata jsonb, + snapshot sha1_git +); + +comment on column origin_visit_update.id is 'visit update id'; +comment on column origin_visit_update.origin is 'origin concerned by the visit update'; +comment on column origin_visit_update.visit is 'visit concerned by the visit update'; +comment on column origin_visit_update.date is 'Visit update timestamp'; +comment on column origin_visit_update.status is 'Visit update status'; +comment on column origin_visit_update.metadata is 'Origin metadata at visit update time'; +comment on column origin_visit_update.snapshot is 'Origin snapshot at visit update time'; + + -- A snapshot represents the entire state of a software origin as crawled by -- Software Heritage. This table is a simple mapping between (public) intrinsic -- snapshot identifiers and (private) numeric sequential identifiers. diff --git a/swh/storage/sql/60-swh-indexes.sql b/swh/storage/sql/60-swh-indexes.sql --- a/swh/storage/sql/60-swh-indexes.sql +++ b/swh/storage/sql/60-swh-indexes.sql @@ -130,6 +130,17 @@ alter table origin_visit add constraint origin_visit_origin_fkey foreign key (origin) references origin(id) not valid; alter table origin_visit validate constraint origin_visit_origin_fkey; +-- origin_visit_update + +create unique index concurrently origin_visit_update_pkey on origin_visit_update(origin, visit, date); +alter table origin_visit_update add primary key using index origin_visit_update_pkey; + +alter table origin_visit_update + add constraint origin_visit_update_origin_visit_fkey + foreign key (origin, visit) + references origin_visit(origin, visit) not valid; +alter table origin_visit_update validate constraint origin_visit_update_origin_visit_fkey; + -- release create unique index concurrently release_pkey on release(id); alter table release add primary key using index release_pkey; diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -19,7 +19,7 @@ import psycopg2.errors from swh.model.model import ( - Content, Directory, Origin, OriginVisit, + Content, Directory, Origin, OriginVisit, OriginVisitUpdate, Revision, Release, SkippedContent, Snapshot, SHA1_SIZE ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex @@ -34,9 +34,14 @@ from .utils import ( get_partition_bounds_bytes, extract_collision_hash ) +from .validate import VALIDATION_EXCEPTIONS from .writer import JournalWriter +def now(): + return datetime.datetime.now(tz=datetime.timezone.utc) + + # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 @@ -44,14 +49,14 @@ """Identifier for the empty snapshot""" -VALIDATION_EXCEPTIONS = ( +VALIDATION_EXCEPTIONS = VALIDATION_EXCEPTIONS + [ psycopg2.errors.CheckViolation, psycopg2.errors.IntegrityError, psycopg2.errors.InvalidTextRepresentation, psycopg2.errors.NotNullViolation, psycopg2.errors.NumericValueOutOfRange, psycopg2.errors.UndefinedFunction, # (raised on wrong argument typs) -) +] """Exceptions raised by postgresql when validation of the arguments failed.""" @@ -62,7 +67,7 @@ re-raises a StorageArgumentException.""" try: yield - except VALIDATION_EXCEPTIONS as e: + except tuple(VALIDATION_EXCEPTIONS) as e: raise StorageArgumentException(*e.args) @@ -186,8 +191,7 @@ @process_metrics def content_add( self, content: Iterable[Content]) -> Dict: - now = datetime.datetime.now(tz=datetime.timezone.utc) - contents = [attr.evolve(c, ctime=now) for c in content] + contents = [attr.evolve(c, ctime=now()) for c in content] objstorage_summary = self.objstorage.content_add(contents) @@ -397,8 +401,7 @@ @db_transaction() def skipped_content_add(self, content: Iterable[SkippedContent], db=None, cur=None) -> Dict: - now = datetime.datetime.now(tz=datetime.timezone.utc) - content = [attr.evolve(c, ctime=now) for c in content] + content = [attr.evolve(c, ctime=now()) for c in content] missing_contents = self.skipped_content_missing( (c.to_dict() for c in content), @@ -669,7 +672,7 @@ db.mktemp_snapshot_branch(cur) created_temp_table = True - try: + with convert_validation_exceptions(): db.copy_to( ( { @@ -684,8 +687,6 @@ ['name', 'target', 'target_type'], cur, ) - except VALIDATION_EXCEPTIONS + (KeyError,) as e: - raise StorageArgumentException(*e.args) self.journal_writer.snapshot_add(snapshot) @@ -809,6 +810,7 @@ with convert_validation_exceptions(): visit_id = db.origin_visit_add(origin_url, date, type, cur=cur) + status = 'ongoing' # We can write to the journal only after inserting to the # DB, because we want the id of the visit visit = OriginVisit.from_dict({ @@ -816,21 +818,46 @@ 'date': date, 'type': type, 'visit': visit_id, - 'status': 'ongoing', + # TODO: Remove when we remove those fields from the model + 'status': status, 'metadata': None, 'snapshot': None }) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date, + status=status, + snapshot=None, + metadata=None, + ) + self._origin_visit_update_add(visit_update, db=db, cur=cur) + self.journal_writer.origin_visit_add(visit) send_metric('origin_visit:add', count=1, method_name='origin_visit') return visit + def _origin_visit_update_add(self, origin_visit_update: OriginVisitUpdate, + db, cur) -> None: + """Add an origin visit update""" + # Inject origin visit update in the schema + db.origin_visit_update_add(origin_visit_update, cur=cur) + + # write to the journal the origin visit update + + send_metric('origin_visit_update:add', + count=1, method_name='origin_visit_update') + @timed @db_transaction() def origin_visit_update(self, origin: str, visit_id: int, status: str, metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None, + date: Optional[datetime.datetime] = None, db=None, cur=None): if not isinstance(origin, str): raise StorageArgumentException( @@ -855,64 +882,144 @@ updated_visit = {**visit, **updates} self.journal_writer.origin_visit_update(updated_visit) + last_visit_update = self._origin_visit_get_updated( + origin, visit_id, db=db, cur=cur) + assert last_visit_update is not None + with convert_validation_exceptions(): - db.origin_visit_update(origin_url, visit_id, updates, cur) + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date or now(), + status=status, + snapshot=snapshot or last_visit_update['snapshot'], + metadata=metadata or last_visit_update['metadata'], + ) + self._origin_visit_update_add(visit_update, db=db, cur=cur) + + def _origin_visit_get_updated( + self, origin: str, visit_id: int, + db, cur) -> Optional[Dict[str, Any]]: + """Retrieve origin visit and latest origin visit update and merge them + into an origin visit. + + """ + row_visit = db.origin_visit_get(origin, visit_id) + if row_visit is None: + return None + visit = dict(zip(db.origin_visit_get_cols, row_visit)) + return self._origin_visit_apply_update(visit, db=db, cur=cur) + + def _origin_visit_apply_update( + self, visit: Dict[str, Any], db, cur=None) -> Dict[str, Any]: + """Retrieve the latest visit update information for the origin visit. + Then merge it with the visit and return it. + + """ + visit_update = db.origin_visit_update_get_latest( + visit['origin'], visit['visit']) + return self._origin_visit_merge(visit, visit_update) + + def _origin_visit_merge( + self, visit: Dict[str, Any], + visit_update: Dict[str, Any]) -> Dict[str, Any]: + """Merge origin_visit and origin_visit_update together. + + """ + return OriginVisit.from_dict({ + # default to the values in visit + **visit, + # override with the last update + **visit_update, + # visit['origin'] is the URL (via a join), while + # visit_update['origin'] is only an id. + 'origin': visit['origin'], + # but keep the date of the creation of the origin visit + 'date': visit['date'] + }).to_dict() @timed @db_transaction() def origin_visit_upsert(self, visits: Iterable[OriginVisit], db=None, cur=None) -> None: + for visit in visits: + if visit.visit is None: + raise StorageArgumentException( + f'Missing visit id for visit {visit}') + self.journal_writer.origin_visit_upsert(visits) for visit in visits: # TODO: upsert them all in a single query + assert visit.visit is not None db.origin_visit_upsert(visit, cur=cur) + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=visit.origin, + visit=visit.visit, + date=now(), + status=visit.status, + snapshot=visit.snapshot, + metadata=visit.metadata, + ) + db.origin_visit_update_add(visit_update, cur=cur) @timed @db_transaction_generator(statement_timeout=500) - def origin_visit_get(self, origin, last_visit=None, limit=None, db=None, - cur=None): - for line in db.origin_visit_get_all( - origin, last_visit=last_visit, limit=limit, cur=cur): - data = dict(zip(db.origin_visit_get_cols, line)) - yield data + def origin_visit_get( + self, origin: str, last_visit: Optional[int] = None, + limit: Optional[int] = None, + db=None, cur=None) -> Iterable[Dict[str, Any]]: + lines = db.origin_visit_get_all( + origin, last_visit=last_visit, limit=limit, cur=cur) + for line in lines: + visit = dict(zip(db.origin_visit_get_cols, line)) + yield self._origin_visit_apply_update(visit, db) @timed @db_transaction(statement_timeout=500) - def origin_visit_find_by_date(self, origin, visit_date, db=None, cur=None): - line = db.origin_visit_find_by_date(origin, visit_date, cur=cur) - if line: - return dict(zip(db.origin_visit_get_cols, line)) + def origin_visit_find_by_date( + self, origin: str, visit_date: datetime.datetime, + db=None, cur=None) -> Optional[Dict[str, Any]]: + visit = db.origin_visit_find_by_date(origin, visit_date, cur=cur) + if visit: + return self._origin_visit_apply_update(visit, db) + return None @timed @db_transaction(statement_timeout=500) - def origin_visit_get_by(self, origin, visit, db=None, cur=None): - ori_visit = db.origin_visit_get(origin, visit, cur) - if not ori_visit: - return None - - return dict(zip(db.origin_visit_get_cols, ori_visit)) + def origin_visit_get_by( + self, origin: str, + visit: int, db=None, cur=None) -> Optional[Dict[str, Any]]: + row = db.origin_visit_get(origin, visit, cur) + if row: + visit_dict = dict(zip(db.origin_visit_get_cols, row)) + return self._origin_visit_apply_update(visit_dict, db) + return None @timed @db_transaction(statement_timeout=4000) def origin_visit_get_latest( - self, origin, allowed_statuses=None, require_snapshot=False, - db=None, cur=None): - origin_visit = db.origin_visit_get_latest( + self, origin: str, allowed_statuses: Optional[List[str]] = None, + require_snapshot: bool = False, + db=None, cur=None) -> Optional[Dict[str, Any]]: + row = db.origin_visit_get_latest( origin, allowed_statuses=allowed_statuses, require_snapshot=require_snapshot, cur=cur) - if origin_visit: - return dict(zip(db.origin_visit_get_cols, origin_visit)) + if row: + visit = dict(zip(db.origin_visit_get_cols, row)) + return self._origin_visit_apply_update(visit, db) + return None @timed @db_transaction() def origin_visit_get_random( self, type: str, db=None, cur=None) -> Optional[Dict[str, Any]]: - result = db.origin_visit_get_random(type, cur) - if result: - return dict(zip(db.origin_visit_get_cols, result)) - else: - return None + row = db.origin_visit_get_random(type, cur) + if row: + visit = dict(zip(db.origin_visit_get_cols, row)) + return self._origin_visit_apply_update(visit, db) + return None @timed @db_transaction(statement_timeout=2000) diff --git a/swh/storage/tests/test_retry.py b/swh/storage/tests/test_retry.py --- a/swh/storage/tests/test_retry.py +++ b/swh/storage/tests/test_retry.py @@ -16,6 +16,8 @@ from swh.storage import get_storage from swh.storage.exc import HashCollision, StorageArgumentException +from .storage_data import date_visit1 + @pytest.fixture def fake_hash_collision(sample_data): @@ -266,7 +268,7 @@ origin = list(swh_storage.origin_visit_get(origin_url)) assert not origin - origin_visit = swh_storage.origin_visit_add(origin_url, '2020-01-01', 'hg') + origin_visit = swh_storage.origin_visit_add(origin_url, date_visit1, 'hg') assert origin_visit.origin == origin_url assert isinstance(origin_visit.visit, int) @@ -300,13 +302,13 @@ origin = list(swh_storage.origin_visit_get(origin_url)) assert not origin - r = swh_storage.origin_visit_add(origin_url, '2020-01-01', 'git') + r = swh_storage.origin_visit_add(origin_url, date_visit1, 'git') assert r == {'origin': origin_url, 'visit': 1} mock_memory.assert_has_calls([ - call(origin_url, '2020-01-01', 'git'), - call(origin_url, '2020-01-01', 'git'), - call(origin_url, '2020-01-01', 'git') + call(origin_url, date_visit1, 'git'), + call(origin_url, date_visit1, 'git'), + call(origin_url, date_visit1, 'git') ]) assert mock_sleep.call_count == 2 @@ -327,10 +329,10 @@ assert not origin with pytest.raises(StorageArgumentException, match='Refuse to add'): - swh_storage.origin_visit_add(origin_url, '2020-01-31', 'svn') + swh_storage.origin_visit_add(origin_url, date_visit1, 'svn') mock_memory.assert_has_calls([ - call(origin_url, '2020-01-31', 'svn'), + call(origin_url, date_visit1, 'svn'), ]) @@ -599,7 +601,7 @@ """ sample_origin = sample_data['origin'][0] origin_url = swh_storage.origin_add_one(sample_origin) - origin_visit = swh_storage.origin_visit_add(origin_url, '2020-01-01', 'hg') + origin_visit = swh_storage.origin_visit_add(origin_url, date_visit1, 'hg') ov = next(swh_storage.origin_visit_get(origin_url)) assert ov['origin'] == origin_url diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1658,6 +1658,7 @@ 'snapshot': None, }, ] + assert len(expected_visits) == len(actual_origin_visits) for visit in expected_visits: assert visit in actual_origin_visits @@ -1680,8 +1681,8 @@ # given origin_url = swh_storage.origin_add_one(data.origin) origin_url2 = swh_storage.origin_add_one(data.origin2) - date_visit = datetime.datetime.now(datetime.timezone.utc) - date_visit2 = date_visit + datetime.timedelta(minutes=1) + date_visit = data.date_visit1 + date_visit2 = data.date_visit2 # Round to milliseconds before insertion, so equality doesn't fail # after a round-trip through a DB (eg. Cassandra) @@ -2064,6 +2065,26 @@ ('origin_visit', data1), ('origin_visit', data2)] + def test_origin_visit_upsert_missing_visit_id(self, swh_storage): + # given + origin_url = swh_storage.origin_add_one(data.origin2) + + # then + with pytest.raises(StorageArgumentException, match='Missing visit id'): + swh_storage.origin_visit_upsert([OriginVisit.from_dict({ + 'origin': origin_url, + 'date': data.date_visit2, + 'visit': None, # <- make the test raise + 'type': data.type_visit1, + 'status': 'full', + 'metadata': None, + 'snapshot': None, + })]) + + assert list(swh_storage.journal_writer.journal.objects) == [ + ('origin', data.origin2) + ] + def test_origin_visit_get_by_no_result(self, swh_storage): swh_storage.origin_add([data.origin]) actual_origin_visit = swh_storage.origin_visit_get_by( diff --git a/swh/storage/validate.py b/swh/storage/validate.py --- a/swh/storage/validate.py +++ b/swh/storage/validate.py @@ -16,11 +16,11 @@ from .exc import StorageArgumentException -VALIDATION_EXCEPTIONS = ( +VALIDATION_EXCEPTIONS = [ KeyError, TypeError, ValueError, -) +] @contextlib.contextmanager @@ -29,7 +29,7 @@ StorageArgumentException.""" try: yield - except VALIDATION_EXCEPTIONS as e: + except tuple(VALIDATION_EXCEPTIONS) as e: raise StorageArgumentException(*e.args)