diff --git a/swh/storage/__init__.py b/swh/storage/__init__.py --- a/swh/storage/__init__.py +++ b/swh/storage/__init__.py @@ -12,7 +12,7 @@ STORAGE_IMPLEMENTATION = { 'pipeline', 'local', 'remote', 'memory', 'filter', 'buffer', 'retry', - 'cassandra', + 'validate', 'cassandra', } @@ -60,6 +60,8 @@ from .buffer import BufferingProxyStorage as Storage elif cls == 'retry': from .retry import RetryingProxyStorage as Storage + elif cls == 'validate': + from .validate import ValidatingProxyStorage as Storage return Storage(**kwargs) diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py --- a/swh/storage/cassandra/converters.py +++ b/swh/storage/cassandra/converters.py @@ -3,9 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information - +import copy import json -from typing import Any, Dict import attr @@ -17,27 +16,20 @@ from ..converters import git_headers_to_db, db_to_git_headers -def revision_to_db(revision: Dict[str, Any]) -> Revision: - metadata = revision.get('metadata') +def revision_to_db(revision: Revision) -> Revision: + metadata = revision.metadata if metadata and 'extra_headers' in metadata: - extra_headers = git_headers_to_db( + metadata = copy.deepcopy(metadata) + metadata['extra_headers'] = git_headers_to_db( metadata['extra_headers']) - revision = { - **revision, - 'metadata': { - **metadata, - 'extra_headers': extra_headers - } - } - - rev = Revision.from_dict(revision) - rev = attr.evolve( - rev, - type=rev.type.value, - metadata=json.dumps(rev.metadata), + + revision = attr.evolve( + revision, + type=revision.type.value, + metadata=json.dumps(metadata), ) - return rev + return revision def revision_from_db(revision) -> Revision: @@ -55,13 +47,12 @@ return rev -def release_to_db(release: Dict[str, Any]) -> Release: - rel = Release.from_dict(release) - rel = attr.evolve( - rel, - target_type=rel.target_type.value, +def release_to_db(release: Release) -> Release: + release = attr.evolve( + release, + target_type=release.target_type.value, ) - return rel + return release def release_from_db(release: Release) -> Release: diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -23,7 +23,7 @@ from swh.model.model import ( Sha1Git, TimestampWithTimezone, Timestamp, Person, Content, - SkippedContent, OriginVisit, + SkippedContent, OriginVisit, Origin ) from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url @@ -454,9 +454,9 @@ @_prepared_statement('INSERT INTO origin (sha1, url, next_visit_id) ' 'VALUES (?, ?, 1) IF NOT EXISTS') - def origin_add_one(self, origin: Dict[str, Any], *, statement) -> None: + def origin_add_one(self, origin: Origin, *, statement) -> None: self._execute_with_retries( - statement, [hash_url(origin['url']), origin['url']]) + statement, [hash_url(origin.url), origin.url]) self._increment_counter('origin', 1) @_prepared_statement('SELECT * FROM origin WHERE sha1 = ?') diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -7,7 +7,7 @@ import json import random import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Iterable, Optional, Union import uuid import attr @@ -15,7 +15,7 @@ from swh.model.model import ( Revision, Release, Directory, DirectoryEntry, Content, SkippedContent, - OriginVisit, Snapshot + OriginVisit, Snapshot, Origin ) from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError @@ -61,22 +61,17 @@ return True - def _content_add(self, contents, with_data): - try: - contents = [Content.from_dict(c) for c in contents] - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) - + def _content_add(self, contents: List[Content], with_data: bool) -> Dict: # Filter-out content already in the database. contents = [c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict())] if self.journal_writer: for content in contents: - content = content.to_dict() - if 'data' in content: - del content['data'] - self.journal_writer.write_addition('content', content) + cont = content.to_dict() + if 'data' in cont: + del cont['data'] + self.journal_writer.write_addition('content', cont) count_contents = 0 count_content_added = 0 @@ -94,6 +89,8 @@ count_content_added += 1 if with_data: content_data = content.data + if content_data is None: + raise StorageArgumentException('Missing data') count_content_bytes_added += len(content_data) self.objstorage.add(content_data, content.sha1) @@ -128,18 +125,15 @@ return summary - def content_add(self, content): - content = [c.copy() for c in content] # semi-shallow copy - for item in content: - item['ctime'] = now() - return self._content_add(content, with_data=True) + def content_add(self, content: Iterable[Content]) -> Dict: + return self._content_add(list(content), with_data=True) def content_update(self, content, keys=[]): raise NotImplementedError( 'content_update is not supported by the Cassandra backend') - def content_add_metadata(self, content): - return self._content_add(content, with_data=False) + def content_add_metadata(self, content: Iterable[Content]) -> Dict: + return self._content_add(list(content), with_data=False) def content_get(self, content): if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: @@ -264,12 +258,7 @@ def content_get_random(self): return self._cql_runner.content_get_random().sha1_git - def _skipped_content_add(self, contents): - try: - contents = [SkippedContent.from_dict(c) for c in contents] - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) - + def _skipped_content_add(self, contents: Iterable[SkippedContent]) -> Dict: # Filter-out content already in the database. contents = [ c for c in contents @@ -277,10 +266,10 @@ if self.journal_writer: for content in contents: - content = content.to_dict() - if 'data' in content: - del content['data'] - self.journal_writer.write_addition('content', content) + cont = content.to_dict() + if 'data' in cont: + del cont['data'] + self.journal_writer.write_addition('content', cont) for content in contents: # Add to index tables @@ -296,10 +285,7 @@ 'skipped_content:add': len(contents) } - def skipped_content_add(self, content): - content = [c.copy() for c in content] # semi-shallow copy - for item in content: - item['ctime'] = now() + def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: return self._skipped_content_add(content) def skipped_content_missing(self, contents): @@ -307,27 +293,23 @@ if not self._cql_runner.skipped_content_get_from_pk(content): yield content - def directory_add(self, directories): + def directory_add(self, directories: Iterable[Directory]) -> Dict: directories = list(directories) # Filter out directories that are already inserted. - missing = self.directory_missing([dir_['id'] for dir_ in directories]) - directories = [dir_ for dir_ in directories if dir_['id'] in missing] + missing = self.directory_missing([dir_.id for dir_ in directories]) + directories = [dir_ for dir_ in directories if dir_.id in missing] if self.journal_writer: self.journal_writer.write_additions('directory', directories) for directory in directories: - try: - directory = Directory.from_dict(directory) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) - # Add directory entries to the 'directory_entry' table for entry in directory.entries: - entry = entry.to_dict() - entry['directory_id'] = directory.id - self._cql_runner.directory_entry_add_one(entry) + self._cql_runner.directory_entry_add_one({ + **entry.to_dict(), + 'directory_id': directory.id + }) # Add the directory *after* adding all the entries, so someone # calling snapshot_get_branch in the meantime won't end up @@ -416,22 +398,18 @@ def directory_get_random(self): return self._cql_runner.directory_get_random().id - def revision_add(self, revisions): + def revision_add(self, revisions: Iterable[Revision]) -> Dict: revisions = list(revisions) # Filter-out revisions already in the database - missing = self.revision_missing([rev['id'] for rev in revisions]) - revisions = [rev for rev in revisions if rev['id'] in missing] + missing = self.revision_missing([rev.id for rev in revisions]) + revisions = [rev for rev in revisions if rev.id in missing] if self.journal_writer: self.journal_writer.write_additions('revision', revisions) for revision in revisions: - try: - revision = revision_to_db(revision) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) - + revision = revision_to_db(revision) if revision: # Add parents first for (rank, parent) in enumerate(revision.parents): @@ -516,21 +494,16 @@ def revision_get_random(self): return self._cql_runner.revision_get_random().id - def release_add(self, releases): - releases = list(releases) - missing = self.release_missing([rel['id'] for rel in releases]) - releases = [rel for rel in releases if rel['id'] in missing] + def release_add(self, releases: Iterable[Release]) -> Dict: + missing = self.release_missing([rel.id for rel in releases]) + releases = [rel for rel in releases if rel.id in missing] if self.journal_writer: self.journal_writer.write_additions('release', releases) for release in releases: - try: - release = release_to_db(release) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) - if release: + release = release_to_db(release) self._cql_runner.release_add_one(release) return {'release:add': len(missing)} @@ -552,11 +525,7 @@ def release_get_random(self): return self._cql_runner.release_get_random().id - def snapshot_add(self, snapshots): - try: - snapshots = [Snapshot.from_dict(snap) for snap in snapshots] - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict: missing = self._cql_runner.snapshot_missing( [snp.id for snp in snapshots]) snapshots = [snp for snp in snapshots if snp.id in missing] @@ -573,13 +542,12 @@ else: target_type = branch.target_type.value target = branch.target - branch = { + self._cql_runner.snapshot_branch_add_one({ 'snapshot_id': snapshot.id, 'name': branch_name, 'target_type': target_type, 'target': target, - } - self._cql_runner.snapshot_branch_add_one(branch) + }) # Add the snapshot *after* adding all the branches, so someone # calling snapshot_get_branch in the meantime won't end up @@ -797,19 +765,15 @@ } for orig in origins[offset:offset+limit]] - def origin_add(self, origins): - origins = list(origins) - if any('id' in origin for origin in origins): - raise StorageArgumentException( - 'Origins must not already have an id.') + def origin_add(self, origins: Iterable[Origin]) -> List[Dict]: results = [] for origin in origins: self.origin_add_one(origin) - results.append(origin) + results.append(origin.to_dict()) return results - def origin_add_one(self, origin): - known_origin = self.origin_get_one(origin) + def origin_add_one(self, origin: Origin) -> str: + known_origin = self.origin_get_one(origin.to_dict()) if known_origin: origin_url = known_origin['url'] @@ -818,11 +782,12 @@ self.journal_writer.write_addition('origin', origin) self._cql_runner.origin_add_one(origin) - origin_url = origin['url'] + origin_url = origin.url return origin_url - def origin_visit_add(self, origin, date, type): + def origin_visit_add( + self, origin, date, type) -> Optional[Dict[str, Union[str, int]]]: origin_url = origin # TODO: rename the argument if isinstance(date, str): @@ -835,24 +800,22 @@ visit_id = self._cql_runner.origin_generate_unique_visit_id(origin_url) - visit = { - 'origin': origin_url, - 'date': date, - 'type': type, - 'status': 'ongoing', - 'snapshot': None, - 'metadata': None, - 'visit': visit_id - } - - if self.journal_writer: - self.journal_writer.write_addition('origin_visit', visit) - try: - visit = OriginVisit.from_dict(visit) + visit = OriginVisit.from_dict({ + 'origin': origin_url, + 'date': date, + 'type': type, + 'status': 'ongoing', + 'snapshot': None, + 'metadata': None, + 'visit': visit_id + }) except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) + if self.journal_writer: + self.journal_writer.write_addition('origin_visit', visit) + self._cql_runner.origin_visit_add_one(visit) return { @@ -860,8 +823,9 @@ 'visit': visit_id, } - def origin_visit_update(self, origin, visit_id, status=None, - metadata=None, snapshot=None): + def origin_visit_update( + self, origin: str, visit_id: int, status: Optional[str] = None, + metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None): origin_url = origin # TODO: rename the argument # Get the existing data of the visit @@ -873,7 +837,7 @@ except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) - updates = {} + updates: Dict[str, Any] = {} if status: updates['status'] = status if metadata: diff --git a/swh/storage/converters.py b/swh/storage/converters.py --- a/swh/storage/converters.py +++ b/swh/storage/converters.py @@ -154,10 +154,11 @@ } -def revision_to_db(revision): +def revision_to_db(rev): """Convert a swh-model revision to its database representation. """ + revision = rev.to_dict() author = author_to_db(revision['author']) date = date_to_db(revision['date']) committer = author_to_db(revision['committer']) @@ -257,10 +258,12 @@ return ret -def release_to_db(release): +def release_to_db(rel): """Convert a swh-model release to its database representation. """ + release = rel.to_dict() + author = author_to_db(release['author']) date = date_to_db(release['date']) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -14,7 +14,7 @@ from collections import defaultdict from datetime import timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Union import attr @@ -76,21 +76,11 @@ def check_config(self, *, check_write): return True - def _content_add(self, contents, with_data): - for content in contents: - if content.status is None: - content.status = 'visible' - if content.status == 'absent': - raise StorageArgumentException('content with status=absent') - if content.length is None: - raise StorageArgumentException('content with length=None') - + def _content_add( + self, contents: Iterable[Content], with_data: bool) -> Dict: if self.journal_writer: for content in contents: - try: - content = attr.evolve(content, data=None) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + content = attr.evolve(content, data=None) self.journal_writer.write_addition('content', content) summary = { @@ -119,24 +109,17 @@ summary['content:add'] += 1 if with_data: content_data = self._contents[key].data - try: - self._contents[key] = attr.evolve( - self._contents[key], - data=None) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + self._contents[key] = attr.evolve( + self._contents[key], + data=None) summary['content:add:bytes'] += len(content_data) self.objstorage.add(content_data, content.sha1) return summary - def content_add(self, content): + def content_add(self, content: Iterable[Content]) -> Dict: now = datetime.datetime.now(tz=datetime.timezone.utc) - try: - content = [attr.evolve(Content.from_dict(c), ctime=now) - for c in content] - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + content = [attr.evolve(c, ctime=now) for c in content] return self._content_add(content, with_data=True) def content_update(self, content, keys=[]): @@ -154,10 +137,7 @@ hash_ = old_cont.get_hash(algorithm) self._content_indexes[algorithm][hash_].remove(old_key) - try: - new_cont = attr.evolve(old_cont, **cont_update) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + new_cont = attr.evolve(old_cont, **cont_update) new_key = self._content_key(new_cont) self._contents[new_key] = new_cont @@ -166,8 +146,7 @@ hash_ = new_cont.get_hash(algorithm) self._content_indexes[algorithm][hash_].add(new_key) - def content_add_metadata(self, content): - content = [Content.from_dict(c) for c in content] + def content_add_metadata(self, content: Iterable[Content]) -> Dict: return self._content_add(content, with_data=False) def content_get(self, content): @@ -285,19 +264,10 @@ def content_get_random(self): return random.choice(list(self._content_indexes['sha1_git'])) - def _skipped_content_add(self, contents): - for content in contents: - if content.status is None: - content = attr.evolve(content, status='absent') - if content.length is None: - content = attr.evolve(content, length=-1) - if content.status != 'absent': - raise StorageArgumentException( - f'Content with status={content.status}') - + def _skipped_content_add(self, contents: Iterable[SkippedContent]) -> Dict: if self.journal_writer: - for content in contents: - self.journal_writer.write_addition('content', content) + for cont in contents: + self.journal_writer.write_addition('content', cont) summary = { 'skipped_content:add': 0 @@ -308,9 +278,9 @@ for content in skipped_content_missing: key = self._content_key(content, allow_missing=True) for algo in DEFAULT_ALGORITHMS: - if algo in content: - self._skipped_content_indexes[algo][content[algo]] \ - .add(key) + if content.get(algo): + self._skipped_content_indexes[algo][ + content.get(algo)].add(key) self._skipped_contents[key] = content summary['skipped_content:add'] += 1 @@ -329,28 +299,19 @@ if content[algo] is not None} break - def skipped_content_add(self, content): + def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: content = list(content) now = datetime.datetime.now(tz=datetime.timezone.utc) - try: - content = [attr.evolve(SkippedContent.from_dict(c), ctime=now) - for c in content] - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + content = [attr.evolve(c, ctime=now) for c in content] return self._skipped_content_add(content) - def directory_add(self, directories): + def directory_add(self, directories: Iterable[Directory]) -> Dict: directories = list(directories) if self.journal_writer: self.journal_writer.write_additions( 'directory', (dir_ for dir_ in directories - if dir_['id'] not in self._directories)) - - try: - directories = [Directory.from_dict(d) for d in directories] - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + if dir_.id not in self._directories)) count = 0 for directory in directories: @@ -435,18 +396,13 @@ return self._directory_entry_get_by_path( first_item['target'], paths[1:], prefix + paths[0] + b'/') - def revision_add(self, revisions): + def revision_add(self, revisions: Iterable[Revision]) -> Dict: revisions = list(revisions) if self.journal_writer: self.journal_writer.write_additions( 'revision', (rev for rev in revisions - if rev['id'] not in self._revisions)) - - try: - revisions = [Revision.from_dict(rev) for rev in revisions] - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + if rev.id not in self._revisions)) count = 0 for revision in revisions: @@ -496,18 +452,13 @@ def revision_get_random(self): return random.choice(list(self._revisions)) - def release_add(self, releases): + def release_add(self, releases: Iterable[Release]) -> Dict: releases = list(releases) if self.journal_writer: self.journal_writer.write_additions( 'release', (rel for rel in releases - if rel['id'] not in self._releases)) - - try: - releases = [Release.from_dict(rel) for rel in releases] - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + if rel.id not in self._releases)) count = 0 for rel in releases: @@ -534,12 +485,8 @@ def release_get_random(self): return random.choice(list(self._releases)) - def snapshot_add(self, snapshots): + def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict: count = 0 - try: - snapshots = [Snapshot.from_dict(d) for d in snapshots] - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) snapshots = (snap for snap in snapshots if snap.id not in self._snapshots) for snapshot in snapshots: @@ -749,17 +696,13 @@ with_visit=with_visit, limit=len(self._origins))) - def origin_add(self, origins): + def origin_add(self, origins: Iterable[Origin]) -> List[Dict]: origins = copy.deepcopy(list(origins)) for origin in origins: self.origin_add_one(origin) - return origins + return [origin.to_dict() for origin in origins] - def origin_add_one(self, origin): - try: - origin = Origin.from_dict(origin) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + def origin_add_one(self, origin: Origin) -> str: if origin.url not in self._origins: if self.journal_writer: self.journal_writer.write_addition('origin', origin) @@ -777,7 +720,8 @@ return origin.url - def origin_visit_add(self, origin, date, type): + def origin_visit_add( + self, origin, date, type) -> Optional[Dict[str, Union[str, int]]]: origin_url = origin if origin_url is None: raise StorageArgumentException('Unknown origin.') @@ -818,10 +762,9 @@ return visit_ret - def origin_visit_update(self, origin, visit_id, status=None, - metadata=None, snapshot=None): - if not isinstance(origin, str): - raise TypeError('origin must be a string, not %r' % (origin,)) + def origin_visit_update( + self, origin: str, visit_id: int, status: Optional[str] = None, + metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None): origin_url = self._get_origin_url(origin) if origin_url is None: raise StorageArgumentException('Unknown origin.') @@ -832,7 +775,7 @@ raise StorageArgumentException( 'Unknown visit_id for this origin') from None - updates = {} + updates: Dict[str, Any] = {} if status: updates['status'] = status if metadata: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -3,9 +3,13 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Union from swh.core.api import remote_api_endpoint +from swh.model.model import ( + SkippedContent, Content, Directory, Revision, Release, + Snapshot, Origin +) def deprecated(f): @@ -20,7 +24,7 @@ ... @remote_api_endpoint('content/add') - def content_add(self, content): + def content_add(self, content: Iterable[Content]) -> Dict: """Add content blobs to the storage Args: @@ -78,7 +82,7 @@ ... @remote_api_endpoint('content/add_metadata') - def content_add_metadata(self, content): + def content_add_metadata(self, content: Iterable[Content]) -> Dict: """Add content metadata to the storage (like `content_add`, but without inserting to the objstorage). @@ -279,7 +283,7 @@ ... @remote_api_endpoint('content/skipped/add') - def skipped_content_add(self, content): + def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: """Add contents to the skipped_content list, which contains (partial) information about content missing from the archive. @@ -330,7 +334,7 @@ ... @remote_api_endpoint('directory/add') - def directory_add(self, directories): + def directory_add(self, directories: Iterable[Directory]) -> Dict: """Add directories to the storage Args: @@ -412,7 +416,7 @@ ... @remote_api_endpoint('revision/add') - def revision_add(self, revisions): + def revision_add(self, revisions: Iterable[Revision]) -> Dict: """Add revisions to the storage Args: @@ -516,7 +520,7 @@ ... @remote_api_endpoint('release/add') - def release_add(self, releases): + def release_add(self, releases: Iterable[Release]) -> Dict: """Add releases to the storage Args: @@ -581,7 +585,7 @@ ... @remote_api_endpoint('snapshot/add') - def snapshot_add(self, snapshots): + def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict: """Add snapshots to the storage. Args: @@ -763,7 +767,8 @@ ... @remote_api_endpoint('origin/visit/add') - def origin_visit_add(self, origin, date, type): + def origin_visit_add( + self, origin, date, type) -> Optional[Dict[str, Union[str, int]]]: """Add an origin_visit for the origin at ts with status 'ongoing'. Args: @@ -781,8 +786,9 @@ ... @remote_api_endpoint('origin/visit/update') - def origin_visit_update(self, origin, visit_id, status=None, - metadata=None, snapshot=None): + def origin_visit_update( + self, origin: str, visit_id: int, status: Optional[str] = None, + metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None): """Update an origin_visit's status. Args: @@ -1047,7 +1053,7 @@ ... @remote_api_endpoint('origin/add_multi') - def origin_add(self, origins): + def origin_add(self, origins: Iterable[Origin]) -> List[Dict]: """Add origins to the storage Args: @@ -1064,7 +1070,7 @@ ... @remote_api_endpoint('origin/add') - def origin_add_one(self, origin): + def origin_add_one(self, origin: Origin) -> str: """Add origin to the storage Args: diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -12,15 +12,19 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Union +import attr import dateutil.parser import psycopg2 import psycopg2.pool import psycopg2.errors -from swh.model.model import SHA1_SIZE -from swh.model.hashutil import ALGORITHMS, hash_to_bytes, hash_to_hex +from swh.model.model import ( + SkippedContent, Content, Directory, Revision, Release, + Snapshot, Origin, SHA1_SIZE +) +from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError try: @@ -151,78 +155,48 @@ return hash return tuple([hash[k] for k in keys]) - @staticmethod - def _content_normalize(d): - d = d.copy() - - if 'status' not in d: - d['status'] = 'visible' - - return d - - @staticmethod - def _content_validate(d): - """Sanity checks on status / reason / length, that postgresql - doesn't enforce.""" - if d['status'] not in ('visible', 'hidden'): - raise StorageArgumentException( - 'Invalid content status: {}'.format(d['status'])) - - if d.get('reason') is not None: - raise StorageArgumentException( - 'Must not provide a reason if content is present.') - - if d['length'] is None or d['length'] < 0: - raise StorageArgumentException('Content length must be positive.') - def _content_add_metadata(self, db, cur, content): """Add content to the postgresql database but not the object storage. """ # create temporary table for metadata injection db.mktemp('content', cur) - with convert_validation_exceptions(): - db.copy_to(content, 'tmp_content', - db.content_add_keys, cur) + db.copy_to((c.to_dict() for c in content), 'tmp_content', + db.content_add_keys, cur) - # move metadata in place - try: - db.content_add_from_temp(cur) - except psycopg2.IntegrityError as e: - if e.diag.sqlstate == '23505' and \ - e.diag.table_name == 'content': - constraint_to_hash_name = { - 'content_pkey': 'sha1', - 'content_sha1_git_idx': 'sha1_git', - 'content_sha256_idx': 'sha256', - } - colliding_hash_name = constraint_to_hash_name \ - .get(e.diag.constraint_name) - raise HashCollision(colliding_hash_name) from None - else: - raise + # move metadata in place + try: + db.content_add_from_temp(cur) + except psycopg2.IntegrityError as e: + if e.diag.sqlstate == '23505' and \ + e.diag.table_name == 'content': + constraint_to_hash_name = { + 'content_pkey': 'sha1', + 'content_sha1_git_idx': 'sha1_git', + 'content_sha256_idx': 'sha256', + } + colliding_hash_name = constraint_to_hash_name \ + .get(e.diag.constraint_name) + raise HashCollision(colliding_hash_name) from None + else: + raise @timed @process_metrics @db_transaction() - def content_add(self, content, db=None, cur=None): - content = [dict(c.items()) for c in content] # semi-shallow copy + def content_add( + self, content: Iterable[Content], db=None, cur=None) -> Dict: now = datetime.datetime.now(tz=datetime.timezone.utc) - for item in content: - item['ctime'] = now - - content = [self._content_normalize(c) for c in content] - for c in content: - self._content_validate(c) + content = [attr.evolve(c, ctime=now) for c in content] - missing = list(self.content_missing(content, key_hash='sha1_git')) - content = [c for c in content if c['sha1_git'] in missing] + missing = list(self.content_missing( + map(Content.to_dict, content), key_hash='sha1_git')) + content = [c for c in content if c.sha1_git in missing] if self.journal_writer: for item in content: - if 'data' in item: - item = item.copy() - del item['data'] + if item.data: + item = attr.evolve(item, data=None) self.journal_writer.write_addition('content', item) def add_to_objstorage(): @@ -236,9 +210,9 @@ content_bytes_added = 0 data = {} for cont in content: - if cont['sha1'] not in data: - data[cont['sha1']] = cont['data'] - content_bytes_added += max(0, cont['length']) + if cont.sha1 not in data: + data[cont.sha1] = cont.data + content_bytes_added += max(0, cont.length) # FIXME: Since we do the filtering anyway now, we might as # well make the objstorage's add_batch call return what we @@ -280,17 +254,16 @@ @timed @process_metrics @db_transaction() - def content_add_metadata(self, content, db=None, cur=None): - content = [self._content_normalize(c) for c in content] - for c in content: - self._content_validate(c) - - missing = self.content_missing(content, key_hash='sha1_git') - content = [c for c in content if c['sha1_git'] in missing] + def content_add_metadata(self, content: Iterable[Content], + db=None, cur=None) -> Dict: + content = list(content) + missing = self.content_missing( + (c.to_dict() for c in content), key_hash='sha1_git') + content = [c for c in content if c.sha1_git in missing] if self.journal_writer: for item in itertools.chain(content): - assert 'data' not in content + assert item.data is None self.journal_writer.write_addition('content', item) self._content_add_metadata(db, cur, content) @@ -400,7 +373,7 @@ @timed @db_transaction() def content_find(self, content, db=None, cur=None): - if not set(content).intersection(ALGORITHMS): + if not set(content).intersection(DEFAULT_ALGORITHMS): raise StorageArgumentException( 'content keys must contain at least one of: ' 'sha1, sha1_git, sha256, blake2s256') @@ -446,40 +419,33 @@ raise StorageArgumentException( 'Content length must be positive or -1.') - def _skipped_content_add_metadata(self, db, cur, content): - content = \ - [cont.copy() for cont in content] + def _skipped_content_add_metadata( + self, db, cur, content: Iterable[SkippedContent]): origin_ids = db.origin_id_get_by_url( - [cont.get('origin') for cont in content], + [cont.origin for cont in content], cur=cur) - for (cont, origin_id) in zip(content, origin_ids): - if 'origin' in cont: - cont['origin'] = origin_id + content = [attr.evolve(c, origin=origin_id) + for (c, origin_id) in zip(content, origin_ids)] db.mktemp('skipped_content', cur) - with convert_validation_exceptions(): - db.copy_to(content, 'tmp_skipped_content', - db.skipped_content_keys, cur) + db.copy_to([c.to_dict() for c in content], 'tmp_skipped_content', + db.skipped_content_keys, cur) - # move metadata in place - db.skipped_content_add_from_temp(cur) + # move metadata in place + db.skipped_content_add_from_temp(cur) @timed @process_metrics @db_transaction() - def skipped_content_add(self, content, db=None, cur=None): - content = [dict(c.items()) for c in content] # semi-shallow copy + def skipped_content_add(self, content: Iterable[SkippedContent], + db=None, cur=None) -> Dict: now = datetime.datetime.now(tz=datetime.timezone.utc) - for item in content: - item['ctime'] = now + content = [attr.evolve(c, ctime=now) for c in content] - content = [self._skipped_content_normalize(c) for c in content] - for c in content: - self._skipped_content_validate(c) - - missing_contents = self.skipped_content_missing(content) + missing_contents = self.skipped_content_missing( + c.to_dict() for c in content) content = [c for c in content - if any(all(c.get(algo) == missing_content.get(algo) - for algo in ALGORITHMS) + if any(all(c.get_hash(algo) == missing_content.get(algo) + for algo in DEFAULT_ALGORITHMS) for missing_content in missing_contents)] if self.journal_writer: @@ -495,33 +461,31 @@ @timed @db_transaction_generator() def skipped_content_missing(self, contents, db=None, cur=None): + contents = list(contents) for content in db.skipped_content_missing(contents, cur): yield dict(zip(db.content_hash_keys, content)) @timed @process_metrics @db_transaction() - def directory_add(self, directories, db=None, cur=None): + def directory_add(self, directories: Iterable[Directory], + db=None, cur=None) -> Dict: directories = list(directories) summary = {'directory:add': 0} dirs = set() - dir_entries = { + dir_entries: Dict[str, defaultdict] = { 'file': defaultdict(list), 'dir': defaultdict(list), 'rev': defaultdict(list), } for cur_dir in directories: - dir_id = cur_dir['id'] + dir_id = cur_dir.id dirs.add(dir_id) - for src_entry in cur_dir['entries']: - entry = src_entry.copy() + for src_entry in cur_dir.entries: + entry = src_entry.to_dict() entry['dir_id'] = dir_id - if entry['type'] not in ('file', 'dir', 'rev'): - raise StorageArgumentException( - 'Entry type must be file, dir, or rev; not %s' - % entry['type']) dir_entries[entry['type']][dir_id].append(entry) dirs_missing = set(self.directory_missing(dirs, db=db, cur=cur)) @@ -532,33 +496,32 @@ self.journal_writer.write_additions( 'directory', (dir_ for dir_ in directories - if dir_['id'] in dirs_missing)) + if dir_.id in dirs_missing)) # Copy directory ids dirs_missing_dict = ({'id': dir} for dir in dirs_missing) db.mktemp('directory', cur) - with convert_validation_exceptions(): - db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) - - # Copy entries - for entry_type, entry_list in dir_entries.items(): - entries = itertools.chain.from_iterable( - entries_for_dir - for dir_id, entries_for_dir - in entry_list.items() - if dir_id in dirs_missing) - - db.mktemp_dir_entry(entry_type) - - db.copy_to( - entries, - 'tmp_directory_entry_%s' % entry_type, - ['target', 'name', 'perms', 'dir_id'], - cur, - ) + db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) + + # Copy entries + for entry_type, entry_list in dir_entries.items(): + entries = itertools.chain.from_iterable( + entries_for_dir + for dir_id, entries_for_dir + in entry_list.items() + if dir_id in dirs_missing) - # Do the final copy - db.directory_add_from_temp(cur) + db.mktemp_dir_entry(entry_type) + + db.copy_to( + entries, + 'tmp_directory_entry_%s' % entry_type, + ['target', 'name', 'perms', 'dir_id'], + cur, + ) + + # Do the final copy + db.directory_add_from_temp(cur) summary['directory:add'] = len(dirs_missing) return summary @@ -595,12 +558,13 @@ @timed @process_metrics @db_transaction() - def revision_add(self, revisions, db=None, cur=None): + def revision_add(self, revisions: Iterable[Revision], + db=None, cur=None) -> Dict: revisions = list(revisions) summary = {'revision:add': 0} revisions_missing = set(self.revision_missing( - set(revision['id'] for revision in revisions), + set(revision.id for revision in revisions), db=db, cur=cur)) if not revisions_missing: @@ -610,14 +574,15 @@ revisions_filtered = [ revision for revision in revisions - if revision['id'] in revisions_missing] + if revision.id in revisions_missing] if self.journal_writer: self.journal_writer.write_additions('revision', revisions_filtered) - revisions_filtered = map(converters.revision_to_db, revisions_filtered) + revisions_filtered = \ + list(map(converters.revision_to_db, revisions_filtered)) - parents_filtered = [] + parents_filtered: List[bytes] = [] with convert_validation_exceptions(): db.copy_to( @@ -679,11 +644,12 @@ @timed @process_metrics @db_transaction() - def release_add(self, releases, db=None, cur=None): + def release_add( + self, releases: Iterable[Release], db=None, cur=None) -> Dict: releases = list(releases) summary = {'release:add': 0} - release_ids = set(release['id'] for release in releases) + release_ids = set(release.id for release in releases) releases_missing = set(self.release_missing(release_ids, db=db, cur=cur)) @@ -692,17 +658,16 @@ db.mktemp_release(cur) - releases_missing = list(releases_missing) - releases_filtered = [ release for release in releases - if release['id'] in releases_missing + if release.id in releases_missing ] if self.journal_writer: self.journal_writer.write_additions('release', releases_filtered) - releases_filtered = map(converters.release_to_db, releases_filtered) + releases_filtered = \ + list(map(converters.release_to_db, releases_filtered)) with convert_validation_exceptions(): db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, @@ -738,12 +703,13 @@ @timed @process_metrics @db_transaction() - def snapshot_add(self, snapshots, db=None, cur=None): + def snapshot_add( + self, snapshots: Iterable[Snapshot], db=None, cur=None) -> Dict: created_temp_table = False count = 0 for snapshot in snapshots: - if not db.snapshot_exists(snapshot['id'], cur): + if not db.snapshot_exists(snapshot.id, cur): if not created_temp_table: db.mktemp_snapshot_branch(cur) created_temp_table = True @@ -753,11 +719,11 @@ ( { 'name': name, - 'target': info['target'] if info else None, - 'target_type': (info['target_type'] + 'target': info.target if info else None, + 'target_type': (info.target_type.value if info else None), } - for name, info in snapshot['branches'].items() + for name, info in snapshot.branches.items() ), 'tmp_snapshot_branch', ['name', 'target', 'target_type'], @@ -769,7 +735,7 @@ if self.journal_writer: self.journal_writer.write_addition('snapshot', snapshot) - db.snapshot_add(snapshot['id'], cur) + db.snapshot_add(snapshot.id, cur) count += 1 return {'snapshot:add': count} @@ -871,8 +837,9 @@ @timed @db_transaction() - def origin_visit_add(self, origin, date, type, - db=None, cur=None): + def origin_visit_add( + self, origin, date, type, db=None, cur=None + ) -> Optional[Dict[str, Union[str, int]]]: origin_url = origin if isinstance(date, str): @@ -898,8 +865,10 @@ @timed @db_transaction() - def origin_visit_update(self, origin, visit_id, status=None, - metadata=None, snapshot=None, + def origin_visit_update(self, origin: str, visit_id: int, + status: Optional[str] = None, + metadata: Optional[Dict] = None, + snapshot: Optional[bytes] = None, db=None, cur=None): if not isinstance(origin, str): raise StorageArgumentException( @@ -912,7 +881,7 @@ visit = dict(zip(db.origin_visit_get_cols, visit)) - updates = {} + updates: Dict[str, Any] = {} if status and status != visit['status']: updates['status'] = status if metadata and metadata != visit['metadata']: @@ -1088,20 +1057,19 @@ @timed @db_transaction() - def origin_add(self, origins, db=None, cur=None): - origins = copy.deepcopy(list(origins)) + def origin_add( + self, origins: Iterable[Origin], db=None, cur=None) -> List[Dict]: + origins = list(origins) for origin in origins: self.origin_add_one(origin, db=db, cur=cur) send_metric('origin:add', count=len(origins), method_name='origin_add') - return origins + return [o.to_dict() for o in origins] @timed @db_transaction() - def origin_add_one(self, origin, db=None, cur=None): - if 'url' not in origin: - raise StorageArgumentException('Missing origin url') - origin_row = list(db.origin_get_by_url([origin['url']], cur))[0] + def origin_add_one(self, origin: Origin, db=None, cur=None) -> str: + origin_row = list(db.origin_get_by_url([origin.url], cur))[0] origin_url = dict(zip(db.origin_cols, origin_row))['url'] if origin_url: return origin_url @@ -1109,7 +1077,7 @@ if self.journal_writer: self.journal_writer.write_addition('origin', origin) - origins = db.origin_add(origin['url'], cur) + origins = db.origin_add(origin.url, cur) send_metric('origin:add', count=len(origins), method_name='origin_add') return origins diff --git a/swh/storage/tests/algos/test_origin.py b/swh/storage/tests/algos/test_origin.py --- a/swh/storage/tests/algos/test_origin.py +++ b/swh/storage/tests/algos/test_origin.py @@ -5,7 +5,7 @@ from unittest.mock import patch -from swh.storage.in_memory import InMemoryStorage +from swh.storage import get_storage from swh.storage.algos.origin import iter_origins @@ -13,8 +13,16 @@ assert list(left) == list(right), msg +storage_config = { + 'cls': 'validate', + 'storage': { + 'cls': 'memory', + } +} + + def test_iter_origins(): - storage = InMemoryStorage() + storage = get_storage(**storage_config) origins = storage.origin_add([ {'url': 'bar'}, {'url': 'qux'}, @@ -62,7 +70,7 @@ @patch('swh.storage.in_memory.InMemoryStorage.origin_get_range') def test_iter_origins_batch_size(mock_origin_get_range): - storage = InMemoryStorage() + storage = get_storage(**storage_config) mock_origin_get_range.return_value = [] list(iter_origins(storage)) diff --git a/swh/storage/tests/algos/test_snapshot.py b/swh/storage/tests/algos/test_snapshot.py --- a/swh/storage/tests/algos/test_snapshot.py +++ b/swh/storage/tests/algos/test_snapshot.py @@ -10,7 +10,7 @@ snapshots, branch_names, branch_targets from swh.storage.algos.snapshot import snapshot_get_all_branches -from swh.storage.tests.test_in_memory import swh_storage # noqa +from swh.storage.tests.test_in_memory import swh_storage_backend_config # noqa @given(snapshot=snapshots(min_size=0, max_size=10, only_objects=False)) diff --git a/swh/storage/tests/conftest.py b/swh/storage/tests/conftest.py --- a/swh/storage/tests/conftest.py +++ b/swh/storage/tests/conftest.py @@ -35,8 +35,8 @@ @pytest.fixture -def swh_storage(postgresql_proc, swh_storage_postgresql): - storage_config = { +def swh_storage_backend_config(postgresql_proc, swh_storage_postgresql): + yield { 'cls': 'local', 'db': 'postgresql://{user}@{host}:{port}/{dbname}'.format( host=postgresql_proc.host, @@ -51,6 +51,15 @@ 'cls': 'memory', }, } + + +@pytest.fixture +def swh_storage(swh_storage_backend_config): + storage_config = { + 'cls': 'validate', + 'storage': swh_storage_backend_config + } + storage = swh.storage.get_storage(**storage_config) return storage diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -21,10 +21,13 @@ @pytest.fixture def app_server(): storage_config = { - 'cls': 'memory', - 'journal_writer': { + 'cls': 'validate', + 'storage': { 'cls': 'memory', - }, + 'journal_writer': { + 'cls': 'memory', + }, + } } server.storage = swh.storage.get_storage(**storage_config) yield server @@ -61,5 +64,5 @@ class TestStorage(_TestStorage): def test_content_update(self, swh_storage, app_server): swh_storage.journal_writer = None # TODO, journal_writer not supported - with patch.object(server.storage, 'journal_writer', None): + with patch.object(server.storage.storage, 'journal_writer', None): super().test_content_update(swh_storage) diff --git a/swh/storage/tests/test_buffer.py b/swh/storage/tests/test_buffer.py --- a/swh/storage/tests/test_buffer.py +++ b/swh/storage/tests/test_buffer.py @@ -6,10 +6,18 @@ from swh.storage.buffer import BufferingProxyStorage +storage_config = { + 'cls': 'validate', + 'storage': { + 'cls': 'memory' + } +} + + def test_buffering_proxy_storage_content_threshold_not_hit(sample_data): contents = sample_data['content'] storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'content': 10, } @@ -37,7 +45,7 @@ def test_buffering_proxy_storage_content_threshold_nb_hit(sample_data): contents = sample_data['content'] storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'content': 1, } @@ -60,7 +68,7 @@ contents = sample_data['content'] content_bytes_min_batch_size = 2 storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'content': 10, 'content_bytes': content_bytes_min_batch_size, @@ -86,7 +94,7 @@ sample_data): contents = sample_data['skipped_content'] storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'skipped_content': 10, } @@ -113,7 +121,7 @@ def test_buffering_proxy_storage_skipped_content_threshold_nb_hit(sample_data): contents = sample_data['skipped_content'] storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'skipped_content': 1, } @@ -134,7 +142,7 @@ def test_buffering_proxy_storage_directory_threshold_not_hit(sample_data): directories = sample_data['directory'] storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'directory': 10, } @@ -160,7 +168,7 @@ def test_buffering_proxy_storage_directory_threshold_hit(sample_data): directories = sample_data['directory'] storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'directory': 1, } @@ -181,7 +189,7 @@ def test_buffering_proxy_storage_revision_threshold_not_hit(sample_data): revisions = sample_data['revision'] storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'revision': 10, } @@ -207,7 +215,7 @@ def test_buffering_proxy_storage_revision_threshold_hit(sample_data): revisions = sample_data['revision'] storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'revision': 1, } @@ -231,7 +239,7 @@ assert len(releases) < threshold storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'release': threshold, # configuration set } @@ -258,7 +266,7 @@ assert len(releases) > threshold storage = BufferingProxyStorage( - storage={'cls': 'memory'}, + storage=storage_config, min_batch_size={ 'release': threshold, # configuration set } diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -149,12 +149,13 @@ # below @pytest.fixture -def swh_storage(cassandra_cluster, keyspace): +def swh_storage_backend_config(cassandra_cluster, keyspace): (hosts, port) = cassandra_cluster - storage = get_storage( - 'cassandra', - hosts=hosts, port=port, + storage_config = dict( + cls='cassandra', + hosts=hosts, + port=port, keyspace=keyspace, journal_writer={ 'cls': 'memory', @@ -165,7 +166,9 @@ }, ) - yield storage + yield storage_config + + storage = get_storage(**storage_config) for table in TABLES: storage._cql_runner._session.execute('TRUNCATE TABLE "%s"' % table) diff --git a/swh/storage/tests/test_filter.py b/swh/storage/tests/test_filter.py --- a/swh/storage/tests/test_filter.py +++ b/swh/storage/tests/test_filter.py @@ -7,9 +7,17 @@ from swh.storage.filter import FilteringProxyStorage +storage_config = { + 'cls': 'validate', + 'storage': { + 'cls': 'memory' + } +} + + def test_filtering_proxy_storage_content(sample_data): sample_content = sample_data['content'][0] - storage = FilteringProxyStorage(storage={'cls': 'memory'}) + storage = FilteringProxyStorage(storage=storage_config) content = next(storage.content_get([sample_content['sha1']])) assert not content @@ -32,7 +40,7 @@ def test_filtering_proxy_storage_skipped_content(sample_data): sample_content = sample_data['skipped_content'][0] - storage = FilteringProxyStorage(storage={'cls': 'memory'}) + storage = FilteringProxyStorage(storage=storage_config) content = next(storage.skipped_content_missing([sample_content])) assert content['sha1'] == sample_content['sha1'] @@ -53,7 +61,7 @@ def test_filtering_proxy_storage_revision(sample_data): sample_revision = sample_data['revision'][0] - storage = FilteringProxyStorage(storage={'cls': 'memory'}) + storage = FilteringProxyStorage(storage=storage_config) revision = next(storage.revision_get([sample_revision['id']])) assert not revision @@ -74,7 +82,7 @@ def test_filtering_proxy_storage_directory(sample_data): sample_directory = sample_data['directory'][0] - storage = FilteringProxyStorage(storage={'cls': 'memory'}) + storage = FilteringProxyStorage(storage=storage_config) directory = next(storage.directory_missing([sample_directory['id']])) assert directory diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -1,11 +1,10 @@ -# Copyright (C) 2018 The Software Heritage developers +# Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest -from swh.storage import get_storage from swh.storage.tests.test_storage import ( # noqa TestStorage, TestStorageGeneratedData) @@ -15,12 +14,10 @@ # below @pytest.fixture -def swh_storage(): - storage_config = { +def swh_storage_backend_config(): + yield { 'cls': 'memory', 'journal_writer': { 'cls': 'memory', }, } - storage = get_storage(**storage_config) - return storage diff --git a/swh/storage/tests/test_retry.py b/swh/storage/tests/test_retry.py --- a/swh/storage/tests/test_retry.py +++ b/swh/storage/tests/test_retry.py @@ -17,7 +17,12 @@ @pytest.fixture def swh_storage(): - return RetryingProxyStorage(storage={'cls': 'memory'}) + return RetryingProxyStorage(storage={ + 'cls': 'validate', + 'storage': { + 'cls': 'memory' + } + }) def test_retrying_proxy_storage_content_add(swh_storage, sample_data): @@ -223,7 +228,7 @@ assert not origin with pytest.raises(StorageArgumentException, match='Refuse to add'): - swh_storage.origin_add_one([sample_origin]) + swh_storage.origin_add_one(sample_origin) assert mock_memory.call_count == 1 diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -26,8 +26,9 @@ from swh.model import from_disk, identifiers from swh.model.hashutil import hash_to_bytes +from swh.model.model import Release, Revision from swh.model.hypothesis_strategies import objects -from swh.storage import HashCollision +from swh.storage import HashCollision, get_storage from swh.storage.converters import origin_url_to_sha1 as sha1 from swh.storage.exc import StorageArgumentException from swh.storage.interface import StorageInterface @@ -98,12 +99,13 @@ """ maxDiff = None # type: ClassVar[Optional[int]] - def test_types(self, swh_storage): + def test_types(self, swh_storage_backend_config): """Checks all methods of StorageInterface are implemented by this backend, and that they have the same signature.""" # Create an instance of the protocol (which cannot be instantiated # directly, so this creates a subclass, then instantiates it) interface = type('_', (StorageInterface,), {})() + storage = get_storage(**swh_storage_backend_config) assert 'content_add' in dir(interface) @@ -114,7 +116,7 @@ continue interface_meth = getattr(interface, meth_name) try: - concrete_meth = getattr(swh_storage, meth_name) + concrete_meth = getattr(storage, meth_name) except AttributeError: if not getattr(interface_meth, 'deprecated_endpoint', False): # The backend is missing a (non-deprecated) endpoint @@ -266,7 +268,8 @@ assert cm.value.args[0] in ['sha1', 'sha1_git', 'blake2s256'] def test_content_update(self, swh_storage): - swh_storage.journal_writer = None # TODO, not supported + if hasattr(swh_storage, 'storage'): + swh_storage.storage.journal_writer = None # TODO, not supported cont = copy.deepcopy(data.cont) @@ -535,6 +538,7 @@ assert actual_contents == {missing_cont['sha1']: []} def test_content_get_random(self, swh_storage): + print(data.cont, data.cont2, data.cont3) swh_storage.content_add([data.cont, data.cont2, data.cont3]) assert swh_storage.content_get_random() in { @@ -762,8 +766,10 @@ end_missing = swh_storage.revision_missing([data.revision['id']]) assert list(end_missing) == [] + normalized_revision = Revision.from_dict(data.revision).to_dict() + assert list(swh_storage.journal_writer.objects) \ - == [('revision', data.revision)] + == [('revision', normalized_revision)] # already there so nothing added actual_result = swh_storage.revision_add([data.revision]) @@ -817,16 +823,19 @@ actual_result = swh_storage.revision_add([data.revision]) assert actual_result == {'revision:add': 1} + normalized_revision = Revision.from_dict(data.revision).to_dict() + normalized_revision2 = Revision.from_dict(data.revision2).to_dict() + assert list(swh_storage.journal_writer.objects) \ - == [('revision', data.revision)] + == [('revision', normalized_revision)] actual_result = swh_storage.revision_add( [data.revision, data.revision2]) assert actual_result == {'revision:add': 1} assert list(swh_storage.journal_writer.objects) \ - == [('revision', data.revision), - ('revision', data.revision2)] + == [('revision', normalized_revision), + ('revision', normalized_revision2)] def test_revision_add_name_clash(self, swh_storage): revision1 = data.revision @@ -866,9 +875,12 @@ assert actual_results[0] == normalize_entity(data.revision4) assert actual_results[1] == normalize_entity(data.revision3) + normalized_revision3 = Revision.from_dict(data.revision3).to_dict() + normalized_revision4 = Revision.from_dict(data.revision4).to_dict() + assert list(swh_storage.journal_writer.objects) == [ - ('revision', data.revision3), - ('revision', data.revision4)] + ('revision', normalized_revision3), + ('revision', normalized_revision4)] def test_revision_log_with_limit(self, swh_storage): # given @@ -949,6 +961,9 @@ {data.revision['id'], data.revision2['id'], data.revision3['id']} def test_release_add(self, swh_storage): + normalized_release = Release.from_dict(data.release).to_dict() + normalized_release2 = Release.from_dict(data.release2).to_dict() + init_missing = swh_storage.release_missing([data.release['id'], data.release2['id']]) assert [data.release['id'], data.release2['id']] == list(init_missing) @@ -961,8 +976,8 @@ assert list(end_missing) == [] assert list(swh_storage.journal_writer.objects) == [ - ('release', data.release), - ('release', data.release2)] + ('release', normalized_release), + ('release', normalized_release2)] # already present so nothing added actual_result = swh_storage.release_add([data.release, data.release2]) @@ -976,12 +991,15 @@ yield data.release yield data.release2 + normalized_release = Release.from_dict(data.release).to_dict() + normalized_release2 = Release.from_dict(data.release2).to_dict() + actual_result = swh_storage.release_add(_rel_gen()) assert actual_result == {'release:add': 2} assert list(swh_storage.journal_writer.objects) == [ - ('release', data.release), - ('release', data.release2)] + ('release', normalized_release), + ('release', normalized_release2)] swh_storage.refresh_stat_counters() assert swh_storage.stat_counters()['release'] == 2 @@ -1025,15 +1043,18 @@ actual_result = swh_storage.release_add([data.release]) assert actual_result == {'release:add': 1} + normalized_release = Release.from_dict(data.release).to_dict() + normalized_release2 = Release.from_dict(data.release2).to_dict() + assert list(swh_storage.journal_writer.objects) \ - == [('release', data.release)] + == [('release', normalized_release)] actual_result = swh_storage.release_add([data.release, data.release2]) assert actual_result == {'release:add': 1} assert list(swh_storage.journal_writer.objects) \ - == [('release', data.release), - ('release', data.release2)] + == [('release', normalized_release), + ('release', normalized_release2)] def test_release_add_name_clash(self, swh_storage): release1 = data.release.copy() @@ -3684,7 +3705,7 @@ """ def test_content_update_with_new_cols(self, swh_storage): - swh_storage.journal_writer = None # TODO, not supported + swh_storage.storage.journal_writer = None # TODO, not supported with db_transaction(swh_storage) as (_, cur): cur.execute("""alter table content diff --git a/swh/storage/validate.py b/swh/storage/validate.py new file mode 100644 --- /dev/null +++ b/swh/storage/validate.py @@ -0,0 +1,105 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import contextlib +import datetime +from typing import Dict, Iterable, List, Union + +from swh.model.model import ( + SkippedContent, Content, Directory, Revision, Release, Snapshot, + OriginVisit, Origin +) + +from . import get_storage +from .exc import StorageArgumentException + + +VALIDATION_EXCEPTIONS = ( + KeyError, + TypeError, + ValueError, +) + + +@contextlib.contextmanager +def convert_validation_exceptions(): + """Catches validation errors arguments, and re-raises a + StorageArgumentException.""" + try: + yield + except VALIDATION_EXCEPTIONS as e: + raise StorageArgumentException(*e.args) + + +def now(): + return datetime.datetime.now(tz=datetime.timezone.utc) + + +class ValidatingProxyStorage: + """Storage implementation converts dictionaries to swh-model objects + before calling its backend, and back to dicts before returning results + + """ + def __init__(self, storage): + self.storage = get_storage(**storage) + + def __getattr__(self, key): + return getattr(self.storage, key) + + def content_add(self, content: Iterable[Dict]) -> Dict: + with convert_validation_exceptions(): + contents = [Content.from_dict({**c, 'ctime': now()}) + for c in content] + return self.storage.content_add(contents) + + def content_add_metadata(self, content: Iterable[Dict]) -> Dict: + with convert_validation_exceptions(): + contents = [Content.from_dict(c) + for c in content] + return self.storage.content_add_metadata(contents) + + def skipped_content_add(self, content: Iterable[Dict]) -> Dict: + with convert_validation_exceptions(): + contents = [SkippedContent.from_dict({**c, 'ctime': now()}) + for c in content] + return self.storage.skipped_content_add(contents) + + def directory_add(self, directories: Iterable[Dict]) -> Dict: + with convert_validation_exceptions(): + directories = [Directory.from_dict(d) for d in directories] + return self.storage.directory_add(directories) + + def revision_add(self, revisions: Iterable[Dict]) -> Dict: + with convert_validation_exceptions(): + revisions = [Revision.from_dict(r) for r in revisions] + return self.storage.revision_add(revisions) + + def release_add(self, releases: Iterable[Dict]) -> Dict: + with convert_validation_exceptions(): + releases = [Release.from_dict(r) for r in releases] + return self.storage.release_add(releases) + + def snapshot_add(self, snapshots: Iterable[Dict]) -> Dict: + with convert_validation_exceptions(): + snapshots = [Snapshot.from_dict(s) for s in snapshots] + return self.storage.snapshot_add(snapshots) + + def origin_visit_add( + self, origin, date, type) -> Dict[str, Union[str, int]]: + with convert_validation_exceptions(): + visit = OriginVisit(origin=origin, date=date, type=type, + status='ongoing', snapshot=None) + return self.storage.origin_visit_add( + visit.origin, visit.date, visit.type) + + def origin_add(self, origins: Iterable[Dict]) -> List[Dict]: + with convert_validation_exceptions(): + origins = [Origin.from_dict(o) for o in origins] + return self.storage.origin_add(origins) + + def origin_add_one(self, origin: Dict) -> int: + with convert_validation_exceptions(): + origin = Origin.from_dict(origin) + return self.storage.origin_add_one(origin)