diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -11,14 +11,19 @@ Any, Callable, Dict, Generator, Iterable, List, Optional, TypeVar ) +from cassandra import CoordinationFailure from cassandra.cluster import ( Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet) from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy from cassandra.query import PreparedStatement -from tenacity import retry, stop_after_attempt, wait_random_exponential +from tenacity import ( + retry, stop_after_attempt, wait_random_exponential, + retry_if_exception_type, +) from swh.model.model import ( Sha1Git, TimestampWithTimezone, Timestamp, Person, Content, + OriginVisit ) from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url @@ -120,7 +125,8 @@ MAX_RETRIES = 3 @retry(wait=wait_random_exponential(multiplier=1, max=10), - stop=stop_after_attempt(MAX_RETRIES)) + stop=stop_after_attempt(MAX_RETRIES), + retry=retry_if_exception_type(CoordinationFailure)) def _execute_with_retries(self, statement, args) -> ResultSet: return self._session.execute(statement, args, timeout=1000.) @@ -528,10 +534,9 @@ @_prepared_insert_statement('origin_visit', _origin_visit_keys) def origin_visit_add_one( - self, visit: Dict[str, Any], *, statement) -> None: - self._execute_with_retries( - statement, [visit[key] for key in self._origin_visit_keys]) - self._increment_counter('origin_visit', 1) + self, visit: OriginVisit, *, statement) -> None: + self._add_one(statement, 'origin_visit', visit, + self._origin_visit_keys) @_prepared_statement( 'UPDATE origin_visit SET ' + diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -15,7 +15,7 @@ from swh.model.model import ( Revision, Release, Directory, DirectoryEntry, Content, SkippedContent, - OriginVisit, + OriginVisit, Snapshot ) from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError @@ -26,6 +26,8 @@ # mypy limitation, see https://github.com/python/mypy/issues/1153 +from .. import HashCollision +from ..exc import StorageArgumentException from .common import TOKEN_BEGIN, TOKEN_END from .converters import ( revision_to_db, revision_from_db, release_to_db, release_from_db, @@ -60,7 +62,10 @@ return True def _content_add(self, contents, with_data): - contents = [Content.from_dict(c) for c in contents] + try: + contents = [Content.from_dict(c) for c in contents] + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) # Filter-out content already in the database. contents = [c for c in contents @@ -112,7 +117,6 @@ algo, content.get_hash(algo)) if len(pks) > 1: # There are more than the one we just inserted. - from .. import HashCollision raise HashCollision(algo, content.get_hash(algo), pks) summary = { @@ -139,7 +143,7 @@ def content_get(self, content): if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: - raise ValueError( + raise StorageArgumentException( "Sending at most %s contents." % BULK_BLOCK_CONTENT_LEN_MAX) for obj_id in content: try: @@ -154,7 +158,7 @@ self, partition_id: int, nb_partitions: int, limit: int = 1000, page_token: str = None): if limit is None: - raise ValueError('Development error: limit should not be None') + raise StorageArgumentException('limit should not be None') # Compute start and end of the range of tokens covered by the # requested partition @@ -165,7 +169,7 @@ # offset the range start according to the `page_token`. if page_token is not None: if not (range_start <= int(page_token) <= range_end): - raise ValueError('Invalid page_token.') + raise StorageArgumentException('Invalid page_token.') range_start = int(page_token) # Get the first rows of the range @@ -213,8 +217,9 @@ # It will be used to do an initial filtering efficiently. filter_algos = list(set(content).intersection(HASH_ALGORITHMS)) if not filter_algos: - raise ValueError('content keys must contain at least one of: ' - '%s' % ', '.join(sorted(HASH_ALGORITHMS))) + raise StorageArgumentException( + 'content keys must contain at least one of: ' + '%s' % ', '.join(sorted(HASH_ALGORITHMS))) common_algo = filter_algos[0] # Find all contents whose common_algo matches at least one @@ -260,7 +265,10 @@ return self._cql_runner.content_get_random().sha1_git def _skipped_content_add(self, contents): - contents = [SkippedContent.from_dict(c) for c in contents] + try: + contents = [SkippedContent.from_dict(c) for c in contents] + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) # Filter-out content already in the database. contents = [c for c in contents @@ -306,7 +314,10 @@ self.journal_writer.write_additions('directory', directories) for directory in directories: - directory = Directory.from_dict(directory) + try: + directory = Directory.from_dict(directory) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) # Add directory entries to the 'directory_entry' table for entry in directory.entries: @@ -412,7 +423,10 @@ self.journal_writer.write_additions('revision', revisions) for revision in revisions: - revision = revision_to_db(revision) + try: + revision = revision_to_db(revision) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) if revision: # Add parents first @@ -507,7 +521,10 @@ self.journal_writer.write_additions('release', releases) for release in releases: - release = release_to_db(release) + try: + release = release_to_db(release) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) if release: self._cql_runner.release_add_one(release) @@ -532,30 +549,38 @@ return self._cql_runner.release_get_random().id def snapshot_add(self, snapshots): - snapshots = list(snapshots) + try: + snapshots = [Snapshot.from_dict(snap) for snap in snapshots] + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) missing = self._cql_runner.snapshot_missing( - [snp['id'] for snp in snapshots]) - snapshots = [snp for snp in snapshots if snp['id'] in missing] + [snp.id for snp in snapshots]) + snapshots = [snp for snp in snapshots if snp.id in missing] for snapshot in snapshots: if self.journal_writer: self.journal_writer.write_addition('snapshot', snapshot) # Add branches - for (branch_name, branch) in snapshot['branches'].items(): + for (branch_name, branch) in snapshot.branches.items(): if branch is None: - branch = {'target_type': None, 'target': None} - self._cql_runner.snapshot_branch_add_one({ - 'snapshot_id': snapshot['id'], + target_type = None + target = None + else: + target_type = branch.target_type.value + target = branch.target + branch = { + 'snapshot_id': snapshot.id, 'name': branch_name, - 'target_type': branch['target_type'], - 'target': branch['target'], - }) + 'target_type': target_type, + 'target': target, + } + self._cql_runner.snapshot_branch_add_one(branch) # Add the snapshot *after* adding all the branches, so someone # calling snapshot_get_branch in the meantime won't end up # with half the branches. - self._cql_runner.snapshot_add_one(snapshot['id']) + self._cql_runner.snapshot_add_one(snapshot.id) return {'snapshot:add': len(snapshots)} @@ -582,7 +607,8 @@ if visit: assert visit['snapshot'] if self._cql_runner.snapshot_missing([visit['snapshot']]): - raise ValueError('Visit references unknown snapshot') + raise StorageArgumentException( + 'Visit references unknown snapshot') return self.snapshot_get_branches(visit['snapshot']) def snapshot_count_branches(self, snapshot_id): @@ -688,7 +714,7 @@ return_single = False if any('id' in origin for origin in origins): - raise ValueError('Origin ids are not supported.') + raise StorageArgumentException('Origin ids are not supported.') results = [self.origin_get_one(origin) for origin in origins] @@ -700,7 +726,9 @@ def origin_get_one(self, origin): if 'id' in origin: - raise ValueError('Origin ids are not supported.') + raise StorageArgumentException('Origin ids are not supported.') + if 'url' not in origin: + raise StorageArgumentException('Missing origin url') rows = self._cql_runner.origin_get_by_url(origin['url']) rows = list(rows) @@ -730,7 +758,7 @@ if page_token: start_token = int(page_token) if not (TOKEN_BEGIN <= start_token <= TOKEN_END): - raise ValueError('Invalid page_token.') + raise StorageArgumentException('Invalid page_token.') rows = self._cql_runner.origin_list(start_token, limit) rows = list(rows) @@ -768,7 +796,8 @@ def origin_add(self, origins): origins = list(origins) if any('id' in origin for origin in origins): - raise ValueError('Origins must not already have an id.') + raise StorageArgumentException( + 'Origins must not already have an id.') results = [] for origin in origins: self.origin_add_one(origin) @@ -815,6 +844,11 @@ if self.journal_writer: self.journal_writer.write_addition('origin_visit', visit) + try: + visit = OriginVisit.from_dict(visit) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) + self._cql_runner.origin_visit_add_one(visit) return { @@ -829,8 +863,11 @@ # Get the existing data of the visit row = self._cql_runner.origin_visit_get_one(origin_url, visit_id) if not row: - raise ValueError('This origin visit does not exist.') - visit = OriginVisit.from_dict(self._format_origin_visit_row(row)) + raise StorageArgumentException('This origin visit does not exist.') + try: + visit = OriginVisit.from_dict(self._format_origin_visit_row(row)) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) updates = {} if status: @@ -840,7 +877,10 @@ if snapshot: updates['snapshot'] = snapshot - visit = attr.evolve(visit, **updates) + try: + visit = attr.evolve(visit, **updates) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) if self.journal_writer: self.journal_writer.write_update('origin_visit', visit) diff --git a/swh/storage/exc.py b/swh/storage/exc.py --- a/swh/storage/exc.py +++ b/swh/storage/exc.py @@ -21,3 +21,8 @@ def __str__(self): args = self.args return 'An unexpected error occurred in the api backend: %s' % args + + +class StorageArgumentException(Exception): + """Argument passed to a Storage endpoint is invalid.""" + pass diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -25,6 +25,8 @@ from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError +from . import HashCollision +from .exc import StorageArgumentException from .storage import get_journal_writer from .converters import origin_url_to_sha1 from .utils import get_partition_bounds_bytes @@ -79,13 +81,16 @@ if content.status is None: content.status = 'visible' if content.status == 'absent': - raise ValueError('content with status=absent') + raise StorageArgumentException('content with status=absent') if content.length is None: - raise ValueError('content with length=None') + raise StorageArgumentException('content with length=None') if self.journal_writer: for content in contents: - content = attr.evolve(content, data=None) + try: + content = attr.evolve(content, data=None) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) self.journal_writer.write_addition('content', content) summary = { @@ -103,7 +108,6 @@ hash_ = content.get_hash(algorithm) if hash_ in self._content_indexes[algorithm]\ and (algorithm not in {'blake2s256', 'sha256'}): - from . import HashCollision raise HashCollision(algorithm, hash_, key) for algorithm in DEFAULT_ALGORITHMS: hash_ = content.get_hash(algorithm) @@ -115,9 +119,12 @@ summary['content:add'] += 1 if with_data: content_data = self._contents[key].data - self._contents[key] = attr.evolve( - self._contents[key], - data=None) + try: + self._contents[key] = attr.evolve( + self._contents[key], + data=None) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) summary['content:add:bytes'] += len(content_data) self.objstorage.add(content_data, content.sha1) @@ -125,8 +132,11 @@ def content_add(self, content): now = datetime.datetime.now(tz=datetime.timezone.utc) - content = [attr.evolve(Content.from_dict(c), ctime=now) - for c in content] + try: + content = [attr.evolve(Content.from_dict(c), ctime=now) + for c in content] + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) return self._content_add(content, with_data=True) def content_update(self, content, keys=[]): @@ -144,7 +154,10 @@ hash_ = old_cont.get_hash(algorithm) self._content_indexes[algorithm][hash_].remove(old_key) - new_cont = attr.evolve(old_cont, **cont_update) + try: + new_cont = attr.evolve(old_cont, **cont_update) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) new_key = self._content_key(new_cont) self._contents[new_key] = new_cont @@ -160,7 +173,7 @@ def content_get(self, content): # FIXME: Make this method support slicing the `data`. if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: - raise ValueError( + raise StorageArgumentException( "Sending at most %s contents." % BULK_BLOCK_CONTENT_LEN_MAX) for obj_id in content: try: @@ -173,7 +186,7 @@ def content_get_range(self, start, end, limit=1000): if limit is None: - raise ValueError('Development error: limit should not be None') + raise StorageArgumentException('limit should not be None') from_index = bisect.bisect_left(self._sorted_sha1s, start) sha1s = itertools.islice(self._sorted_sha1s, from_index, None) sha1s = ((sha1, content_key) @@ -197,7 +210,7 @@ self, partition_id: int, nb_partitions: int, limit: int = 1000, page_token: str = None): if limit is None: - raise ValueError('Development error: limit should not be None') + raise StorageArgumentException('limit should not be None') (start, end) = get_partition_bounds_bytes( partition_id, nb_partitions, SHA1_SIZE) if page_token: @@ -231,8 +244,9 @@ def content_find(self, content): if not set(content).intersection(DEFAULT_ALGORITHMS): - raise ValueError('content keys must contain at least one of: ' - '%s' % ', '.join(sorted(DEFAULT_ALGORITHMS))) + raise StorageArgumentException( + 'content keys must contain at least one of: %s' + % ', '.join(sorted(DEFAULT_ALGORITHMS))) found = [] for algo in DEFAULT_ALGORITHMS: hash = content.get(algo) @@ -278,7 +292,8 @@ if content.length is None: content = attr.evolve(content, length=-1) if content.status != 'absent': - raise ValueError(f'Content with status={content.status}') + raise StorageArgumentException( + f'Content with status={content.status}') if self.journal_writer: for content in contents: @@ -318,8 +333,11 @@ def skipped_content_add(self, content): content = list(content) now = datetime.datetime.now(tz=datetime.timezone.utc) - content = [attr.evolve(SkippedContent.from_dict(c), ctime=now) - for c in content] + try: + content = [attr.evolve(SkippedContent.from_dict(c), ctime=now) + for c in content] + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) return self._skipped_content_add(content) def directory_add(self, directories): @@ -330,7 +348,10 @@ (dir_ for dir_ in directories if dir_['id'] not in self._directories)) - directories = [Directory.from_dict(d) for d in directories] + try: + directories = [Directory.from_dict(d) for d in directories] + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) count = 0 for directory in directories: @@ -423,7 +444,10 @@ (rev for rev in revisions if rev['id'] not in self._revisions)) - revisions = [Revision.from_dict(rev) for rev in revisions] + try: + revisions = [Revision.from_dict(rev) for rev in revisions] + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) count = 0 for revision in revisions: @@ -481,7 +505,10 @@ (rel for rel in releases if rel['id'] not in self._releases)) - releases = [Release.from_dict(rel) for rel in releases] + try: + releases = [Release.from_dict(rel) for rel in releases] + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) count = 0 for rel in releases: @@ -510,7 +537,10 @@ def snapshot_add(self, snapshots): count = 0 - snapshots = (Snapshot.from_dict(d) for d in snapshots) + try: + snapshots = [Snapshot.from_dict(d) for d in snapshots] + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) snapshots = (snap for snap in snapshots if snap.id not in self._snapshots) for snapshot in snapshots: @@ -558,7 +588,7 @@ if visit and visit['snapshot']: snapshot = self.snapshot_get(visit['snapshot']) if not snapshot: - raise ValueError( + raise StorageArgumentException( 'last origin visit references an unknown snapshot') return snapshot @@ -636,11 +666,11 @@ # Sanity check to be error-compatible with the pgsql backend if any('id' in origin for origin in origins) \ and not all('id' in origin for origin in origins): - raise ValueError( + raise StorageArgumentException( 'Either all origins or none at all should have an "id".') if any('url' in origin for origin in origins) \ and not all('url' in origin for origin in origins): - raise ValueError( + raise StorageArgumentException( 'Either all origins or none at all should have ' 'an "url" key.') @@ -651,7 +681,7 @@ if origin['url'] in self._origins: result = self._origins[origin['url']] else: - raise ValueError( + raise StorageArgumentException( 'Origin must have an url.') results.append(self._convert_origin(result)) @@ -727,7 +757,10 @@ return origins def origin_add_one(self, origin): - origin = Origin.from_dict(origin) + try: + origin = Origin.from_dict(origin) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) if origin.url not in self._origins: if self.journal_writer: self.journal_writer.write_addition('origin', origin) @@ -748,13 +781,14 @@ def origin_visit_add(self, origin, date, type): origin_url = origin if origin_url is None: - raise ValueError('Unknown origin.') + raise StorageArgumentException('Unknown origin.') if isinstance(date, str): # FIXME: Converge on iso8601 at some point date = dateutil.parser.parse(date) elif not isinstance(date, datetime.datetime): - raise TypeError('date must be a datetime or a string.') + raise StorageArgumentException( + 'date must be a datetime or a string.') visit_ret = None if origin_url in self._origins: @@ -791,13 +825,13 @@ raise TypeError('origin must be a string, not %r' % (origin,)) origin_url = self._get_origin_url(origin) if origin_url is None: - raise ValueError('Unknown origin.') + raise StorageArgumentException('Unknown origin.') try: visit = self._origin_visits[origin_url][visit_id-1] except IndexError: - raise ValueError('Unknown visit_id for this origin') \ - from None + raise StorageArgumentException( + 'Unknown visit_id for this origin') from None updates = {} if status: @@ -807,7 +841,10 @@ if snapshot: updates['snapshot'] = snapshot - visit = attr.evolve(visit, **updates) + try: + visit = attr.evolve(visit, **updates) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) if self.journal_writer: self.journal_writer.write_update('origin_visit', visit) @@ -819,7 +856,10 @@ if not isinstance(visit['origin'], str): raise TypeError("visit['origin'] must be a string, not %r" % (visit['origin'],)) - visits = [OriginVisit.from_dict(d) for d in visits] + try: + visits = [OriginVisit.from_dict(d) for d in visits] + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) if self.journal_writer: for visit in visits: @@ -829,7 +869,10 @@ visit_id = visit.visit origin_url = visit.origin - visit = attr.evolve(visit, origin=origin_url) + try: + visit = attr.evolve(visit, origin=origin_url) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) self._objects[(origin_url, visit_id)].append( ('origin_visit', None)) diff --git a/swh/storage/retry.py b/swh/storage/retry.py --- a/swh/storage/retry.py +++ b/swh/storage/retry.py @@ -4,46 +4,47 @@ # See top-level LICENSE file for more information import logging -import psycopg2 import traceback from datetime import datetime from typing import Dict, Iterable, List, Optional, Union -from requests.exceptions import ConnectionError from tenacity import ( - retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type + retry, stop_after_attempt, wait_random_exponential, ) -from swh.storage import get_storage, HashCollision +from swh.storage import get_storage +from swh.storage.exc import StorageArgumentException logger = logging.getLogger(__name__) -RETRY_EXCEPTIONS = [ - # raised when two parallel insertions insert the same data - psycopg2.IntegrityError, - HashCollision, - # when the server is restarting - ConnectionError, -] - - -def should_retry_adding(error: Exception) -> bool: - """Retry if the error/exception if one of the RETRY_EXCEPTIONS type. +def should_retry_adding(retry_state) -> bool: + """Retry if the error/exception is (probably) not about a caller error """ - for exc in RETRY_EXCEPTIONS: - if retry_if_exception_type(exc)(error): - error_name = error.__module__ + '.' + error.__class__.__name__ + if retry_state.outcome.failed: + error = retry_state.outcome.exception() + if isinstance(error, StorageArgumentException): + # Exception is due to an invalid argument + return False + else: + # Other exception + module = getattr(error, '__module__', None) + if module: + error_name = error.__module__ + '.' + error.__class__.__name__ + else: + error_name = error.__class__.__name__ logger.warning('Retry adding a batch', exc_info=False, extra={ 'swh_type': 'storage_retry', 'swh_exception_type': error_name, 'swh_exception': traceback.format_exc(), }) return True - return False + else: + # No exception + return False swh_retry = retry(retry=should_retry_adding, diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import contextlib import copy import datetime import itertools @@ -16,6 +17,7 @@ import dateutil.parser import psycopg2 import psycopg2.pool +import psycopg2.errors from swh.model.model import SHA1_SIZE from swh.model.hashutil import ALGORITHMS, hash_to_bytes, hash_to_hex @@ -27,10 +29,10 @@ get_journal_writer = None # type: ignore # mypy limitation, see https://github.com/python/mypy/issues/1153 -from . import converters +from . import converters, HashCollision from .common import db_transaction_generator, db_transaction from .db import Db -from .exc import StorageDBError +from .exc import StorageArgumentException, StorageDBError from .algos import diff from .metrics import timed, send_metric, process_metrics from .utils import get_partition_bounds_bytes @@ -43,6 +45,28 @@ """Identifier for the empty snapshot""" +VALIDATION_EXCEPTIONS = ( + psycopg2.errors.CheckViolation, + psycopg2.errors.IntegrityError, + psycopg2.errors.InvalidTextRepresentation, + psycopg2.errors.NotNullViolation, + psycopg2.errors.NumericValueOutOfRange, + psycopg2.errors.UndefinedFunction, # (raised on wrong argument typs) +) +"""Exceptions raised by postgresql when validation of the arguments +failed.""" + + +@contextlib.contextmanager +def convert_validation_exceptions(): + """Catches postgresql errors related to invalid arguments, and + re-raises a StorageArgumentException.""" + try: + yield + except VALIDATION_EXCEPTIONS as e: + raise StorageArgumentException(*e.args) + + class Storage(): """SWH storage proxy, encompassing DB and object storage @@ -141,14 +165,15 @@ """Sanity checks on status / reason / length, that postgresql doesn't enforce.""" if d['status'] not in ('visible', 'hidden'): - raise ValueError('Invalid content status: {}'.format(d['status'])) + raise StorageArgumentException( + 'Invalid content status: {}'.format(d['status'])) if d.get('reason') is not None: - raise ValueError( + raise StorageArgumentException( 'Must not provide a reason if content is present.') if d['length'] is None or d['length'] < 0: - raise ValueError('Content length must be positive.') + raise StorageArgumentException('Content length must be positive.') def _content_add_metadata(self, db, cur, content): """Add content to the postgresql database but not the object storage. @@ -156,26 +181,26 @@ # create temporary table for metadata injection db.mktemp('content', cur) - db.copy_to(content, 'tmp_content', - db.content_add_keys, cur) + with convert_validation_exceptions(): + db.copy_to(content, 'tmp_content', + db.content_add_keys, cur) - # move metadata in place - try: - db.content_add_from_temp(cur) - except psycopg2.IntegrityError as e: - from . import HashCollision - if e.diag.sqlstate == '23505' and \ - e.diag.table_name == 'content': - constraint_to_hash_name = { - 'content_pkey': 'sha1', - 'content_sha1_git_idx': 'sha1_git', - 'content_sha256_idx': 'sha256', - } - colliding_hash_name = constraint_to_hash_name \ - .get(e.diag.constraint_name) - raise HashCollision(colliding_hash_name) from None - else: - raise + # move metadata in place + try: + db.content_add_from_temp(cur) + except psycopg2.IntegrityError as e: + if e.diag.sqlstate == '23505' and \ + e.diag.table_name == 'content': + constraint_to_hash_name = { + 'content_pkey': 'sha1', + 'content_sha1_git_idx': 'sha1_git', + 'content_sha256_idx': 'sha256', + } + colliding_hash_name = constraint_to_hash_name \ + .get(e.diag.constraint_name) + raise HashCollision(colliding_hash_name) from None + else: + raise @timed @process_metrics @@ -247,9 +272,10 @@ db.mktemp('content', cur) select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) - db.copy_to(content, 'tmp_content', select_keys, cur) - db.content_update_from_temp(keys_to_update=keys, - cur=cur) + with convert_validation_exceptions(): + db.copy_to(content, 'tmp_content', select_keys, cur) + db.content_update_from_temp(keys_to_update=keys, + cur=cur) @timed @process_metrics @@ -277,7 +303,7 @@ def content_get(self, content): # FIXME: Make this method support slicing the `data`. if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: - raise ValueError( + raise StorageArgumentException( "Send at maximum %s contents." % BULK_BLOCK_CONTENT_LEN_MAX) for obj_id in content: @@ -293,7 +319,7 @@ @db_transaction() def content_get_range(self, start, end, limit=1000, db=None, cur=None): if limit is None: - raise ValueError('Development error: limit should not be None') + raise StorageArgumentException('limit should not be None') contents = [] next_content = None for counter, content_row in enumerate( @@ -315,7 +341,7 @@ self, partition_id: int, nb_partitions: int, limit: int = 1000, page_token: str = None, db=None, cur=None): if limit is None: - raise ValueError('Development error: limit should not be None') + raise StorageArgumentException('limit should not be None') (start, end) = get_partition_bounds_bytes( partition_id, nb_partitions, SHA1_SIZE) if page_token: @@ -348,7 +374,8 @@ keys = db.content_hash_keys if key_hash not in keys: - raise ValueError("key_hash should be one of %s" % keys) + raise StorageArgumentException( + "key_hash should be one of %s" % keys) key_hash_idx = keys.index(key_hash) @@ -374,8 +401,9 @@ @db_transaction() def content_find(self, content, db=None, cur=None): if not set(content).intersection(ALGORITHMS): - raise ValueError('content keys must contain at least one of: ' - 'sha1, sha1_git, sha256, blake2s256') + raise StorageArgumentException( + 'content keys must contain at least one of: ' + 'sha1, sha1_git, sha256, blake2s256') contents = db.content_find(sha1=content.get('sha1'), sha1_git=content.get('sha1_git'), @@ -407,14 +435,16 @@ """Sanity checks on status / reason / length, that postgresql doesn't enforce.""" if d['status'] != 'absent': - raise ValueError('Invalid content status: {}'.format(d['status'])) + raise StorageArgumentException( + 'Invalid content status: {}'.format(d['status'])) if d.get('reason') is None: - raise ValueError( + raise StorageArgumentException( 'Must provide a reason if content is absent.') if d['length'] < -1: - raise ValueError('Content length must be positive or -1.') + raise StorageArgumentException( + 'Content length must be positive or -1.') def _skipped_content_add_metadata(self, db, cur, content): content = \ @@ -426,11 +456,12 @@ if 'origin' in cont: cont['origin'] = origin_id db.mktemp('skipped_content', cur) - db.copy_to(content, 'tmp_skipped_content', - db.skipped_content_keys, cur) + with convert_validation_exceptions(): + db.copy_to(content, 'tmp_skipped_content', + db.skipped_content_keys, cur) - # move metadata in place - db.skipped_content_add_from_temp(cur) + # move metadata in place + db.skipped_content_add_from_temp(cur) @timed @process_metrics @@ -488,7 +519,7 @@ entry = src_entry.copy() entry['dir_id'] = dir_id if entry['type'] not in ('file', 'dir', 'rev'): - raise ValueError( + raise StorageArgumentException( 'Entry type must be file, dir, or rev; not %s' % entry['type']) dir_entries[entry['type']][dir_id].append(entry) @@ -506,27 +537,28 @@ # Copy directory ids dirs_missing_dict = ({'id': dir} for dir in dirs_missing) db.mktemp('directory', cur) - db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) + with convert_validation_exceptions(): + db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) - # Copy entries - for entry_type, entry_list in dir_entries.items(): - entries = itertools.chain.from_iterable( - entries_for_dir - for dir_id, entries_for_dir - in entry_list.items() - if dir_id in dirs_missing) + # Copy entries + for entry_type, entry_list in dir_entries.items(): + entries = itertools.chain.from_iterable( + entries_for_dir + for dir_id, entries_for_dir + in entry_list.items() + if dir_id in dirs_missing) - db.mktemp_dir_entry(entry_type) + db.mktemp_dir_entry(entry_type) - db.copy_to( - entries, - 'tmp_directory_entry_%s' % entry_type, - ['target', 'name', 'perms', 'dir_id'], - cur, - ) + db.copy_to( + entries, + 'tmp_directory_entry_%s' % entry_type, + ['target', 'name', 'perms', 'dir_id'], + cur, + ) - # Do the final copy - db.directory_add_from_temp(cur) + # Do the final copy + db.directory_add_from_temp(cur) summary['directory:add'] = len(dirs_missing) return summary @@ -587,15 +619,16 @@ parents_filtered = [] - db.copy_to( - revisions_filtered, 'tmp_revision', db.revision_add_cols, - cur, - lambda rev: parents_filtered.extend(rev['parents'])) + with convert_validation_exceptions(): + db.copy_to( + revisions_filtered, 'tmp_revision', db.revision_add_cols, + cur, + lambda rev: parents_filtered.extend(rev['parents'])) - db.revision_add_from_temp(cur) + db.revision_add_from_temp(cur) - db.copy_to(parents_filtered, 'revision_history', - ['id', 'parent_id', 'parent_rank'], cur) + db.copy_to(parents_filtered, 'revision_history', + ['id', 'parent_id', 'parent_rank'], cur) return {'revision:add': len(revisions_missing)} @@ -671,10 +704,11 @@ releases_filtered = map(converters.release_to_db, releases_filtered) - db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, - cur) + with convert_validation_exceptions(): + db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, + cur) - db.release_add_from_temp(cur) + db.release_add_from_temp(cur) return {'release:add': len(releases_missing)} @@ -714,20 +748,23 @@ db.mktemp_snapshot_branch(cur) created_temp_table = True - db.copy_to( - ( - { - 'name': name, - 'target': info['target'] if info else None, - 'target_type': (info['target_type'] - if info else None), - } - for name, info in snapshot['branches'].items() - ), - 'tmp_snapshot_branch', - ['name', 'target', 'target_type'], - cur, - ) + try: + db.copy_to( + ( + { + 'name': name, + 'target': info['target'] if info else None, + 'target_type': (info['target_type'] + if info else None), + } + for name, info in snapshot['branches'].items() + ), + 'tmp_snapshot_branch', + ['name', 'target', 'target_type'], + cur, + ) + except VALIDATION_EXCEPTIONS + (KeyError,) as e: + raise StorageArgumentException(*e.args) if self.journal_writer: self.journal_writer.write_addition('snapshot', snapshot) @@ -776,7 +813,7 @@ snapshot = self.snapshot_get( origin_visit['snapshot'], db=db, cur=cur) if not snapshot: - raise ValueError( + raise StorageArgumentException( 'last origin visit references an unknown snapshot') return snapshot @@ -842,7 +879,8 @@ # FIXME: Converge on iso8601 at some point date = dateutil.parser.parse(date) - visit_id = db.origin_visit_add(origin_url, date, type, cur) + with convert_validation_exceptions(): + visit_id = db.origin_visit_add(origin_url, date, type, cur) if self.journal_writer: # We can write to the journal only after inserting to the @@ -864,12 +902,13 @@ metadata=None, snapshot=None, db=None, cur=None): if not isinstance(origin, str): - raise TypeError('origin must be a string, not %r' % (origin,)) + raise StorageArgumentException( + 'origin must be a string, not %r' % (origin,)) origin_url = origin visit = db.origin_visit_get(origin_url, visit_id, cur=cur) if not visit: - raise ValueError('Invalid visit_id for this origin.') + raise StorageArgumentException('Invalid visit_id for this origin.') visit = dict(zip(db.origin_visit_get_cols, visit)) @@ -886,7 +925,8 @@ self.journal_writer.write_update('origin_visit', { **visit, **updates}) - db.origin_visit_update(origin_url, visit_id, updates, cur) + with convert_validation_exceptions(): + db.origin_visit_update(origin_url, visit_id, updates, cur) @timed @db_transaction() @@ -896,8 +936,9 @@ if isinstance(visit['date'], str): visit['date'] = dateutil.parser.parse(visit['date']) if not isinstance(visit['origin'], str): - raise TypeError("visit['origin'] must be a string, not %r" - % (visit['origin'],)) + raise StorageArgumentException( + "visit['origin'] must be a string, not %r" + % (visit['origin'],)) if self.journal_writer: for visit in visits: @@ -1013,7 +1054,7 @@ *, db=None, cur=None) -> dict: page_token = page_token or '0' if not isinstance(page_token, str): - raise TypeError('page_token must be a string.') + raise StorageArgumentException('page_token must be a string.') origin_from = int(page_token) result: Dict[str, Any] = { 'origins': [ @@ -1058,6 +1099,8 @@ @timed @db_transaction() def origin_add_one(self, origin, db=None, cur=None): + if 'url' not in origin: + raise StorageArgumentException('Missing origin url') origin_row = list(db.origin_get_by_url([origin['url']], cur))[0] origin_url = dict(zip(db.origin_cols, origin_row))['url'] if origin_url: @@ -1117,11 +1160,12 @@ @db_transaction() def tool_add(self, tools, db=None, cur=None): db.mktemp_tool(cur) - db.copy_to(tools, 'tmp_tool', - ['name', 'version', 'configuration'], - cur) + with convert_validation_exceptions(): + db.copy_to(tools, 'tmp_tool', + ['name', 'version', 'configuration'], + cur) + tools = db.tool_add_from_temp(cur) - tools = db.tool_add_from_temp(cur) results = [dict(zip(db.tool_cols, line)) for line in tools] send_metric('tool:add', count=len(results), method_name='tool_add') return results diff --git a/swh/storage/tests/test_retry.py b/swh/storage/tests/test_retry.py --- a/swh/storage/tests/test_retry.py +++ b/swh/storage/tests/test_retry.py @@ -11,6 +11,7 @@ from unittest.mock import call from swh.storage import HashCollision +from swh.storage.exc import StorageArgumentException from swh.storage.retry import RetryingProxyStorage @@ -76,14 +77,15 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.content_add') - mock_memory.side_effect = ValueError('Refuse to add content always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add content always!') sample_content = sample_data['content'][0] content = next(swh_storage.content_get([sample_content['sha1']])) assert not content - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.content_add([sample_content]) assert mock_memory.call_count == 1 @@ -144,7 +146,8 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.content_add_metadata') - mock_memory.side_effect = ValueError('Refuse to add content_metadata!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add content_metadata!') sample_content = sample_data['content_metadata'][0] pk = sample_content['sha1'] @@ -152,7 +155,7 @@ content_metadata = swh_storage.content_get_metadata([pk]) assert not content_metadata[pk] - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.content_add_metadata([sample_content]) assert mock_memory.call_count == 1 @@ -211,14 +214,15 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.origin_add_one') - mock_memory.side_effect = ValueError('Refuse to add origin always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add origin always!') sample_origin = sample_data['origin'][0] origin = swh_storage.origin_get(sample_origin) assert not origin - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.origin_add_one([sample_origin]) assert mock_memory.call_count == 1 @@ -285,14 +289,15 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.origin_visit_add') - mock_memory.side_effect = ValueError('Refuse to add origin always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add origin always!') origin_url = sample_data['origin'][0]['url'] origin = list(swh_storage.origin_visit_get(origin_url)) assert not origin - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.origin_visit_add(origin_url, '2020-01-31', 'svn') assert mock_memory.has_calls([ @@ -357,14 +362,15 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.tool_add') - mock_memory.side_effect = ValueError('Refuse to add tool always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add tool always!') sample_tool = sample_data['tool'][0] tool = swh_storage.tool_get(sample_tool) assert not tool - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.tool_add([sample_tool]) assert mock_memory.call_count == 1 @@ -439,7 +445,8 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.metadata_provider_add') - mock_memory.side_effect = ValueError('Refuse to add provider_id always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add provider_id always!') provider = sample_data['provider'][0] provider_get = to_provider(provider) @@ -447,7 +454,7 @@ provider_id = swh_storage.metadata_provider_get_by(provider_get) assert not provider_id - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.metadata_provider_add(**provider_get) assert mock_memory.call_count == 1 @@ -520,7 +527,8 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.origin_metadata_add') - mock_memory.side_effect = ValueError('Refuse to add always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add always!') ori_meta = sample_data['origin_metadata'][0] origin = ori_meta['origin'] @@ -532,7 +540,7 @@ tool_id = ori_meta['tool'] metadata = ori_meta['metadata'] - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.origin_metadata_add(url, ts, provider_id, tool_id, metadata) @@ -603,11 +611,12 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.origin_visit_update') - mock_memory.side_effect = ValueError('Refuse to add origin always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add origin always!') origin_url = sample_data['origin'][0]['url'] visit_id = 9 - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.origin_visit_update(origin_url, visit_id, 'partial') assert mock_memory.call_count == 1 @@ -671,14 +680,15 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.directory_add') - mock_memory.side_effect = ValueError('Refuse to add directory always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add directory always!') sample_dir = sample_data['directory'][0] directory_id = swh_storage.directory_get_random() # no directory assert not directory_id - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.directory_add([sample_dir]) assert mock_memory.call_count == 1 @@ -742,14 +752,15 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.revision_add') - mock_memory.side_effect = ValueError('Refuse to add revision always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add revision always!') sample_rev = sample_data['revision'][0] revision = next(swh_storage.revision_get([sample_rev['id']])) assert not revision - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.revision_add([sample_rev]) assert mock_memory.call_count == 1 @@ -813,14 +824,15 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.release_add') - mock_memory.side_effect = ValueError('Refuse to add release always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add release always!') sample_rel = sample_data['release'][0] release = next(swh_storage.release_get([sample_rel['id']])) assert not release - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.release_add([sample_rel]) assert mock_memory.call_count == 1 @@ -884,14 +896,15 @@ """ mock_memory = mocker.patch( 'swh.storage.in_memory.InMemoryStorage.snapshot_add') - mock_memory.side_effect = ValueError('Refuse to add snapshot always!') + mock_memory.side_effect = StorageArgumentException( + 'Refuse to add snapshot always!') sample_snap = sample_data['snapshot'][0] snapshot = swh_storage.snapshot_get(sample_snap['id']) assert not snapshot - with pytest.raises(ValueError, match='Refuse to add'): + with pytest.raises(StorageArgumentException, match='Refuse to add'): swh_storage.snapshot_add([sample_snap]) assert mock_memory.call_count == 1 diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -29,6 +29,7 @@ from swh.model.hypothesis_strategies import objects from swh.storage import HashCollision from swh.storage.converters import origin_url_to_sha1 as sha1 +from swh.storage.exc import StorageArgumentException from swh.storage.interface import StorageInterface from .storage_data import data @@ -175,29 +176,26 @@ def test_content_add_validation(self, swh_storage): cont = data.cont - with pytest.raises(ValueError, match='status'): + with pytest.raises(StorageArgumentException, match='status'): swh_storage.content_add([{**cont, 'status': 'absent'}]) - with pytest.raises(ValueError, match='status'): + with pytest.raises(StorageArgumentException, match='status'): swh_storage.content_add([{**cont, 'status': 'foobar'}]) - with pytest.raises(ValueError, match="(?i)length"): + with pytest.raises(StorageArgumentException, match="(?i)length"): swh_storage.content_add([{**cont, 'length': -2}]) - with pytest.raises( - (ValueError, TypeError), - match="reason"): + with pytest.raises(StorageArgumentException, match="reason"): swh_storage.content_add([{**cont, 'reason': 'foobar'}]) def test_skipped_content_add_validation(self, swh_storage): cont = data.cont.copy() del cont['data'] - with pytest.raises(ValueError, match='status'): + with pytest.raises(StorageArgumentException, match='status'): swh_storage.skipped_content_add([{**cont, 'status': 'visible'}]) - with pytest.raises((ValueError, psycopg2.IntegrityError), - match='reason') as cm: + with pytest.raises(StorageArgumentException, match='reason') as cm: swh_storage.skipped_content_add([{**cont, 'status': 'absent'}]) if type(cm.value) == psycopg2.IntegrityError: @@ -480,10 +478,10 @@ def test_content_get_partition_limit_none(self, swh_storage): """content_get_partition call with wrong limit input should fail""" - with pytest.raises(ValueError) as e: + with pytest.raises(StorageArgumentException) as e: swh_storage.content_get_partition(1, 16, limit=None) - assert e.value.args == ('Development error: limit should not be None',) + assert e.value.args == ('limit should not be None',) def test_generate_content_get_partition_pagination( self, swh_storage, swh_contents): @@ -582,14 +580,13 @@ dir_ = copy.deepcopy(data.dir) dir_['entries'][0]['type'] = 'foobar' - with pytest.raises(ValueError, match='type.*foobar'): + with pytest.raises(StorageArgumentException, match='type.*foobar'): swh_storage.directory_add([dir_]) dir_ = copy.deepcopy(data.dir) del dir_['entries'][0]['target'] - with pytest.raises((TypeError, psycopg2.IntegrityError), - match='target') as cm: + with pytest.raises(StorageArgumentException, match='target') as cm: swh_storage.directory_add([dir_]) if type(cm.value) == psycopg2.IntegrityError: @@ -789,8 +786,7 @@ rev = copy.deepcopy(data.revision) rev['date']['offset'] = 2**16 - with pytest.raises((ValueError, psycopg2.DataError), - match='offset') as cm: + with pytest.raises(StorageArgumentException, match='offset') as cm: swh_storage.revision_add([rev]) if type(cm.value) == psycopg2.DataError: @@ -800,8 +796,7 @@ rev = copy.deepcopy(data.revision) rev['committer_date']['offset'] = 2**16 - with pytest.raises((ValueError, psycopg2.DataError), - match='offset') as cm: + with pytest.raises(StorageArgumentException, match='offset') as cm: swh_storage.revision_add([rev]) if type(cm.value) == psycopg2.DataError: @@ -811,8 +806,7 @@ rev = copy.deepcopy(data.revision) rev['type'] = 'foobar' - with pytest.raises((ValueError, psycopg2.DataError), - match='(?i)type') as cm: + with pytest.raises(StorageArgumentException, match='(?i)type') as cm: swh_storage.revision_add([rev]) if type(cm.value) == psycopg2.DataError: @@ -1011,8 +1005,7 @@ rel = copy.deepcopy(data.release) rel['date']['offset'] = 2**16 - with pytest.raises((ValueError, psycopg2.DataError), - match='offset') as cm: + with pytest.raises(StorageArgumentException, match='offset') as cm: swh_storage.release_add([rel]) if type(cm.value) == psycopg2.DataError: @@ -1022,8 +1015,7 @@ rel = copy.deepcopy(data.release) rel['author'] = None - with pytest.raises((ValueError, psycopg2.IntegrityError), - match='date') as cm: + with pytest.raises(StorageArgumentException, match='date') as cm: swh_storage.release_add([rel]) if type(cm.value) == psycopg2.IntegrityError: @@ -1170,7 +1162,7 @@ assert add1 == add2 def test_origin_add_validation(self, swh_storage): - with pytest.raises((TypeError, KeyError), match='url'): + with pytest.raises(StorageArgumentException, match='url'): swh_storage.origin_add([{'type': 'git'}]) def test_origin_get_legacy(self, swh_storage): @@ -1491,8 +1483,8 @@ def test_origin_visit_add_validation(self, swh_storage): origin_url = swh_storage.origin_add_one(data.origin2) - with pytest.raises((TypeError, psycopg2.ProgrammingError)) as cm: - swh_storage.origin_visit_add(origin_url, date=[b'foo']) + with pytest.raises(StorageArgumentException) as cm: + swh_storage.origin_visit_add(origin_url, date=[b'foo'], type='git') if type(cm.value) == psycopg2.ProgrammingError: assert cm.value.pgcode \ @@ -1673,8 +1665,7 @@ type=data.type_visit2, ) - with pytest.raises((ValueError, psycopg2.DataError), - match='status') as cm: + with pytest.raises(StorageArgumentException, match='status') as cm: swh_storage.origin_visit_update( origin_url, visit['visit'], status='foobar') @@ -2199,13 +2190,13 @@ snap = copy.deepcopy(data.snapshot) snap['branches'][b'foo'] = {'target_type': 'revision'} - with pytest.raises(KeyError, match='target'): + with pytest.raises(StorageArgumentException, match='target'): swh_storage.snapshot_add([snap]) snap = copy.deepcopy(data.snapshot) snap['branches'][b'foo'] = {'target': b'\x42'*20} - with pytest.raises(KeyError, match='target_type'): + with pytest.raises(StorageArgumentException, match='target_type'): swh_storage.snapshot_add([snap]) def test_snapshot_add_count_branches(self, swh_storage): @@ -2431,7 +2422,7 @@ swh_storage.snapshot_add([data.snapshot]) - with pytest.raises(ValueError): + with pytest.raises(StorageArgumentException): swh_storage.origin_visit_update( origin_url, visit_id, snapshot=data.snapshot['id']) @@ -2610,9 +2601,9 @@ swh_storage.origin_visit_update( origin_url, visit1_id, snapshot=data.complete_snapshot['id']) - with pytest.raises(ValueError): - swh_storage.snapshot_get_latest( - origin_url) + with pytest.raises(Exception): + # XXX: should the exception be more specific than this? + swh_storage.snapshot_get_latest(origin_url) # Status filter: both visits are status=ongoing, so no snapshot # returned @@ -2624,7 +2615,8 @@ swh_storage.origin_visit_update( origin_url, visit1_id, status='full') - with pytest.raises(ValueError): + with pytest.raises(Exception): + # XXX: should the exception be more specific than this? swh_storage.snapshot_get_latest( origin_url, allowed_statuses=['full']), @@ -2639,7 +2631,8 @@ swh_storage.origin_visit_update( origin_url, visit2_id, snapshot=data.snapshot['id']) - with pytest.raises(ValueError): + with pytest.raises(Exception): + # XXX: should the exception be more specific than this? swh_storage.snapshot_get_latest( origin_url) @@ -2946,11 +2939,11 @@ def test_content_find_bad_input(self, swh_storage): # 1. with bad input - with pytest.raises(ValueError): + with pytest.raises(StorageArgumentException): swh_storage.content_find({}) # empty is bad # 2. with bad input - with pytest.raises(ValueError): + with pytest.raises(StorageArgumentException): swh_storage.content_find( {'unknown-sha1': 'something'}) # not the right key @@ -3407,10 +3400,10 @@ def test_generate_content_get_range_limit_none(self, swh_storage): """content_get_range call with wrong limit input should fail""" - with pytest.raises(ValueError) as e: + with pytest.raises(StorageArgumentException) as e: swh_storage.content_get_range(start=None, end=None, limit=None) - assert e.value.args == ('Development error: limit should not be None',) + assert e.value.args == ('limit should not be None',) def test_generate_content_get_range_no_limit( self, swh_storage, swh_contents):