diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -8,6 +8,7 @@ import collections import copy import datetime +import functools import itertools import random import re @@ -46,19 +47,21 @@ OriginVisit, OriginVisitStatus, Origin, - SHA1_SIZE, MetadataAuthority, MetadataAuthorityType, MetadataFetcher, MetadataTargetType, RawExtrinsicMetadata, - Sha1, Sha1Git, ) -from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex +from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.storage.cassandra import CassandraStorage -from swh.storage.cassandra.model import BaseRow, ObjectCountRow +from swh.storage.cassandra.model import ( + BaseRow, + ContentRow, + ObjectCountRow, +) from swh.storage.interface import ( ListOrder, PagedResult, @@ -69,8 +72,7 @@ from swh.storage.utils import now from .converters import origin_url_to_sha1 -from .exc import StorageArgumentException, HashCollision -from .utils import get_partition_bounds_bytes +from .exc import StorageArgumentException from .writer import JournalWriter # Max block size of contents to return @@ -212,6 +214,8 @@ class InMemoryCqlRunner: def __init__(self): + self._contents = Table(ContentRow) + self._content_indexes = defaultdict(lambda: defaultdict(set)) self._stat_counters = defaultdict(int) def increment_counter(self, object_type: str, nb: int): @@ -221,12 +225,67 @@ for (object_type, count) in self._stat_counters.items(): yield ObjectCountRow(partition_key=0, object_type=object_type, count=count) + ########################## + # 'content' table + ########################## + + def _content_add_finalize(self, content: ContentRow) -> None: + self._contents.insert(content) + self.increment_counter("content", 1) + + def content_add_prepare(self, content: ContentRow): + finalizer = functools.partial(self._content_add_finalize, content) + return (self._contents.token(self._contents.partition_key(content)), finalizer) + + def content_get_from_pk( + self, content_hashes: Dict[str, bytes] + ) -> Optional[ContentRow]: + primary_key = self._contents.primary_key_from_dict(content_hashes) + return self._contents.get_from_primary_key(primary_key) + + def content_get_from_token(self, token: int) -> Iterable[ContentRow]: + return self._contents.get_from_token(token) + + def content_get_random(self) -> Optional[ContentRow]: + return random.choice( + [ + row + for partition in self._contents.data.values() + for row in partition.values() + ] + ) + + def content_get_token_range( + self, start: int, end: int, limit: int, + ) -> Iterable[Tuple[int, ContentRow]]: + matches = [ + (token, row) + for (token, partition) in self._contents.data.items() + for (clustering_key, row) in partition.items() + if start <= token <= end + ] + matches.sort() + return matches[0:limit] + ########################## # 'content_by_*' tables ########################## def content_missing_by_sha1_git(self, ids: List[bytes]) -> List[bytes]: - return ids + missing = [] + for id_ in ids: + if id_ not in self._content_indexes["sha1_git"]: + missing.append(id_) + + return missing + + def content_index_add_one(self, algo: str, content: Content, token: int) -> None: + self._content_indexes[algo][content.get_hash(algo)].add(token) + + def content_get_tokens_from_single_hash( + self, algo: str, hash_: bytes + ) -> Iterable[int]: + return self._content_indexes[algo][hash_] ########################## # 'directory' table @@ -259,8 +318,6 @@ def reset(self): self._cql_runner = InMemoryCqlRunner() - self._contents = {} - self._content_indexes = defaultdict(lambda: defaultdict(set)) self._skipped_contents = {} self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) self._directories = {} @@ -308,184 +365,6 @@ def check_config(self, *, check_write: bool) -> bool: return True - def _content_add(self, contents: List[Content], with_data: bool) -> Dict: - self.journal_writer.content_add(contents) - - content_add = 0 - if with_data: - summary = self.objstorage.content_add( - c for c in contents if c.status != "absent" - ) - content_add_bytes = summary["content:add:bytes"] - - for content in contents: - key = self._content_key(content) - if key in self._contents: - continue - for algorithm in DEFAULT_ALGORITHMS: - hash_ = content.get_hash(algorithm) - if hash_ in self._content_indexes[algorithm] and ( - algorithm not in {"blake2s256", "sha256"} - ): - colliding_content_hashes = [] - # Add the already stored contents - for content_hashes_set in self._content_indexes[algorithm][hash_]: - hashes = dict(content_hashes_set) - colliding_content_hashes.append(hashes) - # Add the new colliding content - colliding_content_hashes.append(content.hashes()) - raise HashCollision(algorithm, hash_, colliding_content_hashes) - for algorithm in DEFAULT_ALGORITHMS: - hash_ = content.get_hash(algorithm) - self._content_indexes[algorithm][hash_].add(key) - self._objects[content.sha1_git].append(("content", content.sha1)) - self._contents[key] = content - self._sorted_sha1s.add(content.sha1) - self._contents[key] = attr.evolve(self._contents[key], data=None) - content_add += 1 - - self._cql_runner.increment_counter("content", content_add) - - summary = { - "content:add": content_add, - } - if with_data: - summary["content:add:bytes"] = content_add_bytes - - return summary - - def content_add(self, content: List[Content]) -> Dict: - content = [attr.evolve(c, ctime=now()) for c in content] - return self._content_add(content, with_data=True) - - def content_update( - self, contents: List[Dict[str, Any]], keys: List[str] = [] - ) -> None: - self.journal_writer.content_update(contents) - - for cont_update in contents: - cont_update = cont_update.copy() - sha1 = cont_update.pop("sha1") - for old_key in self._content_indexes["sha1"][sha1]: - old_cont = self._contents.pop(old_key) - - for algorithm in DEFAULT_ALGORITHMS: - hash_ = old_cont.get_hash(algorithm) - self._content_indexes[algorithm][hash_].remove(old_key) - - new_cont = attr.evolve(old_cont, **cont_update) - new_key = self._content_key(new_cont) - - self._contents[new_key] = new_cont - - for algorithm in DEFAULT_ALGORITHMS: - hash_ = new_cont.get_hash(algorithm) - self._content_indexes[algorithm][hash_].add(new_key) - - def content_add_metadata(self, content: List[Content]) -> Dict: - return self._content_add(content, with_data=False) - - def content_get_data(self, content: Sha1) -> Optional[bytes]: - # FIXME: Make this method support slicing the `data` - return self.objstorage.content_get(content) - - def content_get_partition( - self, - partition_id: int, - nb_partitions: int, - page_token: Optional[str] = None, - limit: int = 1000, - ) -> PagedResult[Content]: - if limit is None: - raise StorageArgumentException("limit should not be None") - (start, end) = get_partition_bounds_bytes( - partition_id, nb_partitions, SHA1_SIZE - ) - if page_token: - start = hash_to_bytes(page_token) - if end is None: - end = b"\xff" * SHA1_SIZE - - next_page_token: Optional[str] = None - sha1s = ( - (sha1, content_key) - for sha1 in self._sorted_sha1s.iter_from(start) - for content_key in self._content_indexes["sha1"][sha1] - ) - contents: List[Content] = [] - for counter, (sha1, key) in enumerate(sha1s): - if sha1 > end: - break - if counter >= limit: - next_page_token = hash_to_hex(sha1) - break - contents.append(self._contents[key]) - - assert len(contents) <= limit - return PagedResult(results=contents, next_page_token=next_page_token) - - def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]: - contents_by_sha1: Dict[Sha1, Optional[Content]] = {} - for sha1 in contents: - if sha1 in self._content_indexes["sha1"]: - objs = self._content_indexes["sha1"][sha1] - # only 1 element as content_add_metadata would have raised a - # hash collision otherwise - assert len(objs) == 1 - for key in objs: - content = attr.evolve(self._contents[key], data=None, ctime=None) - contents_by_sha1[sha1] = content - return [contents_by_sha1.get(sha1) for sha1 in contents] - - def content_find(self, content: Dict[str, Any]) -> List[Content]: - if not set(content).intersection(DEFAULT_ALGORITHMS): - raise StorageArgumentException( - "content keys must contain at least one " - f"of: {', '.join(sorted(DEFAULT_ALGORITHMS))}" - ) - found = [] - for algo in DEFAULT_ALGORITHMS: - hash = content.get(algo) - if hash and hash in self._content_indexes[algo]: - found.append(self._content_indexes[algo][hash]) - - if not found: - return [] - - keys = list(set.intersection(*found)) - return [self._contents[key] for key in keys] - - def content_missing( - self, contents: List[Dict[str, Any]], key_hash: str = "sha1" - ) -> Iterable[bytes]: - if key_hash not in DEFAULT_ALGORITHMS: - raise StorageArgumentException( - "key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}" - ) - - for content in contents: - for (algo, hash_) in content.items(): - if algo not in DEFAULT_ALGORITHMS: - continue - if hash_ not in self._content_indexes.get(algo, []): - yield content[key_hash] - break - - def content_missing_per_sha1(self, contents: List[bytes]) -> Iterable[bytes]: - for content in contents: - if content not in self._content_indexes["sha1"]: - yield content - - def content_missing_per_sha1_git( - self, contents: List[Sha1Git] - ) -> Iterable[Sha1Git]: - for content in contents: - if content not in self._content_indexes["sha1_git"]: - yield content - - def content_get_random(self) -> Sha1Git: - return random.choice(list(self._content_indexes["sha1_git"])) - def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict: self.journal_writer.skipped_content_add(contents) diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -9,7 +9,8 @@ from swh.storage.cassandra.model import BaseRow from swh.storage.in_memory import SortedList, Table -from swh.storage.tests.test_storage import TestStorage, TestStorageGeneratedData # noqa +from swh.storage.tests.test_storage import TestStorage as _TestStorage +from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa # tests are executed using imported classes (TestStorage and @@ -138,3 +139,9 @@ # order matters assert list(table.get_from_token(table.token(partition_key))) == [row1, row3, row2] + + +class TestInMemoryStorage(_TestStorage): + @pytest.mark.skip("content_update is not yet implemented for Cassandra") + def test_content_update(self): + pass