diff --git a/README.md b/README.md index 25b65f1b..5bcc7a42 100644 --- a/README.md +++ b/README.md @@ -1,194 +1,189 @@ swh-storage =========== Abstraction layer over the archive, allowing to access all stored source code artifacts as well as their metadata. See the [documentation](https://docs.softwareheritage.org/devel/swh-storage/index.html) for more details. ## Quick start ### Dependencies Python tests for this module include tests that cannot be run without a local Postgresql database, so you need the Postgresql server executable on your machine (no need to have a running Postgresql server). They also expect a cassandra server. #### Debian-like host ``` $ sudo apt install libpq-dev postgresql-11 cassandra ``` #### Non Debian-like host The tests expects the path to `cassandra` to either be unspecified, it is then looked up at `/usr/sbin/cassandra`, either specified through the environment variable `SWH_CASSANDRA_BIN`. Optionally, you can avoid running the cassandra tests. ``` (swh) :~/swh-storage$ tox -- -m 'not cassandra' ``` ### Installation It is strongly recommended to use a virtualenv. In the following, we consider you work in a virtualenv named `swh`. See the [developer setup guide](https://docs.softwareheritage.org/devel/developer-setup.html#developer-setup) for a more details on how to setup a working environment. You can install the package directly from [pypi](https://pypi.org/p/swh.storage): ``` (swh) :~$ pip install swh.storage [...] ``` Or from sources: ``` (swh) :~$ git clone https://forge.softwareheritage.org/source/swh-storage.git [...] (swh) :~$ cd swh-storage (swh) :~/swh-storage$ pip install . [...] ``` Then you can check it's properly installed: ``` (swh) :~$ swh storage --help Usage: swh storage [OPTIONS] COMMAND [ARGS]... Software Heritage Storage tools. Options: -h, --help Show this message and exit. Commands: rpc-serve Software Heritage Storage RPC server. ``` ## Tests The best way of running Python tests for this module is to use [tox](https://tox.readthedocs.io/). ``` (swh) :~$ pip install tox ``` ### tox From the sources directory, simply use tox: ``` (swh) :~/swh-storage$ tox [...] ========= 315 passed, 6 skipped, 15 warnings in 40.86 seconds ========== _______________________________ summary ________________________________ flake8: commands succeeded py3: commands succeeded congratulations :) ``` ## Development The storage server can be locally started. It requires a configuration file and a running Postgresql database. ### Sample configuration A typical configuration `storage.yml` file is: ``` storage: cls: local - args: - db: "dbname=softwareheritage-dev user= password=" - objstorage: - cls: pathslicing - args: - root: /tmp/swh-storage/ - slicing: 0:2/2:4/4:6 + db: "dbname=softwareheritage-dev user= password=" + objstorage: + cls: pathslicing + root: /tmp/swh-storage/ + slicing: 0:2/2:4/4:6 ``` which means, this uses: - a local storage instance whose db connection is to `softwareheritage-dev` local instance, - the objstorage uses a local objstorage instance whose: - `root` path is /tmp/swh-storage, - slicing scheme is `0:2/2:4/4:6`. This means that the identifier of the content (sha1) which will be stored on disk at first level with the first 2 hex characters, the second level with the next 2 hex characters and the third level with the next 2 hex characters. And finally the complete hash file holding the raw content. For example: 00062f8bd330715c4f819373653d97b3cd34394c will be stored at 00/06/2f/00062f8bd330715c4f819373653d97b3cd34394c Note that the `root` path should exist on disk before starting the server. ### Starting the storage server If the python package has been properly installed (e.g. in a virtual env), you should be able to use the command: ``` (swh) :~/swh-storage$ swh storage rpc-serve storage.yml ``` This runs a local swh-storage api at 5002 port. ``` (swh) :~/swh-storage$ curl http://127.0.0.1:5002 Software Heritage storage server

You have reached the Software Heritage storage server.
See its documentation and API for more information

``` ### And then what? In your upper layer ([loader-git](https://forge.softwareheritage.org/source/swh-loader-git/), [loader-svn](https://forge.softwareheritage.org/source/swh-loader-svn/), etc...), you can define a remote storage with this snippet of yaml configuration. ``` storage: cls: remote - args: - url: http://localhost:5002/ + url: http://localhost:5002/ ``` You could directly define a local storage with the following snippet: ``` storage: cls: local - args: - db: service=swh-dev - objstorage: - cls: pathslicing - args: - root: /home/storage/swh-storage/ - slicing: 0:2/2:4/4:6 + db: service=swh-dev + objstorage: + cls: pathslicing + root: /home/storage/swh-storage/ + slicing: 0:2/2:4/4:6 ``` diff --git a/requirements-swh.txt b/requirements-swh.txt index d0d31bfe..6d8a7806 100644 --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,3 +1,3 @@ swh.core[db,http] >= 0.3 swh.model >= 0.6.6 -swh.objstorage >= 0.0.40 +swh.objstorage >= 0.2.2 diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py index d78473fd..51373eb8 100644 --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1,626 +1,626 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import defaultdict import datetime import functools import random from typing import ( Any, Dict, Generic, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, ) from swh.model.model import Content, Sha1Git, SkippedContent from swh.storage.cassandra import CassandraStorage from swh.storage.cassandra.model import ( BaseRow, ContentRow, DirectoryEntryRow, DirectoryRow, MetadataAuthorityRow, MetadataFetcherRow, ObjectCountRow, OriginRow, OriginVisitRow, OriginVisitStatusRow, RawExtrinsicMetadataRow, ReleaseRow, RevisionParentRow, RevisionRow, SkippedContentRow, SnapshotBranchRow, SnapshotRow, ) from swh.storage.interface import ListOrder from swh.storage.objstorage import ObjStorage from .common import origin_url_to_sha1 from .writer import JournalWriter TRow = TypeVar("TRow", bound=BaseRow) class Table(Generic[TRow]): def __init__(self, row_class: Type[TRow]): self.row_class = row_class self.primary_key_cols = row_class.PARTITION_KEY + row_class.CLUSTERING_KEY # Map from tokens to clustering keys to rows # These are not actually partitions (or rather, there is one partition # for each token) and they aren't sorted. # But it is good enough if we don't care about performance; # and makes the code a lot simpler. self.data: Dict[int, Dict[Tuple, TRow]] = defaultdict(dict) def __repr__(self): return f"<__module__.Table[{self.row_class.__name__}] object>" def partition_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple: """Returns the partition key of a row (ie. the cells which get hashed into the token.""" if isinstance(row, dict): row_d = row else: row_d = row.to_dict() return tuple(row_d[col] for col in self.row_class.PARTITION_KEY) def clustering_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple: """Returns the clustering key of a row (ie. the cells which are used for sorting rows within a partition.""" if isinstance(row, dict): row_d = row else: row_d = row.to_dict() return tuple(row_d[col] for col in self.row_class.CLUSTERING_KEY) def primary_key(self, row): return self.partition_key(row) + self.clustering_key(row) def primary_key_from_dict(self, d: Dict[str, Any]) -> Tuple: """Returns the primary key (ie. concatenation of partition key and clustering key) of the given dictionary interpreted as a row.""" return tuple(d[col] for col in self.primary_key_cols) def token(self, key: Tuple): """Returns the token of a row (ie. the hash of its partition key).""" return hash(key) def get_partition(self, token: int) -> Dict[Tuple, TRow]: """Returns the partition that contains this token.""" return self.data[token] def insert(self, row: TRow): partition = self.data[self.token(self.partition_key(row))] partition[self.clustering_key(row)] = row def split_primary_key(self, key: Tuple) -> Tuple[Tuple, Tuple]: """Returns (partition_key, clustering_key) from a partition key""" assert len(key) == len(self.primary_key_cols) partition_key = key[0 : len(self.row_class.PARTITION_KEY)] clustering_key = key[len(self.row_class.PARTITION_KEY) :] return (partition_key, clustering_key) def get_from_partition_key(self, partition_key: Tuple) -> Iterable[TRow]: """Returns at most one row, from its partition key.""" token = self.token(partition_key) for row in self.get_from_token(token): if self.partition_key(row) == partition_key: yield row def get_from_primary_key(self, primary_key: Tuple) -> Optional[TRow]: """Returns at most one row, from its primary key.""" (partition_key, clustering_key) = self.split_primary_key(primary_key) token = self.token(partition_key) partition = self.get_partition(token) return partition.get(clustering_key) def get_from_token(self, token: int) -> Iterable[TRow]: """Returns all rows whose token (ie. non-cryptographic hash of the partition key) is the one passed as argument.""" return (v for (k, v) in sorted(self.get_partition(token).items())) def iter_all(self) -> Iterator[Tuple[Tuple, TRow]]: return ( (self.primary_key(row), row) for (token, partition) in self.data.items() for (clustering_key, row) in partition.items() ) def get_random(self) -> Optional[TRow]: return random.choice([row for (pk, row) in self.iter_all()]) class InMemoryCqlRunner: def __init__(self): self._contents = Table(ContentRow) self._content_indexes = defaultdict(lambda: defaultdict(set)) self._skipped_contents = Table(ContentRow) self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) self._directories = Table(DirectoryRow) self._directory_entries = Table(DirectoryEntryRow) self._revisions = Table(RevisionRow) self._revision_parents = Table(RevisionParentRow) self._releases = Table(ReleaseRow) self._snapshots = Table(SnapshotRow) self._snapshot_branches = Table(SnapshotBranchRow) self._origins = Table(OriginRow) self._origin_visits = Table(OriginVisitRow) self._origin_visit_statuses = Table(OriginVisitStatusRow) self._metadata_authorities = Table(MetadataAuthorityRow) self._metadata_fetchers = Table(MetadataFetcherRow) self._raw_extrinsic_metadata = Table(RawExtrinsicMetadataRow) self._stat_counters = defaultdict(int) def increment_counter(self, object_type: str, nb: int): self._stat_counters[object_type] += nb def stat_counters(self) -> Iterable[ObjectCountRow]: for (object_type, count) in self._stat_counters.items(): yield ObjectCountRow(partition_key=0, object_type=object_type, count=count) ########################## # 'content' table ########################## def _content_add_finalize(self, content: ContentRow) -> None: self._contents.insert(content) self.increment_counter("content", 1) def content_add_prepare(self, content: ContentRow): finalizer = functools.partial(self._content_add_finalize, content) return (self._contents.token(self._contents.partition_key(content)), finalizer) def content_get_from_pk( self, content_hashes: Dict[str, bytes] ) -> Optional[ContentRow]: primary_key = self._contents.primary_key_from_dict(content_hashes) return self._contents.get_from_primary_key(primary_key) def content_get_from_token(self, token: int) -> Iterable[ContentRow]: return self._contents.get_from_token(token) def content_get_random(self) -> Optional[ContentRow]: return self._contents.get_random() def content_get_token_range( self, start: int, end: int, limit: int, ) -> Iterable[Tuple[int, ContentRow]]: matches = [ (token, row) for (token, partition) in self._contents.data.items() for (clustering_key, row) in partition.items() if start <= token <= end ] matches.sort() return matches[0:limit] ########################## # 'content_by_*' tables ########################## def content_missing_by_sha1_git(self, ids: List[bytes]) -> List[bytes]: missing = [] for id_ in ids: if id_ not in self._content_indexes["sha1_git"]: missing.append(id_) return missing def content_index_add_one(self, algo: str, content: Content, token: int) -> None: self._content_indexes[algo][content.get_hash(algo)].add(token) def content_get_tokens_from_single_hash( self, algo: str, hash_: bytes ) -> Iterable[int]: return self._content_indexes[algo][hash_] ########################## # 'skipped_content' table ########################## def _skipped_content_add_finalize(self, content: SkippedContentRow) -> None: self._skipped_contents.insert(content) self.increment_counter("skipped_content", 1) def skipped_content_add_prepare(self, content: SkippedContentRow): finalizer = functools.partial(self._skipped_content_add_finalize, content) return ( self._skipped_contents.token(self._contents.partition_key(content)), finalizer, ) def skipped_content_get_from_pk( self, content_hashes: Dict[str, bytes] ) -> Optional[SkippedContentRow]: primary_key = self._skipped_contents.primary_key_from_dict(content_hashes) return self._skipped_contents.get_from_primary_key(primary_key) def skipped_content_get_from_token(self, token: int) -> Iterable[SkippedContentRow]: return self._skipped_contents.get_from_token(token) ########################## # 'skipped_content_by_*' tables ########################## def skipped_content_index_add_one( self, algo: str, content: SkippedContent, token: int ) -> None: self._skipped_content_indexes[algo][content.get_hash(algo)].add(token) def skipped_content_get_tokens_from_single_hash( self, algo: str, hash_: bytes ) -> Iterable[int]: return self._skipped_content_indexes[algo][hash_] ########################## # 'directory' table ########################## def directory_missing(self, ids: List[bytes]) -> List[bytes]: missing = [] for id_ in ids: if self._directories.get_from_primary_key((id_,)) is None: missing.append(id_) return missing def directory_add_one(self, directory: DirectoryRow) -> None: self._directories.insert(directory) self.increment_counter("directory", 1) def directory_get_random(self) -> Optional[DirectoryRow]: return self._directories.get_random() ########################## # 'directory_entry' table ########################## def directory_entry_add_one(self, entry: DirectoryEntryRow) -> None: self._directory_entries.insert(entry) def directory_entry_get( self, directory_ids: List[Sha1Git] ) -> Iterable[DirectoryEntryRow]: for id_ in directory_ids: yield from self._directory_entries.get_from_partition_key((id_,)) ########################## # 'revision' table ########################## def revision_missing(self, ids: List[bytes]) -> Iterable[bytes]: missing = [] for id_ in ids: if self._revisions.get_from_primary_key((id_,)) is None: missing.append(id_) return missing def revision_add_one(self, revision: RevisionRow) -> None: self._revisions.insert(revision) self.increment_counter("revision", 1) def revision_get_ids(self, revision_ids) -> Iterable[int]: for id_ in revision_ids: if self._revisions.get_from_primary_key((id_,)) is not None: yield id_ def revision_get(self, revision_ids: List[Sha1Git]) -> Iterable[RevisionRow]: for id_ in revision_ids: row = self._revisions.get_from_primary_key((id_,)) if row: yield row def revision_get_random(self) -> Optional[RevisionRow]: return self._revisions.get_random() ########################## # 'revision_parent' table ########################## def revision_parent_add_one(self, revision_parent: RevisionParentRow) -> None: self._revision_parents.insert(revision_parent) def revision_parent_get(self, revision_id: Sha1Git) -> Iterable[bytes]: for parent in self._revision_parents.get_from_partition_key((revision_id,)): yield parent.parent_id ########################## # 'release' table ########################## def release_missing(self, ids: List[bytes]) -> List[bytes]: missing = [] for id_ in ids: if self._releases.get_from_primary_key((id_,)) is None: missing.append(id_) return missing def release_add_one(self, release: ReleaseRow) -> None: self._releases.insert(release) self.increment_counter("release", 1) def release_get(self, release_ids: List[str]) -> Iterable[ReleaseRow]: for id_ in release_ids: row = self._releases.get_from_primary_key((id_,)) if row: yield row def release_get_random(self) -> Optional[ReleaseRow]: return self._releases.get_random() ########################## # 'snapshot' table ########################## def snapshot_missing(self, ids: List[bytes]) -> List[bytes]: missing = [] for id_ in ids: if self._snapshots.get_from_primary_key((id_,)) is None: missing.append(id_) return missing def snapshot_add_one(self, snapshot: SnapshotRow) -> None: self._snapshots.insert(snapshot) self.increment_counter("snapshot", 1) def snapshot_get_random(self) -> Optional[SnapshotRow]: return self._snapshots.get_random() ########################## # 'snapshot_branch' table ########################## def snapshot_branch_add_one(self, branch: SnapshotBranchRow) -> None: self._snapshot_branches.insert(branch) def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Dict[Optional[str], int]: """Returns a dictionary from type names to the number of branches of that type.""" counts: Dict[Optional[str], int] = defaultdict(int) for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)): if branch.target_type is None: target_type = None else: target_type = branch.target_type counts[target_type] += 1 return counts def snapshot_branch_get( self, snapshot_id: Sha1Git, from_: bytes, limit: int ) -> Iterable[SnapshotBranchRow]: count = 0 for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)): if branch.name >= from_: count += 1 yield branch if count >= limit: break ########################## # 'origin' table ########################## def origin_add_one(self, origin: OriginRow) -> None: self._origins.insert(origin) self.increment_counter("origin", 1) def origin_get_by_sha1(self, sha1: bytes) -> Iterable[OriginRow]: return self._origins.get_from_partition_key((sha1,)) def origin_get_by_url(self, url: str) -> Iterable[OriginRow]: return self.origin_get_by_sha1(origin_url_to_sha1(url)) def origin_list( self, start_token: int, limit: int ) -> Iterable[Tuple[int, OriginRow]]: """Returns an iterable of (token, origin)""" matches = [ (token, row) for (token, partition) in self._origins.data.items() for (clustering_key, row) in partition.items() if token >= start_token ] matches.sort() return matches[0:limit] def origin_iter_all(self) -> Iterable[OriginRow]: return ( row for (token, partition) in self._origins.data.items() for (clustering_key, row) in partition.items() ) def origin_generate_unique_visit_id(self, origin_url: str) -> int: origin = list(self.origin_get_by_url(origin_url))[0] visit_id = origin.next_visit_id origin.next_visit_id += 1 return visit_id ########################## # 'origin_visit' table ########################## def origin_visit_get( self, origin_url: str, last_visit: Optional[int], limit: int, order: ListOrder, ) -> Iterable[OriginVisitRow]: visits = list(self._origin_visits.get_from_partition_key((origin_url,))) if last_visit is not None: if order == ListOrder.ASC: visits = [v for v in visits if v.visit > last_visit] else: visits = [v for v in visits if v.visit < last_visit] visits.sort(key=lambda v: v.visit, reverse=order == ListOrder.DESC) visits = visits[0:limit] return visits def origin_visit_add_one(self, visit: OriginVisitRow) -> None: self._origin_visits.insert(visit) self.increment_counter("origin_visit", 1) def origin_visit_get_one( self, origin_url: str, visit_id: int ) -> Optional[OriginVisitRow]: return self._origin_visits.get_from_primary_key((origin_url, visit_id)) def origin_visit_get_all(self, origin_url: str) -> Iterable[OriginVisitRow]: return self._origin_visits.get_from_partition_key((origin_url,)) def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]: """Returns all origin visits in order from this token, and wraps around the token space.""" return ( row for (token, partition) in self._origin_visits.data.items() for (clustering_key, row) in partition.items() ) ########################## # 'origin_visit_status' table ########################## def origin_visit_status_get_range( self, origin: str, visit: int, date_from: Optional[datetime.datetime], limit: int, order: ListOrder, ) -> Iterable[OriginVisitStatusRow]: statuses = list(self.origin_visit_status_get(origin, visit)) if date_from is not None: if order == ListOrder.ASC: statuses = [s for s in statuses if s.date >= date_from] else: statuses = [s for s in statuses if s.date <= date_from] statuses.sort(key=lambda s: s.date, reverse=order == ListOrder.DESC) return statuses[0:limit] def origin_visit_status_add_one(self, visit_update: OriginVisitStatusRow) -> None: self._origin_visit_statuses.insert(visit_update) self.increment_counter("origin_visit_status", 1) def origin_visit_status_get_latest( self, origin: str, visit: int, ) -> Optional[OriginVisitStatusRow]: """Given an origin visit id, return its latest origin_visit_status """ return next(self.origin_visit_status_get(origin, visit), None) def origin_visit_status_get( self, origin: str, visit: int, ) -> Iterator[OriginVisitStatusRow]: """Return all origin visit statuses for a given visit """ statuses = [ s for s in self._origin_visit_statuses.get_from_partition_key((origin,)) if s.visit == visit ] statuses.sort(key=lambda s: s.date, reverse=True) return iter(statuses) ########################## # 'metadata_authority' table ########################## def metadata_authority_add(self, authority: MetadataAuthorityRow): self._metadata_authorities.insert(authority) self.increment_counter("metadata_authority", 1) def metadata_authority_get(self, type, url) -> Optional[MetadataAuthorityRow]: return self._metadata_authorities.get_from_primary_key((url, type)) ########################## # 'metadata_fetcher' table ########################## def metadata_fetcher_add(self, fetcher: MetadataFetcherRow): self._metadata_fetchers.insert(fetcher) self.increment_counter("metadata_fetcher", 1) def metadata_fetcher_get(self, name, version) -> Optional[MetadataAuthorityRow]: return self._metadata_fetchers.get_from_primary_key((name, version)) ######################### # 'raw_extrinsic_metadata' table ######################### def raw_extrinsic_metadata_add(self, raw_extrinsic_metadata): self._raw_extrinsic_metadata.insert(raw_extrinsic_metadata) self.increment_counter("raw_extrinsic_metadata", 1) def raw_extrinsic_metadata_get_after_date( self, id: str, authority_type: str, authority_url: str, after: datetime.datetime, ) -> Iterable[RawExtrinsicMetadataRow]: metadata = self.raw_extrinsic_metadata_get(id, authority_type, authority_url) return (m for m in metadata if m.discovery_date > after) def raw_extrinsic_metadata_get_after_date_and_fetcher( self, id: str, authority_type: str, authority_url: str, after_date: datetime.datetime, after_fetcher_name: str, after_fetcher_version: str, ) -> Iterable[RawExtrinsicMetadataRow]: metadata = self._raw_extrinsic_metadata.get_from_partition_key((id,)) after_tuple = (after_date, after_fetcher_name, after_fetcher_version) return ( m for m in metadata if m.authority_type == authority_type and m.authority_url == authority_url and (m.discovery_date, m.fetcher_name, m.fetcher_version) > after_tuple ) def raw_extrinsic_metadata_get( self, id: str, authority_type: str, authority_url: str ) -> Iterable[RawExtrinsicMetadataRow]: metadata = self._raw_extrinsic_metadata.get_from_partition_key((id,)) return ( m for m in metadata if m.authority_type == authority_type and m.authority_url == authority_url ) class InMemoryStorage(CassandraStorage): _cql_runner: InMemoryCqlRunner # type: ignore def __init__(self, journal_writer=None): self.reset() self.journal_writer = JournalWriter(journal_writer) def reset(self): self._cql_runner = InMemoryCqlRunner() - self.objstorage = ObjStorage({"cls": "memory", "args": {}}) + self.objstorage = ObjStorage({"cls": "memory"}) def check_config(self, *, check_write: bool) -> bool: return True diff --git a/swh/storage/pytest_plugin.py b/swh/storage/pytest_plugin.py index 70e7ac47..01565700 100644 --- a/swh/storage/pytest_plugin.py +++ b/swh/storage/pytest_plugin.py @@ -1,200 +1,200 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob from os import environ, path import subprocess from typing import Union import pytest from pytest_postgresql import factories from pytest_postgresql.janitor import DatabaseJanitor, Version, psycopg2 from swh.core.utils import numfile_sortkey as sortkey import swh.storage from swh.storage import get_storage from swh.storage.tests.storage_data import StorageData SQL_DIR = path.join(path.dirname(swh.storage.__file__), "sql") environ["LC_ALL"] = "C.UTF-8" DUMP_FILES = path.join(SQL_DIR, "*.sql") # the postgres_fact factory fixture below is mostly a copy of the code # from pytest-postgresql. We need a custom version here to be able to # specify our version of the DBJanitor we use. def postgresql_fact(process_fixture_name, db_name=None, dump_files=DUMP_FILES): @pytest.fixture def postgresql_factory(request): """ Fixture factory for PostgreSQL. :param FixtureRequest request: fixture request object :rtype: psycopg2.connection :returns: postgresql client """ config = factories.get_config(request) if not psycopg2: raise ImportError("No module named psycopg2. Please install it.") proc_fixture = request.getfixturevalue(process_fixture_name) # _, config = try_import('psycopg2', request) pg_host = proc_fixture.host pg_port = proc_fixture.port pg_user = proc_fixture.user pg_options = proc_fixture.options pg_db = db_name or config["dbname"] with SwhDatabaseJanitor( pg_user, pg_host, pg_port, pg_db, proc_fixture.version, dump_files=dump_files, ): connection = psycopg2.connect( dbname=pg_db, user=pg_user, host=pg_host, port=pg_port, options=pg_options, ) yield connection connection.close() return postgresql_factory swh_storage_postgresql = postgresql_fact("postgresql_proc", db_name="storage") @pytest.fixture def swh_storage_backend_config(swh_storage_postgresql): """Basic pg storage configuration with no journal collaborator (to avoid pulling optional dependency on clients of this fixture) """ yield { "cls": "local", "db": swh_storage_postgresql.dsn, - "objstorage": {"cls": "memory", "args": {}}, + "objstorage": {"cls": "memory"}, "check_config": {"check_write": True}, } @pytest.fixture def swh_storage(swh_storage_backend_config): return get_storage(**swh_storage_backend_config) # This version of the DatabaseJanitor implement a different setup/teardown # behavior than than the stock one: instead of dropping, creating and # initializing the database for each test, it create and initialize the db only # once, then it truncate the tables. This is needed to have acceptable test # performances. class SwhDatabaseJanitor(DatabaseJanitor): def __init__( self, user: str, host: str, port: str, db_name: str, version: Union[str, float, Version], dump_files: str = DUMP_FILES, ) -> None: super().__init__(user, host, port, db_name, version) self.dump_files = sorted(glob.glob(dump_files), key=sortkey) def db_setup(self): conninfo = ( f"host={self.host} user={self.user} port={self.port} dbname={self.db_name}" ) for fname in self.dump_files: subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", fname, ] ) def db_reset(self): with psycopg2.connect( dbname=self.db_name, user=self.user, host=self.host, port=self.port, ) as cnx: with cnx.cursor() as cur: cur.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = %s", ("public",), ) tables = set(table for (table,) in cur.fetchall()) - {"dbversion"} for table in tables: cur.execute("truncate table %s cascade" % table) cur.execute( "SELECT sequence_name FROM information_schema.sequences " "WHERE sequence_schema = %s", ("public",), ) seqs = set(seq for (seq,) in cur.fetchall()) for seq in seqs: cur.execute("ALTER SEQUENCE %s RESTART;" % seq) cnx.commit() def init(self): with self.cursor() as cur: cur.execute( "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,) ) db_exists = cur.fetchone()[0] == 1 if db_exists: cur.execute( "UPDATE pg_database SET datallowconn=true " "WHERE datname = %s;", (self.db_name,), ) if db_exists: self.db_reset() else: with self.cursor() as cur: cur.execute('CREATE DATABASE "{}";'.format(self.db_name)) self.db_setup() def drop(self): pid_column = "pid" with self.cursor() as cur: cur.execute( "UPDATE pg_database SET datallowconn=false " "WHERE datname = %s;", (self.db_name,), ) cur.execute( "SELECT pg_terminate_backend(pg_stat_activity.{})" "FROM pg_stat_activity " "WHERE pg_stat_activity.datname = %s;".format(pid_column), (self.db_name,), ) @pytest.fixture def sample_data() -> StorageData: """Pre-defined sample storage object data to manipulate Returns: StorageData whose attribute keys are data model objects. Either multiple objects: contents, directories, revisions, releases, ... or simple ones: content, directory, revision, release, ... """ return StorageData() diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py index 00bebb3b..4ec1ffa3 100644 --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -1,415 +1,415 @@ # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import os import signal import socket import subprocess import time from typing import Dict import attr import pytest from swh.core.api.classes import stream_results from swh.storage import get_storage from swh.storage.cassandra import create_keyspace from swh.storage.cassandra.model import ContentRow from swh.storage.cassandra.schema import HASH_ALGORITHMS, TABLES from swh.storage.tests.storage_tests import ( TestStorageGeneratedData as _TestStorageGeneratedData, ) from swh.storage.tests.storage_tests import TestStorage as _TestStorage from swh.storage.utils import now CONFIG_TEMPLATE = """ data_file_directories: - {data_dir}/data commitlog_directory: {data_dir}/commitlog hints_directory: {data_dir}/hints saved_caches_directory: {data_dir}/saved_caches commitlog_sync: periodic commitlog_sync_period_in_ms: 1000000 partitioner: org.apache.cassandra.dht.Murmur3Partitioner endpoint_snitch: SimpleSnitch seed_provider: - class_name: org.apache.cassandra.locator.SimpleSeedProvider parameters: - seeds: "127.0.0.1" storage_port: {storage_port} native_transport_port: {native_transport_port} start_native_transport: true listen_address: 127.0.0.1 enable_user_defined_functions: true # speed-up by disabling period saving to disk key_cache_save_period: 0 row_cache_save_period: 0 trickle_fsync: false commitlog_sync_period_in_ms: 100000 """ def free_port(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(("127.0.0.1", 0)) port = sock.getsockname()[1] sock.close() return port def wait_for_peer(addr, port): wait_until = time.time() + 20 while time.time() < wait_until: try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((addr, port)) except ConnectionRefusedError: time.sleep(0.1) else: sock.close() return True return False @pytest.fixture(scope="session") def cassandra_cluster(tmpdir_factory): cassandra_conf = tmpdir_factory.mktemp("cassandra_conf") cassandra_data = tmpdir_factory.mktemp("cassandra_data") cassandra_log = tmpdir_factory.mktemp("cassandra_log") native_transport_port = free_port() storage_port = free_port() jmx_port = free_port() with open(str(cassandra_conf.join("cassandra.yaml")), "w") as fd: fd.write( CONFIG_TEMPLATE.format( data_dir=str(cassandra_data), storage_port=storage_port, native_transport_port=native_transport_port, ) ) if os.environ.get("SWH_CASSANDRA_LOG"): stdout = stderr = None else: stdout = stderr = subprocess.DEVNULL cassandra_bin = os.environ.get("SWH_CASSANDRA_BIN", "/usr/sbin/cassandra") proc = subprocess.Popen( [ cassandra_bin, "-Dcassandra.config=file://%s/cassandra.yaml" % cassandra_conf, "-Dcassandra.logdir=%s" % cassandra_log, "-Dcassandra.jmx.local.port=%d" % jmx_port, "-Dcassandra-foreground=yes", ], start_new_session=True, env={ "MAX_HEAP_SIZE": "300M", "HEAP_NEWSIZE": "50M", "JVM_OPTS": "-Xlog:gc=error:file=%s/gc.log" % cassandra_log, }, stdout=stdout, stderr=stderr, ) running = wait_for_peer("127.0.0.1", native_transport_port) if running: yield (["127.0.0.1"], native_transport_port) if not running or os.environ.get("SWH_CASSANDRA_LOG"): debug_log_path = str(cassandra_log.join("debug.log")) if os.path.exists(debug_log_path): with open(debug_log_path) as fd: print(fd.read()) if not running: raise Exception("cassandra process stopped unexpectedly.") pgrp = os.getpgid(proc.pid) os.killpg(pgrp, signal.SIGKILL) class RequestHandler: def on_request(self, rf): if hasattr(rf.message, "query"): print() print(rf.message.query) @pytest.fixture(scope="session") def keyspace(cassandra_cluster): (hosts, port) = cassandra_cluster keyspace = os.urandom(10).hex() create_keyspace(hosts, keyspace, port) return keyspace # tests are executed using imported classes (TestStorage and # TestStorageGeneratedData) using overloaded swh_storage fixture # below @pytest.fixture def swh_storage_backend_config(cassandra_cluster, keyspace): (hosts, port) = cassandra_cluster storage_config = dict( cls="cassandra", hosts=hosts, port=port, keyspace=keyspace, - journal_writer={"cls": "memory",}, - objstorage={"cls": "memory", "args": {},}, + journal_writer={"cls": "memory"}, + objstorage={"cls": "memory"}, ) yield storage_config storage = get_storage(**storage_config) for table in TABLES: storage._cql_runner._session.execute('TRUNCATE TABLE "%s"' % table) storage._cql_runner._cluster.shutdown() @pytest.mark.cassandra class TestCassandraStorage(_TestStorage): def test_content_add_murmur3_collision(self, swh_storage, mocker, sample_data): """The Murmur3 token is used as link from index tables to the main table; and non-matching contents with colliding murmur3-hash are filtered-out when reading the main table. This test checks the content methods do filter out these collision. """ called = 0 cont, cont2 = sample_data.contents[:2] # always return a token def mock_cgtfsh(algo, hash_): nonlocal called called += 1 assert algo in ("sha1", "sha1_git") return [123456] mocker.patch.object( swh_storage._cql_runner, "content_get_tokens_from_single_hash", mock_cgtfsh, ) # For all tokens, always return cont def mock_cgft(token): nonlocal called called += 1 return [ ContentRow( length=10, ctime=datetime.datetime.now(), status="present", **{algo: getattr(cont, algo) for algo in HASH_ALGORITHMS}, ) ] mocker.patch.object( swh_storage._cql_runner, "content_get_from_token", mock_cgft ) actual_result = swh_storage.content_add([cont2]) assert called == 4 assert actual_result == { "content:add": 1, "content:add:bytes": cont2.length, } def test_content_get_metadata_murmur3_collision( self, swh_storage, mocker, sample_data ): """The Murmur3 token is used as link from index tables to the main table; and non-matching contents with colliding murmur3-hash are filtered-out when reading the main table. This test checks the content methods do filter out these collisions. """ called = 0 cont, cont2 = [attr.evolve(c, ctime=now()) for c in sample_data.contents[:2]] # always return a token def mock_cgtfsh(algo, hash_): nonlocal called called += 1 assert algo in ("sha1", "sha1_git") return [123456] mocker.patch.object( swh_storage._cql_runner, "content_get_tokens_from_single_hash", mock_cgtfsh, ) # For all tokens, always return cont and cont2 cols = list(set(cont.to_dict()) - {"data"}) def mock_cgft(token): nonlocal called called += 1 return [ ContentRow(**{col: getattr(cont, col) for col in cols},) for cont in [cont, cont2] ] mocker.patch.object( swh_storage._cql_runner, "content_get_from_token", mock_cgft ) actual_result = swh_storage.content_get([cont.sha1]) assert called == 2 # dropping extra column not returned expected_cont = attr.evolve(cont, data=None) # but cont2 should be filtered out assert actual_result == [expected_cont] def test_content_find_murmur3_collision(self, swh_storage, mocker, sample_data): """The Murmur3 token is used as link from index tables to the main table; and non-matching contents with colliding murmur3-hash are filtered-out when reading the main table. This test checks the content methods do filter out these collisions. """ called = 0 cont, cont2 = [attr.evolve(c, ctime=now()) for c in sample_data.contents[:2]] # always return a token def mock_cgtfsh(algo, hash_): nonlocal called called += 1 assert algo in ("sha1", "sha1_git") return [123456] mocker.patch.object( swh_storage._cql_runner, "content_get_tokens_from_single_hash", mock_cgtfsh, ) # For all tokens, always return cont and cont2 cols = list(set(cont.to_dict()) - {"data"}) def mock_cgft(token): nonlocal called called += 1 return [ ContentRow(**{col: getattr(cont, col) for col in cols}) for cont in [cont, cont2] ] mocker.patch.object( swh_storage._cql_runner, "content_get_from_token", mock_cgft ) expected_content = attr.evolve(cont, data=None) actual_result = swh_storage.content_find({"sha1": cont.sha1}) assert called == 2 # but cont2 should be filtered out assert actual_result == [expected_content] def test_content_get_partition_murmur3_collision( self, swh_storage, mocker, sample_data ): """The Murmur3 token is used as link from index tables to the main table; and non-matching contents with colliding murmur3-hash are filtered-out when reading the main table. This test checks the content_get_partition endpoints return all contents, even the collisions. """ called = 0 rows: Dict[int, Dict] = {} for tok, content in enumerate(sample_data.contents): cont = attr.evolve(content, data=None, ctime=now()) row_d = {**cont.to_dict(), "tok": tok} rows[tok] = row_d # For all tokens, always return cont def mock_content_get_token_range(range_start, range_end, limit): nonlocal called called += 1 for tok in list(rows.keys()) * 3: # yield multiple times the same tok row_d = dict(rows[tok].items()) row_d.pop("tok") yield (tok, ContentRow(**row_d)) mocker.patch.object( swh_storage._cql_runner, "content_get_token_range", mock_content_get_token_range, ) actual_results = list( stream_results( swh_storage.content_get_partition, partition_id=0, nb_partitions=1 ) ) assert called > 0 # everything is listed, even collisions assert len(actual_results) == 3 * len(sample_data.contents) # as we duplicated the returned results, dropping duplicate should yield # the original length assert len(set(actual_results)) == len(sample_data.contents) @pytest.mark.skip("content_update is not yet implemented for Cassandra") def test_content_update(self): pass @pytest.mark.skip( 'The "person" table of the pgsql is a legacy thing, and not ' "supported by the cassandra backend." ) def test_person_fullname_unicity(self): pass @pytest.mark.skip( 'The "person" table of the pgsql is a legacy thing, and not ' "supported by the cassandra backend." ) def test_person_get(self): pass @pytest.mark.skip("Not supported by Cassandra") def test_origin_count(self): pass @pytest.mark.cassandra class TestCassandraStorageGeneratedData(_TestStorageGeneratedData): @pytest.mark.skip("Not supported by Cassandra") def test_origin_count(self): pass @pytest.mark.skip("Not supported by Cassandra") def test_origin_count_with_visit_no_visits(self): pass @pytest.mark.skip("Not supported by Cassandra") def test_origin_count_with_visit_with_visits_and_snapshot(self): pass @pytest.mark.skip("Not supported by Cassandra") def test_origin_count_with_visit_with_visits_no_snapshot(self): pass diff --git a/swh/storage/tests/test_init.py b/swh/storage/tests/test_init.py index d7ed4fd9..e448b278 100644 --- a/swh/storage/tests/test_init.py +++ b/swh/storage/tests/test_init.py @@ -1,236 +1,232 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from unittest.mock import patch import pytest from swh.core.pytest_plugin import RPCTestAdapter from swh.storage import get_storage from swh.storage.api import client, server from swh.storage.buffer import BufferingProxyStorage from swh.storage.filter import FilteringProxyStorage from swh.storage.in_memory import InMemoryStorage from swh.storage.postgresql.storage import Storage as DbStorage from swh.storage.retry import RetryingProxyStorage STORAGES = [ pytest.param(cls, real_class, kwargs, id=cls) for (cls, real_class, kwargs) in [ ("remote", client.RemoteStorage, {"url": "url"}), ("memory", InMemoryStorage, {}), ( "local", DbStorage, - {"db": "postgresql://db", "objstorage": {"cls": "memory", "args": {}}}, + {"db": "postgresql://db", "objstorage": {"cls": "memory"}}, ), ("filter", FilteringProxyStorage, {"storage": {"cls": "memory"}}), ("buffer", BufferingProxyStorage, {"storage": {"cls": "memory"}}), ("retry", RetryingProxyStorage, {"storage": {"cls": "memory"}}), ] ] @pytest.mark.parametrize("cls,real_class,args", STORAGES) @patch("swh.storage.postgresql.storage.psycopg2.pool") def test_get_storage(mock_pool, cls, real_class, args): """Instantiating an existing storage should be ok """ mock_pool.ThreadedConnectionPool.return_value = None actual_storage = get_storage(cls, **args) assert actual_storage is not None assert isinstance(actual_storage, real_class) @pytest.mark.parametrize("cls,real_class,args", STORAGES) @patch("swh.storage.postgresql.storage.psycopg2.pool") def test_get_storage_legacy_args(mock_pool, cls, real_class, args): """Instantiating an existing storage should be ok even with the legacy explicit 'args' keys """ mock_pool.ThreadedConnectionPool.return_value = None with pytest.warns(DeprecationWarning): actual_storage = get_storage(cls, args=args) assert actual_storage is not None assert isinstance(actual_storage, real_class) def test_get_storage_failure(): """Instantiating an unknown storage should raise """ with pytest.raises(ValueError, match="Unknown storage class `unknown`"): - get_storage("unknown", args=[]) + get_storage("unknown") def test_get_storage_pipeline(): config = { "cls": "pipeline", "steps": [ {"cls": "filter",}, {"cls": "buffer", "min_batch_size": {"content": 10,},}, {"cls": "memory",}, ], } storage = get_storage(**config) assert isinstance(storage, FilteringProxyStorage) assert isinstance(storage.storage, BufferingProxyStorage) assert isinstance(storage.storage.storage, InMemoryStorage) def test_get_storage_pipeline_legacy_args(): config = { "cls": "pipeline", "steps": [ {"cls": "filter",}, {"cls": "buffer", "args": {"min_batch_size": {"content": 10,},}}, {"cls": "memory",}, ], } with pytest.warns(DeprecationWarning): storage = get_storage(**config) assert isinstance(storage, FilteringProxyStorage) assert isinstance(storage.storage, BufferingProxyStorage) assert isinstance(storage.storage.storage, InMemoryStorage) # get_storage's check_config argument tests # the "remote" and "pipeline" cases are tested in dedicated test functions below @pytest.mark.parametrize( "cls,real_class,kwargs", [x for x in STORAGES if x.id not in ("remote", "local")] ) def test_get_storage_check_config(cls, real_class, kwargs, monkeypatch): """Instantiating an existing storage with check_config should be ok """ check_backend_check_config(monkeypatch, dict(cls=cls, **kwargs)) @patch("swh.storage.postgresql.storage.psycopg2.pool") def test_get_storage_local_check_config(mock_pool, monkeypatch): """Instantiating a local storage with check_config should be ok """ mock_pool.ThreadedConnectionPool.return_value = None check_backend_check_config( monkeypatch, - { - "cls": "local", - "db": "postgresql://db", - "objstorage": {"cls": "memory", "args": {}}, - }, + {"cls": "local", "db": "postgresql://db", "objstorage": {"cls": "memory"}}, backend_storage_cls=DbStorage, ) def test_get_storage_pipeline_check_config(monkeypatch): """Test that the check_config option works as intended for a pipelined storage""" config = { "cls": "pipeline", "steps": [ {"cls": "filter",}, {"cls": "buffer", "min_batch_size": {"content": 10,},}, {"cls": "memory",}, ], } check_backend_check_config( monkeypatch, config, ) def test_get_storage_remote_check_config(monkeypatch): """Test that the check_config option works as intended for a remote storage""" monkeypatch.setattr( server, "storage", get_storage(cls="memory", journal_writer={"cls": "memory"}) ) test_client = server.app.test_client() class MockedRemoteStorage(client.RemoteStorage): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.session.adapters.clear() self.session.mount("mock://", RPCTestAdapter(test_client)) monkeypatch.setattr(client, "RemoteStorage", MockedRemoteStorage) config = { "cls": "remote", "url": "mock://example.com", } check_backend_check_config( monkeypatch, config, ) def check_backend_check_config( monkeypatch, config, backend_storage_cls=InMemoryStorage ): """Check the staged/indirect storage (pipeline or remote) works as desired with regard to the check_config option of the get_storage() factory function. If set, the check_config argument is used to call the Storage.check_config() at instantiation time in the get_storage() factory function. This is supposed to be passed through each step of the Storage pipeline until it reached the actual backend's (typically in memory or local) check_config() method which will perform the verification for read/write access to the backend storage. monkeypatch is supposed to be the monkeypatch pytest fixture to be used from the calling test_ function. config is the config dict passed to get_storage() backend_storage_cls is the class of the backend storage to be mocked to simulate the check_config behavior; it should then be the class of the actual backend storage defined in the `config`. """ access = None def mockcheck(self, check_write=False): if access == "none": return False if access == "read": return check_write is False if access == "write": return True monkeypatch.setattr(backend_storage_cls, "check_config", mockcheck) # simulate no read nor write access to the underlying (memory) storage access = "none" # by default, no check, so no complain assert get_storage(**config) # if asked to check, complain with pytest.raises(EnvironmentError): get_storage(check_config={"check_write": False}, **config) with pytest.raises(EnvironmentError): get_storage(check_config={"check_write": True}, **config) # simulate no write access to the underlying (memory) storage access = "read" # by default, no check so no complain assert get_storage(**config) # if asked to check for read access, no complain get_storage(check_config={"check_write": False}, **config) # if asked to check for write access, complain with pytest.raises(EnvironmentError): get_storage(check_config={"check_write": True}, **config) # simulate read & write access to the underlying (memory) storage access = "write" # by default, no check so no complain assert get_storage(**config) # if asked to check for read access, no complain get_storage(check_config={"check_write": False}, **config) # if asked to check for write access, no complain get_storage(check_config={"check_write": True}, **config)