diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py index 5ae8b392..676eb830 100644 --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1,1399 +1,1278 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import bisect import collections import copy import datetime +import functools import itertools import random import re from collections import defaultdict from datetime import timedelta from typing import ( Any, Callable, Dict, Generic, Hashable, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union, ) import attr from swh.core.api.serializers import msgpack_loads, msgpack_dumps from swh.model.identifiers import SWHID from swh.model.model import ( BaseContent, Content, SkippedContent, Directory, Revision, Release, Snapshot, OriginVisit, OriginVisitStatus, Origin, - SHA1_SIZE, MetadataAuthority, MetadataAuthorityType, MetadataFetcher, MetadataTargetType, RawExtrinsicMetadata, - Sha1, Sha1Git, ) -from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex +from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.storage.cassandra import CassandraStorage -from swh.storage.cassandra.model import BaseRow, ObjectCountRow +from swh.storage.cassandra.model import ( + BaseRow, + ContentRow, + ObjectCountRow, +) from swh.storage.interface import ( ListOrder, PagedResult, PartialBranches, VISIT_STATUSES, ) from swh.storage.objstorage import ObjStorage from swh.storage.utils import now from .converters import origin_url_to_sha1 -from .exc import StorageArgumentException, HashCollision -from .utils import get_partition_bounds_bytes +from .exc import StorageArgumentException from .writer import JournalWriter # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 SortedListItem = TypeVar("SortedListItem") SortedListKey = TypeVar("SortedListKey") FetcherKey = Tuple[str, str] class SortedList(collections.UserList, Generic[SortedListKey, SortedListItem]): data: List[Tuple[SortedListKey, SortedListItem]] # https://github.com/python/mypy/issues/708 # key: Callable[[SortedListItem], SortedListKey] def __init__( self, data: List[SortedListItem] = None, key: Optional[Callable[[SortedListItem], SortedListKey]] = None, ): if key is None: def key(item): return item assert key is not None # for mypy super().__init__(sorted((key(x), x) for x in data or [])) self.key: Callable[[SortedListItem], SortedListKey] = key def add(self, item: SortedListItem): k = self.key(item) bisect.insort(self.data, (k, item)) def __iter__(self) -> Iterator[SortedListItem]: for (k, item) in self.data: yield item def iter_from(self, start_key: Any) -> Iterator[SortedListItem]: """Returns an iterator over all the elements whose key is greater or equal to `start_key`. (This is an efficient equivalent to: `(x for x in L if key(x) >= start_key)`) """ from_index = bisect.bisect_left(self.data, (start_key,)) for (k, item) in itertools.islice(self.data, from_index, None): yield item def iter_after(self, start_key: Any) -> Iterator[SortedListItem]: """Same as iter_from, but using a strict inequality.""" it = self.iter_from(start_key) for item in it: if self.key(item) > start_key: # type: ignore yield item break yield from it TRow = TypeVar("TRow", bound=BaseRow) class Table(Generic[TRow]): def __init__(self, row_class: Type[TRow]): self.row_class = row_class self.primary_key_cols = row_class.PARTITION_KEY + row_class.CLUSTERING_KEY # Map from tokens to clustering keys to rows # These are not actually partitions (or rather, there is one partition # for each token) and they aren't sorted. # But it is good enough if we don't care about performance; # and makes the code a lot simpler. self.data: Dict[int, Dict[Tuple, TRow]] = defaultdict(dict) def __repr__(self): return f"<__module__.Table[{self.row_class.__name__}] object>" def partition_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple: """Returns the partition key of a row (ie. the cells which get hashed into the token.""" if isinstance(row, dict): row_d = row else: row_d = row.to_dict() return tuple(row_d[col] for col in self.row_class.PARTITION_KEY) def clustering_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple: """Returns the clustering key of a row (ie. the cells which are used for sorting rows within a partition.""" if isinstance(row, dict): row_d = row else: row_d = row.to_dict() return tuple(row_d[col] for col in self.row_class.CLUSTERING_KEY) def primary_key(self, row): return self.partition_key(row) + self.clustering_key(row) def primary_key_from_dict(self, d: Dict[str, Any]) -> Tuple: """Returns the primary key (ie. concatenation of partition key and clustering key) of the given dictionary interpreted as a row.""" return tuple(d[col] for col in self.primary_key_cols) def token(self, key: Tuple): """Returns the token of a row (ie. the hash of its partition key).""" return hash(key) def get_partition(self, token: int) -> Dict[Tuple, TRow]: """Returns the partition that contains this token.""" return self.data[token] def insert(self, row: TRow): partition = self.data[self.token(self.partition_key(row))] partition[self.clustering_key(row)] = row def split_primary_key(self, key: Tuple) -> Tuple[Tuple, Tuple]: """Returns (partition_key, clustering_key) from a partition key""" assert len(key) == len(self.primary_key_cols) partition_key = key[0 : len(self.row_class.PARTITION_KEY)] clustering_key = key[len(self.row_class.PARTITION_KEY) :] return (partition_key, clustering_key) def get_from_primary_key(self, primary_key: Tuple) -> Optional[TRow]: """Returns at most one row, from its primary key.""" (partition_key, clustering_key) = self.split_primary_key(primary_key) token = self.token(partition_key) partition = self.get_partition(token) return partition.get(clustering_key) def get_from_token(self, token: int) -> Iterable[TRow]: """Returns all rows whose token (ie. non-cryptographic hash of the partition key) is the one passed as argument.""" return (v for (k, v) in sorted(self.get_partition(token).items())) def iter_all(self) -> Iterator[Tuple[Tuple, TRow]]: return ( (self.primary_key(row), row) for (token, partition) in self.data.items() for (clustering_key, row) in partition.items() ) class InMemoryCqlRunner: def __init__(self): + self._contents = Table(ContentRow) + self._content_indexes = defaultdict(lambda: defaultdict(set)) self._stat_counters = defaultdict(int) def increment_counter(self, object_type: str, nb: int): self._stat_counters[object_type] += nb def stat_counters(self) -> Iterable[ObjectCountRow]: for (object_type, count) in self._stat_counters.items(): yield ObjectCountRow(partition_key=0, object_type=object_type, count=count) + ########################## + # 'content' table + ########################## + + def _content_add_finalize(self, content: ContentRow) -> None: + self._contents.insert(content) + self.increment_counter("content", 1) + + def content_add_prepare(self, content: ContentRow): + finalizer = functools.partial(self._content_add_finalize, content) + return (self._contents.token(self._contents.partition_key(content)), finalizer) + + def content_get_from_pk( + self, content_hashes: Dict[str, bytes] + ) -> Optional[ContentRow]: + primary_key = self._contents.primary_key_from_dict(content_hashes) + return self._contents.get_from_primary_key(primary_key) + + def content_get_from_token(self, token: int) -> Iterable[ContentRow]: + return self._contents.get_from_token(token) + + def content_get_random(self) -> Optional[ContentRow]: + return random.choice( + [ + row + for partition in self._contents.data.values() + for row in partition.values() + ] + ) + + def content_get_token_range( + self, start: int, end: int, limit: int, + ) -> Iterable[Tuple[int, ContentRow]]: + matches = [ + (token, row) + for (token, partition) in self._contents.data.items() + for (clustering_key, row) in partition.items() + if start <= token <= end + ] + matches.sort() + return matches[0:limit] + ########################## # 'content_by_*' tables ########################## def content_missing_by_sha1_git(self, ids: List[bytes]) -> List[bytes]: - return ids + missing = [] + for id_ in ids: + if id_ not in self._content_indexes["sha1_git"]: + missing.append(id_) + + return missing + + def content_index_add_one(self, algo: str, content: Content, token: int) -> None: + self._content_indexes[algo][content.get_hash(algo)].add(token) + + def content_get_tokens_from_single_hash( + self, algo: str, hash_: bytes + ) -> Iterable[int]: + return self._content_indexes[algo][hash_] ########################## # 'directory' table ########################## def directory_missing(self, ids: List[bytes]) -> List[bytes]: return ids ########################## # 'revision' table ########################## def revision_missing(self, ids: List[bytes]) -> List[bytes]: return ids ########################## # 'release' table ########################## def release_missing(self, ids: List[bytes]) -> List[bytes]: return ids class InMemoryStorage(CassandraStorage): _cql_runner: InMemoryCqlRunner # type: ignore def __init__(self, journal_writer=None): self.reset() self.journal_writer = JournalWriter(journal_writer) def reset(self): self._cql_runner = InMemoryCqlRunner() - self._contents = {} - self._content_indexes = defaultdict(lambda: defaultdict(set)) self._skipped_contents = {} self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) self._directories = {} self._revisions = {} self._releases = {} self._snapshots = {} self._origins = {} self._origins_by_sha1 = {} self._origin_visits = {} self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {} self._persons = {} # {object_type: {id: {authority: [metadata]}}} self._raw_extrinsic_metadata: Dict[ MetadataTargetType, Dict[ Union[str, SWHID], Dict[ Hashable, SortedList[ Tuple[datetime.datetime, FetcherKey], RawExtrinsicMetadata ], ], ], ] = defaultdict( lambda: defaultdict( lambda: defaultdict( lambda: SortedList( key=lambda x: ( x.discovery_date, self._metadata_fetcher_key(x.fetcher), ) ) ) ) ) # noqa self._metadata_fetchers: Dict[FetcherKey, MetadataFetcher] = {} self._metadata_authorities: Dict[Hashable, MetadataAuthority] = {} self._objects = defaultdict(list) self._sorted_sha1s = SortedList[bytes, bytes]() self.objstorage = ObjStorage({"cls": "memory", "args": {}}) def check_config(self, *, check_write: bool) -> bool: return True - def _content_add(self, contents: List[Content], with_data: bool) -> Dict: - self.journal_writer.content_add(contents) - - content_add = 0 - if with_data: - summary = self.objstorage.content_add( - c for c in contents if c.status != "absent" - ) - content_add_bytes = summary["content:add:bytes"] - - for content in contents: - key = self._content_key(content) - if key in self._contents: - continue - for algorithm in DEFAULT_ALGORITHMS: - hash_ = content.get_hash(algorithm) - if hash_ in self._content_indexes[algorithm] and ( - algorithm not in {"blake2s256", "sha256"} - ): - colliding_content_hashes = [] - # Add the already stored contents - for content_hashes_set in self._content_indexes[algorithm][hash_]: - hashes = dict(content_hashes_set) - colliding_content_hashes.append(hashes) - # Add the new colliding content - colliding_content_hashes.append(content.hashes()) - raise HashCollision(algorithm, hash_, colliding_content_hashes) - for algorithm in DEFAULT_ALGORITHMS: - hash_ = content.get_hash(algorithm) - self._content_indexes[algorithm][hash_].add(key) - self._objects[content.sha1_git].append(("content", content.sha1)) - self._contents[key] = content - self._sorted_sha1s.add(content.sha1) - self._contents[key] = attr.evolve(self._contents[key], data=None) - content_add += 1 - - self._cql_runner.increment_counter("content", content_add) - - summary = { - "content:add": content_add, - } - if with_data: - summary["content:add:bytes"] = content_add_bytes - - return summary - - def content_add(self, content: List[Content]) -> Dict: - content = [attr.evolve(c, ctime=now()) for c in content] - return self._content_add(content, with_data=True) - - def content_update( - self, contents: List[Dict[str, Any]], keys: List[str] = [] - ) -> None: - self.journal_writer.content_update(contents) - - for cont_update in contents: - cont_update = cont_update.copy() - sha1 = cont_update.pop("sha1") - for old_key in self._content_indexes["sha1"][sha1]: - old_cont = self._contents.pop(old_key) - - for algorithm in DEFAULT_ALGORITHMS: - hash_ = old_cont.get_hash(algorithm) - self._content_indexes[algorithm][hash_].remove(old_key) - - new_cont = attr.evolve(old_cont, **cont_update) - new_key = self._content_key(new_cont) - - self._contents[new_key] = new_cont - - for algorithm in DEFAULT_ALGORITHMS: - hash_ = new_cont.get_hash(algorithm) - self._content_indexes[algorithm][hash_].add(new_key) - - def content_add_metadata(self, content: List[Content]) -> Dict: - return self._content_add(content, with_data=False) - - def content_get_data(self, content: Sha1) -> Optional[bytes]: - # FIXME: Make this method support slicing the `data` - return self.objstorage.content_get(content) - - def content_get_partition( - self, - partition_id: int, - nb_partitions: int, - page_token: Optional[str] = None, - limit: int = 1000, - ) -> PagedResult[Content]: - if limit is None: - raise StorageArgumentException("limit should not be None") - (start, end) = get_partition_bounds_bytes( - partition_id, nb_partitions, SHA1_SIZE - ) - if page_token: - start = hash_to_bytes(page_token) - if end is None: - end = b"\xff" * SHA1_SIZE - - next_page_token: Optional[str] = None - sha1s = ( - (sha1, content_key) - for sha1 in self._sorted_sha1s.iter_from(start) - for content_key in self._content_indexes["sha1"][sha1] - ) - contents: List[Content] = [] - for counter, (sha1, key) in enumerate(sha1s): - if sha1 > end: - break - if counter >= limit: - next_page_token = hash_to_hex(sha1) - break - contents.append(self._contents[key]) - - assert len(contents) <= limit - return PagedResult(results=contents, next_page_token=next_page_token) - - def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]: - contents_by_sha1: Dict[Sha1, Optional[Content]] = {} - for sha1 in contents: - if sha1 in self._content_indexes["sha1"]: - objs = self._content_indexes["sha1"][sha1] - # only 1 element as content_add_metadata would have raised a - # hash collision otherwise - assert len(objs) == 1 - for key in objs: - content = attr.evolve(self._contents[key], data=None, ctime=None) - contents_by_sha1[sha1] = content - return [contents_by_sha1.get(sha1) for sha1 in contents] - - def content_find(self, content: Dict[str, Any]) -> List[Content]: - if not set(content).intersection(DEFAULT_ALGORITHMS): - raise StorageArgumentException( - "content keys must contain at least one " - f"of: {', '.join(sorted(DEFAULT_ALGORITHMS))}" - ) - found = [] - for algo in DEFAULT_ALGORITHMS: - hash = content.get(algo) - if hash and hash in self._content_indexes[algo]: - found.append(self._content_indexes[algo][hash]) - - if not found: - return [] - - keys = list(set.intersection(*found)) - return [self._contents[key] for key in keys] - - def content_missing( - self, contents: List[Dict[str, Any]], key_hash: str = "sha1" - ) -> Iterable[bytes]: - if key_hash not in DEFAULT_ALGORITHMS: - raise StorageArgumentException( - "key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}" - ) - - for content in contents: - for (algo, hash_) in content.items(): - if algo not in DEFAULT_ALGORITHMS: - continue - if hash_ not in self._content_indexes.get(algo, []): - yield content[key_hash] - break - - def content_missing_per_sha1(self, contents: List[bytes]) -> Iterable[bytes]: - for content in contents: - if content not in self._content_indexes["sha1"]: - yield content - - def content_missing_per_sha1_git( - self, contents: List[Sha1Git] - ) -> Iterable[Sha1Git]: - for content in contents: - if content not in self._content_indexes["sha1_git"]: - yield content - - def content_get_random(self) -> Sha1Git: - return random.choice(list(self._content_indexes["sha1_git"])) - def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict: self.journal_writer.skipped_content_add(contents) summary = {"skipped_content:add": 0} missing_contents = self.skipped_content_missing([c.hashes() for c in contents]) missing = {self._content_key(c) for c in missing_contents} contents = [c for c in contents if self._content_key(c) in missing] for content in contents: key = self._content_key(content) for algo in DEFAULT_ALGORITHMS: if content.get_hash(algo): self._skipped_content_indexes[algo][content.get_hash(algo)].add(key) self._skipped_contents[key] = content summary["skipped_content:add"] += 1 self._cql_runner.increment_counter("skipped_content", len(contents)) return summary def skipped_content_add(self, content: List[SkippedContent]) -> Dict: content = [attr.evolve(c, ctime=now()) for c in content] return self._skipped_content_add(content) def skipped_content_missing( self, contents: List[Dict[str, Any]] ) -> Iterable[Dict[str, Any]]: for content in contents: matches = list(self._skipped_contents.values()) for (algorithm, key) in self._content_key(content): if algorithm == "blake2s256": continue # Filter out skipped contents with the same hash matches = [ match for match in matches if match.get_hash(algorithm) == key ] # if none of the contents match if not matches: yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS} def directory_add(self, directories: List[Directory]) -> Dict: directories = [dir_ for dir_ in directories if dir_.id not in self._directories] self.journal_writer.directory_add(directories) count = 0 for directory in directories: count += 1 self._directories[directory.id] = directory self._objects[directory.id].append(("directory", directory.id)) self._cql_runner.increment_counter("directory", len(directories)) return {"directory:add": count} def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]: for id in directories: if id not in self._directories: yield id def _directory_ls(self, directory_id, recursive, prefix=b""): if directory_id in self._directories: for entry in self._directories[directory_id].entries: ret = self._join_dentry_to_content(entry) ret["name"] = prefix + ret["name"] ret["dir_id"] = directory_id yield ret if recursive and ret["type"] == "dir": yield from self._directory_ls( ret["target"], True, prefix + ret["name"] + b"/" ) def directory_ls( self, directory: Sha1Git, recursive: bool = False ) -> Iterable[Dict[str, Any]]: yield from self._directory_ls(directory, recursive) def directory_entry_get_by_path( self, directory: Sha1Git, paths: List[bytes] ) -> Optional[Dict[str, Any]]: return self._directory_entry_get_by_path(directory, paths, b"") def directory_get_random(self) -> Sha1Git: return random.choice(list(self._directories)) def _directory_entry_get_by_path( self, directory: Sha1Git, paths: List[bytes], prefix: bytes ) -> Optional[Dict[str, Any]]: if not paths: return None contents = list(self.directory_ls(directory)) if not contents: return None def _get_entry(entries, name): for entry in entries: if entry["name"] == name: entry = entry.copy() entry["name"] = prefix + entry["name"] return entry first_item = _get_entry(contents, paths[0]) if len(paths) == 1: return first_item if not first_item or first_item["type"] != "dir": return None return self._directory_entry_get_by_path( first_item["target"], paths[1:], prefix + paths[0] + b"/" ) def revision_add(self, revisions: List[Revision]) -> Dict: revisions = [rev for rev in revisions if rev.id not in self._revisions] self.journal_writer.revision_add(revisions) count = 0 for revision in revisions: revision = attr.evolve( revision, committer=self._person_add(revision.committer), author=self._person_add(revision.author), ) self._revisions[revision.id] = revision self._objects[revision.id].append(("revision", revision.id)) count += 1 self._cql_runner.increment_counter("revision", len(revisions)) return {"revision:add": count} def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: for id in revisions: if id not in self._revisions: yield id def revision_get( self, revisions: List[Sha1Git] ) -> Iterable[Optional[Dict[str, Any]]]: for id in revisions: if id in self._revisions: yield self._revisions.get(id).to_dict() else: yield None def __get_parent_revs( self, rev_id: Sha1Git, seen: Set[Sha1Git], limit: Optional[int] ) -> Iterable[Dict[str, Any]]: if limit and len(seen) >= limit: return if rev_id in seen or rev_id not in self._revisions: return seen.add(rev_id) yield self._revisions[rev_id].to_dict() for parent in self._revisions[rev_id].parents: yield from self.__get_parent_revs(parent, seen, limit) def revision_log( self, revisions: List[Sha1Git], limit: Optional[int] = None ) -> Iterable[Optional[Dict[str, Any]]]: seen: Set[Sha1Git] = set() for rev_id in revisions: yield from self.__get_parent_revs(rev_id, seen, limit) def revision_shortlog( self, revisions: List[Sha1Git], limit: Optional[int] = None ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: yield from ( (rev["id"], rev["parents"]) if rev else None for rev in self.revision_log(revisions, limit) ) def revision_get_random(self) -> Sha1Git: return random.choice(list(self._revisions)) def release_add(self, releases: List[Release]) -> Dict: to_add = [] for rel in releases: if rel.id not in self._releases and rel not in to_add: to_add.append(rel) self.journal_writer.release_add(to_add) for rel in to_add: if rel.author: self._person_add(rel.author) self._objects[rel.id].append(("release", rel.id)) self._releases[rel.id] = rel self._cql_runner.increment_counter("release", len(to_add)) return {"release:add": len(to_add)} def release_missing(self, releases: List[Sha1Git]) -> Iterable[Sha1Git]: yield from (rel for rel in releases if rel not in self._releases) def release_get( self, releases: List[Sha1Git] ) -> Iterable[Optional[Dict[str, Any]]]: for rel_id in releases: if rel_id in self._releases: yield self._releases[rel_id].to_dict() else: yield None def release_get_random(self) -> Sha1Git: return random.choice(list(self._releases)) def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: count = 0 snapshots = [snap for snap in snapshots if snap.id not in self._snapshots] for snapshot in snapshots: self.journal_writer.snapshot_add([snapshot]) self._snapshots[snapshot.id] = snapshot self._objects[snapshot.id].append(("snapshot", snapshot.id)) count += 1 self._cql_runner.increment_counter("snapshot", len(snapshots)) return {"snapshot:add": count} def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: for id in snapshots: if id not in self._snapshots: yield id def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: d = self.snapshot_get_branches(snapshot_id) if d is None: return None return { "id": d["id"], "branches": { name: branch.to_dict() if branch else None for (name, branch) in d["branches"].items() }, "next_branch": d["next_branch"], } def snapshot_get_by_origin_visit( self, origin: str, visit: int ) -> Optional[Dict[str, Any]]: origin_url = self._get_origin_url(origin) if not origin_url: return None if origin_url not in self._origins or visit > len( self._origin_visits[origin_url] ): return None visit_d = self._origin_visit_get_updated(origin_url, visit) snapshot_id = visit_d["snapshot"] if snapshot_id: return self.snapshot_get(snapshot_id) else: return None def snapshot_count_branches( self, snapshot_id: Sha1Git ) -> Optional[Dict[Optional[str], int]]: snapshot = self._snapshots[snapshot_id] return collections.Counter( branch.target_type.value if branch else None for branch in snapshot.branches.values() ) def snapshot_get_branches( self, snapshot_id: Sha1Git, branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, ) -> Optional[PartialBranches]: snapshot = self._snapshots.get(snapshot_id) if snapshot is None: return None sorted_branches = sorted(snapshot.branches.items()) sorted_branch_names = [k for (k, v) in sorted_branches] from_index = bisect.bisect_left(sorted_branch_names, branches_from) if target_types: next_branch = None branches: Dict = {} for (branch_name, branch) in sorted_branches: if branch_name in sorted_branch_names[from_index:]: if branch and branch.target_type.value in target_types: if len(branches) < branches_count: branches[branch_name] = branch else: next_branch = branch_name break else: # As there is no 'target_types', we can do that much faster to_index = from_index + branches_count returned_branch_names = frozenset(sorted_branch_names[from_index:to_index]) branches = dict( (branch_name, branch) for (branch_name, branch) in snapshot.branches.items() if branch_name in returned_branch_names ) if to_index >= len(sorted_branch_names): next_branch = None else: next_branch = sorted_branch_names[to_index] return PartialBranches( id=snapshot_id, branches=branches, next_branch=next_branch, ) def snapshot_get_random(self) -> Sha1Git: return random.choice(list(self._snapshots)) def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]: ret = super().object_find_by_sha1_git(ids) for id_ in ids: objs = self._objects.get(id_, []) ret[id_].extend([{"sha1_git": id_, "type": obj[0],} for obj in objs]) return ret def _convert_origin(self, t): if t is None: return None return t.to_dict() def origin_get_one(self, origin_url: str) -> Optional[Origin]: return self._origins.get(origin_url) def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: return [self.origin_get_one(origin_url) for origin_url in origins] def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s] def origin_list( self, page_token: Optional[str] = None, limit: int = 100 ) -> PagedResult[Origin]: origin_urls = sorted(self._origins) from_ = bisect.bisect_left(origin_urls, page_token) if page_token else 0 next_page_token = None # Take one more origin so we can reuse it as the next page token if any origins = [Origin(url=url) for url in origin_urls[from_ : from_ + limit + 1]] if len(origins) > limit: # last origin id is the next page token next_page_token = str(origins[-1].url) # excluding that origin from the result to respect the limit size origins = origins[:limit] assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) def origin_search( self, url_pattern: str, page_token: Optional[str] = None, limit: int = 50, regexp: bool = False, with_visit: bool = False, ) -> PagedResult[Origin]: next_page_token = None offset = int(page_token) if page_token else 0 origins = self._origins.values() if regexp: pat = re.compile(url_pattern) origins = [orig for orig in origins if pat.search(orig.url)] else: origins = [orig for orig in origins if url_pattern in orig.url] if with_visit: filtered_origins = [] for orig in origins: visits = ( self._origin_visit_get_updated(ov.origin, ov.visit) for ov in self._origin_visits[orig.url] ) for ov in visits: snapshot = ov["snapshot"] if snapshot and snapshot in self._snapshots: filtered_origins.append(orig) break else: filtered_origins = origins # Take one more origin so we can reuse it as the next page token if any origins = filtered_origins[offset : offset + limit + 1] if len(origins) > limit: # next offset next_page_token = str(offset + limit) # excluding that origin from the result to respect the limit size origins = origins[:limit] assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) def origin_count( self, url_pattern: str, regexp: bool = False, with_visit: bool = False ) -> int: actual_page = self.origin_search( url_pattern, regexp=regexp, with_visit=with_visit, limit=len(self._origins), ) assert actual_page.next_page_token is None return len(actual_page.results) def origin_add(self, origins: List[Origin]) -> Dict[str, int]: added = 0 for origin in origins: if origin.url not in self._origins: self.origin_add_one(origin) added += 1 self._cql_runner.increment_counter("origin", added) return {"origin:add": added} def origin_add_one(self, origin: Origin) -> str: if origin.url not in self._origins: self.journal_writer.origin_add([origin]) self._origins[origin.url] = origin self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin self._origin_visits[origin.url] = [] self._objects[origin.url].append(("origin", origin.url)) return origin.url def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: for visit in visits: origin = self.origin_get_one(visit.origin) if not origin: # Cannot add a visit without an origin raise StorageArgumentException("Unknown origin %s", visit.origin) all_visits = [] for visit in visits: origin_url = visit.origin if origin_url in self._origins: origin = self._origins[origin_url] if visit.visit: self.journal_writer.origin_visit_add([visit]) while len(self._origin_visits[origin_url]) < visit.visit: self._origin_visits[origin_url].append(None) self._origin_visits[origin_url][visit.visit - 1] = visit else: # visit ids are in the range [1, +inf[ visit_id = len(self._origin_visits[origin_url]) + 1 visit = attr.evolve(visit, visit=visit_id) self.journal_writer.origin_visit_add([visit]) self._origin_visits[origin_url].append(visit) visit_key = (origin_url, visit.visit) self._objects[visit_key].append(("origin_visit", None)) assert visit.visit is not None self._origin_visit_status_add_one( OriginVisitStatus( origin=visit.origin, visit=visit.visit, date=visit.date, status="created", snapshot=None, ) ) all_visits.append(visit) self._cql_runner.increment_counter("origin_visit", len(all_visits)) return all_visits def _origin_visit_status_add_one(self, visit_status: OriginVisitStatus) -> None: """Add an origin visit status without checks. If already present, do nothing. """ self.journal_writer.origin_visit_status_add([visit_status]) visit_key = (visit_status.origin, visit_status.visit) self._origin_visit_statuses.setdefault(visit_key, []) visit_statuses = self._origin_visit_statuses[visit_key] if visit_status not in visit_statuses: visit_statuses.append(visit_status) def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus],) -> None: # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: origin_url = self.origin_get_one(visit_status.origin) if not origin_url: raise StorageArgumentException(f"Unknown origin {visit_status.origin}") for visit_status in visit_statuses: self._origin_visit_status_add_one(visit_status) def _origin_visit_status_get_latest( self, origin: str, visit_id: int ) -> Tuple[OriginVisit, OriginVisitStatus]: """Return a tuple of OriginVisit, latest associated OriginVisitStatus. """ assert visit_id >= 1 visit = self._origin_visits[origin][visit_id - 1] assert visit is not None visit_key = (origin, visit_id) visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date) return visit, visit_update def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]: """Merge origin visit and latest origin visit status """ visit, visit_update = self._origin_visit_status_get_latest(origin, visit_id) assert visit is not None and visit_update is not None return { # default to the values in visit **visit.to_dict(), # override with the last update **visit_update.to_dict(), # but keep the date of the creation of the origin visit "date": visit.date, } def origin_visit_get( self, origin: str, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisit]: next_page_token = None page_token = page_token or "0" if not isinstance(order, ListOrder): raise StorageArgumentException("order must be a ListOrder value") if not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") visit_from = int(page_token) origin_url = self._get_origin_url(origin) extra_limit = limit + 1 visits = sorted( self._origin_visits.get(origin_url, []), key=lambda v: v.visit, reverse=(order == ListOrder.DESC), ) if visit_from > 0 and order == ListOrder.ASC: visits = [v for v in visits if v.visit > visit_from] elif visit_from > 0 and order == ListOrder.DESC: visits = [v for v in visits if v.visit < visit_from] visits = visits[:extra_limit] assert len(visits) <= extra_limit if len(visits) == extra_limit: visits = visits[:limit] next_page_token = str(visits[-1].visit) return PagedResult(results=visits, next_page_token=next_page_token) def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime ) -> Optional[OriginVisit]: origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] visit = min(visits, key=lambda v: (abs(v.date - visit_date), -v.visit)) return visit return None def origin_visit_get_by(self, origin: str, visit: int) -> Optional[OriginVisit]: origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits and visit <= len( self._origin_visits[origin_url] ): found_visit, _ = self._origin_visit_status_get_latest(origin, visit) return found_visit return None def origin_visit_get_latest( self, origin: str, type: Optional[str] = None, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisit]: if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): raise StorageArgumentException( f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) ori = self._origins.get(origin) if not ori: return None visits = sorted( self._origin_visits[ori.url], key=lambda v: (v.date, v.visit), reverse=True, ) for visit in visits: if type is not None and visit.type != type: continue visit_statuses = self._origin_visit_statuses[origin, visit.visit] if allowed_statuses is not None: visit_statuses = [ vs for vs in visit_statuses if vs.status in allowed_statuses ] if require_snapshot: visit_statuses = [vs for vs in visit_statuses if vs.snapshot] if visit_statuses: # we found visit statuses matching criteria visit_status = max(visit_statuses, key=lambda vs: (vs.date, vs.visit)) assert visit.origin == visit_status.origin assert visit.visit == visit_status.visit return visit return None def origin_visit_status_get( self, origin: str, visit: int, page_token: Optional[str] = None, order: ListOrder = ListOrder.ASC, limit: int = 10, ) -> PagedResult[OriginVisitStatus]: next_page_token = None date_from = None if page_token is not None: date_from = datetime.datetime.fromisoformat(page_token) visit_statuses = sorted( self._origin_visit_statuses.get((origin, visit), []), key=lambda v: v.date, reverse=(order == ListOrder.DESC), ) if date_from is not None: if order == ListOrder.ASC: visit_statuses = [v for v in visit_statuses if v.date >= date_from] elif order == ListOrder.DESC: visit_statuses = [v for v in visit_statuses if v.date <= date_from] # Take one more visit status so we can reuse it as the next page token if any visit_statuses = visit_statuses[: limit + 1] if len(visit_statuses) > limit: # last visit status date is the next page token next_page_token = str(visit_statuses[-1].date) # excluding that visit status from the result to respect the limit size visit_statuses = visit_statuses[:limit] return PagedResult(results=visit_statuses, next_page_token=next_page_token) def origin_visit_status_get_latest( self, origin_url: str, visit: int, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisitStatus]: if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): raise StorageArgumentException( f"Unknown allowed statuses {','.join(allowed_statuses)}, only " f"{','.join(VISIT_STATUSES)} authorized" ) ori = self._origins.get(origin_url) if not ori: return None visit_key = (origin_url, visit) visits = self._origin_visit_statuses.get(visit_key) if not visits: return None if allowed_statuses is not None: visits = [visit for visit in visits if visit.status in allowed_statuses] if require_snapshot: visits = [visit for visit in visits if visit.snapshot] visit_status = max(visits, key=lambda v: (v.date, v.visit), default=None) return visit_status def _select_random_origin_visit_by_type(self, type: str) -> str: while True: url = random.choice(list(self._origin_visits.keys())) random_origin_visits = self._origin_visits[url] if random_origin_visits[0].type == type: return url def origin_visit_status_get_random( self, type: str ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: url = self._select_random_origin_visit_by_type(type) random_origin_visits = copy.deepcopy(self._origin_visits[url]) random_origin_visits.reverse() back_in_the_day = now() - timedelta(weeks=12) # 3 months back # This should be enough for tests for visit in random_origin_visits: origin_visit, latest_visit_status = self._origin_visit_status_get_latest( url, visit.visit ) assert latest_visit_status is not None if ( origin_visit.date > back_in_the_day and latest_visit_status.status == "full" ): return origin_visit, latest_visit_status else: return None def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata],) -> None: self.journal_writer.raw_extrinsic_metadata_add(metadata) for metadata_entry in metadata: authority_key = self._metadata_authority_key(metadata_entry.authority) if authority_key not in self._metadata_authorities: raise StorageArgumentException( f"Unknown authority {metadata_entry.authority}" ) fetcher_key = self._metadata_fetcher_key(metadata_entry.fetcher) if fetcher_key not in self._metadata_fetchers: raise StorageArgumentException( f"Unknown fetcher {metadata_entry.fetcher}" ) raw_extrinsic_metadata_list = self._raw_extrinsic_metadata[ metadata_entry.type ][metadata_entry.id][authority_key] for existing_raw_extrinsic_metadata in raw_extrinsic_metadata_list: if ( self._metadata_fetcher_key(existing_raw_extrinsic_metadata.fetcher) == fetcher_key and existing_raw_extrinsic_metadata.discovery_date == metadata_entry.discovery_date ): # Duplicate of an existing one; ignore it. break else: raw_extrinsic_metadata_list.add(metadata_entry) def raw_extrinsic_metadata_get( self, type: MetadataTargetType, id: Union[str, SWHID], authority: MetadataAuthority, after: Optional[datetime.datetime] = None, page_token: Optional[bytes] = None, limit: int = 1000, ) -> PagedResult[RawExtrinsicMetadata]: authority_key = self._metadata_authority_key(authority) if type == MetadataTargetType.ORIGIN: if isinstance(id, SWHID): raise StorageArgumentException( f"raw_extrinsic_metadata_get called with type='origin', " f"but provided id is an SWHID: {id!r}" ) else: if not isinstance(id, SWHID): raise StorageArgumentException( f"raw_extrinsic_metadata_get called with type!='origin', " f"but provided id is not an SWHID: {id!r}" ) if page_token is not None: (after_time, after_fetcher) = msgpack_loads(base64.b64decode(page_token)) after_fetcher = tuple(after_fetcher) if after is not None and after > after_time: raise StorageArgumentException( "page_token is inconsistent with the value of 'after'." ) entries = self._raw_extrinsic_metadata[type][id][authority_key].iter_after( (after_time, after_fetcher) ) elif after is not None: entries = self._raw_extrinsic_metadata[type][id][authority_key].iter_from( (after,) ) entries = (entry for entry in entries if entry.discovery_date > after) else: entries = iter(self._raw_extrinsic_metadata[type][id][authority_key]) if limit: entries = itertools.islice(entries, 0, limit + 1) results = [] for entry in entries: entry_authority = self._metadata_authorities[ self._metadata_authority_key(entry.authority) ] entry_fetcher = self._metadata_fetchers[ self._metadata_fetcher_key(entry.fetcher) ] if after: assert entry.discovery_date > after results.append( attr.evolve( entry, authority=attr.evolve(entry_authority, metadata=None), fetcher=attr.evolve(entry_fetcher, metadata=None), ) ) if len(results) > limit: results.pop() assert len(results) == limit last_result = results[-1] next_page_token: Optional[str] = base64.b64encode( msgpack_dumps( ( last_result.discovery_date, self._metadata_fetcher_key(last_result.fetcher), ) ) ).decode() else: next_page_token = None return PagedResult(next_page_token=next_page_token, results=results,) def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> None: self.journal_writer.metadata_fetcher_add(fetchers) for fetcher in fetchers: if fetcher.metadata is None: raise StorageArgumentException( "MetadataFetcher.metadata may not be None in metadata_fetcher_add." ) key = self._metadata_fetcher_key(fetcher) if key not in self._metadata_fetchers: self._metadata_fetchers[key] = fetcher def metadata_fetcher_get( self, name: str, version: str ) -> Optional[MetadataFetcher]: return self._metadata_fetchers.get( self._metadata_fetcher_key(MetadataFetcher(name=name, version=version)) ) def metadata_authority_add(self, authorities: List[MetadataAuthority]) -> None: self.journal_writer.metadata_authority_add(authorities) for authority in authorities: if authority.metadata is None: raise StorageArgumentException( "MetadataAuthority.metadata may not be None in " "metadata_authority_add." ) key = self._metadata_authority_key(authority) self._metadata_authorities[key] = authority def metadata_authority_get( self, type: MetadataAuthorityType, url: str ) -> Optional[MetadataAuthority]: return self._metadata_authorities.get( self._metadata_authority_key(MetadataAuthority(type=type, url=url)) ) def _get_origin_url(self, origin): if isinstance(origin, str): return origin else: raise TypeError("origin must be a string.") def _person_add(self, person): key = ("person", person.fullname) if key not in self._objects: self._persons[person.fullname] = person self._objects[key].append(key) return self._persons[person.fullname] @staticmethod def _content_key(content): """ A stable key and the algorithm for a content""" if isinstance(content, BaseContent): content = content.to_dict() return tuple((key, content.get(key)) for key in sorted(DEFAULT_ALGORITHMS)) @staticmethod def _metadata_fetcher_key(fetcher: MetadataFetcher) -> FetcherKey: return (fetcher.name, fetcher.version) @staticmethod def _metadata_authority_key(authority: MetadataAuthority) -> Hashable: return (authority.type, authority.url) def diff_directories(self, from_dir, to_dir, track_renaming=False): raise NotImplementedError("InMemoryStorage.diff_directories") def diff_revisions(self, from_rev, to_rev, track_renaming=False): raise NotImplementedError("InMemoryStorage.diff_revisions") def diff_revision(self, revision, track_renaming=False): raise NotImplementedError("InMemoryStorage.diff_revision") def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: """Do nothing """ return None def flush(self, object_types: Optional[List[str]] = None) -> Dict: return {} diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py index 075ad0b7..40a858a7 100644 --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -1,70 +1,66 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from unittest.mock import patch - import pytest import swh.storage.api.server as server import swh.storage.storage from swh.storage import get_storage from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa from swh.storage.tests.test_storage import TestStorage as _TestStorage # tests are executed using imported classes (TestStorage and # TestStorageGeneratedData) using overloaded swh_storage fixture # below @pytest.fixture def app_server(): server.storage = swh.storage.get_storage( cls="memory", journal_writer={"cls": "memory"} ) yield server @pytest.fixture def app(app_server): return app_server.app @pytest.fixture def swh_rpc_client_class(): def storage_factory(**kwargs): storage_config = { "cls": "remote", **kwargs, } return get_storage(**storage_config) return storage_factory @pytest.fixture def swh_storage(swh_rpc_client, app_server): # This version of the swh_storage fixture uses the swh_rpc_client fixture # to instantiate a RemoteStorage (see swh_rpc_client_class above) that # proxies, via the swh.core RPC mechanism, the local (in memory) storage # configured in the app_server fixture above. # # Also note that, for the sake of # making it easier to write tests, the in-memory journal writer of the # in-memory backend storage is attached to the RemoteStorage as its # journal_writer attribute. storage = swh_rpc_client journal_writer = getattr(storage, "journal_writer", None) storage.journal_writer = app_server.storage.journal_writer yield storage storage.journal_writer = journal_writer class TestStorage(_TestStorage): - def test_content_update(self, swh_storage, app_server, sample_data): - # TODO, journal_writer not supported - swh_storage.journal_writer.journal = None - with patch.object(server.storage.journal_writer, "journal", None): - super().test_content_update(swh_storage, sample_data) + @pytest.mark.skip("content_update is not yet implemented for Cassandra") + def test_content_update(self): + pass diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py index bd3d63de..91637edc 100644 --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -1,146 +1,153 @@ # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import dataclasses import pytest from swh.storage.cassandra.model import BaseRow from swh.storage.in_memory import SortedList, Table -from swh.storage.tests.test_storage import TestStorage, TestStorageGeneratedData # noqa +from swh.storage.tests.test_storage import TestStorage as _TestStorage +from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa # tests are executed using imported classes (TestStorage and # TestStorageGeneratedData) using overloaded swh_storage fixture # below @pytest.fixture def swh_storage_backend_config(): yield { "cls": "memory", "journal_writer": {"cls": "memory",}, } parametrize = pytest.mark.parametrize( "items", [ [1, 2, 3, 4, 5, 6, 10, 100], [10, 100, 6, 5, 4, 3, 2, 1], [10, 4, 5, 6, 1, 2, 3, 100], ], ) @parametrize def test_sorted_list_iter(items): list1 = SortedList() for item in items: list1.add(item) assert list(list1) == sorted(items) list2 = SortedList(items) assert list(list2) == sorted(items) @parametrize def test_sorted_list_iter__key(items): list1 = SortedList(key=lambda item: -item) for item in items: list1.add(item) assert list(list1) == list(reversed(sorted(items))) list2 = SortedList(items, key=lambda item: -item) assert list(list2) == list(reversed(sorted(items))) @parametrize def test_sorted_list_iter_from(items): list_ = SortedList(items) for split in items: expected = sorted(item for item in items if item >= split) assert list(list_.iter_from(split)) == expected, f"split: {split}" @parametrize def test_sorted_list_iter_from__key(items): list_ = SortedList(items, key=lambda item: -item) for split in items: expected = reversed(sorted(item for item in items if item <= split)) assert list(list_.iter_from(-split)) == list(expected), f"split: {split}" @parametrize def test_sorted_list_iter_after(items): list_ = SortedList(items) for split in items: expected = sorted(item for item in items if item > split) assert list(list_.iter_after(split)) == expected, f"split: {split}" @parametrize def test_sorted_list_iter_after__key(items): list_ = SortedList(items, key=lambda item: -item) for split in items: expected = reversed(sorted(item for item in items if item < split)) assert list(list_.iter_after(-split)) == list(expected), f"split: {split}" @dataclasses.dataclass class Row(BaseRow): PARTITION_KEY = ("col1", "col2") CLUSTERING_KEY = ("col3", "col4") col1: str col2: str col3: str col4: str col5: str col6: int def test_table_keys(): table = Table(Row) primary_key = ("foo", "bar", "baz", "qux") partition_key = ("foo", "bar") clustering_key = ("baz", "qux") row = Row(col1="foo", col2="bar", col3="baz", col4="qux", col5="quux", col6=4) assert table.partition_key(row) == partition_key assert table.clustering_key(row) == clustering_key assert table.primary_key(row) == primary_key assert table.primary_key_from_dict(row.to_dict()) == primary_key assert table.split_primary_key(primary_key) == (partition_key, clustering_key) def test_table(): table = Table(Row) row1 = Row(col1="foo", col2="bar", col3="baz", col4="qux", col5="quux", col6=4) row2 = Row(col1="foo", col2="bar", col3="baz", col4="qux2", col5="quux", col6=4) row3 = Row(col1="foo", col2="bar", col3="baz", col4="qux1", col5="quux", col6=4) partition_key = ("foo", "bar") primary_key1 = ("foo", "bar", "baz", "qux") primary_key2 = ("foo", "bar", "baz", "qux2") primary_key3 = ("foo", "bar", "baz", "qux1") table.insert(row1) table.insert(row2) table.insert(row3) assert table.get_from_primary_key(primary_key1) == row1 assert table.get_from_primary_key(primary_key2) == row2 assert table.get_from_primary_key(primary_key3) == row3 # order matters assert list(table.get_from_token(table.token(partition_key))) == [row1, row3, row2] all_rows = list(table.iter_all()) assert len(all_rows) == 3 for row in (row1, row2, row3): assert (table.primary_key(row), row) in all_rows + + +class TestInMemoryStorage(_TestStorage): + @pytest.mark.skip("content_update is not yet implemented for Cassandra") + def test_content_update(self): + pass diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py index d7f43cd7..2bdb3547 100644 --- a/swh/storage/tests/test_replay.py +++ b/swh/storage/tests/test_replay.py @@ -1,358 +1,407 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import dataclasses import datetime import functools import logging from typing import Any, Container, Dict, Optional +import attr import pytest from swh.model.hashutil import hash_to_hex, MultiHash, DEFAULT_ALGORITHMS from swh.storage import get_storage +from swh.storage.cassandra.model import ContentRow, SkippedContentRow from swh.storage.in_memory import InMemoryStorage from swh.storage.replay import process_replay_objects from swh.journal.serializers import key_to_kafka, value_to_kafka from swh.journal.client import JournalClient from swh.journal.tests.journal_data import ( TEST_OBJECTS, DUPLICATE_CONTENTS, ) UTC = datetime.timezone.utc +def nullify_ctime(obj): + if isinstance(obj, (ContentRow, SkippedContentRow)): + return dataclasses.replace(obj, ctime=None) + else: + return obj + + @pytest.fixture() def replayer_storage_and_client( kafka_prefix: str, kafka_consumer_group: str, kafka_server: str ): journal_writer_config = { "cls": "kafka", "brokers": [kafka_server], "client_id": "kafka_writer", "prefix": kafka_prefix, } storage_config: Dict[str, Any] = { "cls": "memory", "journal_writer": journal_writer_config, } storage = get_storage(**storage_config) replayer = JournalClient( brokers=kafka_server, group_id=kafka_consumer_group, prefix=kafka_prefix, stop_on_eof=True, ) yield storage, replayer def test_storage_replayer(replayer_storage_and_client, caplog): """Optimal replayer scenario. This: - writes objects to a source storage - replayer consumes objects from the topic and replays them - a destination storage is filled from this In the end, both storages should have the same content. """ src, replayer = replayer_storage_and_client # Fill Kafka using a source storage nb_sent = 0 for object_type, objects in TEST_OBJECTS.items(): method = getattr(src, object_type + "_add") method(objects) if object_type == "origin_visit": nb_sent += len(objects) # origin-visit-add adds origin-visit-status as well nb_sent += len(objects) caplog.set_level(logging.ERROR, "swh.journal.replay") # Fill the destination storage from Kafka dst = get_storage(cls="memory") worker_fn = functools.partial(process_replay_objects, storage=dst) nb_inserted = replayer.process(worker_fn) assert nb_sent == nb_inserted _check_replayed(src, dst) collision = 0 for record in caplog.records: logtext = record.getMessage() if "Colliding contents:" in logtext: collision += 1 assert collision == 0, "No collision should be detected" def test_storage_play_with_collision(replayer_storage_and_client, caplog): """Another replayer scenario with collisions. This: - writes objects to the topic, including colliding contents - replayer consumes objects from the topic and replay them - This drops the colliding contents from the replay when detected """ src, replayer = replayer_storage_and_client # Fill Kafka using a source storage nb_sent = 0 for object_type, objects in TEST_OBJECTS.items(): method = getattr(src, object_type + "_add") method(objects) if object_type == "origin_visit": nb_sent += len(objects) # origin-visit-add adds origin-visit-status as well nb_sent += len(objects) # Create collision in input data # These should not be written in the destination producer = src.journal_writer.journal.producer prefix = src.journal_writer.journal._prefix for content in DUPLICATE_CONTENTS: topic = f"{prefix}.content" key = content.sha1 + now = datetime.datetime.now(tz=UTC) + content = attr.evolve(content, ctime=now) producer.produce( topic=topic, key=key_to_kafka(key), value=value_to_kafka(content.to_dict()), ) nb_sent += 1 producer.flush() caplog.set_level(logging.ERROR, "swh.journal.replay") # Fill the destination storage from Kafka dst = get_storage(cls="memory") worker_fn = functools.partial(process_replay_objects, storage=dst) nb_inserted = replayer.process(worker_fn) assert nb_sent == nb_inserted # check the logs for the collision being properly detected nb_collisions = 0 actual_collision: Dict for record in caplog.records: logtext = record.getMessage() if "Collision detected:" in logtext: nb_collisions += 1 actual_collision = record.args["collision"] assert nb_collisions == 1, "1 collision should be detected" algo = "sha1" assert actual_collision["algo"] == algo expected_colliding_hash = hash_to_hex(DUPLICATE_CONTENTS[0].get_hash(algo)) assert actual_collision["hash"] == expected_colliding_hash actual_colliding_hashes = actual_collision["objects"] assert len(actual_colliding_hashes) == len(DUPLICATE_CONTENTS) for content in DUPLICATE_CONTENTS: expected_content_hashes = { k: hash_to_hex(v) for k, v in content.hashes().items() } assert expected_content_hashes in actual_colliding_hashes # all objects from the src should exists in the dst storage _check_replayed(src, dst, exclude=["contents"]) # but the dst has one content more (one of the 2 colliding ones) - assert len(src._contents) == len(dst._contents) - 1 + assert ( + len(list(src._cql_runner._contents.iter_all())) + == len(list(dst._cql_runner._contents.iter_all())) - 1 + ) def test_replay_skipped_content(replayer_storage_and_client): """Test the 'skipped_content' topic is properly replayed.""" src, replayer = replayer_storage_and_client _check_replay_skipped_content(src, replayer, "skipped_content") def test_replay_skipped_content_bwcompat(replayer_storage_and_client): """Test the 'content' topic can be used to replay SkippedContent objects.""" src, replayer = replayer_storage_and_client _check_replay_skipped_content(src, replayer, "content") # utility functions def _check_replayed( src: InMemoryStorage, dst: InMemoryStorage, exclude: Optional[Container] = None ): """Simple utility function to compare the content of 2 in_memory storages """ expected_persons = set(src._persons.values()) got_persons = set(dst._persons.values()) assert got_persons == expected_persons - for attr in ( - "contents", + for attr_ in ( "skipped_contents", "directories", "revisions", "releases", "snapshots", "origins", "origin_visits", "origin_visit_statuses", ): - if exclude and attr in exclude: + if exclude and attr_ in exclude: continue - expected_objects = sorted(getattr(src, f"_{attr}").items()) - got_objects = sorted(getattr(dst, f"_{attr}").items()) - assert got_objects == expected_objects, f"Mismatch object list for {attr}" + expected_objects = sorted(getattr(src, f"_{attr_}").items()) + got_objects = sorted(getattr(dst, f"_{attr_}").items()) + assert got_objects == expected_objects, f"Mismatch object list for {attr_}" + + for attr_ in ("contents",): + if exclude and attr_ in exclude: + continue + expected_objects = [ + (id, nullify_ctime(obj)) + for id, obj in sorted(getattr(src._cql_runner, f"_{attr_}").iter_all()) + ] + got_objects = [ + (id, nullify_ctime(obj)) + for id, obj in sorted(getattr(dst._cql_runner, f"_{attr_}").iter_all()) + ] + assert got_objects == expected_objects, f"Mismatch object list for {attr_}" def _check_replay_skipped_content(storage, replayer, topic): skipped_contents = _gen_skipped_contents(100) nb_sent = len(skipped_contents) producer = storage.journal_writer.journal.producer prefix = storage.journal_writer.journal._prefix for i, obj in enumerate(skipped_contents): producer.produce( topic=f"{prefix}.{topic}", key=key_to_kafka({"sha1": obj["sha1"]}), value=value_to_kafka(obj), ) producer.flush() dst_storage = get_storage(cls="memory") worker_fn = functools.partial(process_replay_objects, storage=dst_storage) nb_inserted = replayer.process(worker_fn) assert nb_sent == nb_inserted for content in skipped_contents: assert not storage.content_find({"sha1": content["sha1"]}) # no skipped_content_find API endpoint, so use this instead assert not list(dst_storage.skipped_content_missing(skipped_contents)) def _updated(d1, d2): d1.update(d2) d1.pop("data", None) return d1 def _gen_skipped_contents(n=10): # we do not use the hypothesis strategy here because this does not play well with # pytest fixtures, and it makes test execution very slow algos = DEFAULT_ALGORITHMS | {"length"} now = datetime.datetime.now(tz=UTC) return [ _updated( MultiHash.from_data(data=f"foo{i}".encode(), hash_names=algos).digest(), { "status": "absent", "reason": "why not", "origin": f"https://somewhere/{i}", "ctime": now, }, ) for i in range(n) ] def test_storage_play_anonymized( kafka_prefix: str, kafka_consumer_group: str, kafka_server: str ): """Optimal replayer scenario. This: - writes objects to the topic - replayer consumes objects from the topic and replay them """ writer_config = { "cls": "kafka", "brokers": [kafka_server], "client_id": "kafka_writer", "prefix": kafka_prefix, "anonymize": True, } src_config: Dict[str, Any] = {"cls": "memory", "journal_writer": writer_config} storage = get_storage(**src_config) # Fill the src storage nb_sent = 0 for obj_type, objs in TEST_OBJECTS.items(): if obj_type in ("origin_visit", "origin_visit_status"): # these are unrelated with what we want to test here continue method = getattr(storage, obj_type + "_add") method(objs) nb_sent += len(objs) # Fill a destination storage from Kafka **using anonymized topics** dst_storage = get_storage(cls="memory") replayer = JournalClient( brokers=kafka_server, group_id=kafka_consumer_group, prefix=kafka_prefix, stop_after_objects=nb_sent, privileged=False, ) worker_fn = functools.partial(process_replay_objects, storage=dst_storage) nb_inserted = replayer.process(worker_fn) assert nb_sent == nb_inserted check_replayed(storage, dst_storage, expected_anonymized=True) # Fill a destination storage from Kafka **with stock (non-anonymized) topics** dst_storage = get_storage(cls="memory") replayer = JournalClient( brokers=kafka_server, group_id=kafka_consumer_group, prefix=kafka_prefix, stop_after_objects=nb_sent, privileged=True, ) worker_fn = functools.partial(process_replay_objects, storage=dst_storage) nb_inserted = replayer.process(worker_fn) assert nb_sent == nb_inserted check_replayed(storage, dst_storage, expected_anonymized=False) def check_replayed(src, dst, expected_anonymized=False): """Simple utility function to compare the content of 2 in_memory storages If expected_anonymized is True, objects from the source storage are anonymized before comparing with the destination storage. """ - def maybe_anonymize(obj): + def maybe_anonymize(attr_, row): if expected_anonymized: - return obj.anonymize() or obj - return obj - - expected_persons = {maybe_anonymize(person) for person in src._persons.values()} + if hasattr(row, "anonymize"): + # for model objects; cases below are for BaseRow objects + row = row.anonymize() or row + elif attr_ == "releases": + row = dataclasses.replace(row, author=row.author.anonymize()) + elif attr_ == "revisions": + row = dataclasses.replace( + row, + author=row.author.anonymize(), + committer=row.committer.anonymize(), + ) + return row + + expected_persons = { + maybe_anonymize("persons", person) for person in src._persons.values() + } got_persons = set(dst._persons.values()) assert got_persons == expected_persons - for attr in ( - "contents", + for attr_ in ( "skipped_contents", "directories", "revisions", "releases", "snapshots", "origins", "origin_visit_statuses", ): expected_objects = [ - (id, maybe_anonymize(obj)) - for id, obj in sorted(getattr(src, f"_{attr}").items()) + (id, maybe_anonymize(attr_, obj)) + for id, obj in sorted(getattr(src, f"_{attr_}").items()) + ] + got_objects = [ + (id, obj) for id, obj in sorted(getattr(dst, f"_{attr_}").items()) + ] + assert got_objects == expected_objects, f"Mismatch object list for {attr_}" + + for attr_ in ("contents",): + expected_objects = [ + (id, nullify_ctime(maybe_anonymize(attr_, obj))) + for id, obj in sorted(getattr(src._cql_runner, f"_{attr_}").iter_all()) ] got_objects = [ - (id, obj) for id, obj in sorted(getattr(dst, f"_{attr}").items()) + (id, nullify_ctime(obj)) + for id, obj in sorted(getattr(dst._cql_runner, f"_{attr_}").iter_all()) ] - assert got_objects == expected_objects, f"Mismatch object list for {attr}" + assert got_objects == expected_objects, f"Mismatch object list for {attr_}" diff --git a/swh/storage/tests/test_retry.py b/swh/storage/tests/test_retry.py index 958b6d58..058cf987 100644 --- a/swh/storage/tests/test_retry.py +++ b/swh/storage/tests/test_retry.py @@ -1,820 +1,821 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import attr from unittest.mock import call import psycopg2 import pytest from swh.model.model import MetadataTargetType from swh.storage.exc import HashCollision, StorageArgumentException +from swh.storage.utils import now @pytest.fixture def monkeypatch_sleep(monkeypatch, swh_storage): """In test context, we don't want to wait, make test faster """ from swh.storage.retry import RetryingProxyStorage for method_name, method in RetryingProxyStorage.__dict__.items(): if "_add" in method_name or "_update" in method_name: monkeypatch.setattr(method.retry, "sleep", lambda x: None) return monkeypatch @pytest.fixture def fake_hash_collision(sample_data): return HashCollision("sha1", "38762cf7f55934b34d179ae6a4c80cadccbb7f0a", []) @pytest.fixture def swh_storage_backend_config(): yield { "cls": "pipeline", "steps": [{"cls": "retry"}, {"cls": "memory"},], } def test_retrying_proxy_storage_content_add(swh_storage, sample_data): """Standard content_add works as before """ sample_content = sample_data.content content = swh_storage.content_get_data(sample_content.sha1) assert content is None s = swh_storage.content_add([sample_content]) assert s == { "content:add": 1, "content:add:bytes": sample_content.length, } content = swh_storage.content_get_data(sample_content.sha1) assert content == sample_content.data def test_retrying_proxy_storage_content_add_with_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision, ): """Multiple retries for hash collision and psycopg2 error but finally ok """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.content_add") mock_memory.side_effect = [ # first try goes ko fake_hash_collision, # second try goes ko psycopg2.IntegrityError("content already inserted"), # ok then! {"content:add": 1}, ] sample_content = sample_data.content content = swh_storage.content_get_data(sample_content.sha1) assert content is None s = swh_storage.content_add([sample_content]) assert s == {"content:add": 1} mock_memory.assert_has_calls( [call([sample_content]), call([sample_content]), call([sample_content]),] ) def test_retrying_proxy_swh_storage_content_add_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.content_add") mock_memory.side_effect = StorageArgumentException("Refuse to add content always!") sample_content = sample_data.content content = swh_storage.content_get_data(sample_content.sha1) assert content is None with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.content_add([sample_content]) assert mock_memory.call_count == 1 def test_retrying_proxy_storage_content_add_metadata(swh_storage, sample_data): """Standard content_add_metadata works as before """ sample_content = sample_data.content content = attr.evolve(sample_content, data=None) pk = content.sha1 content_metadata = swh_storage.content_get([pk]) assert content_metadata == [None] - s = swh_storage.content_add_metadata([content]) + s = swh_storage.content_add_metadata([attr.evolve(content, ctime=now())]) assert s == { "content:add": 1, } content_metadata = swh_storage.content_get([pk]) assert len(content_metadata) == 1 assert content_metadata[0].sha1 == pk def test_retrying_proxy_storage_content_add_metadata_with_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision ): """Multiple retries for hash collision and psycopg2 error but finally ok """ mock_memory = mocker.patch( "swh.storage.in_memory.InMemoryStorage.content_add_metadata" ) mock_memory.side_effect = [ # first try goes ko fake_hash_collision, # second try goes ko psycopg2.IntegrityError("content_metadata already inserted"), # ok then! {"content:add": 1}, ] sample_content = sample_data.content content = attr.evolve(sample_content, data=None) s = swh_storage.content_add_metadata([content]) assert s == {"content:add": 1} mock_memory.assert_has_calls( [call([content]), call([content]), call([content]),] ) def test_retrying_proxy_swh_storage_content_add_metadata_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch( "swh.storage.in_memory.InMemoryStorage.content_add_metadata" ) mock_memory.side_effect = StorageArgumentException( "Refuse to add content_metadata!" ) sample_content = sample_data.content content = attr.evolve(sample_content, data=None) with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.content_add_metadata([content]) assert mock_memory.call_count == 1 def test_retrying_proxy_storage_skipped_content_add(swh_storage, sample_data): """Standard skipped_content_add works as before """ sample_content = sample_data.skipped_content sample_content_dict = sample_content.to_dict() skipped_contents = list(swh_storage.skipped_content_missing([sample_content_dict])) assert len(skipped_contents) == 1 s = swh_storage.skipped_content_add([sample_content]) assert s == { "skipped_content:add": 1, } skipped_content = list(swh_storage.skipped_content_missing([sample_content_dict])) assert len(skipped_content) == 0 def test_retrying_proxy_storage_skipped_content_add_with_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision ): """Multiple retries for hash collision and psycopg2 error but finally ok """ mock_memory = mocker.patch( "swh.storage.in_memory.InMemoryStorage.skipped_content_add" ) mock_memory.side_effect = [ # 1st & 2nd try goes ko fake_hash_collision, psycopg2.IntegrityError("skipped_content already inserted"), # ok then! {"skipped_content:add": 1}, ] sample_content = sample_data.skipped_content s = swh_storage.skipped_content_add([sample_content]) assert s == {"skipped_content:add": 1} mock_memory.assert_has_calls( [call([sample_content]), call([sample_content]), call([sample_content]),] ) def test_retrying_proxy_swh_storage_skipped_content_add_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch( "swh.storage.in_memory.InMemoryStorage.skipped_content_add" ) mock_memory.side_effect = StorageArgumentException( "Refuse to add content_metadata!" ) sample_content = sample_data.skipped_content sample_content_dict = sample_content.to_dict() skipped_contents = list(swh_storage.skipped_content_missing([sample_content_dict])) assert len(skipped_contents) == 1 with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.skipped_content_add([sample_content]) skipped_contents = list(swh_storage.skipped_content_missing([sample_content_dict])) assert len(skipped_contents) == 1 assert mock_memory.call_count == 1 def test_retrying_proxy_swh_storage_origin_visit_add(swh_storage, sample_data): """Standard origin_visit_add works as before """ origin = sample_data.origin visit = sample_data.origin_visit assert visit.origin == origin.url swh_storage.origin_add([origin]) origins = swh_storage.origin_visit_get(origin.url).results assert not origins origin_visit = swh_storage.origin_visit_add([visit])[0] assert origin_visit.origin == origin.url assert isinstance(origin_visit.visit, int) actual_visit = swh_storage.origin_visit_get(origin.url).results[0] assert actual_visit == visit def test_retrying_proxy_swh_storage_origin_visit_add_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision ): """Multiple retries for hash collision and psycopg2 error but finally ok """ origin = sample_data.origin visit = sample_data.origin_visit assert visit.origin == origin.url swh_storage.origin_add([origin]) mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.origin_visit_add") mock_memory.side_effect = [ # first try goes ko fake_hash_collision, # second try goes ko psycopg2.IntegrityError("origin already inserted"), # ok then! [visit], ] origins = swh_storage.origin_visit_get(origin.url).results assert not origins r = swh_storage.origin_visit_add([visit]) assert r == [visit] mock_memory.assert_has_calls( [call([visit]), call([visit]), call([visit]),] ) def test_retrying_proxy_swh_storage_origin_visit_add_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.origin_visit_add") mock_memory.side_effect = StorageArgumentException("Refuse to add origin always!") origin = sample_data.origin visit = sample_data.origin_visit assert visit.origin == origin.url origins = swh_storage.origin_visit_get(origin.url).results assert not origins with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.origin_visit_add([visit]) mock_memory.assert_has_calls( [call([visit]),] ) def test_retrying_proxy_storage_metadata_fetcher_add(swh_storage, sample_data): """Standard metadata_fetcher_add works as before """ fetcher = sample_data.metadata_fetcher metadata_fetcher = swh_storage.metadata_fetcher_get(fetcher.name, fetcher.version) assert not metadata_fetcher swh_storage.metadata_fetcher_add([fetcher]) actual_fetcher = swh_storage.metadata_fetcher_get(fetcher.name, fetcher.version) assert actual_fetcher == fetcher def test_retrying_proxy_storage_metadata_fetcher_add_with_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision, ): """Multiple retries for hash collision and psycopg2 error but finally ok """ fetcher = sample_data.metadata_fetcher mock_memory = mocker.patch( "swh.storage.in_memory.InMemoryStorage.metadata_fetcher_add" ) mock_memory.side_effect = [ # first try goes ko fake_hash_collision, # second try goes ko psycopg2.IntegrityError("metadata_fetcher already inserted"), # ok then! [fetcher], ] actual_fetcher = swh_storage.metadata_fetcher_get(fetcher.name, fetcher.version) assert not actual_fetcher swh_storage.metadata_fetcher_add([fetcher]) mock_memory.assert_has_calls( [call([fetcher]), call([fetcher]), call([fetcher]),] ) def test_retrying_proxy_swh_storage_metadata_fetcher_add_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch( "swh.storage.in_memory.InMemoryStorage.metadata_fetcher_add" ) mock_memory.side_effect = StorageArgumentException( "Refuse to add metadata_fetcher always!" ) fetcher = sample_data.metadata_fetcher actual_fetcher = swh_storage.metadata_fetcher_get(fetcher.name, fetcher.version) assert not actual_fetcher with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.metadata_fetcher_add([fetcher]) assert mock_memory.call_count == 1 def test_retrying_proxy_storage_metadata_authority_add(swh_storage, sample_data): """Standard metadata_authority_add works as before """ authority = sample_data.metadata_authority assert not swh_storage.metadata_authority_get(authority.type, authority.url) swh_storage.metadata_authority_add([authority]) actual_authority = swh_storage.metadata_authority_get(authority.type, authority.url) assert actual_authority == authority def test_retrying_proxy_storage_metadata_authority_add_with_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision, ): """Multiple retries for hash collision and psycopg2 error but finally ok """ authority = sample_data.metadata_authority mock_memory = mocker.patch( "swh.storage.in_memory.InMemoryStorage.metadata_authority_add" ) mock_memory.side_effect = [ # first try goes ko fake_hash_collision, # second try goes ko psycopg2.IntegrityError("foo bar"), # ok then! None, ] assert not swh_storage.metadata_authority_get(authority.type, authority.url) swh_storage.metadata_authority_add([authority]) mock_memory.assert_has_calls( [call([authority]), call([authority]), call([authority])] ) def test_retrying_proxy_swh_storage_metadata_authority_add_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch( "swh.storage.in_memory.InMemoryStorage.metadata_authority_add" ) mock_memory.side_effect = StorageArgumentException( "Refuse to add authority_id always!" ) authority = sample_data.metadata_authority swh_storage.metadata_authority_get(authority.type, authority.url) with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.metadata_authority_add([authority]) assert mock_memory.call_count == 1 def test_retrying_proxy_storage_raw_extrinsic_metadata_add(swh_storage, sample_data): """Standard raw_extrinsic_metadata_add works as before """ origin = sample_data.origin ori_meta = sample_data.origin_metadata1 assert origin.url == ori_meta.id swh_storage.origin_add([origin]) swh_storage.metadata_authority_add([sample_data.metadata_authority]) swh_storage.metadata_fetcher_add([sample_data.metadata_fetcher]) origin_metadata = swh_storage.raw_extrinsic_metadata_get( MetadataTargetType.ORIGIN, ori_meta.id, ori_meta.authority ) assert origin_metadata.next_page_token is None assert not origin_metadata.results swh_storage.raw_extrinsic_metadata_add([ori_meta]) origin_metadata = swh_storage.raw_extrinsic_metadata_get( MetadataTargetType.ORIGIN, ori_meta.id, ori_meta.authority ) assert origin_metadata def test_retrying_proxy_storage_raw_extrinsic_metadata_add_with_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision, ): """Multiple retries for hash collision and psycopg2 error but finally ok """ origin = sample_data.origin ori_meta = sample_data.origin_metadata1 assert origin.url == ori_meta.id swh_storage.origin_add([origin]) swh_storage.metadata_authority_add([sample_data.metadata_authority]) swh_storage.metadata_fetcher_add([sample_data.metadata_fetcher]) mock_memory = mocker.patch( "swh.storage.in_memory.InMemoryStorage.raw_extrinsic_metadata_add" ) mock_memory.side_effect = [ # first try goes ko fake_hash_collision, # second try goes ko psycopg2.IntegrityError("foo bar"), # ok then! None, ] # No exception raised as insertion finally came through swh_storage.raw_extrinsic_metadata_add([ori_meta]) mock_memory.assert_has_calls( [ # 3 calls, as long as error raised call([ori_meta]), call([ori_meta]), call([ori_meta]), ] ) def test_retrying_proxy_swh_storage_raw_extrinsic_metadata_add_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch( "swh.storage.in_memory.InMemoryStorage.raw_extrinsic_metadata_add" ) mock_memory.side_effect = StorageArgumentException("Refuse to add always!") origin = sample_data.origin ori_meta = sample_data.origin_metadata1 assert origin.url == ori_meta.id swh_storage.origin_add([origin]) with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.raw_extrinsic_metadata_add([ori_meta]) assert mock_memory.call_count == 1 def test_retrying_proxy_storage_directory_add(swh_storage, sample_data): """Standard directory_add works as before """ sample_dir = sample_data.directory s = swh_storage.directory_add([sample_dir]) assert s == { "directory:add": 1, } directory_id = swh_storage.directory_get_random() # only 1 assert directory_id == sample_dir.id def test_retrying_proxy_storage_directory_add_with_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision ): """Multiple retries for hash collision and psycopg2 error but finally ok """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.directory_add") mock_memory.side_effect = [ # first try goes ko fake_hash_collision, # second try goes ko psycopg2.IntegrityError("directory already inserted"), # ok then! {"directory:add": 1}, ] sample_dir = sample_data.directories[1] s = swh_storage.directory_add([sample_dir]) assert s == { "directory:add": 1, } mock_memory.assert_has_calls( [call([sample_dir]), call([sample_dir]), call([sample_dir]),] ) def test_retrying_proxy_swh_storage_directory_add_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.directory_add") mock_memory.side_effect = StorageArgumentException( "Refuse to add directory always!" ) sample_dir = sample_data.directory with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.directory_add([sample_dir]) assert mock_memory.call_count == 1 def test_retrying_proxy_storage_revision_add(swh_storage, sample_data): """Standard revision_add works as before """ sample_rev = sample_data.revision revision = next(swh_storage.revision_get([sample_rev.id])) assert not revision s = swh_storage.revision_add([sample_rev]) assert s == { "revision:add": 1, } revision = next(swh_storage.revision_get([sample_rev.id])) assert revision["id"] == sample_rev.id def test_retrying_proxy_storage_revision_add_with_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision ): """Multiple retries for hash collision and psycopg2 error but finally ok """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.revision_add") mock_memory.side_effect = [ # first try goes ko fake_hash_collision, # second try goes ko psycopg2.IntegrityError("revision already inserted"), # ok then! {"revision:add": 1}, ] sample_rev = sample_data.revision revision = next(swh_storage.revision_get([sample_rev.id])) assert not revision s = swh_storage.revision_add([sample_rev]) assert s == { "revision:add": 1, } mock_memory.assert_has_calls( [call([sample_rev]), call([sample_rev]), call([sample_rev]),] ) def test_retrying_proxy_swh_storage_revision_add_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.revision_add") mock_memory.side_effect = StorageArgumentException("Refuse to add revision always!") sample_rev = sample_data.revision revision = next(swh_storage.revision_get([sample_rev.id])) assert not revision with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.revision_add([sample_rev]) assert mock_memory.call_count == 1 def test_retrying_proxy_storage_release_add(swh_storage, sample_data): """Standard release_add works as before """ sample_rel = sample_data.release release = next(swh_storage.release_get([sample_rel.id])) assert not release s = swh_storage.release_add([sample_rel]) assert s == { "release:add": 1, } release = next(swh_storage.release_get([sample_rel.id])) assert release["id"] == sample_rel.id def test_retrying_proxy_storage_release_add_with_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision ): """Multiple retries for hash collision and psycopg2 error but finally ok """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.release_add") mock_memory.side_effect = [ # first try goes ko fake_hash_collision, # second try goes ko psycopg2.IntegrityError("release already inserted"), # ok then! {"release:add": 1}, ] sample_rel = sample_data.release release = next(swh_storage.release_get([sample_rel.id])) assert not release s = swh_storage.release_add([sample_rel]) assert s == { "release:add": 1, } mock_memory.assert_has_calls( [call([sample_rel]), call([sample_rel]), call([sample_rel]),] ) def test_retrying_proxy_swh_storage_release_add_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.release_add") mock_memory.side_effect = StorageArgumentException("Refuse to add release always!") sample_rel = sample_data.release release = next(swh_storage.release_get([sample_rel.id])) assert not release with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.release_add([sample_rel]) assert mock_memory.call_count == 1 def test_retrying_proxy_storage_snapshot_add(swh_storage, sample_data): """Standard snapshot_add works as before """ sample_snap = sample_data.snapshot snapshot = swh_storage.snapshot_get(sample_snap.id) assert not snapshot s = swh_storage.snapshot_add([sample_snap]) assert s == { "snapshot:add": 1, } snapshot = swh_storage.snapshot_get(sample_snap.id) assert snapshot["id"] == sample_snap.id def test_retrying_proxy_storage_snapshot_add_with_retry( monkeypatch_sleep, swh_storage, sample_data, mocker, fake_hash_collision ): """Multiple retries for hash collision and psycopg2 error but finally ok """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.snapshot_add") mock_memory.side_effect = [ # first try goes ko fake_hash_collision, # second try goes ko psycopg2.IntegrityError("snapshot already inserted"), # ok then! {"snapshot:add": 1}, ] sample_snap = sample_data.snapshot snapshot = swh_storage.snapshot_get(sample_snap.id) assert not snapshot s = swh_storage.snapshot_add([sample_snap]) assert s == { "snapshot:add": 1, } mock_memory.assert_has_calls( [call([sample_snap]), call([sample_snap]), call([sample_snap]),] ) def test_retrying_proxy_swh_storage_snapshot_add_failure( swh_storage, sample_data, mocker ): """Unfiltered errors are raising without retry """ mock_memory = mocker.patch("swh.storage.in_memory.InMemoryStorage.snapshot_add") mock_memory.side_effect = StorageArgumentException("Refuse to add snapshot always!") sample_snap = sample_data.snapshot snapshot = swh_storage.snapshot_get(sample_snap.id) assert not snapshot with pytest.raises(StorageArgumentException, match="Refuse to add"): swh_storage.snapshot_add([sample_snap]) assert mock_memory.call_count == 1