diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py index e233e555..7af48053 100644 --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -1,1758 +1,1763 @@ # Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 from collections import defaultdict import datetime import itertools import operator import random import re from typing import ( Any, Callable, Counter, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, ) import attr from swh.core.api.classes import stream_results from swh.core.api.serializers import msgpack_dumps, msgpack_loads from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_hex from swh.model.model import ( Content, Directory, DirectoryEntry, ExtID, MetadataAuthority, MetadataAuthorityType, MetadataFetcher, Origin, OriginVisit, OriginVisitStatus, RawExtrinsicMetadata, Release, Revision, Sha1Git, SkippedContent, Snapshot, SnapshotBranch, TargetType, ) from swh.model.swhids import CoreSWHID, ExtendedObjectType, ExtendedSWHID from swh.model.swhids import ObjectType as SwhidObjectType from swh.storage.interface import ( VISIT_STATUSES, ListOrder, OriginVisitWithStatuses, PagedResult, PartialBranches, Sha1, ) from swh.storage.objstorage import ObjStorage from swh.storage.utils import map_optional, now from swh.storage.writer import JournalWriter from . import converters from ..exc import HashCollision, StorageArgumentException from ..utils import remove_keys from .common import TOKEN_BEGIN, TOKEN_END, hash_url from .cql import CqlRunner from .model import ( ContentRow, DirectoryEntryRow, DirectoryRow, ExtIDByTargetRow, ExtIDRow, MetadataAuthorityRow, MetadataFetcherRow, OriginRow, OriginVisitRow, OriginVisitStatusRow, RawExtrinsicMetadataByIdRow, RawExtrinsicMetadataRow, RevisionParentRow, SkippedContentRow, SnapshotBranchRow, SnapshotRow, ) from .schema import HASH_ALGORITHMS # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 DIRECTORY_ENTRIES_INSERT_ALGOS = ["one-by-one", "concurrent", "batch"] class CassandraStorage: def __init__( self, hosts, keyspace, objstorage, port=9042, journal_writer=None, allow_overwrite=False, consistency_level="ONE", directory_entries_insert_algo="one-by-one", ): """ A backend of swh-storage backed by Cassandra Args: hosts: Seed Cassandra nodes, to start connecting to the cluster keyspace: Name of the Cassandra database to use objstorage: Passed as argument to :class:`ObjStorage` port: Cassandra port journal_writer: Passed as argument to :class:`JournalWriter` allow_overwrite: Whether ``*_add`` functions will check if an object already exists in the database before sending it in an INSERT. ``False`` is the default as it is more efficient when there is a moderately high probability the object is already known, but ``True`` can be useful to overwrite existing objects (eg. when applying a schema update), or when the database is known to be mostly empty. Note that a ``False`` value does not guarantee there won't be any overwrite. consistency_level: The default read/write consistency to use directory_entries_insert_algo: Must be one of: * one-by-one: naive, one INSERT per directory entry, serialized * concurrent: one INSERT per directory entry, concurrent * batch: using UNLOGGED BATCH to insert many entries in a few statements """ self._hosts = hosts self._keyspace = keyspace self._port = port self._consistency_level = consistency_level self._set_cql_runner() self.journal_writer: JournalWriter = JournalWriter(journal_writer) self.objstorage: ObjStorage = ObjStorage(objstorage) self._allow_overwrite = allow_overwrite if directory_entries_insert_algo not in DIRECTORY_ENTRIES_INSERT_ALGOS: raise ValueError( f"directory_entries_insert_algo must be one of: " f"{', '.join(DIRECTORY_ENTRIES_INSERT_ALGOS)}" ) self._directory_entries_insert_algo = directory_entries_insert_algo def _set_cql_runner(self): """Used by tests when they need to reset the CqlRunner""" self._cql_runner: CqlRunner = CqlRunner( self._hosts, self._keyspace, self._port, self._consistency_level ) def check_config(self, *, check_write: bool) -> bool: self._cql_runner.check_read() return True def _content_get_from_hashes(self, algo, hashes: List[bytes]) -> Iterable: """From the name of a hash algorithm and a value of that hash, looks up the "hash -> token" secondary table (content_by_{algo}) to get tokens. Then, looks up the main table (content) to get all contents with that token, and filters out contents whose hash doesn't match.""" found_tokens = list( self._cql_runner.content_get_tokens_from_single_algo(algo, hashes) ) assert all(isinstance(token, int) for token in found_tokens) # Query the main table ('content'). rows = self._cql_runner.content_get_from_tokens(found_tokens) for row in rows: # re-check the the hash (in case of murmur3 collision) if getattr(row, algo) in hashes: yield row def _content_add(self, contents: List[Content], with_data: bool) -> Dict[str, int]: # Filter-out content already in the database. if not self._allow_overwrite: contents = [ c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict()) ] if with_data: # First insert to the objstorage, if the endpoint is # `content_add` (as opposed to `content_add_metadata`). # Must add to the objstorage before the DB and journal. Otherwise: # 1. in case of a crash the DB may "believe" we have the content, but # we didn't have time to write to the objstorage before the crash # 2. the objstorage mirroring, which reads from the journal, may attempt to # read from the objstorage before we finished writing it summary = self.objstorage.content_add( c for c in contents if c.status != "absent" ) content_add_bytes = summary["content:add:bytes"] self.journal_writer.content_add(contents) content_add = 0 for content in contents: content_add += 1 # Check for sha1 or sha1_git collisions. This test is not atomic # with the insertion, so it won't detect a collision if both # contents are inserted at the same time, but it's good enough. # # The proper way to do it would probably be a BATCH, but this # would be inefficient because of the number of partitions we # need to affect (len(HASH_ALGORITHMS)+1, which is currently 5) if not self._allow_overwrite: for algo in {"sha1", "sha1_git"}: collisions = [] # Get tokens of 'content' rows with the same value for # sha1/sha1_git # TODO: batch these requests, instead of sending them one by one rows = self._content_get_from_hashes(algo, [content.get_hash(algo)]) for row in rows: if getattr(row, algo) != content.get_hash(algo): # collision of token(partition key), ignore this # row continue for other_algo in HASH_ALGORITHMS: if getattr(row, other_algo) != content.get_hash(other_algo): # This hash didn't match; discard the row. collisions.append( {k: getattr(row, k) for k in HASH_ALGORITHMS} ) if collisions: collisions.append(content.hashes()) raise HashCollision(algo, content.get_hash(algo), collisions) (token, insertion_finalizer) = self._cql_runner.content_add_prepare( ContentRow(**remove_keys(content.to_dict(), ("data",))) ) # Then add to index tables for algo in HASH_ALGORITHMS: self._cql_runner.content_index_add_one(algo, content, token) # Then to the main table insertion_finalizer() summary = { "content:add": content_add, } if with_data: summary["content:add:bytes"] = content_add_bytes return summary def content_add(self, content: List[Content]) -> Dict[str, int]: to_add = { (c.sha1, c.sha1_git, c.sha256, c.blake2s256): c for c in content }.values() contents = [attr.evolve(c, ctime=now()) for c in to_add] return self._content_add(list(contents), with_data=True) def content_update( self, contents: List[Dict[str, Any]], keys: List[str] = [] ) -> None: raise NotImplementedError( "content_update is not supported by the Cassandra backend" ) def content_add_metadata(self, content: List[Content]) -> Dict[str, int]: return self._content_add(content, with_data=False) def content_get_data(self, content: Sha1) -> Optional[bytes]: # FIXME: Make this method support slicing the `data` return self.objstorage.content_get(content) def content_get_partition( self, partition_id: int, nb_partitions: int, page_token: Optional[str] = None, limit: int = 1000, ) -> PagedResult[Content]: if limit is None: raise StorageArgumentException("limit should not be None") # Compute start and end of the range of tokens covered by the # requested partition partition_size = (TOKEN_END - TOKEN_BEGIN) // nb_partitions range_start = TOKEN_BEGIN + partition_id * partition_size range_end = TOKEN_BEGIN + (partition_id + 1) * partition_size # offset the range start according to the `page_token`. if page_token is not None: if not (range_start <= int(page_token) <= range_end): raise StorageArgumentException("Invalid page_token.") range_start = int(page_token) next_page_token: Optional[str] = None rows = self._cql_runner.content_get_token_range( range_start, range_end, limit + 1 ) contents = [] for counter, (tok, row) in enumerate(rows): if row.status == "absent": continue row_d = row.to_dict() if counter >= limit: next_page_token = str(tok) break row_d.pop("ctime") contents.append(Content(**row_d)) assert len(contents) <= limit return PagedResult(results=contents, next_page_token=next_page_token) def content_get( self, contents: List[bytes], algo: str = "sha1" ) -> List[Optional[Content]]: if algo not in DEFAULT_ALGORITHMS: raise StorageArgumentException( "algo should be one of {','.join(DEFAULT_ALGORITHMS)}" ) key = operator.attrgetter(algo) contents_by_hash: Dict[Sha1, Optional[Content]] = {} for row in self._content_get_from_hashes(algo, contents): # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1/sha1_git # matches the argument, from the index table ('content_by_*') row_d = row.to_dict() row_d.pop("ctime") content = Content(**row_d) contents_by_hash[key(content)] = content return [contents_by_hash.get(hash_) for hash_ in contents] def content_find(self, content: Dict[str, Any]) -> List[Content]: return self._content_find_many([content]) def _content_find_many(self, contents: List[Dict[str, Any]]) -> List[Content]: # Find an algorithm that is common to all the requested contents. # It will be used to do an initial filtering efficiently. # TODO: prioritize sha256, we can do more efficient lookups from this hash. filter_algos = set(HASH_ALGORITHMS) for content in contents: filter_algos &= set(content) if not filter_algos: raise StorageArgumentException( "content keys must contain at least one " f"of: {', '.join(sorted(HASH_ALGORITHMS))}" ) common_algo = list(filter_algos)[0] results = [] rows = self._content_get_from_hashes( common_algo, [content[common_algo] for content in contents] ) for row in rows: # Re-check all the hashes, in case of collisions (either of the # hash of the partition key, or the hashes in it) for content in contents: for algo in HASH_ALGORITHMS: if content.get(algo) and getattr(row, algo) != content[algo]: # This hash didn't match; discard the row. break else: # All hashes match, keep this row. row_d = row.to_dict() row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc) results.append(Content(**row_d)) break else: # No content matched; skip it pass return results def content_missing( self, contents: List[Dict[str, Any]], key_hash: str = "sha1" ) -> Iterable[bytes]: if key_hash not in DEFAULT_ALGORITHMS: raise StorageArgumentException( "key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}" ) contents_with_all_hashes = [] contents_with_missing_hashes = [] for content in contents: if DEFAULT_ALGORITHMS <= set(content): contents_with_all_hashes.append(content) else: contents_with_missing_hashes.append(content) # These contents can be queried efficiently directly in the main table for content in self._cql_runner.content_missing_from_all_hashes( contents_with_all_hashes ): yield content[key_hash] if contents_with_missing_hashes: # For these, we need the expensive index lookups + main table. # Get all contents in the database that match (at least) one of the # requested contents, concurrently. found_contents = self._content_find_many(contents_with_missing_hashes) # Bucket the known contents by hash found_contents_by_hash: Dict[str, Dict[str, list]] = { algo: defaultdict(list) for algo in DEFAULT_ALGORITHMS } for found_content in found_contents: for algo in DEFAULT_ALGORITHMS: found_contents_by_hash[algo][found_content.get_hash(algo)].append( found_content ) # For each of the requested contents, check if they are in the # 'found_contents' set (via 'found_contents_by_hash' for efficient access, # since we need to check using dict inclusion instead of hash+equality) for missing_content in contents_with_missing_hashes: # Pick any of the algorithms provided in missing_content algo = next(algo for (algo, hash_) in missing_content.items() if hash_) # Get the list of found_contents that match this hash in the # missing_content. (its length is at most 1, unless there is a # collision) found_contents_with_same_hash = found_contents_by_hash[algo][ missing_content[algo] ] # Check if there is a found_content that matches all hashes in the # missing_content. # This is functionally equivalent to 'for found_content in # found_contents', but runs almost in constant time (it is linear # in the number of hash collisions) instead of linear. # This allows this function to run in linear time overall instead of # quadratic. for found_content in found_contents_with_same_hash: # check if the found_content.hashes() dictionary contains a superset # of the (key, value) pairs in missing_content if missing_content.items() <= found_content.hashes().items(): # Found! break else: # Not found yield missing_content[key_hash] def content_missing_per_sha1(self, contents: List[bytes]) -> Iterable[bytes]: return self.content_missing([{"sha1": c} for c in contents]) def content_missing_per_sha1_git( self, contents: List[Sha1Git] ) -> Iterable[Sha1Git]: return self.content_missing( [{"sha1_git": c} for c in contents], key_hash="sha1_git" ) def content_get_random(self) -> Sha1Git: content = self._cql_runner.content_get_random() assert content, "Could not find any content" return content.sha1_git def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict[str, int]: # Filter-out content already in the database. if not self._allow_overwrite: contents = [ c for c in contents if not self._cql_runner.skipped_content_get_from_pk(c.to_dict()) ] self.journal_writer.skipped_content_add(contents) for content in contents: # Compute token of the row in the main table (token, insertion_finalizer) = self._cql_runner.skipped_content_add_prepare( SkippedContentRow.from_dict({"origin": None, **content.to_dict()}) ) # Then add to index tables for algo in HASH_ALGORITHMS: self._cql_runner.skipped_content_index_add_one(algo, content, token) # Then to the main table insertion_finalizer() return {"skipped_content:add": len(contents)} def skipped_content_add(self, content: List[SkippedContent]) -> Dict[str, int]: contents = [attr.evolve(c, ctime=now()) for c in content] return self._skipped_content_add(contents) def skipped_content_missing( self, contents: List[Dict[str, Any]] ) -> Iterable[Dict[str, Any]]: for content in contents: if not self._cql_runner.skipped_content_get_from_pk(content): yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS} def directory_add(self, directories: List[Directory]) -> Dict[str, int]: to_add = {d.id: d for d in directories}.values() if not self._allow_overwrite: # Filter out directories that are already inserted. missing = self.directory_missing([dir_.id for dir_ in to_add]) directories = [dir_ for dir_ in directories if dir_.id in missing] self.journal_writer.directory_add(directories) for directory in directories: # Add directory entries to the 'directory_entry' table rows = [ DirectoryEntryRow(directory_id=directory.id, **entry.to_dict()) for entry in directory.entries ] if self._directory_entries_insert_algo == "one-by-one": for row in rows: self._cql_runner.directory_entry_add_one(row) elif self._directory_entries_insert_algo == "concurrent": self._cql_runner.directory_entry_add_concurrent(rows) elif self._directory_entries_insert_algo == "batch": self._cql_runner.directory_entry_add_batch(rows) else: raise ValueError( f"Unexpected value for directory_entries_insert_algo: " f"{self._directory_entries_insert_algo}" ) # Add the directory *after* adding all the entries, so someone # calling snapshot_get_branch in the meantime won't end up # with half the entries. self._cql_runner.directory_add_one( DirectoryRow(id=directory.id, raw_manifest=directory.raw_manifest) ) return {"directory:add": len(directories)} def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.directory_missing(directories) def _join_dentry_to_content( self, dentry: DirectoryEntry, contents: List[Content] ) -> Dict[str, Any]: content: Union[None, Content, SkippedContentRow] keys = ( "status", "sha1", "sha1_git", "sha256", "length", ) ret = dict.fromkeys(keys) ret.update(dentry.to_dict()) if ret["type"] == "file": for content in contents: if dentry.target == content.sha1_git: break else: target = ret["target"] assert target is not None tokens = list( self._cql_runner.skipped_content_get_tokens_from_single_hash( "sha1_git", target ) ) if tokens: content = list( self._cql_runner.skipped_content_get_from_token(tokens[0]) )[0] else: content = None if content: for key in keys: ret[key] = getattr(content, key) return ret def _directory_ls( self, directory_id: Sha1Git, recursive: bool, prefix: bytes = b"" ) -> Iterable[Dict[str, Any]]: if self.directory_missing([directory_id]): return rows = list(self._cql_runner.directory_entry_get([directory_id])) # TODO: dedup to be fast in case the directory contains the same subdir/file # multiple times contents = self._content_find_many([{"sha1_git": row.target} for row in rows]) for row in rows: entry_d = row.to_dict() # Build and yield the directory entry dict del entry_d["directory_id"] entry = DirectoryEntry.from_dict(entry_d) ret = self._join_dentry_to_content(entry, contents) ret["name"] = prefix + ret["name"] ret["dir_id"] = directory_id yield ret if recursive and ret["type"] == "dir": yield from self._directory_ls( ret["target"], True, prefix + ret["name"] + b"/" ) def directory_entry_get_by_path( self, directory: Sha1Git, paths: List[bytes] ) -> Optional[Dict[str, Any]]: return self._directory_entry_get_by_path(directory, paths, b"") def _directory_entry_get_by_path( self, directory: Sha1Git, paths: List[bytes], prefix: bytes ) -> Optional[Dict[str, Any]]: if not paths: return None contents = list(self.directory_ls(directory)) if not contents: return None def _get_entry(entries, name): """Finds the entry with the requested name, prepends the prefix (to get its full path), and returns it. If no entry has that name, returns None.""" for entry in entries: if entry["name"] == name: entry = entry.copy() entry["name"] = prefix + entry["name"] return entry first_item = _get_entry(contents, paths[0]) if len(paths) == 1: return first_item if not first_item or first_item["type"] != "dir": return None return self._directory_entry_get_by_path( first_item["target"], paths[1:], prefix + paths[0] + b"/" ) def directory_ls( self, directory: Sha1Git, recursive: bool = False ) -> Iterable[Dict[str, Any]]: yield from self._directory_ls(directory, recursive) def directory_get_entries( self, directory_id: Sha1Git, page_token: Optional[bytes] = None, limit: int = 1000, ) -> Optional[PagedResult[DirectoryEntry]]: if self.directory_missing([directory_id]): return None entries_from: bytes = page_token or b"" rows = self._cql_runner.directory_entry_get_from_name( directory_id, entries_from, limit + 1 ) entries = [ DirectoryEntry.from_dict(remove_keys(row.to_dict(), ("directory_id",))) for row in rows ] if len(entries) > limit: last_entry = entries.pop() next_page_token = last_entry.name else: next_page_token = None return PagedResult(results=entries, next_page_token=next_page_token) def directory_get_raw_manifest( self, directory_ids: List[Sha1Git] ) -> Dict[Sha1Git, Optional[bytes]]: return { dir_.id: dir_.raw_manifest for dir_ in self._cql_runner.directory_get(directory_ids) } def directory_get_random(self) -> Sha1Git: directory = self._cql_runner.directory_get_random() assert directory, "Could not find any directory" return directory.id def revision_add(self, revisions: List[Revision]) -> Dict[str, int]: # Filter-out revisions already in the database if not self._allow_overwrite: to_add = {r.id: r for r in revisions}.values() missing = self.revision_missing([rev.id for rev in to_add]) revisions = [rev for rev in revisions if rev.id in missing] self.journal_writer.revision_add(revisions) for revision in revisions: revobject = converters.revision_to_db(revision) if revobject: # Add parents first for (rank, parent) in enumerate(revision.parents): self._cql_runner.revision_parent_add_one( RevisionParentRow( id=revobject.id, parent_rank=rank, parent_id=parent ) ) # Then write the main revision row. # Writing this after all parents were written ensures that # read endpoints don't return a partial view while writing # the parents self._cql_runner.revision_add_one(revobject) return {"revision:add": len(revisions)} def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.revision_missing(revisions) def revision_get( self, revision_ids: List[Sha1Git], ignore_displayname: bool = False ) -> List[Optional[Revision]]: rows = self._cql_runner.revision_get(revision_ids) revisions: Dict[Sha1Git, Revision] = {} for row in rows: # TODO: use a single query to get all parents? # (it might have lower latency, but requires more code and more # bandwidth, because revision id would be part of each returned # row) parents = tuple(self._cql_runner.revision_parent_get(row.id)) # parent_rank is the clustering key, so results are already # sorted by rank. rev = converters.revision_from_db(row, parents=parents) revisions[rev.id] = rev return [revisions.get(rev_id) for rev_id in revision_ids] def _get_parent_revs( self, rev_ids: Iterable[Sha1Git], seen: Set[Sha1Git], limit: Optional[int], short: bool, ) -> Union[ Iterable[Dict[str, Any]], Iterable[Tuple[Sha1Git, Tuple[Sha1Git, ...]]], ]: if limit and len(seen) >= limit: return rev_ids = [id_ for id_ in rev_ids if id_ not in seen] if not rev_ids: return seen |= set(rev_ids) # We need this query, even if short=True, to return consistent # results (ie. not return only a subset of a revision's parents # if it is being written) if short: ids = self._cql_runner.revision_get_ids(rev_ids) for id_ in ids: # TODO: use a single query to get all parents? # (it might have less latency, but requires less code and more # bandwidth (because revision id would be part of each returned # row) parents = tuple(self._cql_runner.revision_parent_get(id_)) # parent_rank is the clustering key, so results are already # sorted by rank. yield (id_, parents) yield from self._get_parent_revs(parents, seen, limit, short) else: rows = self._cql_runner.revision_get(rev_ids) for row in rows: # TODO: use a single query to get all parents? # (it might have less latency, but requires less code and more # bandwidth (because revision id would be part of each returned # row) parents = tuple(self._cql_runner.revision_parent_get(row.id)) # parent_rank is the clustering key, so results are already # sorted by rank. rev = converters.revision_from_db(row, parents=parents) yield rev.to_dict() yield from self._get_parent_revs(parents, seen, limit, short) def revision_log( self, revisions: List[Sha1Git], ignore_displayname: bool = False, limit: Optional[int] = None, ) -> Iterable[Optional[Dict[str, Any]]]: seen: Set[Sha1Git] = set() yield from self._get_parent_revs(revisions, seen, limit, False) def revision_shortlog( self, revisions: List[Sha1Git], limit: Optional[int] = None ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: seen: Set[Sha1Git] = set() yield from self._get_parent_revs(revisions, seen, limit, True) def revision_get_random(self) -> Sha1Git: revision = self._cql_runner.revision_get_random() assert revision, "Could not find any revision" return revision.id def release_add(self, releases: List[Release]) -> Dict[str, int]: if not self._allow_overwrite: to_add = {r.id: r for r in releases}.values() missing = set(self.release_missing([rel.id for rel in to_add])) releases = [rel for rel in to_add if rel.id in missing] self.journal_writer.release_add(releases) for release in releases: if release: self._cql_runner.release_add_one(converters.release_to_db(release)) return {"release:add": len(releases)} def release_missing(self, releases: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.release_missing(releases) def release_get( self, releases: List[Sha1Git], ignore_displayname: bool = False ) -> List[Optional[Release]]: rows = self._cql_runner.release_get(releases) rels: Dict[Sha1Git, Release] = {} for row in rows: release = converters.release_from_db(row) rels[row.id] = release return [rels.get(rel_id) for rel_id in releases] def release_get_random(self) -> Sha1Git: release = self._cql_runner.release_get_random() assert release, "Could not find any release" return release.id def snapshot_add(self, snapshots: List[Snapshot]) -> Dict[str, int]: if not self._allow_overwrite: to_add = {s.id: s for s in snapshots}.values() missing = self._cql_runner.snapshot_missing([snp.id for snp in to_add]) snapshots = [snp for snp in snapshots if snp.id in missing] for snapshot in snapshots: self.journal_writer.snapshot_add([snapshot]) # Add branches for (branch_name, branch) in snapshot.branches.items(): if branch is None: target_type: Optional[str] = None target: Optional[bytes] = None else: target_type = branch.target_type.value target = branch.target self._cql_runner.snapshot_branch_add_one( SnapshotBranchRow( snapshot_id=snapshot.id, name=branch_name, target_type=target_type, target=target, ) ) # Add the snapshot *after* adding all the branches, so someone # calling snapshot_get_branch in the meantime won't end up # with half the branches. self._cql_runner.snapshot_add_one(SnapshotRow(id=snapshot.id)) return {"snapshot:add": len(snapshots)} def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: return self._cql_runner.snapshot_missing(snapshots) def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: d = self.snapshot_get_branches(snapshot_id) if d is None: return None return { "id": d["id"], "branches": { name: branch.to_dict() if branch else None for (name, branch) in d["branches"].items() }, "next_branch": d["next_branch"], } def snapshot_count_branches( self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[bytes] = None, ) -> Optional[Dict[Optional[str], int]]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. return None return self._cql_runner.snapshot_count_branches( snapshot_id, branch_name_exclude_prefix ) def snapshot_get_branches( self, snapshot_id: Sha1Git, branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, branch_name_include_substring: Optional[bytes] = None, branch_name_exclude_prefix: Optional[bytes] = None, ) -> Optional[PartialBranches]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. return None branches: List = [] while len(branches) < branches_count + 1: new_branches = list( self._cql_runner.snapshot_branch_get( snapshot_id, branches_from, branches_count + 1, branch_name_exclude_prefix, ) ) if not new_branches: break branches_from = new_branches[-1].name new_branches_filtered = new_branches # Filter by target_type if target_types: new_branches_filtered = [ branch for branch in new_branches_filtered if branch.target is not None and branch.target_type in target_types ] # Filter by branches_name_pattern if branch_name_include_substring: new_branches_filtered = [ branch for branch in new_branches_filtered if branch.name is not None and ( branch_name_include_substring is None or branch_name_include_substring in branch.name ) ] branches.extend(new_branches_filtered) if len(new_branches) < branches_count + 1: break if len(branches) > branches_count: last_branch = branches.pop(-1).name else: last_branch = None return PartialBranches( id=snapshot_id, branches={ branch.name: None if branch.target is None else SnapshotBranch( target=branch.target, target_type=TargetType(branch.target_type) ) for branch in branches }, next_branch=last_branch, ) def snapshot_get_random(self) -> Sha1Git: snapshot = self._cql_runner.snapshot_get_random() assert snapshot, "Could not find any snapshot" return snapshot.id def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]: results: Dict[Sha1Git, List[Dict]] = {id_: [] for id_ in ids} missing_ids = set(ids) # Mind the order, revision is the most likely one for a given ID, # so we check revisions first. queries: List[Tuple[str, Callable[[List[Sha1Git]], Iterable[Sha1Git]]]] = [ ("revision", self._cql_runner.revision_missing), ("release", self._cql_runner.release_missing), ("content", self.content_missing_per_sha1_git), ("directory", self._cql_runner.directory_missing), ] for (object_type, query_fn) in queries: found_ids = missing_ids - set(query_fn(list(missing_ids))) for sha1_git in found_ids: results[sha1_git].append( { "sha1_git": sha1_git, "type": object_type, } ) missing_ids.remove(sha1_git) if not missing_ids: # We found everything, skipping the next queries. break return results def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: return [self.origin_get_one(origin) for origin in origins] def origin_get_one(self, origin_url: str) -> Optional[Origin]: """Given an origin url, return the origin if it exists, None otherwise""" rows = list(self._cql_runner.origin_get_by_url(origin_url)) if rows: assert len(rows) == 1 return Origin(url=rows[0].url) else: return None def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: results = [] for sha1 in sha1s: rows = list(self._cql_runner.origin_get_by_sha1(sha1)) origin = {"url": rows[0].url} if rows else None results.append(origin) return results def origin_list( self, page_token: Optional[str] = None, limit: int = 100 ) -> PagedResult[Origin]: # Compute what token to begin the listing from start_token = TOKEN_BEGIN if page_token: start_token = int(page_token) if not (TOKEN_BEGIN <= start_token <= TOKEN_END): raise StorageArgumentException("Invalid page_token.") next_page_token = None origins = [] # Take one more origin so we can reuse it as the next page token if any for (tok, row) in self._cql_runner.origin_list(start_token, limit + 1): origins.append(Origin(url=row.url)) # keep reference of the last id for pagination purposes last_id = tok if len(origins) > limit: # last origin id is the next page token next_page_token = str(last_id) # excluding that origin from the result to respect the limit size origins = origins[:limit] assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) def origin_search( self, url_pattern: str, page_token: Optional[str] = None, limit: int = 50, regexp: bool = False, with_visit: bool = False, visit_types: Optional[List[str]] = None, ) -> PagedResult[Origin]: # TODO: remove this endpoint, swh-search should be used instead. next_page_token = None offset = int(page_token) if page_token else 0 origin_rows = [row for row in self._cql_runner.origin_iter_all()] if regexp: pat = re.compile(url_pattern) origin_rows = [row for row in origin_rows if pat.search(row.url)] else: origin_rows = [row for row in origin_rows if url_pattern in row.url] if with_visit: origin_rows = [row for row in origin_rows if row.next_visit_id > 1] if visit_types: def _has_visit_types(origin, visit_types): for origin_visit in stream_results(self.origin_visit_get, origin): if origin_visit.type in visit_types: return True return False origin_rows = [ row for row in origin_rows if _has_visit_types(row.url, visit_types) ] origins = [Origin(url=row.url) for row in origin_rows] origins = origins[offset : offset + limit + 1] if len(origins) > limit: # next offset next_page_token = str(offset + limit) # excluding that origin from the result to respect the limit size origins = origins[:limit] assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) def origin_count( self, url_pattern: str, regexp: bool = False, with_visit: bool = False ) -> int: raise NotImplementedError( "The Cassandra backend does not implement origin_count" ) def origin_snapshot_get_all(self, origin_url: str) -> List[Sha1Git]: return list(self._cql_runner.origin_snapshot_get_all(origin_url)) def origin_add(self, origins: List[Origin]) -> Dict[str, int]: if not self._allow_overwrite: to_add = {o.url: o for o in origins}.values() origins = [ori for ori in to_add if self.origin_get_one(ori.url) is None] self.journal_writer.origin_add(origins) for origin in origins: self._cql_runner.origin_add_one( OriginRow(sha1=hash_url(origin.url), url=origin.url, next_visit_id=1) ) return {"origin:add": len(origins)} def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: for visit in visits: origin = self.origin_get_one(visit.origin) if not origin: # Cannot add a visit without an origin raise StorageArgumentException("Unknown origin %s", visit.origin) all_visits = [] for visit in visits: if visit.visit: # Set origin.next_visit_id = max(origin.next_visit_id, visit.visit+1) # so the next loader run does not reuse the id. self._cql_runner.origin_bump_next_visit_id(visit.origin, visit.visit) add_status = False else: visit_id = self._cql_runner.origin_generate_unique_visit_id( visit.origin ) visit = attr.evolve(visit, visit=visit_id) add_status = True self.journal_writer.origin_visit_add([visit]) self._cql_runner.origin_visit_add_one(OriginVisitRow(**visit.to_dict())) assert visit.visit is not None all_visits.append(visit) if add_status: self._origin_visit_status_add( OriginVisitStatus( origin=visit.origin, visit=visit.visit, date=visit.date, type=visit.type, status="created", snapshot=None, ) ) return all_visits def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None: """Add an origin visit status""" if visit_status.type is None: visit_row = self._cql_runner.origin_visit_get_one( visit_status.origin, visit_status.visit ) if visit_row is None: raise StorageArgumentException( f"Unknown origin visit {visit_status.visit} " f"of origin {visit_status.origin}" ) visit_status = attr.evolve(visit_status, type=visit_row.type) self.journal_writer.origin_visit_status_add([visit_status]) self._cql_runner.origin_visit_status_add_one( converters.visit_status_to_row(visit_status) ) def origin_visit_status_add( self, visit_statuses: List[OriginVisitStatus] ) -> Dict[str, int]: # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: origin_url = self.origin_get_one(visit_status.origin) if not origin_url: raise StorageArgumentException(f"Unknown origin {visit_status.origin}") for visit_status in visit_statuses: self._origin_visit_status_add(visit_status) return {"origin_visit_status:add": len(visit_statuses)} def _origin_visit_apply_status( self, visit: Dict[str, Any], visit_status: OriginVisitStatusRow ) -> Dict[str, Any]: """Retrieve the latest visit status information for the origin visit. Then merge it with the visit and return it. """ return { # default to the values in visit **visit, # override with the last update **visit_status.to_dict(), # visit['origin'] is the URL (via a join), while # visit_status['origin'] is only an id. "origin": visit["origin"], # but keep the date of the creation of the origin visit "date": visit["date"], # We use the visit type from origin visit # if it's not present on the origin visit status "type": visit_status.type or visit["type"], } def _origin_visit_get_latest_status(self, visit: OriginVisit) -> OriginVisitStatus: """Retrieve the latest visit status information for the origin visit object.""" assert visit.visit row = self._cql_runner.origin_visit_status_get_latest(visit.origin, visit.visit) assert row is not None visit_status = converters.row_to_visit_status(row) return attr.evolve(visit_status, origin=visit.origin) @staticmethod def _format_origin_visit_row(visit): return { **visit.to_dict(), "origin": visit.origin, "date": visit.date.replace(tzinfo=datetime.timezone.utc), } def origin_visit_get( self, origin: str, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisit]: if not isinstance(order, ListOrder): raise StorageArgumentException("order must be a ListOrder value") if page_token and not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") next_page_token = None visit_from = None if page_token is None else int(page_token) visits: List[OriginVisit] = [] extra_limit = limit + 1 rows = self._cql_runner.origin_visit_get(origin, visit_from, extra_limit, order) for row in rows: visits.append(converters.row_to_visit(row)) assert len(visits) <= extra_limit if len(visits) == extra_limit: visits = visits[:limit] next_page_token = str(visits[-1].visit) return PagedResult(results=visits, next_page_token=next_page_token) def origin_visit_get_with_statuses( self, origin: str, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisitWithStatuses]: next_page_token = None visit_from = None if page_token is None else int(page_token) extra_limit = limit + 1 # First get visits (plus one so we can use it as the next page token if any) rows = self._cql_runner.origin_visit_get(origin, visit_from, extra_limit, order) visits: List[OriginVisit] = [converters.row_to_visit(row) for row in rows] if visits: assert visits[0].visit is not None assert visits[-1].visit is not None visit_from = min(visits[0].visit, visits[-1].visit) visit_to = max(visits[0].visit, visits[-1].visit) # Then, fetch all statuses associated to these visits statuses_rows = self._cql_runner.origin_visit_status_get_all_range( origin, visit_from, visit_to ) visit_statuses: Dict[int, List[OriginVisitStatus]] = defaultdict(list) for status_row in statuses_rows: if allowed_statuses and status_row.status not in allowed_statuses: continue if require_snapshot and status_row.snapshot is None: continue visit_status = converters.row_to_visit_status(status_row) visit_statuses[visit_status.visit].append(visit_status) # Add pagination if there are more visits assert len(visits) <= extra_limit if len(visits) == extra_limit: # excluding that visit from the result to respect the limit size visits = visits[:limit] # last visit id is the next page token next_page_token = str(visits[-1].visit) results = [ OriginVisitWithStatuses(visit=visit, statuses=visit_statuses[visit.visit]) for visit in visits if visit.visit is not None ] return PagedResult(results=results, next_page_token=next_page_token) def origin_visit_status_get( self, origin: str, visit: int, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisitStatus]: next_page_token = None date_from = None if page_token is not None: - date_from = datetime.datetime.fromisoformat(page_token) + try: + date_from = datetime.datetime.fromisoformat(page_token) + except ValueError: + raise StorageArgumentException( + "Invalid page_token argument to origin_visit_status_get." + ) from None # Take one more visit status so we can reuse it as the next page token if any rows = self._cql_runner.origin_visit_status_get_range( origin, visit, date_from, limit + 1, order ) visit_statuses = [converters.row_to_visit_status(row) for row in rows] if len(visit_statuses) > limit: # last visit status date is the next page token next_page_token = str(visit_statuses[-1].date) # excluding that visit status from the result to respect the limit size visit_statuses = visit_statuses[:limit] return PagedResult(results=visit_statuses, next_page_token=next_page_token) def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime ) -> Optional[OriginVisit]: # Iterator over all the visits of the origin # This should be ok for now, as there aren't too many visits # per origin. rows = list(self._cql_runner.origin_visit_iter_all(origin)) def key(visit): dt = visit.date.replace(tzinfo=datetime.timezone.utc) - visit_date return (abs(dt), -visit.visit) if rows: return converters.row_to_visit(min(rows, key=key)) return None def origin_visit_get_by(self, origin: str, visit: int) -> Optional[OriginVisit]: row = self._cql_runner.origin_visit_get_one(origin, visit) if row: return converters.row_to_visit(row) return None def origin_visit_get_latest( self, origin: str, type: Optional[str] = None, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisit]: if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): raise StorageArgumentException( f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) rows = self._cql_runner.origin_visit_iter_all(origin) for row in rows: visit = self._format_origin_visit_row(row) for status_row in self._cql_runner.origin_visit_status_get( origin, visit["visit"] ): updated_visit = self._origin_visit_apply_status(visit, status_row) if type is not None and updated_visit["type"] != type: continue if allowed_statuses and updated_visit["status"] not in allowed_statuses: continue if require_snapshot and updated_visit["snapshot"] is None: continue return OriginVisit( origin=visit["origin"], visit=visit["visit"], date=visit["date"], type=visit["type"], ) return None def origin_visit_status_get_latest( self, origin_url: str, visit: int, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisitStatus]: if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): raise StorageArgumentException( f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) rows = list(self._cql_runner.origin_visit_status_get(origin_url, visit)) # filtering is done python side as we cannot do it server side if allowed_statuses: rows = [row for row in rows if row.status in allowed_statuses] if require_snapshot: rows = [row for row in rows if row.snapshot is not None] if not rows: return None return converters.row_to_visit_status(rows[0]) def origin_visit_status_get_random(self, type: str) -> Optional[OriginVisitStatus]: back_in_the_day = now() - datetime.timedelta(weeks=13) # 3 months back # Random position to start iteration at start_token = random.randint(TOKEN_BEGIN, TOKEN_END) # Iterator over all visits, ordered by token(origins) then visit_id rows = self._cql_runner.origin_visit_iter(start_token) for row in rows: visit = converters.row_to_visit(row) visit_status = self._origin_visit_get_latest_status(visit) if visit.date > back_in_the_day and visit_status.status == "full": return visit_status return None def stat_counters(self): rows = self._cql_runner.stat_counters() keys = ( "content", "directory", "origin", "origin_visit", "release", "revision", "skipped_content", "snapshot", ) stats = {key: 0 for key in keys} stats.update({row.object_type: row.count for row in rows}) return stats def refresh_stat_counters(self): pass def raw_extrinsic_metadata_add( self, metadata: List[RawExtrinsicMetadata] ) -> Dict[str, int]: self.journal_writer.raw_extrinsic_metadata_add(metadata) counter = Counter[ExtendedObjectType]() for metadata_entry in metadata: if not self._cql_runner.metadata_authority_get( metadata_entry.authority.type.value, metadata_entry.authority.url ): raise StorageArgumentException( f"Unknown authority {metadata_entry.authority}" ) if not self._cql_runner.metadata_fetcher_get( metadata_entry.fetcher.name, metadata_entry.fetcher.version ): raise StorageArgumentException( f"Unknown fetcher {metadata_entry.fetcher}" ) try: row = RawExtrinsicMetadataRow( id=metadata_entry.id, type=metadata_entry.target.object_type.name.lower(), target=str(metadata_entry.target), authority_type=metadata_entry.authority.type.value, authority_url=metadata_entry.authority.url, discovery_date=metadata_entry.discovery_date, fetcher_name=metadata_entry.fetcher.name, fetcher_version=metadata_entry.fetcher.version, format=metadata_entry.format, metadata=metadata_entry.metadata, origin=metadata_entry.origin, visit=metadata_entry.visit, snapshot=map_optional(str, metadata_entry.snapshot), release=map_optional(str, metadata_entry.release), revision=map_optional(str, metadata_entry.revision), path=metadata_entry.path, directory=map_optional(str, metadata_entry.directory), ) except TypeError as e: raise StorageArgumentException(*e.args) # Add to the index first self._cql_runner.raw_extrinsic_metadata_by_id_add( RawExtrinsicMetadataByIdRow( id=row.id, target=row.target, authority_type=row.authority_type, authority_url=row.authority_url, ) ) # Then to the main table self._cql_runner.raw_extrinsic_metadata_add(row) counter[metadata_entry.target.object_type] += 1 return { f"{type.value}_metadata:add": count for (type, count) in counter.items() } def raw_extrinsic_metadata_get( self, target: ExtendedSWHID, authority: MetadataAuthority, after: Optional[datetime.datetime] = None, page_token: Optional[bytes] = None, limit: int = 1000, ) -> PagedResult[RawExtrinsicMetadata]: if page_token is not None: (after_date, id_) = msgpack_loads(base64.b64decode(page_token)) if after and after_date < after: raise StorageArgumentException( "page_token is inconsistent with the value of 'after'." ) entries = self._cql_runner.raw_extrinsic_metadata_get_after_date_and_id( str(target), authority.type.value, authority.url, after_date, id_, ) elif after is not None: entries = self._cql_runner.raw_extrinsic_metadata_get_after_date( str(target), authority.type.value, authority.url, after ) else: entries = self._cql_runner.raw_extrinsic_metadata_get( str(target), authority.type.value, authority.url ) if limit: entries = itertools.islice(entries, 0, limit + 1) results = [] for entry in entries: assert str(target) == entry.target results.append(converters.row_to_raw_extrinsic_metadata(entry)) if len(results) > limit: results.pop() assert len(results) == limit last_result = results[-1] next_page_token: Optional[str] = base64.b64encode( msgpack_dumps( ( last_result.discovery_date, last_result.id, ) ) ).decode() else: next_page_token = None return PagedResult( next_page_token=next_page_token, results=results, ) def raw_extrinsic_metadata_get_by_ids( self, ids: List[Sha1Git] ) -> List[RawExtrinsicMetadata]: keys = self._cql_runner.raw_extrinsic_metadata_get_by_ids(ids) results: Set[RawExtrinsicMetadata] = set() for key in keys: candidates = self._cql_runner.raw_extrinsic_metadata_get( key.target, key.authority_type, key.authority_url ) candidates = [ candidate for candidate in candidates if candidate.id == key.id ] if len(candidates) > 1: raise Exception( "Found multiple RawExtrinsicMetadata objects with the same id: " + hash_to_hex(key.id) ) results.update(map(converters.row_to_raw_extrinsic_metadata, candidates)) return list(results) def raw_extrinsic_metadata_get_authorities( self, target: ExtendedSWHID ) -> List[MetadataAuthority]: return [ MetadataAuthority( type=MetadataAuthorityType(authority_type), url=authority_url ) for (authority_type, authority_url) in set( self._cql_runner.raw_extrinsic_metadata_get_authorities(str(target)) ) ] def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> Dict[str, int]: self.journal_writer.metadata_fetcher_add(fetchers) for fetcher in fetchers: self._cql_runner.metadata_fetcher_add( MetadataFetcherRow( name=fetcher.name, version=fetcher.version, ) ) return {"metadata_fetcher:add": len(fetchers)} def metadata_fetcher_get( self, name: str, version: str ) -> Optional[MetadataFetcher]: fetcher = self._cql_runner.metadata_fetcher_get(name, version) if fetcher: return MetadataFetcher( name=fetcher.name, version=fetcher.version, ) else: return None def metadata_authority_add( self, authorities: List[MetadataAuthority] ) -> Dict[str, int]: self.journal_writer.metadata_authority_add(authorities) for authority in authorities: self._cql_runner.metadata_authority_add( MetadataAuthorityRow( url=authority.url, type=authority.type.value, ) ) return {"metadata_authority:add": len(authorities)} def metadata_authority_get( self, type: MetadataAuthorityType, url: str ) -> Optional[MetadataAuthority]: authority = self._cql_runner.metadata_authority_get(type.value, url) if authority: return MetadataAuthority( type=MetadataAuthorityType(authority.type), url=authority.url, ) else: return None # ExtID tables def extid_add(self, ids: List[ExtID]) -> Dict[str, int]: if not self._allow_overwrite: extids = [ extid for extid in ids if not self._cql_runner.extid_get_from_pk( extid_type=extid.extid_type, extid_version=extid.extid_version, extid=extid.extid, target=extid.target, ) ] else: extids = list(ids) self.journal_writer.extid_add(extids) inserted = 0 for extid in extids: target_type = extid.target.object_type.value target = extid.target.object_id extid_version = extid.extid_version extid_type = extid.extid_type extidrow = ExtIDRow( extid_type=extid_type, extid_version=extid_version, extid=extid.extid, target_type=target_type, target=target, ) (token, insertion_finalizer) = self._cql_runner.extid_add_prepare(extidrow) indexrow = ExtIDByTargetRow( target_type=target_type, target=target, target_token=token, ) self._cql_runner.extid_index_add_one(indexrow) insertion_finalizer() inserted += 1 return {"extid:add": inserted} def extid_get_from_extid( self, id_type: str, ids: List[bytes], version: Optional[int] = None ) -> List[ExtID]: result: List[ExtID] = [] for extid in ids: if version is not None: extidrows = self._cql_runner.extid_get_from_extid_and_version( id_type, extid, version ) else: extidrows = self._cql_runner.extid_get_from_extid(id_type, extid) result.extend( ExtID( extid_type=extidrow.extid_type, extid_version=extidrow.extid_version, extid=extidrow.extid, target=CoreSWHID( object_type=extidrow.target_type, object_id=extidrow.target, ), ) for extidrow in extidrows ) return result def extid_get_from_target( self, target_type: SwhidObjectType, ids: List[Sha1Git], extid_type: Optional[str] = None, extid_version: Optional[int] = None, ) -> List[ExtID]: if (extid_version is not None and extid_type is None) or ( extid_version is None and extid_type is not None ): raise ValueError("You must provide both extid_type and extid_version") result: List[ExtID] = [] for target in ids: extidrows = self._cql_runner.extid_get_from_target( target_type.value, target, extid_type=extid_type, extid_version=extid_version, ) result.extend( ExtID( extid_type=extidrow.extid_type, extid_version=extidrow.extid_version, extid=extidrow.extid, target=CoreSWHID( object_type=SwhidObjectType(extidrow.target_type), object_id=extidrow.target, ), ) for extidrow in extidrows ) return result # Misc def clear_buffers(self, object_types: Sequence[str] = ()) -> None: """Do nothing""" return None def flush(self, object_types: Sequence[str] = ()) -> Dict[str, int]: return {} diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py index 6b1ce0d9..051aa5a3 100644 --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -1,1703 +1,1708 @@ # Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 from collections import defaultdict import contextlib from contextlib import contextmanager import datetime import itertools import logging import operator from typing import Any, Counter, Dict, Iterable, List, Optional, Sequence, Tuple import attr import psycopg2 import psycopg2.errors import psycopg2.pool from swh.core.api.serializers import msgpack_dumps, msgpack_loads from swh.core.db.common import db_transaction, db_transaction_generator from swh.core.db.db_utils import swh_db_flavor, swh_db_version from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex from swh.model.model import ( SHA1_SIZE, Content, Directory, DirectoryEntry, ExtID, MetadataAuthority, MetadataAuthorityType, MetadataFetcher, Origin, OriginVisit, OriginVisitStatus, RawExtrinsicMetadata, Release, Revision, Sha1, Sha1Git, SkippedContent, Snapshot, SnapshotBranch, TargetType, ) from swh.model.swhids import ExtendedObjectType, ExtendedSWHID, ObjectType from swh.storage.exc import HashCollision, StorageArgumentException, StorageDBError from swh.storage.interface import ( VISIT_STATUSES, ListOrder, OriginVisitWithStatuses, PagedResult, PartialBranches, ) from swh.storage.objstorage import ObjStorage from swh.storage.utils import ( extract_collision_hash, get_partition_bounds_bytes, map_optional, now, ) from swh.storage.writer import JournalWriter from . import converters from .db import Db logger = logging.getLogger(__name__) # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 EMPTY_SNAPSHOT_ID = hash_to_bytes("1a8893e6a86f444e8be8e7bda6cb34fb1735a00e") """Identifier for the empty snapshot""" VALIDATION_EXCEPTIONS = ( KeyError, TypeError, ValueError, psycopg2.errors.CheckViolation, psycopg2.errors.IntegrityError, psycopg2.errors.InvalidTextRepresentation, psycopg2.errors.NotNullViolation, psycopg2.errors.NumericValueOutOfRange, psycopg2.errors.UndefinedFunction, # (raised on wrong argument typs) psycopg2.errors.ProgramLimitExceeded, # typically person_name_idx ) """Exceptions raised by postgresql when validation of the arguments failed.""" @contextlib.contextmanager def convert_validation_exceptions(): """Catches postgresql errors related to invalid arguments, and re-raises a StorageArgumentException.""" try: yield except psycopg2.errors.UniqueViolation: # This only happens because of concurrent insertions, but it is # a subclass of IntegrityError; so we need to catch and reraise it # before the next clause converts it to StorageArgumentException. raise except VALIDATION_EXCEPTIONS as e: raise StorageArgumentException(str(e)) class Storage: """SWH storage datastore proxy, encompassing DB and object storage""" current_version: int = 186 def __init__( self, db, objstorage, min_pool_conns=1, max_pool_conns=10, journal_writer=None, query_options=None, ): """Instantiate a storage instance backed by a PostgreSQL database and an objstorage. When ``db`` is passed as a connection string, then this module automatically manages a connection pool between ``min_pool_conns`` and ``max_pool_conns``. When ``db`` is an explicit psycopg2 connection, then ``min_pool_conns`` and ``max_pool_conns`` are ignored and the connection is used directly. Args: db: either a libpq connection string, or a psycopg2 connection objstorage: configuration for the backend :class:`ObjStorage` min_pool_conns: min number of connections in the psycopg2 pool max_pool_conns: max number of connections in the psycopg2 pool journal_writer: configuration for the :class:`JournalWriter` query_options: configuration for the sql connections; keys of the dict are the method names decorated with :func:`db_transaction` or :func:`db_transaction_generator` (eg. :func:`content_find`), and values are dicts (config_name, config_value) used to configure the sql connection for the method_name. For example, using:: {"content_get": {"statement_timeout": 5000}} will override the default statement timeout for the :func:`content_get` endpoint from 500ms to 5000ms. See :mod:`swh.core.db.common` for more details. """ try: if isinstance(db, psycopg2.extensions.connection): self._pool = None self._db = Db(db) # See comment below self._db.cursor().execute("SET TIME ZONE 'UTC'") else: self._pool = psycopg2.pool.ThreadedConnectionPool( min_pool_conns, max_pool_conns, db ) self._db = None except psycopg2.OperationalError as e: raise StorageDBError(e) self.journal_writer = JournalWriter(journal_writer) self.objstorage = ObjStorage(objstorage) self.query_options = query_options self._flavor: Optional[str] = None def get_db(self): if self._db: return self._db else: db = Db.from_pool(self._pool) # Workaround for psycopg2 < 2.9.0 not handling fractional timezones, # which may happen on old revision/release dates on systems configured # with non-UTC timezones. # https://www.psycopg.org/docs/usage.html#time-zones-handling db.cursor().execute("SET TIME ZONE 'UTC'") return db def put_db(self, db): if db is not self._db: db.put_conn() @contextmanager def db(self): db = None try: db = self.get_db() yield db finally: if db: self.put_db(db) @db_transaction() def get_flavor(self, *, db: Db, cur=None) -> str: flavor = swh_db_flavor(db.conn.dsn) assert flavor is not None return flavor @property def flavor(self) -> str: if self._flavor is None: self._flavor = self.get_flavor() assert self._flavor is not None return self._flavor @db_transaction() def check_config(self, *, check_write: bool, db: Db, cur=None) -> bool: if not self.objstorage.check_config(check_write=check_write): return False dbversion = swh_db_version(db.conn.dsn) if dbversion != self.current_version: logger.warning( "database dbversion (%s) != %s current_version (%s)", dbversion, __name__, self.current_version, ) return False # Check permissions on one of the tables check = "INSERT" if check_write else "SELECT" cur.execute("select has_table_privilege(current_user, 'content', %s)", (check,)) return cur.fetchone()[0] def _content_unique_key(self, hash, db): """Given a hash (tuple or dict), return a unique key from the aggregation of keys. """ keys = db.content_hash_keys if isinstance(hash, tuple): return hash return tuple([hash[k] for k in keys]) def _content_add_metadata(self, db, cur, content): """Add content to the postgresql database but not the object storage.""" # create temporary table for metadata injection db.mktemp("content", cur) db.copy_to( (c.to_dict() for c in content), "tmp_content", db.content_add_keys, cur ) # move metadata in place try: db.content_add_from_temp(cur) except psycopg2.IntegrityError as e: if e.diag.sqlstate == "23505" and e.diag.table_name == "content": message_detail = e.diag.message_detail if message_detail: hash_name, hash_id = extract_collision_hash(message_detail) collision_contents_hashes = [ c.hashes() for c in content if c.get_hash(hash_name) == hash_id ] else: constraint_to_hash_name = { "content_pkey": "sha1", "content_sha1_git_idx": "sha1_git", "content_sha256_idx": "sha256", } hash_name = constraint_to_hash_name.get(e.diag.constraint_name) hash_id = None collision_contents_hashes = None raise HashCollision( hash_name, hash_id, collision_contents_hashes ) from None else: raise def content_add(self, content: List[Content]) -> Dict[str, int]: ctime = now() contents = [attr.evolve(c, ctime=ctime) for c in content] # Must add to the objstorage before the DB and journal. Otherwise: # 1. in case of a crash the DB may "believe" we have the content, but # we didn't have time to write to the objstorage before the crash # 2. the objstorage mirroring, which reads from the journal, may attempt to # read from the objstorage before we finished writing it objstorage_summary = self.objstorage.content_add(contents) with self.db() as db: with db.transaction() as cur: missing = list( self.content_missing( map(Content.to_dict, contents), key_hash="sha1_git", db=db, cur=cur, ) ) contents = [c for c in contents if c.sha1_git in missing] self.journal_writer.content_add(contents) self._content_add_metadata(db, cur, contents) return { "content:add": len(contents), "content:add:bytes": objstorage_summary["content:add:bytes"], } @db_transaction() def content_update( self, contents: List[Dict[str, Any]], keys: List[str] = [], *, db: Db, cur=None ) -> None: # TODO: Add a check on input keys. How to properly implement # this? We don't know yet the new columns. self.journal_writer.content_update(contents) db.mktemp("content", cur) select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) with convert_validation_exceptions(): db.copy_to(contents, "tmp_content", select_keys, cur) db.content_update_from_temp(keys_to_update=keys, cur=cur) @db_transaction() def content_add_metadata( self, content: List[Content], *, db: Db, cur=None ) -> Dict[str, int]: missing = self.content_missing( (c.to_dict() for c in content), key_hash="sha1_git", db=db, cur=cur, ) contents = [c for c in content if c.sha1_git in missing] self.journal_writer.content_add_metadata(contents) self._content_add_metadata(db, cur, contents) return { "content:add": len(contents), } def content_get_data(self, content: Sha1) -> Optional[bytes]: # FIXME: Make this method support slicing the `data` return self.objstorage.content_get(content) @db_transaction() def content_get_partition( self, partition_id: int, nb_partitions: int, page_token: Optional[str] = None, limit: int = 1000, *, db: Db, cur=None, ) -> PagedResult[Content]: if limit is None: raise StorageArgumentException("limit should not be None") (start, end) = get_partition_bounds_bytes( partition_id, nb_partitions, SHA1_SIZE ) if page_token: start = hash_to_bytes(page_token) if end is None: end = b"\xff" * SHA1_SIZE next_page_token: Optional[str] = None contents = [] for counter, row in enumerate(db.content_get_range(start, end, limit + 1, cur)): row_d = dict(zip(db.content_get_metadata_keys, row)) content = Content(**row_d) if counter >= limit: # take the last content for the next page starting from this next_page_token = hash_to_hex(content.sha1) break contents.append(content) assert len(contents) <= limit return PagedResult(results=contents, next_page_token=next_page_token) @db_transaction(statement_timeout=500) def content_get( self, contents: List[bytes], algo: str = "sha1", *, db: Db, cur=None ) -> List[Optional[Content]]: contents_by_hash: Dict[bytes, Optional[Content]] = {} if algo not in DEFAULT_ALGORITHMS: raise StorageArgumentException( "algo should be one of {','.join(DEFAULT_ALGORITHMS)}" ) rows = db.content_get_metadata_from_hashes(contents, algo, cur) key = operator.attrgetter(algo) for row in rows: row_d = dict(zip(db.content_get_metadata_keys, row)) content = Content(**row_d) contents_by_hash[key(content)] = content return [contents_by_hash.get(sha1) for sha1 in contents] @db_transaction_generator() def content_missing( self, contents: List[Dict[str, Any]], key_hash: str = "sha1", *, db: Db, cur=None, ) -> Iterable[bytes]: if key_hash not in DEFAULT_ALGORITHMS: raise StorageArgumentException( "key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}" ) keys = db.content_hash_keys key_hash_idx = keys.index(key_hash) for obj in db.content_missing_from_list(contents, cur): yield obj[key_hash_idx] @db_transaction_generator() def content_missing_per_sha1( self, contents: List[bytes], *, db: Db, cur=None ) -> Iterable[bytes]: for obj in db.content_missing_per_sha1(contents, cur): yield obj[0] @db_transaction_generator() def content_missing_per_sha1_git( self, contents: List[bytes], *, db: Db, cur=None ) -> Iterable[Sha1Git]: for obj in db.content_missing_per_sha1_git(contents, cur): yield obj[0] @db_transaction() def content_find( self, content: Dict[str, Any], *, db: Db, cur=None ) -> List[Content]: if not set(content).intersection(DEFAULT_ALGORITHMS): raise StorageArgumentException( "content keys must contain at least one " f"of: {', '.join(sorted(DEFAULT_ALGORITHMS))}" ) rows = db.content_find( sha1=content.get("sha1"), sha1_git=content.get("sha1_git"), sha256=content.get("sha256"), blake2s256=content.get("blake2s256"), cur=cur, ) contents = [] for row in rows: row_d = dict(zip(db.content_find_cols, row)) contents.append(Content(**row_d)) return contents @db_transaction() def content_get_random(self, *, db: Db, cur=None) -> Sha1Git: return db.content_get_random(cur) @staticmethod def _skipped_content_normalize(d): d = d.copy() if d.get("status") is None: d["status"] = "absent" if d.get("length") is None: d["length"] = -1 return d def _skipped_content_add_metadata(self, db, cur, content: List[SkippedContent]): origin_ids = db.origin_id_get_by_url([cont.origin for cont in content], cur=cur) content = [ attr.evolve(c, origin=origin_id) for (c, origin_id) in zip(content, origin_ids) ] db.mktemp("skipped_content", cur) db.copy_to( [c.to_dict() for c in content], "tmp_skipped_content", db.skipped_content_keys, cur, ) # move metadata in place db.skipped_content_add_from_temp(cur) @db_transaction() def skipped_content_add( self, content: List[SkippedContent], *, db: Db, cur=None ) -> Dict[str, int]: ctime = now() content = [attr.evolve(c, ctime=ctime) for c in content] missing_contents = self.skipped_content_missing( (c.to_dict() for c in content), db=db, cur=cur, ) content = [ c for c in content if any( all( c.get_hash(algo) == missing_content.get(algo) for algo in DEFAULT_ALGORITHMS ) for missing_content in missing_contents ) ] self.journal_writer.skipped_content_add(content) self._skipped_content_add_metadata(db, cur, content) return { "skipped_content:add": len(content), } @db_transaction_generator() def skipped_content_missing( self, contents: List[Dict[str, Any]], *, db: Db, cur=None ) -> Iterable[Dict[str, Any]]: contents = list(contents) for content in db.skipped_content_missing(contents, cur): yield dict(zip(db.content_hash_keys, content)) @db_transaction() def directory_add( self, directories: List[Directory], *, db: Db, cur=None ) -> Dict[str, int]: summary = {"directory:add": 0} dirs = set() dir_entries: Dict[str, defaultdict] = { "file": defaultdict(list), "dir": defaultdict(list), "rev": defaultdict(list), } for cur_dir in directories: dir_id = cur_dir.id dirs.add(dir_id) for src_entry in cur_dir.entries: entry = src_entry.to_dict() entry["dir_id"] = dir_id dir_entries[entry["type"]][dir_id].append(entry) dirs_missing = set(self.directory_missing(dirs, db=db, cur=cur)) if not dirs_missing: return summary self.journal_writer.directory_add( dir_ for dir_ in directories if dir_.id in dirs_missing ) # Copy directory metadata dirs_missing_dict = ( {"id": dir_.id, "raw_manifest": dir_.raw_manifest} for dir_ in directories if dir_.id in dirs_missing ) db.mktemp("directory", cur) db.copy_to(dirs_missing_dict, "tmp_directory", ["id", "raw_manifest"], cur) # Copy entries for entry_type, entry_list in dir_entries.items(): entries = itertools.chain.from_iterable( entries_for_dir for dir_id, entries_for_dir in entry_list.items() if dir_id in dirs_missing ) db.mktemp_dir_entry(entry_type) db.copy_to( entries, "tmp_directory_entry_%s" % entry_type, ["target", "name", "perms", "dir_id"], cur, ) # Do the final copy db.directory_add_from_temp(cur) summary["directory:add"] = len(dirs_missing) return summary @db_transaction_generator() def directory_missing( self, directories: List[Sha1Git], *, db: Db, cur=None ) -> Iterable[Sha1Git]: for obj in db.directory_missing_from_list(directories, cur): yield obj[0] @db_transaction_generator(statement_timeout=20000) def directory_ls( self, directory: Sha1Git, recursive: bool = False, *, db: Db, cur=None ) -> Iterable[Dict[str, Any]]: if recursive: res_gen = db.directory_walk(directory, cur=cur) else: res_gen = db.directory_walk_one(directory, cur=cur) for line in res_gen: yield dict(zip(db.directory_ls_cols, line)) @db_transaction(statement_timeout=4000) def directory_entry_get_by_path( self, directory: Sha1Git, paths: List[bytes], *, db: Db, cur=None ) -> Optional[Dict[str, Any]]: res = db.directory_entry_get_by_path(directory, paths, cur) return dict(zip(db.directory_ls_cols, res)) if res else None @db_transaction() def directory_get_random(self, *, db: Db, cur=None) -> Sha1Git: return db.directory_get_random(cur) @db_transaction() def directory_get_entries( self, directory_id: Sha1Git, page_token: Optional[bytes] = None, limit: int = 1000, *, db: Db, cur=None, ) -> Optional[PagedResult[DirectoryEntry]]: if list(self.directory_missing([directory_id], db=db, cur=cur)): return None if page_token is not None: raise StorageArgumentException("Unsupported page token") # TODO: actually paginate rows = db.directory_get_entries(directory_id, cur=cur) return PagedResult( results=[ DirectoryEntry(**dict(zip(db.directory_get_entries_cols, row))) for row in rows ], next_page_token=None, ) @db_transaction() def directory_get_raw_manifest( self, directory_ids: List[Sha1Git], *, db: Db, cur=None ) -> Dict[Sha1Git, Optional[bytes]]: return dict(db.directory_get_raw_manifest(directory_ids, cur=cur)) @db_transaction() def revision_add( self, revisions: List[Revision], *, db: Db, cur=None ) -> Dict[str, int]: summary = {"revision:add": 0} revisions_missing = set( self.revision_missing( set(revision.id for revision in revisions), db=db, cur=cur ) ) if not revisions_missing: return summary db.mktemp_revision(cur) revisions_filtered = [ revision for revision in revisions if revision.id in revisions_missing ] self.journal_writer.revision_add(revisions_filtered) db_revisions_filtered = list(map(converters.revision_to_db, revisions_filtered)) parents_filtered: List[Dict[str, Any]] = [] with convert_validation_exceptions(): db.copy_to( db_revisions_filtered, "tmp_revision", db.revision_add_cols, cur, lambda rev: parents_filtered.extend(rev["parents"]), ) db.revision_add_from_temp(cur) db.copy_to( parents_filtered, "revision_history", ["id", "parent_id", "parent_rank"], cur, ) return {"revision:add": len(revisions_missing)} @db_transaction_generator() def revision_missing( self, revisions: List[Sha1Git], *, db: Db, cur=None ) -> Iterable[Sha1Git]: if not revisions: return None for obj in db.revision_missing_from_list(revisions, cur): yield obj[0] @db_transaction(statement_timeout=2000) def revision_get( self, revision_ids: List[Sha1Git], ignore_displayname: bool = False, *, db: Db, cur=None, ) -> List[Optional[Revision]]: revisions = [] for line in db.revision_get_from_list(revision_ids, ignore_displayname, cur): revision = converters.db_to_revision(dict(zip(db.revision_get_cols, line))) revisions.append(revision) return revisions @db_transaction_generator(statement_timeout=2000) def revision_log( self, revisions: List[Sha1Git], ignore_displayname: bool = False, limit: Optional[int] = None, *, db: Db, cur=None, ) -> Iterable[Optional[Dict[str, Any]]]: for line in db.revision_log( revisions, ignore_displayname=ignore_displayname, limit=limit, cur=cur ): data = converters.db_to_revision(dict(zip(db.revision_get_cols, line))) if not data: yield None continue yield data.to_dict() @db_transaction_generator(statement_timeout=2000) def revision_shortlog( self, revisions: List[Sha1Git], limit: Optional[int] = None, *, db: Db, cur=None ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: yield from db.revision_shortlog(revisions, limit, cur) @db_transaction() def revision_get_random(self, *, db: Db, cur=None) -> Sha1Git: return db.revision_get_random(cur) @db_transaction() def extid_get_from_extid( self, id_type: str, ids: List[bytes], version: Optional[int] = None, *, db: Db, cur=None, ) -> List[ExtID]: extids = [] for row in db.extid_get_from_extid_list(id_type, ids, version=version, cur=cur): if row[0] is not None: extids.append(converters.db_to_extid(dict(zip(db.extid_cols, row)))) return extids @db_transaction() def extid_get_from_target( self, target_type: ObjectType, ids: List[Sha1Git], extid_type: Optional[str] = None, extid_version: Optional[int] = None, *, db: Db, cur=None, ) -> List[ExtID]: extids = [] if (extid_version is not None and extid_type is None) or ( extid_version is None and extid_type is not None ): raise ValueError("You must provide both extid_type and extid_version") for row in db.extid_get_from_swhid_list( target_type.value, ids, extid_version=extid_version, extid_type=extid_type, cur=cur, ): if row[0] is not None: extids.append(converters.db_to_extid(dict(zip(db.extid_cols, row)))) return extids @db_transaction() def extid_add(self, ids: List[ExtID], *, db: Db, cur=None) -> Dict[str, int]: extid = [ { "extid": extid.extid, "extid_type": extid.extid_type, "extid_version": getattr(extid, "extid_version", 0), "target": extid.target.object_id, "target_type": extid.target.object_type.name.lower(), # arghh } for extid in ids ] db.mktemp("extid", cur) self.journal_writer.extid_add(ids) db.copy_to(extid, "tmp_extid", db.extid_cols, cur) # move metadata in place db.extid_add_from_temp(cur) return {"extid:add": len(extid)} @db_transaction() def release_add( self, releases: List[Release], *, db: Db, cur=None ) -> Dict[str, int]: summary = {"release:add": 0} release_ids = set(release.id for release in releases) releases_missing = set(self.release_missing(release_ids, db=db, cur=cur)) if not releases_missing: return summary db.mktemp_release(cur) releases_filtered = [ release for release in releases if release.id in releases_missing ] self.journal_writer.release_add(releases_filtered) db_releases_filtered = list(map(converters.release_to_db, releases_filtered)) with convert_validation_exceptions(): db.copy_to(db_releases_filtered, "tmp_release", db.release_add_cols, cur) db.release_add_from_temp(cur) return {"release:add": len(releases_missing)} @db_transaction_generator() def release_missing( self, releases: List[Sha1Git], *, db: Db, cur=None ) -> Iterable[Sha1Git]: if not releases: return for obj in db.release_missing_from_list(releases, cur): yield obj[0] @db_transaction(statement_timeout=1000) def release_get( self, releases: List[Sha1Git], ignore_displayname: bool = False, *, db: Db, cur=None, ) -> List[Optional[Release]]: rels = [] for release in db.release_get_from_list(releases, ignore_displayname, cur): data = converters.db_to_release(dict(zip(db.release_get_cols, release))) rels.append(data if data else None) return rels @db_transaction() def release_get_random(self, *, db: Db, cur=None) -> Sha1Git: return db.release_get_random(cur) @db_transaction() def snapshot_add( self, snapshots: List[Snapshot], *, db: Db, cur=None ) -> Dict[str, int]: created_temp_table = False count = 0 for snapshot in snapshots: if not db.snapshot_exists(snapshot.id, cur): if not created_temp_table: db.mktemp_snapshot_branch(cur) created_temp_table = True with convert_validation_exceptions(): db.copy_to( ( { "name": name, "target": info.target if info else None, "target_type": ( info.target_type.value if info else None ), } for name, info in snapshot.branches.items() ), "tmp_snapshot_branch", ["name", "target", "target_type"], cur, ) self.journal_writer.snapshot_add([snapshot]) db.snapshot_add(snapshot.id, cur) count += 1 return {"snapshot:add": count} @db_transaction_generator() def snapshot_missing( self, snapshots: List[Sha1Git], *, db: Db, cur=None ) -> Iterable[Sha1Git]: for obj in db.snapshot_missing_from_list(snapshots, cur): yield obj[0] @db_transaction(statement_timeout=2000) def snapshot_get( self, snapshot_id: Sha1Git, *, db: Db, cur=None ) -> Optional[Dict[str, Any]]: d = self.snapshot_get_branches(snapshot_id) if d is None: return d return { "id": d["id"], "branches": { name: branch.to_dict() if branch else None for (name, branch) in d["branches"].items() }, "next_branch": d["next_branch"], } @db_transaction(statement_timeout=2000) def snapshot_count_branches( self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[bytes] = None, *, db: Db, cur=None, ) -> Optional[Dict[Optional[str], int]]: return dict( [ bc for bc in db.snapshot_count_branches( snapshot_id, branch_name_exclude_prefix, cur, ) ] ) @db_transaction(statement_timeout=2000) def snapshot_get_branches( self, snapshot_id: Sha1Git, branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, branch_name_include_substring: Optional[bytes] = None, branch_name_exclude_prefix: Optional[bytes] = None, *, db: Db, cur=None, ) -> Optional[PartialBranches]: if snapshot_id == EMPTY_SNAPSHOT_ID: return PartialBranches( id=snapshot_id, branches={}, next_branch=None, ) if list(self.snapshot_missing([snapshot_id])): return None branches = {} next_branch = None fetched_branches = list( db.snapshot_get_by_id( snapshot_id, branches_from=branches_from, # the underlying SQL query can be quite expensive to execute for small # branches_count value, so we ensure a minimum branches limit of 10 for # optimal performances branches_count=max(branches_count + 1, 10), target_types=target_types, branch_name_include_substring=branch_name_include_substring, branch_name_exclude_prefix=branch_name_exclude_prefix, cur=cur, ) ) for row in fetched_branches[:branches_count]: branch_d = dict(zip(db.snapshot_get_cols, row)) del branch_d["snapshot_id"] name = branch_d.pop("name") if branch_d["target"] is None and branch_d["target_type"] is None: branch = None else: assert branch_d["target_type"] is not None branch = SnapshotBranch( target=branch_d["target"], target_type=TargetType(branch_d["target_type"]), ) branches[name] = branch if len(fetched_branches) > branches_count: next_branch = dict( zip(db.snapshot_get_cols, fetched_branches[branches_count]) )["name"] return PartialBranches( id=snapshot_id, branches=branches, next_branch=next_branch, ) @db_transaction() def snapshot_get_random(self, *, db: Db, cur=None) -> Sha1Git: return db.snapshot_get_random(cur) @db_transaction() def origin_visit_add( self, visits: List[OriginVisit], *, db: Db, cur=None ) -> Iterable[OriginVisit]: for visit in visits: origin = self.origin_get([visit.origin], db=db, cur=cur)[0] if not origin: # Cannot add a visit without an origin raise StorageArgumentException("Unknown origin %s", visit.origin) all_visits = [] for visit in visits: if visit.visit: self.journal_writer.origin_visit_add([visit]) db.origin_visit_add_with_id(visit, cur=cur) else: # visit_id is not given, it needs to be set by the db with convert_validation_exceptions(): visit_id = db.origin_visit_add( visit.origin, visit.date, visit.type, cur=cur ) visit = attr.evolve(visit, visit=visit_id) # Forced to write in the journal after the db (since its the db # call that set the visit id) self.journal_writer.origin_visit_add([visit]) # In this case, we also want to create the initial OVS object visit_status = OriginVisitStatus( origin=visit.origin, visit=visit_id, date=visit.date, type=visit.type, status="created", snapshot=None, ) self._origin_visit_status_add(visit_status, db=db, cur=cur) assert visit.visit is not None all_visits.append(visit) return all_visits def _origin_visit_status_add( self, visit_status: OriginVisitStatus, db, cur ) -> None: """Add an origin visit status""" self.journal_writer.origin_visit_status_add([visit_status]) db.origin_visit_status_add(visit_status, cur=cur) @db_transaction() def origin_visit_status_add( self, visit_statuses: List[OriginVisitStatus], *, db: Db, cur=None, ) -> Dict[str, int]: visit_statuses_ = [] # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: origin_url = self.origin_get([visit_status.origin], db=db, cur=cur)[0] if not origin_url: raise StorageArgumentException(f"Unknown origin {visit_status.origin}") if visit_status.type is None: origin_visit = self.origin_visit_get_by( visit_status.origin, visit_status.visit, db=db, cur=cur ) if origin_visit is None: raise StorageArgumentException( f"Unknown origin visit {visit_status.visit} " f"of origin {visit_status.origin}" ) origin_visit_status = attr.evolve(visit_status, type=origin_visit.type) else: origin_visit_status = visit_status visit_statuses_.append(origin_visit_status) for visit_status in visit_statuses_: self._origin_visit_status_add(visit_status, db, cur) return {"origin_visit_status:add": len(visit_statuses_)} @db_transaction() def origin_visit_status_get_latest( self, origin_url: str, visit: int, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, *, db: Db, cur=None, ) -> Optional[OriginVisitStatus]: if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): raise StorageArgumentException( f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) row_d = db.origin_visit_status_get_latest( origin_url, visit, allowed_statuses, require_snapshot, cur=cur ) if not row_d: return None return OriginVisitStatus(**row_d) @db_transaction(statement_timeout=500) def origin_visit_get( self, origin: str, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, *, db: Db, cur=None, ) -> PagedResult[OriginVisit]: page_token = page_token or "0" if not isinstance(order, ListOrder): raise StorageArgumentException("order must be a ListOrder value") if not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") next_page_token = None visit_from = int(page_token) visits: List[OriginVisit] = [] extra_limit = limit + 1 for row in db.origin_visit_get_range( origin, visit_from=visit_from, order=order, limit=extra_limit, cur=cur ): row_d = dict(zip(db.origin_visit_cols, row)) visits.append( OriginVisit( origin=row_d["origin"], visit=row_d["visit"], date=row_d["date"], type=row_d["type"], ) ) assert len(visits) <= extra_limit if len(visits) == extra_limit: visits = visits[:limit] next_page_token = str(visits[-1].visit) return PagedResult(results=visits, next_page_token=next_page_token) @db_transaction(statement_timeout=2000) def origin_visit_get_with_statuses( self, origin: str, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, *, db: Db, cur=None, ) -> PagedResult[OriginVisitWithStatuses]: page_token = page_token or "0" if not isinstance(order, ListOrder): raise StorageArgumentException("order must be a ListOrder value") if not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") # First get visits (plus one so we can use it as the next page token if any) visits_page = self.origin_visit_get( origin=origin, page_token=page_token, order=order, limit=limit, db=db, cur=cur, ) visits = visits_page.results next_page_token = visits_page.next_page_token if visits: visit_from = min(visits[0].visit, visits[-1].visit) visit_to = max(visits[0].visit, visits[-1].visit) # Then, fetch all statuses associated to these visits visit_statuses: Dict[int, List[OriginVisitStatus]] = defaultdict(list) for row in db.origin_visit_status_get_all_in_range( origin, allowed_statuses, require_snapshot, visit_from=visit_from, visit_to=visit_to, cur=cur, ): row_d = dict(zip(db.origin_visit_status_cols, row)) visit_statuses[row_d["visit"]].append(OriginVisitStatus(**row_d)) results = [ OriginVisitWithStatuses(visit=visit, statuses=visit_statuses[visit.visit]) for visit in visits ] return PagedResult(results=results, next_page_token=next_page_token) @db_transaction(statement_timeout=2000) def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime, *, db: Db, cur=None ) -> Optional[OriginVisit]: row_d = db.origin_visit_find_by_date(origin, visit_date, cur=cur) if not row_d: return None return OriginVisit( origin=row_d["origin"], visit=row_d["visit"], date=row_d["date"], type=row_d["type"], ) @db_transaction(statement_timeout=500) def origin_visit_get_by( self, origin: str, visit: int, *, db: Db, cur=None ) -> Optional[OriginVisit]: row = db.origin_visit_get(origin, visit, cur) if row: row_d = dict(zip(db.origin_visit_get_cols, row)) return OriginVisit( origin=row_d["origin"], visit=row_d["visit"], date=row_d["date"], type=row_d["type"], ) return None @db_transaction(statement_timeout=4000) def origin_visit_get_latest( self, origin: str, type: Optional[str] = None, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, *, db: Db, cur=None, ) -> Optional[OriginVisit]: if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): raise StorageArgumentException( f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) row = db.origin_visit_get_latest( origin, type=type, allowed_statuses=allowed_statuses, require_snapshot=require_snapshot, cur=cur, ) if row: row_d = dict(zip(db.origin_visit_get_cols, row)) visit = OriginVisit( origin=row_d["origin"], visit=row_d["visit"], date=row_d["date"], type=row_d["type"], ) return visit return None @db_transaction(statement_timeout=500) def origin_visit_status_get( self, origin: str, visit: int, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, *, db: Db, cur=None, ) -> PagedResult[OriginVisitStatus]: next_page_token = None date_from = None if page_token is not None: - date_from = datetime.datetime.fromisoformat(page_token) + try: + date_from = datetime.datetime.fromisoformat(page_token) + except ValueError: + raise StorageArgumentException( + "Invalid page_token argument to origin_visit_status_get." + ) from None visit_statuses: List[OriginVisitStatus] = [] # Take one more visit status so we can reuse it as the next page token if any for row in db.origin_visit_status_get_range( origin, visit, date_from=date_from, order=order, limit=limit + 1, cur=cur, ): row_d = dict(zip(db.origin_visit_status_cols, row)) visit_statuses.append(OriginVisitStatus(**row_d)) if len(visit_statuses) > limit: # last visit status date is the next page token next_page_token = str(visit_statuses[-1].date) # excluding that visit status from the result to respect the limit size visit_statuses = visit_statuses[:limit] return PagedResult(results=visit_statuses, next_page_token=next_page_token) @db_transaction() def origin_visit_status_get_random( self, type: str, *, db: Db, cur=None ) -> Optional[OriginVisitStatus]: row = db.origin_visit_get_random(type, cur) if row is not None: row_d = dict(zip(db.origin_visit_status_cols, row)) return OriginVisitStatus(**row_d) return None @db_transaction(statement_timeout=2000) def object_find_by_sha1_git( self, ids: List[Sha1Git], *, db: Db, cur=None ) -> Dict[Sha1Git, List[Dict]]: ret: Dict[Sha1Git, List[Dict]] = {id: [] for id in ids} for retval in db.object_find_by_sha1_git(ids, cur=cur): if retval[1]: ret[retval[0]].append( dict(zip(db.object_find_by_sha1_git_cols, retval)) ) return ret @db_transaction(statement_timeout=1000) def origin_get( self, origins: List[str], *, db: Db, cur=None ) -> Iterable[Optional[Origin]]: rows = db.origin_get_by_url(origins, cur) result: List[Optional[Origin]] = [] for row in rows: origin_d = dict(zip(db.origin_cols, row)) url = origin_d["url"] result.append(None if url is None else Origin(url=url)) return result @db_transaction(statement_timeout=1000) def origin_get_by_sha1( self, sha1s: List[bytes], *, db: Db, cur=None ) -> List[Optional[Dict[str, Any]]]: return [ dict(zip(db.origin_cols, row)) if row[0] else None for row in db.origin_get_by_sha1(sha1s, cur) ] @db_transaction_generator() def origin_get_range(self, origin_from=1, origin_count=100, *, db: Db, cur=None): for origin in db.origin_get_range(origin_from, origin_count, cur): yield dict(zip(db.origin_get_range_cols, origin)) @db_transaction() def origin_list( self, page_token: Optional[str] = None, limit: int = 100, *, db: Db, cur=None ) -> PagedResult[Origin]: page_token = page_token or "0" if not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") origin_from = int(page_token) next_page_token = None origins: List[Origin] = [] # Take one more origin so we can reuse it as the next page token if any for row_d in self.origin_get_range(origin_from, limit + 1, db=db, cur=cur): origins.append(Origin(url=row_d["url"])) # keep the last_id for the pagination if needed last_id = row_d["id"] if len(origins) > limit: # data left for subsequent call # last origin id is the next page token next_page_token = str(last_id) # excluding that origin from the result to respect the limit size origins = origins[:limit] assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) @db_transaction() def origin_search( self, url_pattern: str, page_token: Optional[str] = None, limit: int = 50, regexp: bool = False, with_visit: bool = False, visit_types: Optional[List[str]] = None, *, db: Db, cur=None, ) -> PagedResult[Origin]: next_page_token = None offset = int(page_token) if page_token else 0 origins = [] # Take one more origin so we can reuse it as the next page token if any for origin in db.origin_search( url_pattern, offset, limit + 1, regexp, with_visit, visit_types, cur ): row_d = dict(zip(db.origin_cols, origin)) origins.append(Origin(url=row_d["url"])) if len(origins) > limit: # next offset next_page_token = str(offset + limit) # excluding that origin from the result to respect the limit size origins = origins[:limit] assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) @db_transaction() def origin_count( self, url_pattern: str, regexp: bool = False, with_visit: bool = False, *, db: Db, cur=None, ) -> int: return db.origin_count(url_pattern, regexp, with_visit, cur) @db_transaction() def origin_snapshot_get_all( self, origin_url: str, *, db: Db, cur=None ) -> List[Sha1Git]: return list(db.origin_snapshot_get_all(origin_url, cur)) @db_transaction() def origin_add(self, origins: List[Origin], *, db: Db, cur=None) -> Dict[str, int]: urls = [o.url for o in origins] known_origins = set(url for (url,) in db.origin_get_by_url(urls, cur)) # keep only one occurrence of each given origin while keeping the list # sorted as originally given to_add = sorted(set(urls) - known_origins, key=urls.index) self.journal_writer.origin_add([Origin(url=url) for url in to_add]) added = 0 for url in to_add: if db.origin_add(url, cur): added += 1 return {"origin:add": added} @db_transaction(statement_timeout=500) def stat_counters(self, *, db: Db, cur=None): return {k: v for (k, v) in db.stat_counters()} @db_transaction() def refresh_stat_counters(self, *, db: Db, cur=None): keys = [ "content", "directory", "directory_entry_dir", "directory_entry_file", "directory_entry_rev", "origin", "origin_visit", "person", "release", "revision", "revision_history", "skipped_content", "snapshot", ] for key in keys: cur.execute("select * from swh_update_counter(%s)", (key,)) @db_transaction() def raw_extrinsic_metadata_add( self, metadata: List[RawExtrinsicMetadata], db, cur, ) -> Dict[str, int]: metadata = list(metadata) self.journal_writer.raw_extrinsic_metadata_add(metadata) counter = Counter[ExtendedObjectType]() for metadata_entry in metadata: authority_id = self._get_authority_id(metadata_entry.authority, db, cur) fetcher_id = self._get_fetcher_id(metadata_entry.fetcher, db, cur) db.raw_extrinsic_metadata_add( id=metadata_entry.id, type=metadata_entry.target.object_type.name.lower(), target=str(metadata_entry.target), discovery_date=metadata_entry.discovery_date, authority_id=authority_id, fetcher_id=fetcher_id, format=metadata_entry.format, metadata=metadata_entry.metadata, origin=metadata_entry.origin, visit=metadata_entry.visit, snapshot=map_optional(str, metadata_entry.snapshot), release=map_optional(str, metadata_entry.release), revision=map_optional(str, metadata_entry.revision), path=metadata_entry.path, directory=map_optional(str, metadata_entry.directory), cur=cur, ) counter[metadata_entry.target.object_type] += 1 return { f"{type.value}_metadata:add": count for (type, count) in counter.items() } @db_transaction() def raw_extrinsic_metadata_get( self, target: ExtendedSWHID, authority: MetadataAuthority, after: Optional[datetime.datetime] = None, page_token: Optional[bytes] = None, limit: int = 1000, *, db: Db, cur=None, ) -> PagedResult[RawExtrinsicMetadata]: if page_token: (after_time, after_fetcher) = msgpack_loads(base64.b64decode(page_token)) if after and after_time < after: raise StorageArgumentException( "page_token is inconsistent with the value of 'after'." ) else: after_time = after after_fetcher = None authority_id = self._get_authority_id(authority, db, cur) if not authority_id: return PagedResult( next_page_token=None, results=[], ) rows = db.raw_extrinsic_metadata_get( str(target), authority_id, after_time, after_fetcher, limit + 1, cur, ) rows = [dict(zip(db.raw_extrinsic_metadata_get_cols, row)) for row in rows] results = [] for row in rows: assert str(target) == row["raw_extrinsic_metadata.target"] results.append(converters.db_to_raw_extrinsic_metadata(row)) if len(results) > limit: results.pop() assert len(results) == limit last_returned_row = rows[-2] # rows[-1] corresponds to the popped result next_page_token: Optional[str] = base64.b64encode( msgpack_dumps( ( last_returned_row["discovery_date"], last_returned_row["metadata_fetcher.id"], ) ) ).decode() else: next_page_token = None return PagedResult( next_page_token=next_page_token, results=results, ) @db_transaction() def raw_extrinsic_metadata_get_by_ids( self, ids: List[Sha1Git], *, db: Db, cur=None, ) -> List[RawExtrinsicMetadata]: return [ converters.db_to_raw_extrinsic_metadata( dict(zip(db.raw_extrinsic_metadata_get_cols, row)) ) for row in db.raw_extrinsic_metadata_get_by_ids(ids) ] @db_transaction() def raw_extrinsic_metadata_get_authorities( self, target: ExtendedSWHID, *, db: Db, cur=None, ) -> List[MetadataAuthority]: return [ MetadataAuthority( type=MetadataAuthorityType(authority_type), url=authority_url ) for ( authority_type, authority_url, ) in db.raw_extrinsic_metadata_get_authorities(str(target), cur) ] @db_transaction() def metadata_fetcher_add( self, fetchers: List[MetadataFetcher], *, db: Db, cur=None ) -> Dict[str, int]: fetchers = list(fetchers) self.journal_writer.metadata_fetcher_add(fetchers) count = 0 for fetcher in fetchers: db.metadata_fetcher_add(fetcher.name, fetcher.version, cur=cur) count += 1 return {"metadata_fetcher:add": count} @db_transaction(statement_timeout=500) def metadata_fetcher_get( self, name: str, version: str, *, db: Db, cur=None ) -> Optional[MetadataFetcher]: row = db.metadata_fetcher_get(name, version, cur=cur) if not row: return None return MetadataFetcher.from_dict(dict(zip(db.metadata_fetcher_cols, row))) @db_transaction() def metadata_authority_add( self, authorities: List[MetadataAuthority], *, db: Db, cur=None ) -> Dict[str, int]: authorities = list(authorities) self.journal_writer.metadata_authority_add(authorities) count = 0 for authority in authorities: db.metadata_authority_add(authority.type.value, authority.url, cur=cur) count += 1 return {"metadata_authority:add": count} @db_transaction() def metadata_authority_get( self, type: MetadataAuthorityType, url: str, *, db: Db, cur=None ) -> Optional[MetadataAuthority]: row = db.metadata_authority_get(type.value, url, cur=cur) if not row: return None return MetadataAuthority.from_dict(dict(zip(db.metadata_authority_cols, row))) def clear_buffers(self, object_types: Sequence[str] = ()) -> None: """Do nothing""" return None def flush(self, object_types: Sequence[str] = ()) -> Dict[str, int]: return {} def _get_authority_id(self, authority: MetadataAuthority, db, cur): authority_id = db.metadata_authority_get_id( authority.type.value, authority.url, cur ) if not authority_id: raise StorageArgumentException(f"Unknown authority {authority}") return authority_id def _get_fetcher_id(self, fetcher: MetadataFetcher, db, cur): fetcher_id = db.metadata_fetcher_get_id(fetcher.name, fetcher.version, cur) if not fetcher_id: raise StorageArgumentException(f"Unknown fetcher {fetcher}") return fetcher_id