diff --git a/requirements-swh.txt b/requirements-swh.txt --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,3 +1,4 @@ swh.core[db,http] >= 0.14.0 +swh.counters >= v0.8.0 swh.model >= 2.1.0 swh.objstorage >= 0.2.2 diff --git a/swh/storage/__init__.py b/swh/storage/__init__.py --- a/swh/storage/__init__.py +++ b/swh/storage/__init__.py @@ -19,11 +19,12 @@ # deprecated "local": ".postgresql.storage.Storage", # proxy storages - "filter": ".proxies.filter.FilteringProxyStorage", "buffer": ".proxies.buffer.BufferingProxyStorage", + "counter": ".proxies.counter.CountingProxyStorage", + "filter": ".proxies.filter.FilteringProxyStorage", "retry": ".proxies.retry.RetryingProxyStorage", - "validate": ".proxies.validate.ValidatingProxyStorage", "tenacious": ".proxies.tenacious.TenaciousProxyStorage", + "validate": ".proxies.validate.ValidatingProxyStorage", } diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -7,6 +7,7 @@ import dataclasses import datetime import functools +import itertools import logging import random from typing import ( @@ -17,6 +18,7 @@ Iterator, List, Optional, + Sequence, Tuple, Type, TypeVar, @@ -25,6 +27,7 @@ from cassandra import ConsistencyLevel, CoordinationFailure from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, ResultSet +from cassandra.concurrent import execute_concurrent_with_args from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy from cassandra.query import BoundStatement, PreparedStatement, dict_factory from mypy_extensions import NamedArg @@ -59,7 +62,6 @@ ExtIDRow, MetadataAuthorityRow, MetadataFetcherRow, - ObjectCountRow, OriginRow, OriginVisitRow, OriginVisitStatusRow, @@ -85,6 +87,8 @@ See for details and rationale. """ +BATCH_INSERT_MAX_SIZE = 1000 + logger = logging.getLogger(__name__) @@ -166,6 +170,14 @@ TSelf = TypeVar("TSelf") +def _insert_query(row_class): + columns = row_class.cols() + return ( + f"INSERT INTO {row_class.TABLE} ({', '.join(columns)}) " + f"VALUES ({', '.join('?' for _ in columns)})" + ) + + def _prepared_insert_statement( row_class: Type[BaseRow], ) -> Callable[ @@ -174,11 +186,7 @@ ]: """Shorthand for using `_prepared_statement` for `INSERT INTO` statements.""" - columns = row_class.cols() - return _prepared_statement( - "INSERT INTO %s (%s) VALUES (%s)" - % (row_class.TABLE, ", ".join(columns), ", ".join("?" for _ in columns),) - ) + return _prepared_statement(_insert_query(row_class)) def _prepared_exists_statement( @@ -280,22 +288,28 @@ stop=stop_after_attempt(MAX_RETRIES), retry=retry_if_exception_type(CoordinationFailure), ) - def _execute_with_retries(self, statement, args) -> ResultSet: + def _execute_with_retries(self, statement, args: Optional[Sequence]) -> ResultSet: return self._session.execute(statement, args, timeout=1000.0) - @_prepared_statement( - "UPDATE object_count SET count = count + ? " - "WHERE partition_key = 0 AND object_type = ?" + @retry( + wait=wait_random_exponential(multiplier=1, max=10), + stop=stop_after_attempt(MAX_RETRIES), + retry=retry_if_exception_type(CoordinationFailure), ) - def _increment_counter( - self, object_type: str, nb: int, *, statement: PreparedStatement - ) -> None: - self._execute_with_retries(statement, [nb, object_type]) + def _execute_many_with_retries( + self, statement, args_list: List[Tuple] + ) -> ResultSet: + return execute_concurrent_with_args(self._session, statement, args_list) def _add_one(self, statement, obj: BaseRow) -> None: - self._increment_counter(obj.TABLE, 1) self._execute_with_retries(statement, dataclasses.astuple(obj)) + def _add_many(self, statement, objs: Sequence[BaseRow]) -> None: + tables = {obj.TABLE for obj in objs} + assert len(tables) == 1, f"Cannot insert to multiple tables: {tables}" + (table,) = tables + self._execute_many_with_retries(statement, list(map(dataclasses.astuple, objs))) + _T = TypeVar("_T", bound=BaseRow) def _get_random_row(self, row_class: Type[_T], statement) -> Optional[_T]: # noqa @@ -328,7 +342,6 @@ """Returned currified by content_add_prepare, to be called when the content row should be added to the primary table.""" self._execute_with_retries(statement, None) - self._increment_counter("content", 1) @_prepared_insert_statement(ContentRow) def content_add_prepare( @@ -482,7 +495,6 @@ """Returned currified by skipped_content_add_prepare, to be called when the content row should be added to the primary table.""" self._execute_with_retries(statement, None) - self._increment_counter("skipped_content", 1) @_prepared_insert_statement(SkippedContentRow) def skipped_content_add_prepare( @@ -677,6 +689,56 @@ def directory_entry_add_one(self, entry: DirectoryEntryRow, *, statement) -> None: self._add_one(statement, entry) + @_prepared_insert_statement(DirectoryEntryRow) + def directory_entry_add_concurrent( + self, entries: List[DirectoryEntryRow], *, statement + ) -> None: + if len(entries) == 0: + # nothing to do + return + assert ( + len({entry.directory_id for entry in entries}) == 1 + ), "directory_entry_add_many must be called with entries for a single dir" + self._add_many(statement, entries) + + def directory_entry_add_batch(self, entries: List[DirectoryEntryRow],) -> None: + if len(entries) == 0: + # nothing to do + return + assert ( + len({entry.directory_id for entry in entries}) == 1 + ), "directory_entry_add_many must be called with entries for a single dir" + + # query to INSERT one row + insert_query = _insert_query(DirectoryEntryRow) + ";\n" + + # In "steady state", we insert batches of the maximum allowed size. + # Then, the last one has however many entries remain. + last_batch_size = len(entries) % BATCH_INSERT_MAX_SIZE + if len(entries) >= BATCH_INSERT_MAX_SIZE: + # TODO: the main_statement's size is statically known, so we could avoid + # re-preparing it on every call + main_statement = self._session.prepare( + "BEGIN UNLOGGED BATCH\n" + + insert_query * BATCH_INSERT_MAX_SIZE + + "APPLY BATCH" + ) + last_statement = self._session.prepare( + "BEGIN UNLOGGED BATCH\n" + insert_query * last_batch_size + "APPLY BATCH" + ) + + for entry_group in grouper(entries, BATCH_INSERT_MAX_SIZE): + entry_group = list(map(dataclasses.astuple, entry_group)) + if len(entry_group) == BATCH_INSERT_MAX_SIZE: + self._execute_with_retries( + main_statement, list(itertools.chain.from_iterable(entry_group)) + ) + else: + assert len(entry_group) == last_batch_size + self._execute_with_retries( + last_statement, list(itertools.chain.from_iterable(entry_group)) + ) + @_prepared_select_statement(DirectoryEntryRow, "WHERE directory_id IN ?") def directory_entry_get( self, directory_ids, *, statement @@ -1219,7 +1281,6 @@ """Returned currified by extid_add_prepare, to be called when the extid row should be added to the primary table.""" self._execute_with_retries(statement, None) - self._increment_counter("extid", 1) @_prepared_insert_statement(ExtIDRow) def extid_add_prepare( @@ -1331,7 +1392,3 @@ @_prepared_statement("SELECT uuid() FROM revision LIMIT 1;") def check_read(self, *, statement): self._execute_with_retries(statement, []) - - @_prepared_select_statement(ObjectCountRow, "WHERE partition_key=0") - def stat_counters(self, *, statement) -> Iterable[ObjectCountRow]: - return map(ObjectCountRow.from_dict, self._execute_with_retries(statement, [])) diff --git a/swh/storage/cassandra/model.py b/swh/storage/cassandra/model.py --- a/swh/storage/cassandra/model.py +++ b/swh/storage/cassandra/model.py @@ -300,17 +300,6 @@ authority_url: str -@dataclasses.dataclass -class ObjectCountRow(BaseRow): - TABLE = "object_count" - PARTITION_KEY = ("partition_key",) - CLUSTERING_KEY = ("object_type",) - - partition_key: int - object_type: str - count: int - - @dataclasses.dataclass class ExtIDRow(BaseRow): TABLE = "extid" diff --git a/swh/storage/cassandra/schema.py b/swh/storage/cassandra/schema.py --- a/swh/storage/cassandra/schema.py +++ b/swh/storage/cassandra/schema.py @@ -267,13 +267,6 @@ PRIMARY KEY ((id)) );""", """ -CREATE TABLE IF NOT EXISTS object_count ( - partition_key smallint, -- Constant, must always be 0 - object_type ascii, - count counter, - PRIMARY KEY ((partition_key), object_type) -);""", - """ CREATE TABLE IF NOT EXISTS extid ( extid_type ascii, extid blob, @@ -319,7 +312,6 @@ "origin_visit", "origin", "raw_extrinsic_metadata", - "object_count", "origin_visit_status", "metadata_authority", "metadata_fetcher", diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -100,6 +100,7 @@ journal_writer=None, allow_overwrite=False, consistency_level="ONE", + directory_entries_insert_algo="one-by-one", ): """ A backend of swh-storage backed by Cassandra @@ -120,6 +121,10 @@ Note that a ``False`` value does not guarantee there won't be any overwrite. consistency_level: The default read/write consistency to use + directory_entries_insert_algo: Must be one of: + * one-by-one: naive, one INSERT per directory entry, serialized + * concurrent: one INSERT per directory entry, concurrent + * batch: using UNLOGGED BATCH to insert many entries in a few statements """ self._hosts = hosts self._keyspace = keyspace @@ -129,6 +134,7 @@ self.journal_writer: JournalWriter = JournalWriter(journal_writer) self.objstorage: ObjStorage = ObjStorage(objstorage) self._allow_overwrite = allow_overwrite + self._directory_entries_insert_algo = directory_entries_insert_algo def _set_cql_runner(self): """Used by tests when they need to reset the CqlRunner""" @@ -438,9 +444,21 @@ for directory in directories: # Add directory entries to the 'directory_entry' table - for entry in directory.entries: - self._cql_runner.directory_entry_add_one( - DirectoryEntryRow(directory_id=directory.id, **entry.to_dict()) + rows = [ + DirectoryEntryRow(directory_id=directory.id, **entry.to_dict()) + for entry in directory.entries + ] + if self._directory_entries_insert_algo == "one-by-one": + for row in rows: + self._cql_runner.directory_entry_add_one(row) + elif self._directory_entries_insert_algo == "concurrent": + self._cql_runner.directory_entry_add_concurrent(rows) + elif self._directory_entries_insert_algo == "batch": + self._cql_runner.directory_entry_add_batch(rows) + else: + raise ValueError( + f"Unexpected value for directory_entries_insert_algo: " + f"{self._directory_entries_insert_algo}" ) # Add the directory *after* adding all the entries, so someone @@ -1272,20 +1290,7 @@ return None def stat_counters(self): - rows = self._cql_runner.stat_counters() - keys = ( - "content", - "directory", - "origin", - "origin_visit", - "release", - "revision", - "skipped_content", - "snapshot", - ) - stats = {key: 0 for key in keys} - stats.update({row.object_type: row.count for row in rows}) - return stats + raise NotImplementedError() def refresh_stat_counters(self): pass diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -34,7 +34,6 @@ ExtIDRow, MetadataAuthorityRow, MetadataFetcherRow, - ObjectCountRow, OriginRow, OriginVisitRow, OriginVisitStatusRow, @@ -171,14 +170,6 @@ self._raw_extrinsic_metadata = Table(RawExtrinsicMetadataRow) self._raw_extrinsic_metadata_by_id = Table(RawExtrinsicMetadataByIdRow) self._extid = Table(ExtIDRow) - self._stat_counters = defaultdict(int) - - def increment_counter(self, object_type: str, nb: int): - self._stat_counters[object_type] += nb - - def stat_counters(self) -> Iterable[ObjectCountRow]: - for (object_type, count) in self._stat_counters.items(): - yield ObjectCountRow(partition_key=0, object_type=object_type, count=count) ########################## # 'content' table @@ -186,7 +177,6 @@ def _content_add_finalize(self, content: ContentRow) -> None: self._contents.insert(content) - self.increment_counter("content", 1) def content_add_prepare(self, content: ContentRow): finalizer = functools.partial(self._content_add_finalize, content) @@ -248,7 +238,6 @@ def _skipped_content_add_finalize(self, content: SkippedContentRow) -> None: self._skipped_contents.insert(content) - self.increment_counter("skipped_content", 1) def skipped_content_add_prepare(self, content: SkippedContentRow): finalizer = functools.partial(self._skipped_content_add_finalize, content) @@ -293,7 +282,6 @@ def directory_add_one(self, directory: DirectoryRow) -> None: self._directories.insert(directory) - self.increment_counter("directory", 1) def directory_get_random(self) -> Optional[DirectoryRow]: return self._directories.get_random() @@ -334,7 +322,6 @@ def revision_add_one(self, revision: RevisionRow) -> None: self._revisions.insert(revision) - self.increment_counter("revision", 1) def revision_get_ids(self, revision_ids) -> Iterable[int]: for id_ in revision_ids: @@ -374,7 +361,6 @@ def release_add_one(self, release: ReleaseRow) -> None: self._releases.insert(release) - self.increment_counter("release", 1) def release_get(self, release_ids: List[str]) -> Iterable[ReleaseRow]: for id_ in release_ids: @@ -398,7 +384,6 @@ def snapshot_add_one(self, snapshot: SnapshotRow) -> None: self._snapshots.insert(snapshot) - self.increment_counter("snapshot", 1) def snapshot_get_random(self) -> Optional[SnapshotRow]: return self._snapshots.get_random() @@ -452,7 +437,6 @@ def origin_add_one(self, origin: OriginRow) -> None: self._origins.insert(origin) - self.increment_counter("origin", 1) def origin_get_by_sha1(self, sha1: bytes) -> Iterable[OriginRow]: return self._origins.get_from_partition_key((sha1,)) @@ -513,7 +497,6 @@ def origin_visit_add_one(self, visit: OriginVisitRow) -> None: self._origin_visits.insert(visit) - self.increment_counter("origin_visit", 1) def origin_visit_get_one( self, origin_url: str, visit_id: int @@ -558,7 +541,6 @@ def origin_visit_status_add_one(self, visit_update: OriginVisitStatusRow) -> None: self._origin_visit_statuses.insert(visit_update) - self.increment_counter("origin_visit_status", 1) def origin_visit_status_get_latest( self, origin: str, visit: int, @@ -588,7 +570,6 @@ def metadata_authority_add(self, authority: MetadataAuthorityRow): self._metadata_authorities.insert(authority) - self.increment_counter("metadata_authority", 1) def metadata_authority_get(self, type, url) -> Optional[MetadataAuthorityRow]: return self._metadata_authorities.get_from_primary_key((url, type)) @@ -599,7 +580,6 @@ def metadata_fetcher_add(self, fetcher: MetadataFetcherRow): self._metadata_fetchers.insert(fetcher) - self.increment_counter("metadata_fetcher", 1) def metadata_fetcher_get(self, name, version) -> Optional[MetadataAuthorityRow]: return self._metadata_fetchers.get_from_primary_key((name, version)) @@ -629,7 +609,6 @@ def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata): self._raw_extrinsic_metadata.insert(raw_extrinsic_metadata) - self.increment_counter("raw_extrinsic_metadata", 1) def raw_extrinsic_metadata_get_after_date( self, @@ -682,7 +661,6 @@ ######################### def _extid_add_finalize(self, extid: ExtIDRow) -> None: self._extid.insert(extid) - self.increment_counter("extid", 1) def extid_add_prepare(self, extid: ExtIDRow): finalizer = functools.partial(self._extid_add_finalize, extid) @@ -729,6 +707,7 @@ self.reset() self.journal_writer = JournalWriter(journal_writer) self._allow_overwrite = False + self._directory_entries_insert_algo = "one-by-one" def reset(self): self._cql_runner = InMemoryCqlRunner() diff --git a/swh/storage/proxies/counter.py b/swh/storage/proxies/counter.py new file mode 100644 --- /dev/null +++ b/swh/storage/proxies/counter.py @@ -0,0 +1,66 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +from typing import Callable + +from swh.counters import get_counters +from swh.counters.interface import CountersInterface +from swh.storage import get_storage +from swh.storage.interface import StorageInterface + +OBJECT_TYPES = [ + "content", + "directory", + "snapshot", + "origin_visit_status", + "origin_visit", + "origin", +] + + +class CountingProxyStorage: + """Counting Storage Proxy. + + This is in charge of adding objects directly to swh-counters, without + going through Kafka/swh-journal. + This is meant as a simple way to setup counters for experiments; production + should use swh-journal to reduce load/latency of the storage server. + + Additionally, unlike the journal-based counting, it does not count persons + or the number of origins per netloc. + + Sample configuration use case for filtering storage: + + .. code-block: yaml + + storage: + cls: counter + counters: + cls: remote + url: http://counters.internal.staging.swh.network:5011/ + storage: + cls: remote + url: http://storage.internal.staging.swh.network:5002/ + + """ + + def __init__(self, counters, storage): + self.counters: CountersInterface = get_counters(**counters) + self.storage: StorageInterface = get_storage(**storage) + + def __getattr__(self, key): + if key == "storage": + raise AttributeError(key) + if key.endswith("_add"): + return self._adder(key[0:-4], getattr(self.storage, key)) + return getattr(self.storage, key) + + def _adder(self, collection: str, backend_function: Callable): + def f(objs): + self.counters.add(collection, [obj.unique_key() for obj in objs]) + return backend_function(objs) + + return f diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -38,6 +38,8 @@ TargetType, ) from swh.storage import get_storage +from swh.storage.api.client import RemoteStorage +from swh.storage.cassandra.storage import CassandraStorage from swh.storage.common import origin_url_to_sha1 as sha1 from swh.storage.exc import HashCollision, StorageArgumentException from swh.storage.interface import ListOrder, PagedResult, StorageInterface @@ -187,8 +189,9 @@ assert obj.ctime <= insertion_end_time assert obj == expected_cont - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["content"] == 1 + if not isinstance(swh_storage, (CassandraStorage, RemoteStorage)): + swh_storage.refresh_stat_counters() + assert swh_storage.stat_counters()["content"] == 1 def test_content_add_from_lazy_content(self, swh_storage, sample_data): cont = sample_data.content @@ -221,8 +224,9 @@ assert obj.ctime <= insertion_end_time assert attr.evolve(obj, ctime=None).to_dict() == expected_cont.to_dict() - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["content"] == 1 + if not isinstance(swh_storage, (CassandraStorage, RemoteStorage)): + swh_storage.refresh_stat_counters() + assert swh_storage.stat_counters()["content"] == 1 def test_content_get_data_missing(self, swh_storage, sample_data): cont, cont2 = sample_data.contents[:2] @@ -705,8 +709,9 @@ after_missing = list(swh_storage.directory_missing([directory.id])) assert after_missing == [] - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["directory"] == 1 + if not isinstance(swh_storage, (CassandraStorage, RemoteStorage)): + swh_storage.refresh_stat_counters() + assert swh_storage.stat_counters()["directory"] == 1 def test_directory_add_twice(self, swh_storage, sample_data): directory = sample_data.directories[1] @@ -975,8 +980,9 @@ actual_result = swh_storage.revision_add([revision]) assert actual_result == {"revision:add": 0} - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["revision"] == 1 + if not isinstance(swh_storage, (CassandraStorage, RemoteStorage)): + swh_storage.refresh_stat_counters() + assert swh_storage.stat_counters()["revision"] == 1 def test_revision_add_twice(self, swh_storage, sample_data): revision, revision2 = sample_data.revisions[:2] @@ -1376,8 +1382,9 @@ actual_result = swh_storage.release_add([release, release2]) assert actual_result == {"release:add": 0} - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["release"] == 2 + if not isinstance(swh_storage, (CassandraStorage, RemoteStorage)): + swh_storage.refresh_stat_counters() + assert swh_storage.stat_counters()["release"] == 2 def test_release_add_no_author_date(self, swh_storage, sample_data): full_release = sample_data.release @@ -1482,8 +1489,9 @@ [("origin", origin) for origin in origins] ) - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["origin"] == len(origins) + if not isinstance(swh_storage, (CassandraStorage, RemoteStorage)): + swh_storage.refresh_stat_counters() + assert swh_storage.stat_counters()["origin"] == len(origins) def test_origin_add_twice(self, swh_storage, sample_data): origin, origin2 = sample_data.origins[:2] @@ -1921,11 +1929,11 @@ ] ) - swh_storage.refresh_stat_counters() - - stats = swh_storage.stat_counters() - assert stats["origin"] == len(origins) - assert stats["origin_visit"] == len(origins) * len(visits) + if not isinstance(swh_storage, (CassandraStorage, RemoteStorage)): + swh_storage.refresh_stat_counters() + stats = swh_storage.stat_counters() + assert stats["origin"] == len(origins) + assert stats["origin_visit"] == len(origins) * len(visits) random_ovs = swh_storage.origin_visit_status_get_random(visit_type) assert random_ovs @@ -3122,8 +3130,9 @@ "next_branch": None, } - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["snapshot"] == 2 + if not isinstance(swh_storage, (CassandraStorage, RemoteStorage)): + swh_storage.refresh_stat_counters() + assert swh_storage.stat_counters()["snapshot"] == 2 def test_snapshot_add_many_incremental(self, swh_storage, sample_data): snapshot, _, complete_snapshot = sample_data.snapshots[:3] @@ -3623,6 +3632,8 @@ assert list(missing_snapshots) == [missing_snapshot.id] def test_stat_counters(self, swh_storage, sample_data): + if isinstance(swh_storage, (CassandraStorage, RemoteStorage)): + pytest.skip("Cassandra backend does not support stat counters") origin = sample_data.origin snapshot = sample_data.snapshot revision = sample_data.revision diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -494,21 +494,18 @@ # early for this test to be relevant swh_storage.journal_writer.journal = None - class MyException(Exception): - pass - class CrashyEntry(DirectoryEntry): def __init__(self): pass def to_dict(self): - raise MyException() + return {**directory.entries[0].to_dict(), "perms": "abcde"} directory = sample_data.directory3 entries = directory.entries directory = attr.evolve(directory, entries=entries + (CrashyEntry(),)) - with pytest.raises(MyException): + with pytest.raises(TypeError): swh_storage.directory_add([directory]) # This should have written some of the entries to the database: diff --git a/swh/storage/tests/test_counter.py b/swh/storage/tests/test_counter.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_counter.py @@ -0,0 +1,63 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import attr +import pytest + +from swh.storage import get_storage + + +@pytest.fixture +def swh_storage(): + storage_config = { + "cls": "pipeline", + "steps": [ + {"cls": "counter", "counters": {"cls": "memory"}}, + {"cls": "memory"}, + ], + } + + return get_storage(**storage_config) + + +def test_couting_proxy_storage_content(swh_storage, sample_data): + assert swh_storage.counters.counters["content"] == set() + + swh_storage.content_add([sample_data.content]) + + assert swh_storage.counters.counters["content"] == {sample_data.content.sha1} + + swh_storage.content_add([sample_data.content2, sample_data.content3]) + + assert swh_storage.counters.counters["content"] == { + sample_data.content.sha1, + sample_data.content2.sha1, + sample_data.content3.sha1, + } + + assert [ + attr.evolve(cnt, ctime=None) + for cnt in swh_storage.content_find({"sha256": sample_data.content2.sha256}) + ] == [attr.evolve(sample_data.content2, data=None)] + + +def test_couting_proxy_storage_revision(swh_storage, sample_data): + assert swh_storage.counters.counters["revision"] == set() + + swh_storage.revision_add([sample_data.revision]) + + assert swh_storage.counters.counters["revision"] == {sample_data.revision.id} + + swh_storage.revision_add([sample_data.revision2, sample_data.revision3]) + + assert swh_storage.counters.counters["revision"] == { + sample_data.revision.id, + sample_data.revision2.id, + sample_data.revision3.id, + } + + assert swh_storage.revision_get([sample_data.revision2.id]) == [ + sample_data.revision2 + ]