diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py index 4fb7f3ef..ffd0dc19 100644 --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -1,979 +1,975 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json import random import re from typing import Any, Dict, List, Iterable, Optional, Union import uuid import attr import dateutil from swh.model.model import ( Revision, Release, Directory, DirectoryEntry, Content, SkippedContent, OriginVisit, Snapshot, Origin ) from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.storage.objstorage import ObjStorage from swh.storage.writer import JournalWriter +from swh.storage.validate import convert_validation_exceptions from swh.storage.utils import now from ..exc import StorageArgumentException, HashCollision from .common import TOKEN_BEGIN, TOKEN_END from .converters import ( revision_to_db, revision_from_db, release_to_db, release_from_db, ) from .cql import CqlRunner from .schema import HASH_ALGORITHMS # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 class CassandraStorage: def __init__(self, hosts, keyspace, objstorage, port=9042, journal_writer=None): self._cql_runner = CqlRunner(hosts, keyspace, port) self.journal_writer = JournalWriter(journal_writer) self.objstorage = ObjStorage(objstorage) def check_config(self, *, check_write): self._cql_runner.check_read() return True def _content_get_from_hash(self, algo, hash_) -> Iterable: """From the name of a hash algorithm and a value of that hash, looks up the "hash -> token" secondary table (content_by_{algo}) to get tokens. Then, looks up the main table (content) to get all contents with that token, and filters out contents whose hash doesn't match.""" found_tokens = self._cql_runner.content_get_tokens_from_single_hash( algo, hash_) for token in found_tokens: # Query the main table ('content'). res = self._cql_runner.content_get_from_token(token) for row in res: # re-check the the hash (in case of murmur3 collision) if getattr(row, algo) == hash_: yield row def _content_add(self, contents: List[Content], with_data: bool) -> Dict: # Filter-out content already in the database. contents = [c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict())] self.journal_writer.content_add(contents) if with_data: # First insert to the objstorage, if the endpoint is # `content_add` (as opposed to `content_add_metadata`). # TODO: this should probably be done in concurrently to inserting # in index tables (but still before the main table; so an entry is # only added to the main table after everything else was # successfully inserted. summary = self.objstorage.content_add( c for c in contents if c.status != 'absent') content_add_bytes = summary['content:add:bytes'] content_add = 0 for content in contents: content_add += 1 # Check for sha1 or sha1_git collisions. This test is not atomic # with the insertion, so it won't detect a collision if both # contents are inserted at the same time, but it's good enough. # # The proper way to do it would probably be a BATCH, but this # would be inefficient because of the number of partitions we # need to affect (len(HASH_ALGORITHMS)+1, which is currently 5) for algo in {'sha1', 'sha1_git'}: collisions = [] # Get tokens of 'content' rows with the same value for # sha1/sha1_git rows = self._content_get_from_hash( algo, content.get_hash(algo)) for row in rows: if getattr(row, algo) != content.get_hash(algo): # collision of token(partition key), ignore this # row continue for algo in HASH_ALGORITHMS: if getattr(row, algo) != content.get_hash(algo): # This hash didn't match; discard the row. collisions.append({ algo: getattr(row, algo) for algo in HASH_ALGORITHMS}) if collisions: collisions.append(content.hashes()) raise HashCollision( algo, content.get_hash(algo), collisions) (token, insertion_finalizer) = \ self._cql_runner.content_add_prepare(content) # Then add to index tables for algo in HASH_ALGORITHMS: self._cql_runner.content_index_add_one(algo, content, token) # Then to the main table insertion_finalizer() summary = { 'content:add': content_add, } if with_data: summary['content:add:bytes'] = content_add_bytes return summary def content_add(self, content: Iterable[Content]) -> Dict: contents = [attr.evolve(c, ctime=now()) for c in content] return self._content_add(list(contents), with_data=True) def content_update(self, content, keys=[]): raise NotImplementedError( 'content_update is not supported by the Cassandra backend') def content_add_metadata(self, content: Iterable[Content]) -> Dict: return self._content_add(list(content), with_data=False) def content_get(self, content): if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: raise StorageArgumentException( "Sending at most %s contents." % BULK_BLOCK_CONTENT_LEN_MAX) yield from self.objstorage.content_get(content) def content_get_partition( self, partition_id: int, nb_partitions: int, limit: int = 1000, page_token: str = None): if limit is None: raise StorageArgumentException('limit should not be None') # Compute start and end of the range of tokens covered by the # requested partition partition_size = (TOKEN_END-TOKEN_BEGIN)//nb_partitions range_start = TOKEN_BEGIN + partition_id*partition_size range_end = TOKEN_BEGIN + (partition_id+1)*partition_size # offset the range start according to the `page_token`. if page_token is not None: if not (range_start <= int(page_token) <= range_end): raise StorageArgumentException('Invalid page_token.') range_start = int(page_token) # Get the first rows of the range rows = self._cql_runner.content_get_token_range( range_start, range_end, limit) rows = list(rows) if len(rows) == limit: next_page_token: Optional[str] = str(rows[-1].tok+1) else: next_page_token = None return { 'contents': [row._asdict() for row in rows if row.status != 'absent'], 'next_page_token': next_page_token, } def content_get_metadata( self, contents: List[bytes]) -> Dict[bytes, List[Dict]]: result: Dict[bytes, List[Dict]] = {sha1: [] for sha1 in contents} for sha1 in contents: # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1 # matches the argument, from the index table ('content_by_sha1') for row in self._content_get_from_hash('sha1', sha1): content_metadata = row._asdict() content_metadata.pop('ctime') result[content_metadata['sha1']].append(content_metadata) return result def content_find(self, content): # Find an algorithm that is common to all the requested contents. # It will be used to do an initial filtering efficiently. filter_algos = list(set(content).intersection(HASH_ALGORITHMS)) if not filter_algos: raise StorageArgumentException( 'content keys must contain at least one of: ' '%s' % ', '.join(sorted(HASH_ALGORITHMS))) common_algo = filter_algos[0] results = [] rows = self._content_get_from_hash( common_algo, content[common_algo]) for row in rows: # Re-check all the hashes, in case of collisions (either of the # hash of the partition key, or the hashes in it) for algo in HASH_ALGORITHMS: if content.get(algo) and getattr(row, algo) != content[algo]: # This hash didn't match; discard the row. break else: # All hashes match, keep this row. results.append({ **row._asdict(), 'ctime': row.ctime.replace(tzinfo=datetime.timezone.utc) }) return results def content_missing(self, content, key_hash='sha1'): for cont in content: res = self.content_find(cont) if not res: yield cont[key_hash] if any(c['status'] == 'missing' for c in res): yield cont[key_hash] def content_missing_per_sha1(self, contents): return self.content_missing([{'sha1': c for c in contents}]) def content_missing_per_sha1_git(self, contents): return self.content_missing([{'sha1_git': c for c in contents}], key_hash='sha1_git') def content_get_random(self): return self._cql_runner.content_get_random().sha1_git def _skipped_content_get_from_hash(self, algo, hash_) -> Iterable: """From the name of a hash algorithm and a value of that hash, looks up the "hash -> token" secondary table (skipped_content_by_{algo}) to get tokens. Then, looks up the main table (content) to get all contents with that token, and filters out contents whose hash doesn't match.""" found_tokens = \ self._cql_runner.skipped_content_get_tokens_from_single_hash( algo, hash_) for token in found_tokens: # Query the main table ('content'). res = self._cql_runner.skipped_content_get_from_token(token) for row in res: # re-check the the hash (in case of murmur3 collision) if getattr(row, algo) == hash_: yield row def _skipped_content_add(self, contents: Iterable[SkippedContent]) -> Dict: # Filter-out content already in the database. contents = [ c for c in contents if not self._cql_runner.skipped_content_get_from_pk(c.to_dict())] self.journal_writer.skipped_content_add(contents) for content in contents: # Compute token of the row in the main table (token, insertion_finalizer) = \ self._cql_runner.skipped_content_add_prepare(content) # Then add to index tables for algo in HASH_ALGORITHMS: self._cql_runner.skipped_content_index_add_one( algo, content, token) # Then to the main table insertion_finalizer() return { 'skipped_content:add': len(contents) } def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: contents = [attr.evolve(c, ctime=now()) for c in content] return self._skipped_content_add(contents) def skipped_content_missing(self, contents): for content in contents: if not self._cql_runner.skipped_content_get_from_pk(content): yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS} def directory_add(self, directories: Iterable[Directory]) -> Dict: directories = list(directories) # Filter out directories that are already inserted. missing = self.directory_missing([dir_.id for dir_ in directories]) directories = [dir_ for dir_ in directories if dir_.id in missing] self.journal_writer.directory_add(directories) for directory in directories: # Add directory entries to the 'directory_entry' table for entry in directory.entries: self._cql_runner.directory_entry_add_one({ **entry.to_dict(), 'directory_id': directory.id }) # Add the directory *after* adding all the entries, so someone # calling snapshot_get_branch in the meantime won't end up # with half the entries. self._cql_runner.directory_add_one(directory.id) return {'directory:add': len(missing)} def directory_missing(self, directories): return self._cql_runner.directory_missing(directories) def _join_dentry_to_content(self, dentry): keys = ( 'status', 'sha1', 'sha1_git', 'sha256', 'length', ) ret = dict.fromkeys(keys) ret.update(dentry.to_dict()) if ret['type'] == 'file': content = self.content_find({'sha1_git': ret['target']}) if content: content = content[0] for key in keys: ret[key] = content[key] return ret def _directory_ls(self, directory_id, recursive, prefix=b''): if self.directory_missing([directory_id]): return rows = list(self._cql_runner.directory_entry_get([directory_id])) for row in rows: # Build and yield the directory entry dict entry = row._asdict() del entry['directory_id'] entry = DirectoryEntry.from_dict(entry) ret = self._join_dentry_to_content(entry) ret['name'] = prefix + ret['name'] ret['dir_id'] = directory_id yield ret if recursive and ret['type'] == 'dir': yield from self._directory_ls( ret['target'], True, prefix + ret['name'] + b'/') def directory_entry_get_by_path(self, directory, paths): return self._directory_entry_get_by_path(directory, paths, b'') def _directory_entry_get_by_path(self, directory, paths, prefix): if not paths: return contents = list(self.directory_ls(directory)) if not contents: return def _get_entry(entries, name): """Finds the entry with the requested name, prepends the prefix (to get its full path), and returns it. If no entry has that name, returns None.""" for entry in entries: if entry['name'] == name: entry = entry.copy() entry['name'] = prefix + entry['name'] return entry first_item = _get_entry(contents, paths[0]) if len(paths) == 1: return first_item if not first_item or first_item['type'] != 'dir': return return self._directory_entry_get_by_path( first_item['target'], paths[1:], prefix + paths[0] + b'/') def directory_ls(self, directory, recursive=False): yield from self._directory_ls(directory, recursive) def directory_get_random(self): return self._cql_runner.directory_get_random().id def revision_add(self, revisions: Iterable[Revision]) -> Dict: revisions = list(revisions) # Filter-out revisions already in the database missing = self.revision_missing([rev.id for rev in revisions]) revisions = [rev for rev in revisions if rev.id in missing] self.journal_writer.revision_add(revisions) for revision in revisions: revision = revision_to_db(revision) if revision: # Add parents first for (rank, parent) in enumerate(revision.parents): self._cql_runner.revision_parent_add_one( revision.id, rank, parent) # Then write the main revision row. # Writing this after all parents were written ensures that # read endpoints don't return a partial view while writing # the parents self._cql_runner.revision_add_one(revision) return {'revision:add': len(revisions)} def revision_missing(self, revisions): return self._cql_runner.revision_missing(revisions) def revision_get(self, revisions): rows = self._cql_runner.revision_get(revisions) revs = {} for row in rows: # TODO: use a single query to get all parents? # (it might have lower latency, but requires more code and more # bandwidth, because revision id would be part of each returned # row) parent_rows = self._cql_runner.revision_parent_get(row.id) # parent_rank is the clustering key, so results are already # sorted by rank. parents = [row.parent_id for row in parent_rows] rev = Revision(**row._asdict(), parents=parents) rev = revision_from_db(rev) revs[rev.id] = rev.to_dict() for rev_id in revisions: yield revs.get(rev_id) def _get_parent_revs(self, rev_ids, seen, limit, short): if limit and len(seen) >= limit: return rev_ids = [id_ for id_ in rev_ids if id_ not in seen] if not rev_ids: return seen |= set(rev_ids) # We need this query, even if short=True, to return consistent # results (ie. not return only a subset of a revision's parents # if it is being written) if short: rows = self._cql_runner.revision_get_ids(rev_ids) else: rows = self._cql_runner.revision_get(rev_ids) for row in rows: # TODO: use a single query to get all parents? # (it might have less latency, but requires less code and more # bandwidth (because revision id would be part of each returned # row) parent_rows = self._cql_runner.revision_parent_get(row.id) # parent_rank is the clustering key, so results are already # sorted by rank. parents = [row.parent_id for row in parent_rows] if short: yield (row.id, parents) else: rev = revision_from_db(Revision( **row._asdict(), parents=parents)) yield rev.to_dict() yield from self._get_parent_revs(parents, seen, limit, short) def revision_log(self, revisions, limit=None): seen = set() yield from self._get_parent_revs(revisions, seen, limit, False) def revision_shortlog(self, revisions, limit=None): seen = set() yield from self._get_parent_revs(revisions, seen, limit, True) def revision_get_random(self): return self._cql_runner.revision_get_random().id def release_add(self, releases: Iterable[Release]) -> Dict: missing = self.release_missing([rel.id for rel in releases]) releases = [rel for rel in releases if rel.id in missing] self.journal_writer.release_add(releases) for release in releases: if release: release = release_to_db(release) self._cql_runner.release_add_one(release) return {'release:add': len(missing)} def release_missing(self, releases): return self._cql_runner.release_missing(releases) def release_get(self, releases): rows = self._cql_runner.release_get(releases) rels = {} for row in rows: release = Release(**row._asdict()) release = release_from_db(release) rels[row.id] = release.to_dict() for rel_id in releases: yield rels.get(rel_id) def release_get_random(self): return self._cql_runner.release_get_random().id def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict: missing = self._cql_runner.snapshot_missing( [snp.id for snp in snapshots]) snapshots = [snp for snp in snapshots if snp.id in missing] for snapshot in snapshots: self.journal_writer.snapshot_add(snapshot) # Add branches for (branch_name, branch) in snapshot.branches.items(): if branch is None: target_type = None target = None else: target_type = branch.target_type.value target = branch.target self._cql_runner.snapshot_branch_add_one({ 'snapshot_id': snapshot.id, 'name': branch_name, 'target_type': target_type, 'target': target, }) # Add the snapshot *after* adding all the branches, so someone # calling snapshot_get_branch in the meantime won't end up # with half the branches. self._cql_runner.snapshot_add_one(snapshot.id) return {'snapshot:add': len(snapshots)} def snapshot_missing(self, snapshots): return self._cql_runner.snapshot_missing(snapshots) def snapshot_get(self, snapshot_id): return self.snapshot_get_branches(snapshot_id) def snapshot_get_by_origin_visit(self, origin, visit): try: visit = self._cql_runner.origin_visit_get_one(origin, visit) except IndexError: return None return self.snapshot_get(visit.snapshot) def snapshot_get_latest(self, origin, allowed_statuses=None): visit = self.origin_visit_get_latest( origin, allowed_statuses=allowed_statuses, require_snapshot=True) if visit: assert visit['snapshot'] if self._cql_runner.snapshot_missing([visit['snapshot']]): raise StorageArgumentException( 'Visit references unknown snapshot') return self.snapshot_get_branches(visit['snapshot']) def snapshot_count_branches(self, snapshot_id): if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. return None rows = list(self._cql_runner.snapshot_count_branches(snapshot_id)) assert len(rows) == 1 (nb_none, counts) = rows[0].counts counts = dict(counts) if nb_none: counts[None] = nb_none return counts def snapshot_get_branches(self, snapshot_id, branches_from=b'', branches_count=1000, target_types=None): if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. return None branches = [] while len(branches) < branches_count+1: new_branches = list(self._cql_runner.snapshot_branch_get( snapshot_id, branches_from, branches_count+1)) if not new_branches: break branches_from = new_branches[-1].name new_branches_filtered = new_branches # Filter by target_type if target_types: new_branches_filtered = [ branch for branch in new_branches_filtered if branch.target is not None and branch.target_type in target_types] branches.extend(new_branches_filtered) if len(new_branches) < branches_count+1: break if len(branches) > branches_count: last_branch = branches.pop(-1).name else: last_branch = None branches = { branch.name: { 'target': branch.target, 'target_type': branch.target_type, } if branch.target else None for branch in branches } return { 'id': snapshot_id, 'branches': branches, 'next_branch': last_branch, } def snapshot_get_random(self): return self._cql_runner.snapshot_get_random().id def object_find_by_sha1_git(self, ids): results = {id_: [] for id_ in ids} missing_ids = set(ids) # Mind the order, revision is the most likely one for a given ID, # so we check revisions first. queries = [ ('revision', self._cql_runner.revision_missing), ('release', self._cql_runner.release_missing), ('content', self._cql_runner.content_missing_by_sha1_git), ('directory', self._cql_runner.directory_missing), ] for (object_type, query_fn) in queries: found_ids = missing_ids - set(query_fn(missing_ids)) for sha1_git in found_ids: results[sha1_git].append({ 'sha1_git': sha1_git, 'type': object_type, }) missing_ids.remove(sha1_git) if not missing_ids: # We found everything, skipping the next queries. break return results def origin_get(self, origins): if isinstance(origins, dict): # Old API return_single = True origins = [origins] else: return_single = False if any('id' in origin for origin in origins): raise StorageArgumentException('Origin ids are not supported.') results = [self.origin_get_one(origin) for origin in origins] if return_single: assert len(results) == 1 return results[0] else: return results def origin_get_one(self, origin: Dict[str, Any]) -> Optional[ Dict[str, Any]]: if 'id' in origin: raise StorageArgumentException('Origin ids are not supported.') if 'url' not in origin: raise StorageArgumentException('Missing origin url') rows = self._cql_runner.origin_get_by_url(origin['url']) rows = list(rows) if rows: assert len(rows) == 1 result = rows[0]._asdict() return { 'url': result['url'], } else: return None def origin_get_by_sha1(self, sha1s): results = [] for sha1 in sha1s: rows = self._cql_runner.origin_get_by_sha1(sha1) if rows: results.append({'url': rows.one().url}) else: results.append(None) return results def origin_list(self, page_token: Optional[str] = None, limit: int = 100 ) -> dict: # Compute what token to begin the listing from start_token = TOKEN_BEGIN if page_token: start_token = int(page_token) if not (TOKEN_BEGIN <= start_token <= TOKEN_END): raise StorageArgumentException('Invalid page_token.') rows = self._cql_runner.origin_list(start_token, limit) rows = list(rows) if len(rows) == limit: next_page_token: Optional[str] = str(rows[-1].tok+1) else: next_page_token = None return { 'origins': [{'url': row.url} for row in rows], 'next_page_token': next_page_token, } def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False): # TODO: remove this endpoint, swh-search should be used instead. origins = self._cql_runner.origin_iter_all() if regexp: pat = re.compile(url_pattern) origins = [orig for orig in origins if pat.search(orig.url)] else: origins = [orig for orig in origins if url_pattern in orig.url] if with_visit: origins = [orig for orig in origins if orig.next_visit_id > 1] return [ { 'url': orig.url, } for orig in origins[offset:offset+limit]] def origin_add(self, origins: Iterable[Origin]) -> List[Dict]: results = [] for origin in origins: self.origin_add_one(origin) results.append(origin.to_dict()) return results def origin_add_one(self, origin: Origin) -> str: known_origin = self.origin_get_one(origin.to_dict()) if known_origin: origin_url = known_origin['url'] else: self.journal_writer.origin_add_one(origin) self._cql_runner.origin_add_one(origin) origin_url = origin.url return origin_url def origin_visit_add(self, origin_url: str, date: Union[str, datetime.datetime], type: str) -> OriginVisit: if isinstance(date, str): # FIXME: Converge on iso8601 at some point date = dateutil.parser.parse(date) elif not isinstance(date, datetime.datetime): raise StorageArgumentException( 'Date must be a datetime or a string') if not self.origin_get_one({'url': origin_url}): raise StorageArgumentException( 'Unknown origin %s', origin_url) visit_id = self._cql_runner.origin_generate_unique_visit_id(origin_url) - try: + with convert_validation_exceptions(): visit = OriginVisit.from_dict({ 'origin': origin_url, 'date': date, 'type': type, 'status': 'ongoing', 'snapshot': None, 'metadata': None, 'visit': visit_id }) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + self.journal_writer.origin_visit_add(visit) self._cql_runner.origin_visit_add_one(visit) return visit def origin_visit_update( self, origin: str, visit_id: int, status: str, metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None, date: Optional[datetime.datetime] = None): origin_url = origin # TODO: rename the argument # Get the existing data of the visit row = self._cql_runner.origin_visit_get_one(origin_url, visit_id) if not row: raise StorageArgumentException('This origin visit does not exist.') - try: + with convert_validation_exceptions(): visit = OriginVisit.from_dict(self._format_origin_visit_row(row)) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) updates: Dict[str, Any] = { 'status': status } if metadata: updates['metadata'] = metadata if snapshot: updates['snapshot'] = snapshot - try: + with convert_validation_exceptions(): visit = attr.evolve(visit, **updates) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) self.journal_writer.origin_visit_update(visit) self._cql_runner.origin_visit_update(origin_url, visit_id, updates) def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: self.journal_writer.origin_visit_upsert(visits) for visit in visits: self._cql_runner.origin_visit_upsert(visit) @staticmethod def _format_origin_visit_row(visit): return { **visit._asdict(), 'origin': visit.origin, 'date': visit.date.replace(tzinfo=datetime.timezone.utc), 'metadata': (json.loads(visit.metadata) if visit.metadata else None), } def origin_visit_get(self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None): rows = self._cql_runner.origin_visit_get(origin, last_visit, limit) yield from map(self._format_origin_visit_row, rows) def origin_visit_find_by_date(self, origin, visit_date): # Iterator over all the visits of the origin # This should be ok for now, as there aren't too many visits # per origin. visits = list(self._cql_runner.origin_visit_get_all(origin)) def key(visit): dt = visit.date.replace(tzinfo=datetime.timezone.utc) - visit_date return (abs(dt), -visit.visit) if visits: visit = min(visits, key=key) return visit._asdict() def origin_visit_get_by(self, origin, visit): visit = self._cql_runner.origin_visit_get_one(origin, visit) if visit: return self._format_origin_visit_row(visit) def origin_visit_get_latest( self, origin, allowed_statuses=None, require_snapshot=False): visit = self._cql_runner.origin_visit_get_latest( origin, allowed_statuses=allowed_statuses, require_snapshot=require_snapshot) if visit: return self._format_origin_visit_row(visit) def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back # Random position to start iteration at start_token = random.randint(TOKEN_BEGIN, TOKEN_END) # Iterator over all visits, ordered by token(origins) then visit_id rows = self._cql_runner.origin_visit_iter(start_token) for row in rows: visit = self._format_origin_visit_row(row) if visit['date'] > back_in_the_day \ and visit['status'] == 'full': return visit else: return None def tool_add(self, tools): inserted = [] for tool in tools: tool = tool.copy() tool_json = tool.copy() tool_json['configuration'] = json.dumps( tool['configuration'], sort_keys=True).encode() id_ = self._cql_runner.tool_get_one_uuid(**tool_json) if not id_: id_ = uuid.uuid1() tool_json['id'] = id_ self._cql_runner.tool_by_uuid_add_one(tool_json) self._cql_runner.tool_add_one(tool_json) tool['id'] = id_ inserted.append(tool) return inserted def tool_get(self, tool): id_ = self._cql_runner.tool_get_one_uuid( tool['name'], tool['version'], json.dumps(tool['configuration'], sort_keys=True).encode()) if id_: tool = tool.copy() tool['id'] = id_ return tool else: return None def stat_counters(self): rows = self._cql_runner.stat_counters() keys = ( 'content', 'directory', 'origin', 'origin_visit', 'release', 'revision', 'skipped_content', 'snapshot') stats = {key: 0 for key in keys} stats.update({row.object_type: row.count for row in rows}) return stats def refresh_stat_counters(self): pass def origin_metadata_add(self, origin_url, ts, provider, tool, metadata): # TODO raise NotImplementedError('not yet supported for Cassandra') def origin_metadata_get_by(self, origin_url, provider_type=None): # TODO raise NotImplementedError('not yet supported for Cassandra') def metadata_provider_add(self, provider_name, provider_type, provider_url, metadata): # TODO raise NotImplementedError('not yet supported for Cassandra') def metadata_provider_get(self, provider_id): # TODO raise NotImplementedError('not yet supported for Cassandra') def metadata_provider_get_by(self, provider): # TODO raise NotImplementedError('not yet supported for Cassandra') diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py index c4857dbc..ac496c18 100644 --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1,998 +1,995 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re import bisect import dateutil import collections import copy import datetime import itertools import random from collections import defaultdict from datetime import timedelta from typing import Any, Dict, Iterable, List, Optional, Union import attr from swh.model.model import ( BaseContent, Content, SkippedContent, Directory, Revision, Release, Snapshot, OriginVisit, Origin, SHA1_SIZE ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex from swh.storage.objstorage import ObjStorage +from swh.storage.validate import convert_validation_exceptions from swh.storage.utils import now from .exc import StorageArgumentException, HashCollision from .converters import origin_url_to_sha1 from .utils import get_partition_bounds_bytes from .writer import JournalWriter # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 class InMemoryStorage: def __init__(self, journal_writer=None): self.reset() self.journal_writer = JournalWriter(journal_writer) def reset(self): self._contents = {} self._content_indexes = defaultdict(lambda: defaultdict(set)) self._skipped_contents = {} self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) self._directories = {} self._revisions = {} self._releases = {} self._snapshots = {} self._origins = {} self._origins_by_id = [] self._origins_by_sha1 = {} self._origin_visits = {} self._persons = [] self._origin_metadata = defaultdict(list) self._tools = {} self._metadata_providers = {} self._objects = defaultdict(list) # ideally we would want a skip list for both fast inserts and searches self._sorted_sha1s = [] self.objstorage = ObjStorage({'cls': 'memory', 'args': {}}) def check_config(self, *, check_write): return True def _content_add( self, contents: Iterable[Content], with_data: bool) -> Dict: self.journal_writer.content_add(contents) content_add = 0 content_add_bytes = 0 if with_data: summary = self.objstorage.content_add( c for c in contents if c.status != 'absent') content_add_bytes = summary['content:add:bytes'] for content in contents: key = self._content_key(content) if key in self._contents: continue for algorithm in DEFAULT_ALGORITHMS: hash_ = content.get_hash(algorithm) if hash_ in self._content_indexes[algorithm]\ and (algorithm not in {'blake2s256', 'sha256'}): colliding_content_hashes = [] # Add the already stored contents for content_hashes_set in self._content_indexes[ algorithm][hash_]: hashes = dict(content_hashes_set) colliding_content_hashes.append(hashes) # Add the new colliding content colliding_content_hashes.append(content.hashes()) raise HashCollision( algorithm, hash_, colliding_content_hashes) for algorithm in DEFAULT_ALGORITHMS: hash_ = content.get_hash(algorithm) self._content_indexes[algorithm][hash_].add(key) self._objects[content.sha1_git].append( ('content', content.sha1)) self._contents[key] = content bisect.insort(self._sorted_sha1s, content.sha1) self._contents[key] = attr.evolve( self._contents[key], data=None) content_add += 1 summary = { 'content:add': content_add, } if with_data: summary['content:add:bytes'] = content_add_bytes return summary def content_add(self, content: Iterable[Content]) -> Dict: content = [attr.evolve(c, ctime=now()) for c in content] return self._content_add(content, with_data=True) def content_update(self, content, keys=[]): self.journal_writer.content_update(content) for cont_update in content: cont_update = cont_update.copy() sha1 = cont_update.pop('sha1') for old_key in self._content_indexes['sha1'][sha1]: old_cont = self._contents.pop(old_key) for algorithm in DEFAULT_ALGORITHMS: hash_ = old_cont.get_hash(algorithm) self._content_indexes[algorithm][hash_].remove(old_key) new_cont = attr.evolve(old_cont, **cont_update) new_key = self._content_key(new_cont) self._contents[new_key] = new_cont for algorithm in DEFAULT_ALGORITHMS: hash_ = new_cont.get_hash(algorithm) self._content_indexes[algorithm][hash_].add(new_key) def content_add_metadata(self, content: Iterable[Content]) -> Dict: return self._content_add(content, with_data=False) def content_get(self, content): # FIXME: Make this method support slicing the `data`. if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: raise StorageArgumentException( "Sending at most %s contents." % BULK_BLOCK_CONTENT_LEN_MAX) yield from self.objstorage.content_get(content) def content_get_range(self, start, end, limit=1000): if limit is None: raise StorageArgumentException('limit should not be None') from_index = bisect.bisect_left(self._sorted_sha1s, start) sha1s = itertools.islice(self._sorted_sha1s, from_index, None) sha1s = ((sha1, content_key) for sha1 in sha1s for content_key in self._content_indexes['sha1'][sha1]) matched = [] next_content = None for sha1, key in sha1s: if sha1 > end: break if len(matched) >= limit: next_content = sha1 break matched.append(self._contents[key].to_dict()) return { 'contents': matched, 'next': next_content, } def content_get_partition( self, partition_id: int, nb_partitions: int, limit: int = 1000, page_token: str = None): if limit is None: raise StorageArgumentException('limit should not be None') (start, end) = get_partition_bounds_bytes( partition_id, nb_partitions, SHA1_SIZE) if page_token: start = hash_to_bytes(page_token) if end is None: end = b'\xff'*SHA1_SIZE result = self.content_get_range(start, end, limit) result2 = { 'contents': result['contents'], 'next_page_token': None, } if result['next']: result2['next_page_token'] = hash_to_hex(result['next']) return result2 def content_get_metadata( self, contents: List[bytes]) -> Dict[bytes, List[Dict]]: result: Dict = {sha1: [] for sha1 in contents} for sha1 in contents: if sha1 in self._content_indexes['sha1']: objs = self._content_indexes['sha1'][sha1] # only 1 element as content_add_metadata would have raised a # hash collision otherwise for key in objs: d = self._contents[key].to_dict() del d['ctime'] if 'data' in d: del d['data'] result[sha1].append(d) return result def content_find(self, content): if not set(content).intersection(DEFAULT_ALGORITHMS): raise StorageArgumentException( 'content keys must contain at least one of: %s' % ', '.join(sorted(DEFAULT_ALGORITHMS))) found = [] for algo in DEFAULT_ALGORITHMS: hash = content.get(algo) if hash and hash in self._content_indexes[algo]: found.append(self._content_indexes[algo][hash]) if not found: return [] keys = list(set.intersection(*found)) return [self._contents[key].to_dict() for key in keys] def content_missing(self, content, key_hash='sha1'): for cont in content: for (algo, hash_) in cont.items(): if algo not in DEFAULT_ALGORITHMS: continue if hash_ not in self._content_indexes.get(algo, []): yield cont[key_hash] break else: for result in self.content_find(cont): if result['status'] == 'missing': yield cont[key_hash] def content_missing_per_sha1(self, contents): for content in contents: if content not in self._content_indexes['sha1']: yield content def content_missing_per_sha1_git(self, contents): for content in contents: if content not in self._content_indexes['sha1_git']: yield content def content_get_random(self): return random.choice(list(self._content_indexes['sha1_git'])) def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict: self.journal_writer.skipped_content_add(contents) summary = { 'skipped_content:add': 0 } missing_contents = self.skipped_content_missing( [c.hashes() for c in contents]) missing = {self._content_key(c) for c in missing_contents} contents = [c for c in contents if self._content_key(c) in missing] for content in contents: key = self._content_key(content) for algo in DEFAULT_ALGORITHMS: if content.get_hash(algo): self._skipped_content_indexes[algo][ content.get_hash(algo)].add(key) self._skipped_contents[key] = content summary['skipped_content:add'] += 1 return summary def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: content = [attr.evolve(c, ctime=now()) for c in content] return self._skipped_content_add(content) def skipped_content_missing(self, contents): for content in contents: matches = list(self._skipped_contents.values()) for (algorithm, key) in self._content_key(content): if algorithm == 'blake2s256': continue # Filter out skipped contents with the same hash matches = [ match for match in matches if match.get_hash(algorithm) == key] # if none of the contents match if not matches: yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS} def directory_add(self, directories: Iterable[Directory]) -> Dict: directories = [dir_ for dir_ in directories if dir_.id not in self._directories] self.journal_writer.directory_add(directories) count = 0 for directory in directories: count += 1 self._directories[directory.id] = directory self._objects[directory.id].append( ('directory', directory.id)) return {'directory:add': count} def directory_missing(self, directories): for id in directories: if id not in self._directories: yield id def _join_dentry_to_content(self, dentry): keys = ( 'status', 'sha1', 'sha1_git', 'sha256', 'length', ) ret = dict.fromkeys(keys) ret.update(dentry) if ret['type'] == 'file': # TODO: Make it able to handle more than one content content = self.content_find({'sha1_git': ret['target']}) if content: content = content[0] for key in keys: ret[key] = content[key] return ret def _directory_ls(self, directory_id, recursive, prefix=b''): if directory_id in self._directories: for entry in self._directories[directory_id].entries: ret = self._join_dentry_to_content(entry.to_dict()) ret['name'] = prefix + ret['name'] ret['dir_id'] = directory_id yield ret if recursive and ret['type'] == 'dir': yield from self._directory_ls( ret['target'], True, prefix + ret['name'] + b'/') def directory_ls(self, directory, recursive=False): yield from self._directory_ls(directory, recursive) def directory_entry_get_by_path(self, directory, paths): return self._directory_entry_get_by_path(directory, paths, b'') def directory_get_random(self): if not self._directories: return None return random.choice(list(self._directories)) def _directory_entry_get_by_path(self, directory, paths, prefix): if not paths: return contents = list(self.directory_ls(directory)) if not contents: return def _get_entry(entries, name): for entry in entries: if entry['name'] == name: entry = entry.copy() entry['name'] = prefix + entry['name'] return entry first_item = _get_entry(contents, paths[0]) if len(paths) == 1: return first_item if not first_item or first_item['type'] != 'dir': return return self._directory_entry_get_by_path( first_item['target'], paths[1:], prefix + paths[0] + b'/') def revision_add(self, revisions: Iterable[Revision]) -> Dict: revisions = [rev for rev in revisions if rev.id not in self._revisions] self.journal_writer.revision_add(revisions) count = 0 for revision in revisions: revision = attr.evolve( revision, committer=self._person_add(revision.committer), author=self._person_add(revision.author)) self._revisions[revision.id] = revision self._objects[revision.id].append( ('revision', revision.id)) count += 1 return {'revision:add': count} def revision_missing(self, revisions): for id in revisions: if id not in self._revisions: yield id def revision_get(self, revisions): for id in revisions: if id in self._revisions: yield self._revisions.get(id).to_dict() else: yield None def _get_parent_revs(self, rev_id, seen, limit): if limit and len(seen) >= limit: return if rev_id in seen or rev_id not in self._revisions: return seen.add(rev_id) yield self._revisions[rev_id].to_dict() for parent in self._revisions[rev_id].parents: yield from self._get_parent_revs(parent, seen, limit) def revision_log(self, revisions, limit=None): seen = set() for rev_id in revisions: yield from self._get_parent_revs(rev_id, seen, limit) def revision_shortlog(self, revisions, limit=None): yield from ((rev['id'], rev['parents']) for rev in self.revision_log(revisions, limit)) def revision_get_random(self): return random.choice(list(self._revisions)) def release_add(self, releases: Iterable[Release]) -> Dict: releases = [rel for rel in releases if rel.id not in self._releases] self.journal_writer.release_add(releases) count = 0 for rel in releases: if rel.author: self._person_add(rel.author) self._objects[rel.id].append( ('release', rel.id)) self._releases[rel.id] = rel count += 1 return {'release:add': count} def release_missing(self, releases): yield from (rel for rel in releases if rel not in self._releases) def release_get(self, releases): for rel_id in releases: if rel_id in self._releases: yield self._releases[rel_id].to_dict() else: yield None def release_get_random(self): return random.choice(list(self._releases)) def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict: count = 0 snapshots = (snap for snap in snapshots if snap.id not in self._snapshots) for snapshot in snapshots: self.journal_writer.snapshot_add(snapshot) sorted_branch_names = sorted(snapshot.branches) self._snapshots[snapshot.id] = (snapshot, sorted_branch_names) self._objects[snapshot.id].append(('snapshot', snapshot.id)) count += 1 return {'snapshot:add': count} def snapshot_missing(self, snapshots): for id in snapshots: if id not in self._snapshots: yield id def snapshot_get(self, snapshot_id): return self.snapshot_get_branches(snapshot_id) def snapshot_get_by_origin_visit(self, origin, visit): origin_url = self._get_origin_url(origin) if not origin_url: return if origin_url not in self._origins or \ visit > len(self._origin_visits[origin_url]): return None snapshot_id = self._origin_visits[origin_url][visit-1].snapshot if snapshot_id: return self.snapshot_get(snapshot_id) else: return None def snapshot_get_latest(self, origin, allowed_statuses=None): origin_url = self._get_origin_url(origin) if not origin_url: return visit = self.origin_visit_get_latest( origin_url, allowed_statuses=allowed_statuses, require_snapshot=True) if visit and visit['snapshot']: snapshot = self.snapshot_get(visit['snapshot']) if not snapshot: raise StorageArgumentException( 'last origin visit references an unknown snapshot') return snapshot def snapshot_count_branches(self, snapshot_id): (snapshot, _) = self._snapshots[snapshot_id] return collections.Counter(branch.target_type.value if branch else None for branch in snapshot.branches.values()) def snapshot_get_branches(self, snapshot_id, branches_from=b'', branches_count=1000, target_types=None): res = self._snapshots.get(snapshot_id) if res is None: return None (snapshot, sorted_branch_names) = res from_index = bisect.bisect_left( sorted_branch_names, branches_from) if target_types: next_branch = None branches = {} for branch_name in sorted_branch_names[from_index:]: branch = snapshot.branches[branch_name] if branch and branch.target_type.value in target_types: if len(branches) < branches_count: branches[branch_name] = branch else: next_branch = branch_name break else: # As there is no 'target_types', we can do that much faster to_index = from_index + branches_count returned_branch_names = sorted_branch_names[from_index:to_index] branches = {branch_name: snapshot.branches[branch_name] for branch_name in returned_branch_names} if to_index >= len(sorted_branch_names): next_branch = None else: next_branch = sorted_branch_names[to_index] branches = {name: branch.to_dict() if branch else None for (name, branch) in branches.items()} return { 'id': snapshot_id, 'branches': branches, 'next_branch': next_branch, } def snapshot_get_random(self): return random.choice(list(self._snapshots)) def object_find_by_sha1_git(self, ids): ret = {} for id_ in ids: objs = self._objects.get(id_, []) ret[id_] = [{ 'sha1_git': id_, 'type': obj[0], } for obj in objs] return ret def _convert_origin(self, t): if t is None: return None return t.to_dict() def origin_get(self, origins): if isinstance(origins, dict): # Old API return_single = True origins = [origins] else: return_single = False # Sanity check to be error-compatible with the pgsql backend if any('id' in origin for origin in origins) \ and not all('id' in origin for origin in origins): raise StorageArgumentException( 'Either all origins or none at all should have an "id".') if any('url' in origin for origin in origins) \ and not all('url' in origin for origin in origins): raise StorageArgumentException( 'Either all origins or none at all should have ' 'an "url" key.') results = [] for origin in origins: result = None if 'url' in origin: if origin['url'] in self._origins: result = self._origins[origin['url']] else: raise StorageArgumentException( 'Origin must have an url.') results.append(self._convert_origin(result)) if return_single: assert len(results) == 1 return results[0] else: return results def origin_get_by_sha1(self, sha1s): return [ self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s ] def origin_get_range(self, origin_from=1, origin_count=100): origin_from = max(origin_from, 1) if origin_from <= len(self._origins_by_id): max_idx = origin_from + origin_count - 1 if max_idx > len(self._origins_by_id): max_idx = len(self._origins_by_id) for idx in range(origin_from-1, max_idx): origin = self._convert_origin( self._origins[self._origins_by_id[idx]]) yield {'id': idx+1, **origin} def origin_list(self, page_token: Optional[str] = None, limit: int = 100 ) -> dict: origin_urls = sorted(self._origins) if page_token: from_ = bisect.bisect_left(origin_urls, page_token) else: from_ = 0 result = { 'origins': [{'url': origin_url} for origin_url in origin_urls[from_:from_+limit]] } if from_+limit < len(origin_urls): result['next_page_token'] = origin_urls[from_+limit] return result def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False): origins = map(self._convert_origin, self._origins.values()) if regexp: pat = re.compile(url_pattern) origins = [orig for orig in origins if pat.search(orig['url'])] else: origins = [orig for orig in origins if url_pattern in orig['url']] if with_visit: origins = [ orig for orig in origins if len(self._origin_visits[orig['url']]) > 0 and set(ov.snapshot for ov in self._origin_visits[orig['url']] if ov.snapshot) & set(self._snapshots)] return origins[offset:offset+limit] def origin_count(self, url_pattern, regexp=False, with_visit=False): return len(self.origin_search(url_pattern, regexp=regexp, with_visit=with_visit, limit=len(self._origins))) def origin_add(self, origins: Iterable[Origin]) -> List[Dict]: origins = copy.deepcopy(list(origins)) for origin in origins: self.origin_add_one(origin) return [origin.to_dict() for origin in origins] def origin_add_one(self, origin: Origin) -> str: if origin.url not in self._origins: self.journal_writer.origin_add_one(origin) # generate an origin_id because it is needed by origin_get_range. # TODO: remove this when we remove origin_get_range origin_id = len(self._origins) + 1 self._origins_by_id.append(origin.url) assert len(self._origins_by_id) == origin_id self._origins[origin.url] = origin self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin self._origin_visits[origin.url] = [] self._objects[origin.url].append(('origin', origin.url)) return origin.url def origin_visit_add(self, origin_url: str, date: Union[str, datetime.datetime], type: str) -> OriginVisit: if isinstance(date, str): # FIXME: Converge on iso8601 at some point date = dateutil.parser.parse(date) elif not isinstance(date, datetime.datetime): raise StorageArgumentException( 'Date must be a datetime or a string') origin = self.origin_get({'url': origin_url}) if not origin: # Cannot add a visit without an origin raise StorageArgumentException( 'Unknown origin %s', origin_url) if origin_url in self._origins: origin = self._origins[origin_url] # visit ids are in the range [1, +inf[ visit_id = len(self._origin_visits[origin_url]) + 1 status = 'ongoing' visit = OriginVisit( origin=origin_url, date=date, type=type, status=status, snapshot=None, metadata=None, visit=visit_id, ) self._origin_visits[origin_url].append(visit) visit = visit self._objects[(origin_url, visit.visit)].append( ('origin_visit', None)) self.journal_writer.origin_visit_add(visit) # return last visit return visit def origin_visit_update( self, origin: str, visit_id: int, status: str, metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None, date: Optional[datetime.datetime] = None): origin_url = self._get_origin_url(origin) if origin_url is None: raise StorageArgumentException('Unknown origin.') try: visit = self._origin_visits[origin_url][visit_id-1] except IndexError: raise StorageArgumentException( 'Unknown visit_id for this origin') from None updates: Dict[str, Any] = { 'status': status } if metadata: updates['metadata'] = metadata if snapshot: updates['snapshot'] = snapshot - try: + with convert_validation_exceptions(): visit = attr.evolve(visit, **updates) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) self.journal_writer.origin_visit_update(visit) self._origin_visits[origin_url][visit_id-1] = visit def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: self.journal_writer.origin_visit_upsert(visits) for visit in visits: visit_id = visit.visit origin_url = visit.origin - try: + with convert_validation_exceptions(): visit = attr.evolve(visit, origin=origin_url) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) self._objects[(origin_url, visit_id)].append( ('origin_visit', None)) if visit_id: while len(self._origin_visits[origin_url]) <= visit_id: self._origin_visits[origin_url].append(None) self._origin_visits[origin_url][visit_id-1] = visit def _convert_visit(self, visit): if visit is None: return visit = visit.to_dict() return visit def origin_visit_get(self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None): origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] if last_visit is not None: visits = visits[last_visit:] if limit is not None: visits = visits[:limit] for visit in visits: if not visit: continue visit_id = visit.visit yield self._convert_visit( self._origin_visits[origin_url][visit_id-1]) def origin_visit_find_by_date(self, origin, visit_date): origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] visit = min( visits, key=lambda v: (abs(v.date - visit_date), -v.visit)) return self._convert_visit(visit) def origin_visit_get_by(self, origin, visit): origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits and \ visit <= len(self._origin_visits[origin_url]): return self._convert_visit( self._origin_visits[origin_url][visit-1]) def origin_visit_get_latest( self, origin, allowed_statuses=None, require_snapshot=False): origin = self._origins.get(origin) if not origin: return visits = self._origin_visits[origin.url] if allowed_statuses is not None: visits = [visit for visit in visits if visit.status in allowed_statuses] if require_snapshot: visits = [visit for visit in visits if visit.snapshot] visit = max( visits, key=lambda v: (v.date, v.visit), default=None) return self._convert_visit(visit) def _select_random_origin_visit_by_type(self, type: str) -> str: while True: url = random.choice(list(self._origin_visits.keys())) random_origin_visits = self._origin_visits[url] if random_origin_visits[0].type == type: return url def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: url = self._select_random_origin_visit_by_type(type) random_origin_visits = copy.deepcopy(self._origin_visits[url]) random_origin_visits.reverse() back_in_the_day = now() - timedelta(weeks=12) # 3 months back # This should be enough for tests for visit in random_origin_visits: if visit.date > back_in_the_day and visit.status == 'full': return visit.to_dict() else: return None def stat_counters(self): keys = ( 'content', 'directory', 'origin', 'origin_visit', 'person', 'release', 'revision', 'skipped_content', 'snapshot' ) stats = {key: 0 for key in keys} stats.update(collections.Counter( obj_type for (obj_type, obj_id) in itertools.chain(*self._objects.values()))) return stats def refresh_stat_counters(self): pass def origin_metadata_add(self, origin_url, ts, provider, tool, metadata): if not isinstance(origin_url, str): raise TypeError('origin_id must be str, not %r' % (origin_url,)) if isinstance(ts, str): ts = dateutil.parser.parse(ts) origin_metadata = { 'origin_url': origin_url, 'discovery_date': ts, 'tool_id': tool, 'metadata': metadata, 'provider_id': provider, } self._origin_metadata[origin_url].append(origin_metadata) return None def origin_metadata_get_by(self, origin_url, provider_type=None): if not isinstance(origin_url, str): raise TypeError('origin_url must be str, not %r' % (origin_url,)) metadata = [] for item in self._origin_metadata[origin_url]: item = copy.deepcopy(item) provider = self.metadata_provider_get(item['provider_id']) for attr_name in ('name', 'type', 'url'): item['provider_' + attr_name] = \ provider['provider_' + attr_name] metadata.append(item) return metadata def tool_add(self, tools): inserted = [] for tool in tools: key = self._tool_key(tool) assert 'id' not in tool record = copy.deepcopy(tool) record['id'] = key # TODO: remove this if key not in self._tools: self._tools[key] = record inserted.append(copy.deepcopy(self._tools[key])) return inserted def tool_get(self, tool): return self._tools.get(self._tool_key(tool)) def metadata_provider_add(self, provider_name, provider_type, provider_url, metadata): provider = { 'provider_name': provider_name, 'provider_type': provider_type, 'provider_url': provider_url, 'metadata': metadata, } key = self._metadata_provider_key(provider) provider['id'] = key self._metadata_providers[key] = provider return key def metadata_provider_get(self, provider_id): return self._metadata_providers.get(provider_id) def metadata_provider_get_by(self, provider): key = self._metadata_provider_key(provider) return self._metadata_providers.get(key) def _get_origin_url(self, origin): if isinstance(origin, str): return origin else: raise TypeError('origin must be a string.') def _person_add(self, person): key = ('person', person.fullname) if key not in self._objects: person_id = len(self._persons) + 1 self._persons.append(person) self._objects[key].append(('person', person_id)) else: person_id = self._objects[key][0][1] person = self._persons[person_id-1] return person @staticmethod def _content_key(content): """ A stable key and the algorithm for a content""" if isinstance(content, BaseContent): content = content.to_dict() return tuple((key, content.get(key)) for key in sorted(DEFAULT_ALGORITHMS)) @staticmethod def _tool_key(tool): return '%r %r %r' % (tool['name'], tool['version'], tuple(sorted(tool['configuration'].items()))) @staticmethod def _metadata_provider_key(provider): return '%r %r' % (provider['provider_name'], provider['provider_url']) def diff_directories(self, from_dir, to_dir, track_renaming=False): raise NotImplementedError('InMemoryStorage.diff_directories') def diff_revisions(self, from_rev, to_rev, track_renaming=False): raise NotImplementedError('InMemoryStorage.diff_revisions') def diff_revision(self, revision, track_renaming=False): raise NotImplementedError('InMemoryStorage.diff_revision') diff --git a/swh/storage/storage.py b/swh/storage/storage.py index 78d0d5b0..e4884964 100644 --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1,1141 +1,1140 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import contextlib import datetime import itertools import json from collections import defaultdict from contextlib import contextmanager from typing import Any, Dict, Iterable, List, Optional, Union import attr import dateutil.parser import psycopg2 import psycopg2.pool import psycopg2.errors from swh.model.model import ( Content, Directory, Origin, OriginVisit, Revision, Release, SkippedContent, Snapshot, SHA1_SIZE ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex from swh.storage.objstorage import ObjStorage +from swh.storage.validate import VALIDATION_EXCEPTIONS from swh.storage.utils import now from . import converters from .common import db_transaction_generator, db_transaction from .db import Db from .exc import StorageArgumentException, StorageDBError, HashCollision from .algos import diff from .metrics import timed, send_metric, process_metrics from .utils import ( get_partition_bounds_bytes, extract_collision_hash ) from .writer import JournalWriter # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 EMPTY_SNAPSHOT_ID = hash_to_bytes('1a8893e6a86f444e8be8e7bda6cb34fb1735a00e') """Identifier for the empty snapshot""" -VALIDATION_EXCEPTIONS = ( +VALIDATION_EXCEPTIONS = VALIDATION_EXCEPTIONS + [ psycopg2.errors.CheckViolation, psycopg2.errors.IntegrityError, psycopg2.errors.InvalidTextRepresentation, psycopg2.errors.NotNullViolation, psycopg2.errors.NumericValueOutOfRange, psycopg2.errors.UndefinedFunction, # (raised on wrong argument typs) -) +] """Exceptions raised by postgresql when validation of the arguments failed.""" @contextlib.contextmanager def convert_validation_exceptions(): """Catches postgresql errors related to invalid arguments, and re-raises a StorageArgumentException.""" try: yield - except VALIDATION_EXCEPTIONS as e: - raise StorageArgumentException(*e.args) + except tuple(VALIDATION_EXCEPTIONS) as e: + raise StorageArgumentException(str(e)) class Storage(): """SWH storage proxy, encompassing DB and object storage """ def __init__(self, db, objstorage, min_pool_conns=1, max_pool_conns=10, journal_writer=None): """ Args: db_conn: either a libpq connection string, or a psycopg2 connection obj_root: path to the root of the object storage """ try: if isinstance(db, psycopg2.extensions.connection): self._pool = None self._db = Db(db) else: self._pool = psycopg2.pool.ThreadedConnectionPool( min_pool_conns, max_pool_conns, db ) self._db = None except psycopg2.OperationalError as e: raise StorageDBError(e) self.journal_writer = JournalWriter(journal_writer) self.objstorage = ObjStorage(objstorage) def get_db(self): if self._db: return self._db else: return Db.from_pool(self._pool) def put_db(self, db): if db is not self._db: db.put_conn() @contextmanager def db(self): db = None try: db = self.get_db() yield db finally: if db: self.put_db(db) @timed @db_transaction() def check_config(self, *, check_write, db=None, cur=None): if not self.objstorage.check_config(check_write=check_write): return False # Check permissions on one of the tables if check_write: check = 'INSERT' else: check = 'SELECT' cur.execute( "select has_table_privilege(current_user, 'content', %s)", (check,) ) return cur.fetchone()[0] def _content_unique_key(self, hash, db): """Given a hash (tuple or dict), return a unique key from the aggregation of keys. """ keys = db.content_hash_keys if isinstance(hash, tuple): return hash return tuple([hash[k] for k in keys]) def _content_add_metadata(self, db, cur, content): """Add content to the postgresql database but not the object storage. """ # create temporary table for metadata injection db.mktemp('content', cur) db.copy_to((c.to_dict() for c in content), 'tmp_content', db.content_add_keys, cur) # move metadata in place try: db.content_add_from_temp(cur) except psycopg2.IntegrityError as e: if e.diag.sqlstate == '23505' and \ e.diag.table_name == 'content': message_detail = e.diag.message_detail if message_detail: hash_name, hash_id = extract_collision_hash(message_detail) collision_contents_hashes = [ c.hashes() for c in content if c.get_hash(hash_name) == hash_id ] else: constraint_to_hash_name = { 'content_pkey': 'sha1', 'content_sha1_git_idx': 'sha1_git', 'content_sha256_idx': 'sha256', } hash_name = constraint_to_hash_name \ .get(e.diag.constraint_name) hash_id = None collision_contents_hashes = None raise HashCollision( hash_name, hash_id, collision_contents_hashes ) from None else: raise @timed @process_metrics def content_add( self, content: Iterable[Content]) -> Dict: contents = [attr.evolve(c, ctime=now()) for c in content] objstorage_summary = self.objstorage.content_add(contents) with self.db() as db: with db.transaction() as cur: missing = list(self.content_missing( map(Content.to_dict, contents), key_hash='sha1_git', db=db, cur=cur, )) contents = [c for c in contents if c.sha1_git in missing] self.journal_writer.content_add(contents) self._content_add_metadata(db, cur, contents) return { 'content:add': len(contents), 'content:add:bytes': objstorage_summary['content:add:bytes'], } @timed @db_transaction() def content_update(self, content, keys=[], db=None, cur=None): # TODO: Add a check on input keys. How to properly implement # this? We don't know yet the new columns. self.journal_writer.content_update(content) db.mktemp('content', cur) select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) with convert_validation_exceptions(): db.copy_to(content, 'tmp_content', select_keys, cur) db.content_update_from_temp(keys_to_update=keys, cur=cur) @timed @process_metrics @db_transaction() def content_add_metadata(self, content: Iterable[Content], db=None, cur=None) -> Dict: contents = list(content) missing = self.content_missing( (c.to_dict() for c in contents), key_hash='sha1_git', db=db, cur=cur, ) contents = [c for c in contents if c.sha1_git in missing] self.journal_writer.content_add_metadata(contents) self._content_add_metadata(db, cur, contents) return { 'content:add': len(contents), } @timed def content_get(self, content): # FIXME: Make this method support slicing the `data`. if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: raise StorageArgumentException( "Send at maximum %s contents." % BULK_BLOCK_CONTENT_LEN_MAX) yield from self.objstorage.content_get(content) @timed @db_transaction() def content_get_range(self, start, end, limit=1000, db=None, cur=None): if limit is None: raise StorageArgumentException('limit should not be None') contents = [] next_content = None for counter, content_row in enumerate( db.content_get_range(start, end, limit+1, cur)): content = dict(zip(db.content_get_metadata_keys, content_row)) if counter >= limit: # take the last commit for the next page starting from this next_content = content['sha1'] break contents.append(content) return { 'contents': contents, 'next': next_content, } @timed def content_get_partition( self, partition_id: int, nb_partitions: int, limit: int = 1000, page_token: str = None): if limit is None: raise StorageArgumentException('limit should not be None') (start, end) = get_partition_bounds_bytes( partition_id, nb_partitions, SHA1_SIZE) if page_token: start = hash_to_bytes(page_token) if end is None: end = b'\xff'*SHA1_SIZE result = self.content_get_range(start, end, limit) result2 = { 'contents': result['contents'], 'next_page_token': None, } if result['next']: result2['next_page_token'] = hash_to_hex(result['next']) return result2 @timed @db_transaction(statement_timeout=500) def content_get_metadata( self, contents: List[bytes], db=None, cur=None) -> Dict[bytes, List[Dict]]: result: Dict[bytes, List[Dict]] = {sha1: [] for sha1 in contents} for row in db.content_get_metadata_from_sha1s(contents, cur): content_meta = dict(zip(db.content_get_metadata_keys, row)) result[content_meta['sha1']].append(content_meta) return result @timed @db_transaction_generator() def content_missing(self, content, key_hash='sha1', db=None, cur=None): keys = db.content_hash_keys if key_hash not in keys: raise StorageArgumentException( "key_hash should be one of %s" % keys) key_hash_idx = keys.index(key_hash) if not content: return for obj in db.content_missing_from_list(content, cur): yield obj[key_hash_idx] @timed @db_transaction_generator() def content_missing_per_sha1(self, contents, db=None, cur=None): for obj in db.content_missing_per_sha1(contents, cur): yield obj[0] @timed @db_transaction_generator() def content_missing_per_sha1_git(self, contents, db=None, cur=None): for obj in db.content_missing_per_sha1_git(contents, cur): yield obj[0] @timed @db_transaction() def content_find(self, content, db=None, cur=None): if not set(content).intersection(DEFAULT_ALGORITHMS): raise StorageArgumentException( 'content keys must contain at least one of: ' 'sha1, sha1_git, sha256, blake2s256') contents = db.content_find(sha1=content.get('sha1'), sha1_git=content.get('sha1_git'), sha256=content.get('sha256'), blake2s256=content.get('blake2s256'), cur=cur) return [dict(zip(db.content_find_cols, content)) for content in contents] @timed @db_transaction() def content_get_random(self, db=None, cur=None): return db.content_get_random(cur) @staticmethod def _skipped_content_normalize(d): d = d.copy() if d.get('status') is None: d['status'] = 'absent' if d.get('length') is None: d['length'] = -1 return d @staticmethod def _skipped_content_validate(d): """Sanity checks on status / reason / length, that postgresql doesn't enforce.""" if d['status'] != 'absent': raise StorageArgumentException( 'Invalid content status: {}'.format(d['status'])) if d.get('reason') is None: raise StorageArgumentException( 'Must provide a reason if content is absent.') if d['length'] < -1: raise StorageArgumentException( 'Content length must be positive or -1.') def _skipped_content_add_metadata( self, db, cur, content: Iterable[SkippedContent]): origin_ids = db.origin_id_get_by_url( [cont.origin for cont in content], cur=cur) content = [attr.evolve(c, origin=origin_id) for (c, origin_id) in zip(content, origin_ids)] db.mktemp('skipped_content', cur) db.copy_to([c.to_dict() for c in content], 'tmp_skipped_content', db.skipped_content_keys, cur) # move metadata in place db.skipped_content_add_from_temp(cur) @timed @process_metrics @db_transaction() def skipped_content_add(self, content: Iterable[SkippedContent], db=None, cur=None) -> Dict: content = [attr.evolve(c, ctime=now()) for c in content] missing_contents = self.skipped_content_missing( (c.to_dict() for c in content), db=db, cur=cur, ) content = [c for c in content if any(all(c.get_hash(algo) == missing_content.get(algo) for algo in DEFAULT_ALGORITHMS) for missing_content in missing_contents)] self.journal_writer.skipped_content_add(content) self._skipped_content_add_metadata(db, cur, content) return { 'skipped_content:add': len(content), } @timed @db_transaction_generator() def skipped_content_missing(self, contents, db=None, cur=None): contents = list(contents) for content in db.skipped_content_missing(contents, cur): yield dict(zip(db.content_hash_keys, content)) @timed @process_metrics @db_transaction() def directory_add(self, directories: Iterable[Directory], db=None, cur=None) -> Dict: directories = list(directories) summary = {'directory:add': 0} dirs = set() dir_entries: Dict[str, defaultdict] = { 'file': defaultdict(list), 'dir': defaultdict(list), 'rev': defaultdict(list), } for cur_dir in directories: dir_id = cur_dir.id dirs.add(dir_id) for src_entry in cur_dir.entries: entry = src_entry.to_dict() entry['dir_id'] = dir_id dir_entries[entry['type']][dir_id].append(entry) dirs_missing = set(self.directory_missing(dirs, db=db, cur=cur)) if not dirs_missing: return summary self.journal_writer.directory_add( dir_ for dir_ in directories if dir_.id in dirs_missing ) # Copy directory ids dirs_missing_dict = ({'id': dir} for dir in dirs_missing) db.mktemp('directory', cur) db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) # Copy entries for entry_type, entry_list in dir_entries.items(): entries = itertools.chain.from_iterable( entries_for_dir for dir_id, entries_for_dir in entry_list.items() if dir_id in dirs_missing) db.mktemp_dir_entry(entry_type) db.copy_to( entries, 'tmp_directory_entry_%s' % entry_type, ['target', 'name', 'perms', 'dir_id'], cur, ) # Do the final copy db.directory_add_from_temp(cur) summary['directory:add'] = len(dirs_missing) return summary @timed @db_transaction_generator() def directory_missing(self, directories, db=None, cur=None): for obj in db.directory_missing_from_list(directories, cur): yield obj[0] @timed @db_transaction_generator(statement_timeout=20000) def directory_ls(self, directory, recursive=False, db=None, cur=None): if recursive: res_gen = db.directory_walk(directory, cur=cur) else: res_gen = db.directory_walk_one(directory, cur=cur) for line in res_gen: yield dict(zip(db.directory_ls_cols, line)) @timed @db_transaction(statement_timeout=2000) def directory_entry_get_by_path(self, directory, paths, db=None, cur=None): res = db.directory_entry_get_by_path(directory, paths, cur) if res: return dict(zip(db.directory_ls_cols, res)) @timed @db_transaction() def directory_get_random(self, db=None, cur=None): return db.directory_get_random(cur) @timed @process_metrics @db_transaction() def revision_add(self, revisions: Iterable[Revision], db=None, cur=None) -> Dict: revisions = list(revisions) summary = {'revision:add': 0} revisions_missing = set(self.revision_missing( set(revision.id for revision in revisions), db=db, cur=cur)) if not revisions_missing: return summary db.mktemp_revision(cur) revisions_filtered = [ revision for revision in revisions if revision.id in revisions_missing] self.journal_writer.revision_add(revisions_filtered) revisions_filtered = \ list(map(converters.revision_to_db, revisions_filtered)) parents_filtered: List[bytes] = [] with convert_validation_exceptions(): db.copy_to( revisions_filtered, 'tmp_revision', db.revision_add_cols, cur, lambda rev: parents_filtered.extend(rev['parents'])) db.revision_add_from_temp(cur) db.copy_to(parents_filtered, 'revision_history', ['id', 'parent_id', 'parent_rank'], cur) return {'revision:add': len(revisions_missing)} @timed @db_transaction_generator() def revision_missing(self, revisions, db=None, cur=None): if not revisions: return for obj in db.revision_missing_from_list(revisions, cur): yield obj[0] @timed @db_transaction_generator(statement_timeout=1000) def revision_get(self, revisions, db=None, cur=None): for line in db.revision_get_from_list(revisions, cur): data = converters.db_to_revision( dict(zip(db.revision_get_cols, line)) ) if not data['type']: yield None continue yield data @timed @db_transaction_generator(statement_timeout=2000) def revision_log(self, revisions, limit=None, db=None, cur=None): for line in db.revision_log(revisions, limit, cur): data = converters.db_to_revision( dict(zip(db.revision_get_cols, line)) ) if not data['type']: yield None continue yield data @timed @db_transaction_generator(statement_timeout=2000) def revision_shortlog(self, revisions, limit=None, db=None, cur=None): yield from db.revision_shortlog(revisions, limit, cur) @timed @db_transaction() def revision_get_random(self, db=None, cur=None): return db.revision_get_random(cur) @timed @process_metrics @db_transaction() def release_add( self, releases: Iterable[Release], db=None, cur=None) -> Dict: releases = list(releases) summary = {'release:add': 0} release_ids = set(release.id for release in releases) releases_missing = set(self.release_missing(release_ids, db=db, cur=cur)) if not releases_missing: return summary db.mktemp_release(cur) releases_filtered = [ release for release in releases if release.id in releases_missing ] self.journal_writer.release_add(releases_filtered) releases_filtered = \ list(map(converters.release_to_db, releases_filtered)) with convert_validation_exceptions(): db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, cur) db.release_add_from_temp(cur) return {'release:add': len(releases_missing)} @timed @db_transaction_generator() def release_missing(self, releases, db=None, cur=None): if not releases: return for obj in db.release_missing_from_list(releases, cur): yield obj[0] @timed @db_transaction_generator(statement_timeout=500) def release_get(self, releases, db=None, cur=None): for release in db.release_get_from_list(releases, cur): data = converters.db_to_release( dict(zip(db.release_get_cols, release)) ) yield data if data['target_type'] else None @timed @db_transaction() def release_get_random(self, db=None, cur=None): return db.release_get_random(cur) @timed @process_metrics @db_transaction() def snapshot_add( self, snapshots: Iterable[Snapshot], db=None, cur=None) -> Dict: created_temp_table = False count = 0 for snapshot in snapshots: if not db.snapshot_exists(snapshot.id, cur): if not created_temp_table: db.mktemp_snapshot_branch(cur) created_temp_table = True - try: + with convert_validation_exceptions(): db.copy_to( ( { 'name': name, 'target': info.target if info else None, 'target_type': (info.target_type.value if info else None), } for name, info in snapshot.branches.items() ), 'tmp_snapshot_branch', ['name', 'target', 'target_type'], cur, ) - except VALIDATION_EXCEPTIONS + (KeyError,) as e: - raise StorageArgumentException(*e.args) self.journal_writer.snapshot_add(snapshot) db.snapshot_add(snapshot.id, cur) count += 1 return {'snapshot:add': count} @timed @db_transaction_generator() def snapshot_missing(self, snapshots, db=None, cur=None): for obj in db.snapshot_missing_from_list(snapshots, cur): yield obj[0] @timed @db_transaction(statement_timeout=2000) def snapshot_get(self, snapshot_id, db=None, cur=None): return self.snapshot_get_branches(snapshot_id, db=db, cur=cur) @timed @db_transaction(statement_timeout=2000) def snapshot_get_by_origin_visit(self, origin, visit, db=None, cur=None): snapshot_id = db.snapshot_get_by_origin_visit(origin, visit, cur) if snapshot_id: return self.snapshot_get(snapshot_id, db=db, cur=cur) return None @timed @db_transaction(statement_timeout=4000) def snapshot_get_latest(self, origin, allowed_statuses=None, db=None, cur=None): if isinstance(origin, int): origin = self.origin_get({'id': origin}, db=db, cur=cur) if not origin: return origin = origin['url'] origin_visit = self.origin_visit_get_latest( origin, allowed_statuses=allowed_statuses, require_snapshot=True, db=db, cur=cur) if origin_visit and origin_visit['snapshot']: snapshot = self.snapshot_get( origin_visit['snapshot'], db=db, cur=cur) if not snapshot: raise StorageArgumentException( 'last origin visit references an unknown snapshot') return snapshot @timed @db_transaction(statement_timeout=2000) def snapshot_count_branches(self, snapshot_id, db=None, cur=None): return dict([bc for bc in db.snapshot_count_branches(snapshot_id, cur)]) @timed @db_transaction(statement_timeout=2000) def snapshot_get_branches(self, snapshot_id, branches_from=b'', branches_count=1000, target_types=None, db=None, cur=None): if snapshot_id == EMPTY_SNAPSHOT_ID: return { 'id': snapshot_id, 'branches': {}, 'next_branch': None, } branches = {} next_branch = None fetched_branches = list(db.snapshot_get_by_id( snapshot_id, branches_from=branches_from, branches_count=branches_count+1, target_types=target_types, cur=cur, )) for branch in fetched_branches[:branches_count]: branch = dict(zip(db.snapshot_get_cols, branch)) del branch['snapshot_id'] name = branch.pop('name') if branch == {'target': None, 'target_type': None}: branch = None branches[name] = branch if len(fetched_branches) > branches_count: branch = dict(zip(db.snapshot_get_cols, fetched_branches[-1])) next_branch = branch['name'] if branches: return { 'id': snapshot_id, 'branches': branches, 'next_branch': next_branch, } return None @timed @db_transaction() def snapshot_get_random(self, db=None, cur=None): return db.snapshot_get_random(cur) @timed @db_transaction() def origin_visit_add( self, origin_url: str, date: Union[str, datetime.datetime], type: str, db=None, cur=None) -> OriginVisit: if isinstance(date, str): # FIXME: Converge on iso8601 at some point date = dateutil.parser.parse(date) elif not isinstance(date, datetime.datetime): raise StorageArgumentException( 'Date must be a datetime or a string') origin = self.origin_get({'url': origin_url}, db=db, cur=cur) if not origin: # Cannot add a visit without an origin raise StorageArgumentException( 'Unknown origin %s', origin_url) with convert_validation_exceptions(): visit_id = db.origin_visit_add(origin_url, date, type, cur=cur) # We can write to the journal only after inserting to the # DB, because we want the id of the visit visit = OriginVisit.from_dict({ 'origin': origin_url, 'date': date, 'type': type, 'visit': visit_id, 'status': 'ongoing', 'metadata': None, 'snapshot': None }) self.journal_writer.origin_visit_add(visit) send_metric('origin_visit:add', count=1, method_name='origin_visit') return visit @timed @db_transaction() def origin_visit_update(self, origin: str, visit_id: int, status: str, metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None, date: Optional[datetime.datetime] = None, db=None, cur=None): if not isinstance(origin, str): raise StorageArgumentException( 'origin must be a string, not %r' % (origin,)) origin_url = origin visit = db.origin_visit_get(origin_url, visit_id, cur=cur) if not visit: raise StorageArgumentException('Invalid visit_id for this origin.') visit = dict(zip(db.origin_visit_get_cols, visit)) updates: Dict[str, Any] = { 'status': status, } if metadata and metadata != visit['metadata']: updates['metadata'] = metadata if snapshot and snapshot != visit['snapshot']: updates['snapshot'] = snapshot if updates: updated_visit = {**visit, **updates} self.journal_writer.origin_visit_update(updated_visit) with convert_validation_exceptions(): db.origin_visit_update(origin_url, visit_id, updates, cur) @timed @db_transaction() def origin_visit_upsert(self, visits: Iterable[OriginVisit], db=None, cur=None) -> None: self.journal_writer.origin_visit_upsert(visits) for visit in visits: # TODO: upsert them all in a single query db.origin_visit_upsert(visit, cur=cur) @timed @db_transaction_generator(statement_timeout=500) def origin_visit_get(self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None, db=None, cur=None): for line in db.origin_visit_get_all( origin, last_visit=last_visit, limit=limit, cur=cur): data = dict(zip(db.origin_visit_get_cols, line)) yield data @timed @db_transaction(statement_timeout=500) def origin_visit_find_by_date(self, origin, visit_date, db=None, cur=None): line = db.origin_visit_find_by_date(origin, visit_date, cur=cur) if line: return dict(zip(db.origin_visit_get_cols, line)) @timed @db_transaction(statement_timeout=500) def origin_visit_get_by(self, origin, visit, db=None, cur=None): ori_visit = db.origin_visit_get(origin, visit, cur) if not ori_visit: return None return dict(zip(db.origin_visit_get_cols, ori_visit)) @timed @db_transaction(statement_timeout=4000) def origin_visit_get_latest( self, origin, allowed_statuses=None, require_snapshot=False, db=None, cur=None): origin_visit = db.origin_visit_get_latest( origin, allowed_statuses=allowed_statuses, require_snapshot=require_snapshot, cur=cur) if origin_visit: return dict(zip(db.origin_visit_get_cols, origin_visit)) @timed @db_transaction() def origin_visit_get_random( self, type: str, db=None, cur=None) -> Optional[Dict[str, Any]]: result = db.origin_visit_get_random(type, cur) if result: return dict(zip(db.origin_visit_get_cols, result)) else: return None @timed @db_transaction(statement_timeout=2000) def object_find_by_sha1_git(self, ids, db=None, cur=None): ret = {id: [] for id in ids} for retval in db.object_find_by_sha1_git(ids, cur=cur): if retval[1]: ret[retval[0]].append(dict(zip(db.object_find_by_sha1_git_cols, retval))) return ret @timed @db_transaction(statement_timeout=500) def origin_get(self, origins, db=None, cur=None): if isinstance(origins, dict): # Old API return_single = True origins = [origins] elif len(origins) == 0: return [] else: return_single = False origin_urls = [origin['url'] for origin in origins] results = db.origin_get_by_url(origin_urls, cur) results = [dict(zip(db.origin_cols, result)) for result in results] if return_single: assert len(results) == 1 if results[0]['url'] is not None: return results[0] else: return None else: return [None if res['url'] is None else res for res in results] @timed @db_transaction_generator(statement_timeout=500) def origin_get_by_sha1(self, sha1s, db=None, cur=None): for line in db.origin_get_by_sha1(sha1s, cur): if line[0] is not None: yield dict(zip(db.origin_cols, line)) else: yield None @timed @db_transaction_generator() def origin_get_range(self, origin_from=1, origin_count=100, db=None, cur=None): for origin in db.origin_get_range(origin_from, origin_count, cur): yield dict(zip(db.origin_get_range_cols, origin)) @timed @db_transaction() def origin_list(self, page_token: Optional[str] = None, limit: int = 100, *, db=None, cur=None) -> dict: page_token = page_token or '0' if not isinstance(page_token, str): raise StorageArgumentException('page_token must be a string.') origin_from = int(page_token) result: Dict[str, Any] = { 'origins': [ dict(zip(db.origin_get_range_cols, origin)) for origin in db.origin_get_range(origin_from, limit, cur) ], } assert len(result['origins']) <= limit if len(result['origins']) == limit: result['next_page_token'] = str(result['origins'][limit-1]['id']+1) for origin in result['origins']: del origin['id'] return result @timed @db_transaction_generator() def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, db=None, cur=None): for origin in db.origin_search(url_pattern, offset, limit, regexp, with_visit, cur): yield dict(zip(db.origin_cols, origin)) @timed @db_transaction() def origin_count(self, url_pattern, regexp=False, with_visit=False, db=None, cur=None): return db.origin_count(url_pattern, regexp, with_visit, cur) @timed @db_transaction() def origin_add( self, origins: Iterable[Origin], db=None, cur=None) -> List[Dict]: origins = list(origins) for origin in origins: self.origin_add_one(origin, db=db, cur=cur) return [o.to_dict() for o in origins] @timed @db_transaction() def origin_add_one(self, origin: Origin, db=None, cur=None) -> str: origin_row = list(db.origin_get_by_url([origin.url], cur))[0] origin_url = dict(zip(db.origin_cols, origin_row))['url'] if origin_url: return origin_url self.journal_writer.origin_add_one(origin) url = db.origin_add(origin.url, cur) send_metric('origin:add', count=1, method_name='origin_add_one') return url @db_transaction(statement_timeout=500) def stat_counters(self, db=None, cur=None): return {k: v for (k, v) in db.stat_counters()} @db_transaction() def refresh_stat_counters(self, db=None, cur=None): keys = [ 'content', 'directory', 'directory_entry_dir', 'directory_entry_file', 'directory_entry_rev', 'origin', 'origin_visit', 'person', 'release', 'revision', 'revision_history', 'skipped_content', 'snapshot'] for key in keys: cur.execute('select * from swh_update_counter(%s)', (key,)) @timed @db_transaction() def origin_metadata_add(self, origin_url, ts, provider, tool, metadata, db=None, cur=None): if isinstance(ts, str): ts = dateutil.parser.parse(ts) db.origin_metadata_add(origin_url, ts, provider, tool, metadata, cur) send_metric( 'origin_metadata:add', count=1, method_name='origin_metadata_add') @timed @db_transaction_generator(statement_timeout=500) def origin_metadata_get_by(self, origin_url, provider_type=None, db=None, cur=None): for line in db.origin_metadata_get_by(origin_url, provider_type, cur): yield dict(zip(db.origin_metadata_get_cols, line)) @timed @db_transaction() def tool_add(self, tools, db=None, cur=None): db.mktemp_tool(cur) with convert_validation_exceptions(): db.copy_to(tools, 'tmp_tool', ['name', 'version', 'configuration'], cur) tools = db.tool_add_from_temp(cur) results = [dict(zip(db.tool_cols, line)) for line in tools] send_metric('tool:add', count=len(results), method_name='tool_add') return results @timed @db_transaction(statement_timeout=500) def tool_get(self, tool, db=None, cur=None): tool_conf = tool['configuration'] if isinstance(tool_conf, dict): tool_conf = json.dumps(tool_conf) idx = db.tool_get(tool['name'], tool['version'], tool_conf) if not idx: return None return dict(zip(db.tool_cols, idx)) @timed @db_transaction() def metadata_provider_add(self, provider_name, provider_type, provider_url, metadata, db=None, cur=None): result = db.metadata_provider_add(provider_name, provider_type, provider_url, metadata, cur) send_metric( 'metadata_provider:add', count=1, method_name='metadata_provider') return result @timed @db_transaction() def metadata_provider_get(self, provider_id, db=None, cur=None): result = db.metadata_provider_get(provider_id) if not result: return None return dict(zip(db.metadata_provider_cols, result)) @timed @db_transaction() def metadata_provider_get_by(self, provider, db=None, cur=None): result = db.metadata_provider_get_by(provider['provider_name'], provider['provider_url']) if not result: return None return dict(zip(db.metadata_provider_cols, result)) @timed def diff_directories(self, from_dir, to_dir, track_renaming=False): return diff.diff_directories(self, from_dir, to_dir, track_renaming) @timed def diff_revisions(self, from_rev, to_rev, track_renaming=False): return diff.diff_revisions(self, from_rev, to_rev, track_renaming) @timed def diff_revision(self, revision, track_renaming=False): return diff.diff_revision(self, revision, track_renaming) diff --git a/swh/storage/validate.py b/swh/storage/validate.py index 82b1792f..8f2ed1cd 100644 --- a/swh/storage/validate.py +++ b/swh/storage/validate.py @@ -1,102 +1,102 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import contextlib from typing import Dict, Iterable, List from swh.model.model import ( BaseModel, SkippedContent, Content, Directory, Revision, Release, Snapshot, OriginVisit, Origin ) from . import get_storage from .exc import StorageArgumentException -VALIDATION_EXCEPTIONS = ( +VALIDATION_EXCEPTIONS = [ KeyError, TypeError, ValueError, -) +] @contextlib.contextmanager def convert_validation_exceptions(): """Catches validation errors arguments, and re-raises a StorageArgumentException.""" try: yield - except VALIDATION_EXCEPTIONS as e: + except tuple(VALIDATION_EXCEPTIONS) as e: raise StorageArgumentException(str(e)) class ValidatingProxyStorage: """Storage implementation converts dictionaries to swh-model objects before calling its backend, and back to dicts before returning results """ def __init__(self, storage): self.storage = get_storage(**storage) def __getattr__(self, key): if key == 'storage': raise AttributeError(key) return getattr(self.storage, key) def content_add(self, content: Iterable[Dict]) -> Dict: with convert_validation_exceptions(): contents = [Content.from_dict(c) for c in content] return self.storage.content_add(contents) def content_add_metadata(self, content: Iterable[Dict]) -> Dict: with convert_validation_exceptions(): contents = [Content.from_dict(c) for c in content] return self.storage.content_add_metadata(contents) def skipped_content_add(self, content: Iterable[Dict]) -> Dict: with convert_validation_exceptions(): contents = [SkippedContent.from_dict(c) for c in content] return self.storage.skipped_content_add(contents) def directory_add(self, directories: Iterable[Dict]) -> Dict: with convert_validation_exceptions(): directories = [Directory.from_dict(d) for d in directories] return self.storage.directory_add(directories) def revision_add(self, revisions: Iterable[Dict]) -> Dict: with convert_validation_exceptions(): revisions = [Revision.from_dict(r) for r in revisions] return self.storage.revision_add(revisions) def release_add(self, releases: Iterable[Dict]) -> Dict: with convert_validation_exceptions(): releases = [Release.from_dict(r) for r in releases] return self.storage.release_add(releases) def snapshot_add(self, snapshots: Iterable[Dict]) -> Dict: with convert_validation_exceptions(): snapshots = [Snapshot.from_dict(s) for s in snapshots] return self.storage.snapshot_add(snapshots) def origin_visit_add( self, origin_url: str, date: datetime.datetime, type: str) -> Dict[str, BaseModel]: with convert_validation_exceptions(): visit = OriginVisit(origin=origin_url, date=date, type=type, status='ongoing', snapshot=None) return self.storage.origin_visit_add( visit.origin, visit.date, visit.type) def origin_add(self, origins: Iterable[Dict]) -> List[Dict]: with convert_validation_exceptions(): origins = [Origin.from_dict(o) for o in origins] return self.storage.origin_add(origins) def origin_add_one(self, origin: Dict) -> int: with convert_validation_exceptions(): origin = Origin.from_dict(origin) return self.storage.origin_add_one(origin)