Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7123156
D2628.id9397.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
43 KB
Subscribers
None
D2628.id9397.diff
View Options
diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py
--- a/swh/storage/cassandra/cql.py
+++ b/swh/storage/cassandra/cql.py
@@ -11,14 +11,19 @@
Any, Callable, Dict, Generator, Iterable, List, Optional, TypeVar
)
+from cassandra import CoordinationFailure
from cassandra.cluster import (
Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet)
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy
from cassandra.query import PreparedStatement
-from tenacity import retry, stop_after_attempt, wait_random_exponential
+from tenacity import (
+ retry, stop_after_attempt, wait_random_exponential,
+ retry_if_exception_type,
+)
from swh.model.model import (
Sha1Git, TimestampWithTimezone, Timestamp, Person, Content,
+ OriginVisit
)
from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url
@@ -120,7 +125,8 @@
MAX_RETRIES = 3
@retry(wait=wait_random_exponential(multiplier=1, max=10),
- stop=stop_after_attempt(MAX_RETRIES))
+ stop=stop_after_attempt(MAX_RETRIES),
+ retry=retry_if_exception_type(CoordinationFailure))
def _execute_with_retries(self, statement, args) -> ResultSet:
return self._session.execute(statement, args, timeout=1000.)
@@ -528,10 +534,9 @@
@_prepared_insert_statement('origin_visit', _origin_visit_keys)
def origin_visit_add_one(
- self, visit: Dict[str, Any], *, statement) -> None:
- self._execute_with_retries(
- statement, [visit[key] for key in self._origin_visit_keys])
- self._increment_counter('origin_visit', 1)
+ self, visit: OriginVisit, *, statement) -> None:
+ self._add_one(statement, 'origin_visit', visit,
+ self._origin_visit_keys)
@_prepared_statement(
'UPDATE origin_visit SET ' +
diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -15,7 +15,7 @@
from swh.model.model import (
Revision, Release, Directory, DirectoryEntry, Content, SkippedContent,
- OriginVisit,
+ OriginVisit, Snapshot
)
from swh.objstorage import get_objstorage
from swh.objstorage.exc import ObjNotFoundError
@@ -26,6 +26,8 @@
# mypy limitation, see https://github.com/python/mypy/issues/1153
+from .. import HashCollision
+from ..exc import StorageArgumentException
from .common import TOKEN_BEGIN, TOKEN_END
from .converters import (
revision_to_db, revision_from_db, release_to_db, release_from_db,
@@ -60,7 +62,10 @@
return True
def _content_add(self, contents, with_data):
- contents = [Content.from_dict(c) for c in contents]
+ try:
+ contents = [Content.from_dict(c) for c in contents]
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
# Filter-out content already in the database.
contents = [c for c in contents
@@ -112,7 +117,6 @@
algo, content.get_hash(algo))
if len(pks) > 1:
# There are more than the one we just inserted.
- from .. import HashCollision
raise HashCollision(algo, content.get_hash(algo), pks)
summary = {
@@ -139,7 +143,7 @@
def content_get(self, content):
if len(content) > BULK_BLOCK_CONTENT_LEN_MAX:
- raise ValueError(
+ raise StorageArgumentException(
"Sending at most %s contents." % BULK_BLOCK_CONTENT_LEN_MAX)
for obj_id in content:
try:
@@ -154,7 +158,7 @@
self, partition_id: int, nb_partitions: int, limit: int = 1000,
page_token: str = None):
if limit is None:
- raise ValueError('Development error: limit should not be None')
+ raise StorageArgumentException('limit should not be None')
# Compute start and end of the range of tokens covered by the
# requested partition
@@ -165,7 +169,7 @@
# offset the range start according to the `page_token`.
if page_token is not None:
if not (range_start <= int(page_token) <= range_end):
- raise ValueError('Invalid page_token.')
+ raise StorageArgumentException('Invalid page_token.')
range_start = int(page_token)
# Get the first rows of the range
@@ -213,8 +217,9 @@
# It will be used to do an initial filtering efficiently.
filter_algos = list(set(content).intersection(HASH_ALGORITHMS))
if not filter_algos:
- raise ValueError('content keys must contain at least one of: '
- '%s' % ', '.join(sorted(HASH_ALGORITHMS)))
+ raise StorageArgumentException(
+ 'content keys must contain at least one of: '
+ '%s' % ', '.join(sorted(HASH_ALGORITHMS)))
common_algo = filter_algos[0]
# Find all contents whose common_algo matches at least one
@@ -260,7 +265,10 @@
return self._cql_runner.content_get_random().sha1_git
def _skipped_content_add(self, contents):
- contents = [SkippedContent.from_dict(c) for c in contents]
+ try:
+ contents = [SkippedContent.from_dict(c) for c in contents]
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
# Filter-out content already in the database.
contents = [c for c in contents
@@ -306,7 +314,10 @@
self.journal_writer.write_additions('directory', directories)
for directory in directories:
- directory = Directory.from_dict(directory)
+ try:
+ directory = Directory.from_dict(directory)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
# Add directory entries to the 'directory_entry' table
for entry in directory.entries:
@@ -412,7 +423,10 @@
self.journal_writer.write_additions('revision', revisions)
for revision in revisions:
- revision = revision_to_db(revision)
+ try:
+ revision = revision_to_db(revision)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
if revision:
# Add parents first
@@ -507,7 +521,10 @@
self.journal_writer.write_additions('release', releases)
for release in releases:
- release = release_to_db(release)
+ try:
+ release = release_to_db(release)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
if release:
self._cql_runner.release_add_one(release)
@@ -532,30 +549,38 @@
return self._cql_runner.release_get_random().id
def snapshot_add(self, snapshots):
- snapshots = list(snapshots)
+ try:
+ snapshots = [Snapshot.from_dict(snap) for snap in snapshots]
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
missing = self._cql_runner.snapshot_missing(
- [snp['id'] for snp in snapshots])
- snapshots = [snp for snp in snapshots if snp['id'] in missing]
+ [snp.id for snp in snapshots])
+ snapshots = [snp for snp in snapshots if snp.id in missing]
for snapshot in snapshots:
if self.journal_writer:
self.journal_writer.write_addition('snapshot', snapshot)
# Add branches
- for (branch_name, branch) in snapshot['branches'].items():
+ for (branch_name, branch) in snapshot.branches.items():
if branch is None:
- branch = {'target_type': None, 'target': None}
- self._cql_runner.snapshot_branch_add_one({
- 'snapshot_id': snapshot['id'],
+ target_type = None
+ target = None
+ else:
+ target_type = branch.target_type.value
+ target = branch.target
+ branch = {
+ 'snapshot_id': snapshot.id,
'name': branch_name,
- 'target_type': branch['target_type'],
- 'target': branch['target'],
- })
+ 'target_type': target_type,
+ 'target': target,
+ }
+ self._cql_runner.snapshot_branch_add_one(branch)
# Add the snapshot *after* adding all the branches, so someone
# calling snapshot_get_branch in the meantime won't end up
# with half the branches.
- self._cql_runner.snapshot_add_one(snapshot['id'])
+ self._cql_runner.snapshot_add_one(snapshot.id)
return {'snapshot:add': len(snapshots)}
@@ -582,7 +607,8 @@
if visit:
assert visit['snapshot']
if self._cql_runner.snapshot_missing([visit['snapshot']]):
- raise ValueError('Visit references unknown snapshot')
+ raise StorageArgumentException(
+ 'Visit references unknown snapshot')
return self.snapshot_get_branches(visit['snapshot'])
def snapshot_count_branches(self, snapshot_id):
@@ -688,7 +714,7 @@
return_single = False
if any('id' in origin for origin in origins):
- raise ValueError('Origin ids are not supported.')
+ raise StorageArgumentException('Origin ids are not supported.')
results = [self.origin_get_one(origin) for origin in origins]
@@ -700,7 +726,9 @@
def origin_get_one(self, origin):
if 'id' in origin:
- raise ValueError('Origin ids are not supported.')
+ raise StorageArgumentException('Origin ids are not supported.')
+ if 'url' not in origin:
+ raise StorageArgumentException('Missing origin url')
rows = self._cql_runner.origin_get_by_url(origin['url'])
rows = list(rows)
@@ -730,7 +758,7 @@
if page_token:
start_token = int(page_token)
if not (TOKEN_BEGIN <= start_token <= TOKEN_END):
- raise ValueError('Invalid page_token.')
+ raise StorageArgumentException('Invalid page_token.')
rows = self._cql_runner.origin_list(start_token, limit)
rows = list(rows)
@@ -768,7 +796,8 @@
def origin_add(self, origins):
origins = list(origins)
if any('id' in origin for origin in origins):
- raise ValueError('Origins must not already have an id.')
+ raise StorageArgumentException(
+ 'Origins must not already have an id.')
results = []
for origin in origins:
self.origin_add_one(origin)
@@ -815,6 +844,11 @@
if self.journal_writer:
self.journal_writer.write_addition('origin_visit', visit)
+ try:
+ visit = OriginVisit.from_dict(visit)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
+
self._cql_runner.origin_visit_add_one(visit)
return {
@@ -829,8 +863,11 @@
# Get the existing data of the visit
row = self._cql_runner.origin_visit_get_one(origin_url, visit_id)
if not row:
- raise ValueError('This origin visit does not exist.')
- visit = OriginVisit.from_dict(self._format_origin_visit_row(row))
+ raise StorageArgumentException('This origin visit does not exist.')
+ try:
+ visit = OriginVisit.from_dict(self._format_origin_visit_row(row))
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
updates = {}
if status:
@@ -840,7 +877,10 @@
if snapshot:
updates['snapshot'] = snapshot
- visit = attr.evolve(visit, **updates)
+ try:
+ visit = attr.evolve(visit, **updates)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
if self.journal_writer:
self.journal_writer.write_update('origin_visit', visit)
diff --git a/swh/storage/exc.py b/swh/storage/exc.py
--- a/swh/storage/exc.py
+++ b/swh/storage/exc.py
@@ -21,3 +21,8 @@
def __str__(self):
args = self.args
return 'An unexpected error occurred in the api backend: %s' % args
+
+
+class StorageArgumentException(Exception):
+ """Argument passed to a Storage endpoint is invalid."""
+ pass
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -25,6 +25,8 @@
from swh.objstorage import get_objstorage
from swh.objstorage.exc import ObjNotFoundError
+from . import HashCollision
+from .exc import StorageArgumentException
from .storage import get_journal_writer
from .converters import origin_url_to_sha1
from .utils import get_partition_bounds_bytes
@@ -79,13 +81,16 @@
if content.status is None:
content.status = 'visible'
if content.status == 'absent':
- raise ValueError('content with status=absent')
+ raise StorageArgumentException('content with status=absent')
if content.length is None:
- raise ValueError('content with length=None')
+ raise StorageArgumentException('content with length=None')
if self.journal_writer:
for content in contents:
- content = attr.evolve(content, data=None)
+ try:
+ content = attr.evolve(content, data=None)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
self.journal_writer.write_addition('content', content)
summary = {
@@ -103,7 +108,6 @@
hash_ = content.get_hash(algorithm)
if hash_ in self._content_indexes[algorithm]\
and (algorithm not in {'blake2s256', 'sha256'}):
- from . import HashCollision
raise HashCollision(algorithm, hash_, key)
for algorithm in DEFAULT_ALGORITHMS:
hash_ = content.get_hash(algorithm)
@@ -115,9 +119,12 @@
summary['content:add'] += 1
if with_data:
content_data = self._contents[key].data
- self._contents[key] = attr.evolve(
- self._contents[key],
- data=None)
+ try:
+ self._contents[key] = attr.evolve(
+ self._contents[key],
+ data=None)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
summary['content:add:bytes'] += len(content_data)
self.objstorage.add(content_data, content.sha1)
@@ -125,8 +132,11 @@
def content_add(self, content):
now = datetime.datetime.now(tz=datetime.timezone.utc)
- content = [attr.evolve(Content.from_dict(c), ctime=now)
- for c in content]
+ try:
+ content = [attr.evolve(Content.from_dict(c), ctime=now)
+ for c in content]
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
return self._content_add(content, with_data=True)
def content_update(self, content, keys=[]):
@@ -144,7 +154,10 @@
hash_ = old_cont.get_hash(algorithm)
self._content_indexes[algorithm][hash_].remove(old_key)
- new_cont = attr.evolve(old_cont, **cont_update)
+ try:
+ new_cont = attr.evolve(old_cont, **cont_update)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
new_key = self._content_key(new_cont)
self._contents[new_key] = new_cont
@@ -160,7 +173,7 @@
def content_get(self, content):
# FIXME: Make this method support slicing the `data`.
if len(content) > BULK_BLOCK_CONTENT_LEN_MAX:
- raise ValueError(
+ raise StorageArgumentException(
"Sending at most %s contents." % BULK_BLOCK_CONTENT_LEN_MAX)
for obj_id in content:
try:
@@ -173,7 +186,7 @@
def content_get_range(self, start, end, limit=1000):
if limit is None:
- raise ValueError('Development error: limit should not be None')
+ raise StorageArgumentException('limit should not be None')
from_index = bisect.bisect_left(self._sorted_sha1s, start)
sha1s = itertools.islice(self._sorted_sha1s, from_index, None)
sha1s = ((sha1, content_key)
@@ -197,7 +210,7 @@
self, partition_id: int, nb_partitions: int, limit: int = 1000,
page_token: str = None):
if limit is None:
- raise ValueError('Development error: limit should not be None')
+ raise StorageArgumentException('limit should not be None')
(start, end) = get_partition_bounds_bytes(
partition_id, nb_partitions, SHA1_SIZE)
if page_token:
@@ -231,8 +244,9 @@
def content_find(self, content):
if not set(content).intersection(DEFAULT_ALGORITHMS):
- raise ValueError('content keys must contain at least one of: '
- '%s' % ', '.join(sorted(DEFAULT_ALGORITHMS)))
+ raise StorageArgumentException(
+ 'content keys must contain at least one of: %s'
+ % ', '.join(sorted(DEFAULT_ALGORITHMS)))
found = []
for algo in DEFAULT_ALGORITHMS:
hash = content.get(algo)
@@ -278,7 +292,8 @@
if content.length is None:
content = attr.evolve(content, length=-1)
if content.status != 'absent':
- raise ValueError(f'Content with status={content.status}')
+ raise StorageArgumentException(
+ f'Content with status={content.status}')
if self.journal_writer:
for content in contents:
@@ -318,8 +333,11 @@
def skipped_content_add(self, content):
content = list(content)
now = datetime.datetime.now(tz=datetime.timezone.utc)
- content = [attr.evolve(SkippedContent.from_dict(c), ctime=now)
- for c in content]
+ try:
+ content = [attr.evolve(SkippedContent.from_dict(c), ctime=now)
+ for c in content]
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
return self._skipped_content_add(content)
def directory_add(self, directories):
@@ -330,7 +348,10 @@
(dir_ for dir_ in directories
if dir_['id'] not in self._directories))
- directories = [Directory.from_dict(d) for d in directories]
+ try:
+ directories = [Directory.from_dict(d) for d in directories]
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
count = 0
for directory in directories:
@@ -423,7 +444,10 @@
(rev for rev in revisions
if rev['id'] not in self._revisions))
- revisions = [Revision.from_dict(rev) for rev in revisions]
+ try:
+ revisions = [Revision.from_dict(rev) for rev in revisions]
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
count = 0
for revision in revisions:
@@ -481,7 +505,10 @@
(rel for rel in releases
if rel['id'] not in self._releases))
- releases = [Release.from_dict(rel) for rel in releases]
+ try:
+ releases = [Release.from_dict(rel) for rel in releases]
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
count = 0
for rel in releases:
@@ -510,7 +537,10 @@
def snapshot_add(self, snapshots):
count = 0
- snapshots = (Snapshot.from_dict(d) for d in snapshots)
+ try:
+ snapshots = [Snapshot.from_dict(d) for d in snapshots]
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
snapshots = (snap for snap in snapshots
if snap.id not in self._snapshots)
for snapshot in snapshots:
@@ -558,7 +588,7 @@
if visit and visit['snapshot']:
snapshot = self.snapshot_get(visit['snapshot'])
if not snapshot:
- raise ValueError(
+ raise StorageArgumentException(
'last origin visit references an unknown snapshot')
return snapshot
@@ -636,11 +666,11 @@
# Sanity check to be error-compatible with the pgsql backend
if any('id' in origin for origin in origins) \
and not all('id' in origin for origin in origins):
- raise ValueError(
+ raise StorageArgumentException(
'Either all origins or none at all should have an "id".')
if any('url' in origin for origin in origins) \
and not all('url' in origin for origin in origins):
- raise ValueError(
+ raise StorageArgumentException(
'Either all origins or none at all should have '
'an "url" key.')
@@ -651,7 +681,7 @@
if origin['url'] in self._origins:
result = self._origins[origin['url']]
else:
- raise ValueError(
+ raise StorageArgumentException(
'Origin must have an url.')
results.append(self._convert_origin(result))
@@ -727,7 +757,10 @@
return origins
def origin_add_one(self, origin):
- origin = Origin.from_dict(origin)
+ try:
+ origin = Origin.from_dict(origin)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
if origin.url not in self._origins:
if self.journal_writer:
self.journal_writer.write_addition('origin', origin)
@@ -748,13 +781,14 @@
def origin_visit_add(self, origin, date, type):
origin_url = origin
if origin_url is None:
- raise ValueError('Unknown origin.')
+ raise StorageArgumentException('Unknown origin.')
if isinstance(date, str):
# FIXME: Converge on iso8601 at some point
date = dateutil.parser.parse(date)
elif not isinstance(date, datetime.datetime):
- raise TypeError('date must be a datetime or a string.')
+ raise StorageArgumentException(
+ 'date must be a datetime or a string.')
visit_ret = None
if origin_url in self._origins:
@@ -791,13 +825,13 @@
raise TypeError('origin must be a string, not %r' % (origin,))
origin_url = self._get_origin_url(origin)
if origin_url is None:
- raise ValueError('Unknown origin.')
+ raise StorageArgumentException('Unknown origin.')
try:
visit = self._origin_visits[origin_url][visit_id-1]
except IndexError:
- raise ValueError('Unknown visit_id for this origin') \
- from None
+ raise StorageArgumentException(
+ 'Unknown visit_id for this origin') from None
updates = {}
if status:
@@ -807,7 +841,10 @@
if snapshot:
updates['snapshot'] = snapshot
- visit = attr.evolve(visit, **updates)
+ try:
+ visit = attr.evolve(visit, **updates)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
if self.journal_writer:
self.journal_writer.write_update('origin_visit', visit)
@@ -819,7 +856,10 @@
if not isinstance(visit['origin'], str):
raise TypeError("visit['origin'] must be a string, not %r"
% (visit['origin'],))
- visits = [OriginVisit.from_dict(d) for d in visits]
+ try:
+ visits = [OriginVisit.from_dict(d) for d in visits]
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
if self.journal_writer:
for visit in visits:
@@ -829,7 +869,10 @@
visit_id = visit.visit
origin_url = visit.origin
- visit = attr.evolve(visit, origin=origin_url)
+ try:
+ visit = attr.evolve(visit, origin=origin_url)
+ except (KeyError, TypeError, ValueError) as e:
+ raise StorageArgumentException(*e.args)
self._objects[(origin_url, visit_id)].append(
('origin_visit', None))
diff --git a/swh/storage/retry.py b/swh/storage/retry.py
--- a/swh/storage/retry.py
+++ b/swh/storage/retry.py
@@ -4,46 +4,36 @@
# See top-level LICENSE file for more information
import logging
-import psycopg2
import traceback
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Union
-from requests.exceptions import ConnectionError
from tenacity import (
- retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type
+ retry, stop_after_attempt, wait_random_exponential,
)
-from swh.storage import get_storage, HashCollision
+from swh.storage import get_storage
+from swh.storage.exc import StorageArgumentException
logger = logging.getLogger(__name__)
-RETRY_EXCEPTIONS = [
- # raised when two parallel insertions insert the same data
- psycopg2.IntegrityError,
- HashCollision,
- # when the server is restarting
- ConnectionError,
-]
-
-
def should_retry_adding(error: Exception) -> bool:
- """Retry if the error/exception if one of the RETRY_EXCEPTIONS type.
+ """Retry if the error/exception is (probably) not about a caller error
"""
- for exc in RETRY_EXCEPTIONS:
- if retry_if_exception_type(exc)(error):
- error_name = error.__module__ + '.' + error.__class__.__name__
- logger.warning('Retry adding a batch', exc_info=False, extra={
- 'swh_type': 'storage_retry',
- 'swh_exception_type': error_name,
- 'swh_exception': traceback.format_exc(),
- })
- return True
- return False
+ if isinstance(error, StorageArgumentException):
+ return False
+ else:
+ error_name = error.__module__ + '.' + error.__class__.__name__
+ logger.warning('Retry adding a batch', exc_info=False, extra={
+ 'swh_type': 'storage_retry',
+ 'swh_exception_type': error_name,
+ 'swh_exception': traceback.format_exc(),
+ })
+ return True
swh_retry = retry(retry=should_retry_adding,
diff --git a/swh/storage/storage.py b/swh/storage/storage.py
--- a/swh/storage/storage.py
+++ b/swh/storage/storage.py
@@ -27,10 +27,10 @@
get_journal_writer = None # type: ignore
# mypy limitation, see https://github.com/python/mypy/issues/1153
-from . import converters
+from . import converters, HashCollision
from .common import db_transaction_generator, db_transaction
from .db import Db
-from .exc import StorageDBError
+from .exc import StorageArgumentException, StorageDBError
from .algos import diff
from .metrics import timed, send_metric, process_metrics
from .utils import get_partition_bounds_bytes
@@ -141,14 +141,15 @@
"""Sanity checks on status / reason / length, that postgresql
doesn't enforce."""
if d['status'] not in ('visible', 'hidden'):
- raise ValueError('Invalid content status: {}'.format(d['status']))
+ raise StorageArgumentException(
+ 'Invalid content status: {}'.format(d['status']))
if d.get('reason') is not None:
- raise ValueError(
+ raise StorageArgumentException(
'Must not provide a reason if content is present.')
if d['length'] is None or d['length'] < 0:
- raise ValueError('Content length must be positive.')
+ raise StorageArgumentException('Content length must be positive.')
def _content_add_metadata(self, db, cur, content):
"""Add content to the postgresql database but not the object storage.
@@ -163,7 +164,6 @@
try:
db.content_add_from_temp(cur)
except psycopg2.IntegrityError as e:
- from . import HashCollision
if e.diag.sqlstate == '23505' and \
e.diag.table_name == 'content':
constraint_to_hash_name = {
@@ -277,7 +277,7 @@
def content_get(self, content):
# FIXME: Make this method support slicing the `data`.
if len(content) > BULK_BLOCK_CONTENT_LEN_MAX:
- raise ValueError(
+ raise StorageArgumentException(
"Send at maximum %s contents." % BULK_BLOCK_CONTENT_LEN_MAX)
for obj_id in content:
@@ -293,7 +293,7 @@
@db_transaction()
def content_get_range(self, start, end, limit=1000, db=None, cur=None):
if limit is None:
- raise ValueError('Development error: limit should not be None')
+ raise StorageArgumentException('limit should not be None')
contents = []
next_content = None
for counter, content_row in enumerate(
@@ -315,7 +315,7 @@
self, partition_id: int, nb_partitions: int, limit: int = 1000,
page_token: str = None, db=None, cur=None):
if limit is None:
- raise ValueError('Development error: limit should not be None')
+ raise StorageArgumentException('limit should not be None')
(start, end) = get_partition_bounds_bytes(
partition_id, nb_partitions, SHA1_SIZE)
if page_token:
@@ -348,7 +348,8 @@
keys = db.content_hash_keys
if key_hash not in keys:
- raise ValueError("key_hash should be one of %s" % keys)
+ raise StorageArgumentException(
+ "key_hash should be one of %s" % keys)
key_hash_idx = keys.index(key_hash)
@@ -374,8 +375,9 @@
@db_transaction()
def content_find(self, content, db=None, cur=None):
if not set(content).intersection(ALGORITHMS):
- raise ValueError('content keys must contain at least one of: '
- 'sha1, sha1_git, sha256, blake2s256')
+ raise StorageArgumentException(
+ 'content keys must contain at least one of: '
+ 'sha1, sha1_git, sha256, blake2s256')
contents = db.content_find(sha1=content.get('sha1'),
sha1_git=content.get('sha1_git'),
@@ -407,14 +409,16 @@
"""Sanity checks on status / reason / length, that postgresql
doesn't enforce."""
if d['status'] != 'absent':
- raise ValueError('Invalid content status: {}'.format(d['status']))
+ raise StorageArgumentException(
+ 'Invalid content status: {}'.format(d['status']))
if d.get('reason') is None:
- raise ValueError(
+ raise StorageArgumentException(
'Must provide a reason if content is absent.')
if d['length'] < -1:
- raise ValueError('Content length must be positive or -1.')
+ raise StorageArgumentException(
+ 'Content length must be positive or -1.')
def _skipped_content_add_metadata(self, db, cur, content):
content = \
@@ -488,7 +492,7 @@
entry = src_entry.copy()
entry['dir_id'] = dir_id
if entry['type'] not in ('file', 'dir', 'rev'):
- raise ValueError(
+ raise StorageArgumentException(
'Entry type must be file, dir, or rev; not %s'
% entry['type'])
dir_entries[entry['type']][dir_id].append(entry)
@@ -776,7 +780,7 @@
snapshot = self.snapshot_get(
origin_visit['snapshot'], db=db, cur=cur)
if not snapshot:
- raise ValueError(
+ raise StorageArgumentException(
'last origin visit references an unknown snapshot')
return snapshot
@@ -864,12 +868,13 @@
metadata=None, snapshot=None,
db=None, cur=None):
if not isinstance(origin, str):
- raise TypeError('origin must be a string, not %r' % (origin,))
+ raise StorageArgumentException(
+ 'origin must be a string, not %r' % (origin,))
origin_url = origin
visit = db.origin_visit_get(origin_url, visit_id, cur=cur)
if not visit:
- raise ValueError('Invalid visit_id for this origin.')
+ raise StorageArgumentException('Invalid visit_id for this origin.')
visit = dict(zip(db.origin_visit_get_cols, visit))
@@ -896,8 +901,9 @@
if isinstance(visit['date'], str):
visit['date'] = dateutil.parser.parse(visit['date'])
if not isinstance(visit['origin'], str):
- raise TypeError("visit['origin'] must be a string, not %r"
- % (visit['origin'],))
+ raise StorageArgumentException(
+ "visit['origin'] must be a string, not %r"
+ % (visit['origin'],))
if self.journal_writer:
for visit in visits:
@@ -1013,7 +1019,7 @@
*, db=None, cur=None) -> dict:
page_token = page_token or '0'
if not isinstance(page_token, str):
- raise TypeError('page_token must be a string.')
+ raise StorageArgumentException('page_token must be a string.')
origin_from = int(page_token)
result: Dict[str, Any] = {
'origins': [
diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py
--- a/swh/storage/tests/test_storage.py
+++ b/swh/storage/tests/test_storage.py
@@ -29,6 +29,7 @@
from swh.model.hypothesis_strategies import objects
from swh.storage import HashCollision
from swh.storage.converters import origin_url_to_sha1 as sha1
+from swh.storage.exc import StorageArgumentException
from swh.storage.interface import StorageInterface
from .storage_data import data
@@ -175,29 +176,26 @@
def test_content_add_validation(self, swh_storage):
cont = data.cont
- with pytest.raises(ValueError, match='status'):
+ with pytest.raises(StorageArgumentException, match='status'):
swh_storage.content_add([{**cont, 'status': 'absent'}])
- with pytest.raises(ValueError, match='status'):
+ with pytest.raises(StorageArgumentException, match='status'):
swh_storage.content_add([{**cont, 'status': 'foobar'}])
- with pytest.raises(ValueError, match="(?i)length"):
+ with pytest.raises(StorageArgumentException, match="(?i)length"):
swh_storage.content_add([{**cont, 'length': -2}])
- with pytest.raises(
- (ValueError, TypeError),
- match="reason"):
+ with pytest.raises(StorageArgumentException, match="reason"):
swh_storage.content_add([{**cont, 'reason': 'foobar'}])
def test_skipped_content_add_validation(self, swh_storage):
cont = data.cont.copy()
del cont['data']
- with pytest.raises(ValueError, match='status'):
+ with pytest.raises(StorageArgumentException, match='status'):
swh_storage.skipped_content_add([{**cont, 'status': 'visible'}])
- with pytest.raises((ValueError, psycopg2.IntegrityError),
- match='reason') as cm:
+ with pytest.raises(StorageArgumentException, match='reason') as cm:
swh_storage.skipped_content_add([{**cont, 'status': 'absent'}])
if type(cm.value) == psycopg2.IntegrityError:
@@ -480,10 +478,10 @@
def test_content_get_partition_limit_none(self, swh_storage):
"""content_get_partition call with wrong limit input should fail"""
- with pytest.raises(ValueError) as e:
+ with pytest.raises(StorageArgumentException) as e:
swh_storage.content_get_partition(1, 16, limit=None)
- assert e.value.args == ('Development error: limit should not be None',)
+ assert e.value.args == ('limit should not be None',)
def test_generate_content_get_partition_pagination(
self, swh_storage, swh_contents):
@@ -582,14 +580,13 @@
dir_ = copy.deepcopy(data.dir)
dir_['entries'][0]['type'] = 'foobar'
- with pytest.raises(ValueError, match='type.*foobar'):
+ with pytest.raises(StorageArgumentException, match='type.*foobar'):
swh_storage.directory_add([dir_])
dir_ = copy.deepcopy(data.dir)
del dir_['entries'][0]['target']
- with pytest.raises((TypeError, psycopg2.IntegrityError),
- match='target') as cm:
+ with pytest.raises(StorageArgumentException, match='target') as cm:
swh_storage.directory_add([dir_])
if type(cm.value) == psycopg2.IntegrityError:
@@ -789,8 +786,7 @@
rev = copy.deepcopy(data.revision)
rev['date']['offset'] = 2**16
- with pytest.raises((ValueError, psycopg2.DataError),
- match='offset') as cm:
+ with pytest.raises(StorageArgumentException, match='offset') as cm:
swh_storage.revision_add([rev])
if type(cm.value) == psycopg2.DataError:
@@ -800,8 +796,7 @@
rev = copy.deepcopy(data.revision)
rev['committer_date']['offset'] = 2**16
- with pytest.raises((ValueError, psycopg2.DataError),
- match='offset') as cm:
+ with pytest.raises(StorageArgumentException, match='offset') as cm:
swh_storage.revision_add([rev])
if type(cm.value) == psycopg2.DataError:
@@ -811,8 +806,7 @@
rev = copy.deepcopy(data.revision)
rev['type'] = 'foobar'
- with pytest.raises((ValueError, psycopg2.DataError),
- match='(?i)type') as cm:
+ with pytest.raises(StorageArgumentException, match='(?i)type') as cm:
swh_storage.revision_add([rev])
if type(cm.value) == psycopg2.DataError:
@@ -1011,8 +1005,7 @@
rel = copy.deepcopy(data.release)
rel['date']['offset'] = 2**16
- with pytest.raises((ValueError, psycopg2.DataError),
- match='offset') as cm:
+ with pytest.raises(StorageArgumentException, match='offset') as cm:
swh_storage.release_add([rel])
if type(cm.value) == psycopg2.DataError:
@@ -1022,8 +1015,7 @@
rel = copy.deepcopy(data.release)
rel['author'] = None
- with pytest.raises((ValueError, psycopg2.IntegrityError),
- match='date') as cm:
+ with pytest.raises(StorageArgumentException, match='date') as cm:
swh_storage.release_add([rel])
if type(cm.value) == psycopg2.IntegrityError:
@@ -1170,7 +1162,7 @@
assert add1 == add2
def test_origin_add_validation(self, swh_storage):
- with pytest.raises((TypeError, KeyError), match='url'):
+ with pytest.raises(StorageArgumentException, match='url'):
swh_storage.origin_add([{'type': 'git'}])
def test_origin_get_legacy(self, swh_storage):
@@ -1491,8 +1483,8 @@
def test_origin_visit_add_validation(self, swh_storage):
origin_url = swh_storage.origin_add_one(data.origin2)
- with pytest.raises((TypeError, psycopg2.ProgrammingError)) as cm:
- swh_storage.origin_visit_add(origin_url, date=[b'foo'])
+ with pytest.raises(StorageArgumentException) as cm:
+ swh_storage.origin_visit_add(origin_url, date=[b'foo'], type='git')
if type(cm.value) == psycopg2.ProgrammingError:
assert cm.value.pgcode \
@@ -1673,8 +1665,7 @@
type=data.type_visit2,
)
- with pytest.raises((ValueError, psycopg2.DataError),
- match='status') as cm:
+ with pytest.raises(StorageArgumentException, match='status') as cm:
swh_storage.origin_visit_update(
origin_url, visit['visit'], status='foobar')
@@ -2199,13 +2190,13 @@
snap = copy.deepcopy(data.snapshot)
snap['branches'][b'foo'] = {'target_type': 'revision'}
- with pytest.raises(KeyError, match='target'):
+ with pytest.raises(StorageArgumentException, match='target'):
swh_storage.snapshot_add([snap])
snap = copy.deepcopy(data.snapshot)
snap['branches'][b'foo'] = {'target': b'\x42'*20}
- with pytest.raises(KeyError, match='target_type'):
+ with pytest.raises(StorageArgumentException, match='target_type'):
swh_storage.snapshot_add([snap])
def test_snapshot_add_count_branches(self, swh_storage):
@@ -2431,7 +2422,7 @@
swh_storage.snapshot_add([data.snapshot])
- with pytest.raises(ValueError):
+ with pytest.raises(StorageArgumentException):
swh_storage.origin_visit_update(
origin_url, visit_id, snapshot=data.snapshot['id'])
@@ -2610,9 +2601,9 @@
swh_storage.origin_visit_update(
origin_url,
visit1_id, snapshot=data.complete_snapshot['id'])
- with pytest.raises(ValueError):
- swh_storage.snapshot_get_latest(
- origin_url)
+ with pytest.raises(Exception):
+ # XXX: should the exception be more specific than this?
+ swh_storage.snapshot_get_latest(origin_url)
# Status filter: both visits are status=ongoing, so no snapshot
# returned
@@ -2624,7 +2615,8 @@
swh_storage.origin_visit_update(
origin_url,
visit1_id, status='full')
- with pytest.raises(ValueError):
+ with pytest.raises(Exception):
+ # XXX: should the exception be more specific than this?
swh_storage.snapshot_get_latest(
origin_url,
allowed_statuses=['full']),
@@ -2639,7 +2631,8 @@
swh_storage.origin_visit_update(
origin_url,
visit2_id, snapshot=data.snapshot['id'])
- with pytest.raises(ValueError):
+ with pytest.raises(Exception):
+ # XXX: should the exception be more specific than this?
swh_storage.snapshot_get_latest(
origin_url)
@@ -2946,11 +2939,11 @@
def test_content_find_bad_input(self, swh_storage):
# 1. with bad input
- with pytest.raises(ValueError):
+ with pytest.raises(StorageArgumentException):
swh_storage.content_find({}) # empty is bad
# 2. with bad input
- with pytest.raises(ValueError):
+ with pytest.raises(StorageArgumentException):
swh_storage.content_find(
{'unknown-sha1': 'something'}) # not the right key
@@ -3407,10 +3400,10 @@
def test_generate_content_get_range_limit_none(self, swh_storage):
"""content_get_range call with wrong limit input should fail"""
- with pytest.raises(ValueError) as e:
+ with pytest.raises(StorageArgumentException) as e:
swh_storage.content_get_range(start=None, end=None, limit=None)
- assert e.value.args == ('Development error: limit should not be None',)
+ assert e.value.args == ('limit should not be None',)
def test_generate_content_get_range_no_limit(
self, swh_storage, swh_contents):
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Wed, Dec 18, 12:51 AM (2 d, 6 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3221320
Attached To
D2628: Unify exception raised by invalid input to API endpoints.
Event Timeline
Log In to Comment