diff --git a/mypy.ini b/mypy.ini index 0e12c6de..bb4da196 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,42 +1,48 @@ [mypy] namespace_packages = True # due to the conditional import logic on swh.journal, in some cases a specific # type: ignore is needed, in other it isn't... warn_unused_ignores = False # support for sqlalchemy magic: see https://github.com/dropbox/sqlalchemy-stubs plugins = sqlmypy # 3rd party libraries without stubs (yet) [mypy-cassandra.*] ignore_missing_imports = True +[mypy-confluent_kafka.*] +ignore_missing_imports = True + # only shipped indirectly via hypothesis [mypy-django.*] ignore_missing_imports = True [mypy-pkg_resources.*] ignore_missing_imports = True [mypy-psycopg2.*] ignore_missing_imports = True [mypy-pytest.*] ignore_missing_imports = True +[mypy-pytest_kafka.*] +ignore_missing_imports = True + [mypy-systemd.daemon.*] ignore_missing_imports = True [mypy-tenacity.*] ignore_missing_imports = True # temporary work-around for landing typing support in spite of the current # journal<->storage dependency loop [mypy-swh.journal.*] ignore_missing_imports = True [mypy-pytest_postgresql.*] ignore_missing_imports = True diff --git a/requirements-swh-journal.txt b/requirements-swh-journal.txt index e43a37f4..6460fde2 100644 --- a/requirements-swh-journal.txt +++ b/requirements-swh-journal.txt @@ -1 +1 @@ -swh.journal >= 0.0.30 +swh.journal >= 0.0.31 diff --git a/swh/storage/cli.py b/swh/storage/cli.py index 1dc7b656..36ed65fc 100644 --- a/swh/storage/cli.py +++ b/swh/storage/cli.py @@ -1,144 +1,193 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import functools import logging import os import warnings import click from swh.core import config from swh.core.cli import CONTEXT_SETTINGS +from swh.journal.cli import get_journal_client +from swh.storage import get_storage from swh.storage.api.server import load_and_check_config, app +try: + from systemd.daemon import notify +except ImportError: + notify = None + @click.group(name="storage", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False,), help="Configuration file.", ) @click.pass_context def storage(ctx, config_file): """Software Heritage Storage tools.""" if not config_file: config_file = os.environ.get("SWH_CONFIG_FILENAME") if config_file: if not os.path.exists(config_file): raise ValueError("%s does not exist" % config_file) conf = config.read(config_file) else: conf = {} ctx.ensure_object(dict) ctx.obj["config"] = conf @storage.command(name="rpc-serve") @click.argument("config-path", default=None, required=False) @click.option( "--host", default="0.0.0.0", metavar="IP", show_default=True, help="Host ip address to bind the server on", ) @click.option( "--port", default=5002, type=click.INT, metavar="PORT", show_default=True, help="Binding port of the server", ) @click.option( "--debug/--no-debug", default=True, help="Indicates if the server should run in debug mode", ) @click.pass_context def serve(ctx, config_path, host, port, debug): """Software Heritage Storage RPC server. Do NOT use this in a production environment. """ if "log_level" in ctx.obj: logging.getLogger("werkzeug").setLevel(ctx.obj["log_level"]) if config_path: # for bw compat warnings.warn( "The `config_path` argument of the `swh storage rpc-server` is now " "deprecated. Please use the --config option of `swh storage` instead.", DeprecationWarning, ) api_cfg = load_and_check_config(config_path, type="any") app.config.update(api_cfg) else: app.config.update(ctx.obj["config"]) app.run(host, port=int(port), debug=bool(debug)) @storage.command() @click.argument("object_type") @click.option("--start-object", default=None) @click.option("--end-object", default=None) @click.option("--dry-run", is_flag=True, default=False) @click.pass_context def backfill(ctx, object_type, start_object, end_object, dry_run): """Run the backfiller The backfiller list objects from a Storage and produce journal entries from there. Typically used to rebuild a journal or compensate for missing objects in a journal (eg. due to a downtime of this later). The configuration file requires the following entries: - brokers: a list of kafka endpoints (the journal) in which entries will be added. - storage_dbconn: URL to connect to the storage DB. - prefix: the prefix of the topics (topics will be .). - client_id: the kafka client ID. """ # for "lazy" loading from swh.storage.backfill import JournalBackfiller try: from systemd.daemon import notify except ImportError: notify = None conf = ctx.obj["config"] backfiller = JournalBackfiller(conf) if notify: notify("READY=1") try: backfiller.run( object_type=object_type, start_object=start_object, end_object=end_object, dry_run=dry_run, ) except KeyboardInterrupt: if notify: notify("STOPPING=1") ctx.exit(0) +@storage.command() +@click.option( + "--stop-after-objects", + "-n", + default=None, + type=int, + help="Stop after processing this many objects. Default is to " "run forever.", +) +@click.pass_context +def replay(ctx, stop_after_objects): + """Fill a Storage by reading a Journal. + + There can be several 'replayers' filling a Storage as long as they use + the same `group-id`. + """ + from swh.storage.replay import process_replay_objects + + conf = ctx.obj["config"] + try: + storage = get_storage(**conf.pop("storage")) + except KeyError: + ctx.fail("You must have a storage configured in your config file.") + + client = get_journal_client(ctx, stop_after_objects=stop_after_objects) + worker_fn = functools.partial(process_replay_objects, storage=storage) + + if notify: + notify("READY=1") + + try: + client.process(worker_fn) + except KeyboardInterrupt: + ctx.exit(0) + else: + print("Done.") + finally: + if notify: + notify("STOPPING=1") + client.close() + + def main(): logging.basicConfig() return serve(auto_envvar_prefix="SWH_STORAGE") if __name__ == "__main__": main() diff --git a/swh/storage/fixer.py b/swh/storage/fixer.py new file mode 100644 index 00000000..7322b352 --- /dev/null +++ b/swh/storage/fixer.py @@ -0,0 +1,293 @@ +import copy +import logging +from typing import Any, Dict, List, Optional +from swh.model.identifiers import normalize_timestamp + +logger = logging.getLogger(__name__) + + +def _fix_content(content: Dict[str, Any]) -> Dict[str, Any]: + """Filters-out invalid 'perms' key that leaked from swh.model.from_disk + to the journal. + + >>> _fix_content({'perms': 0o100644, 'sha1_git': b'foo'}) + {'sha1_git': b'foo'} + + >>> _fix_content({'sha1_git': b'bar'}) + {'sha1_git': b'bar'} + + """ + content = content.copy() + content.pop("perms", None) + return content + + +def _fix_revision_pypi_empty_string(rev): + """PyPI loader failed to encode empty strings as bytes, see: + swh:1:rev:8f0095ee0664867055d03de9bcc8f95b91d8a2b9 + or https://forge.softwareheritage.org/D1772 + """ + rev = { + **rev, + "author": rev["author"].copy(), + "committer": rev["committer"].copy(), + } + if rev["author"].get("email") == "": + rev["author"]["email"] = b"" + if rev["author"].get("name") == "": + rev["author"]["name"] = b"" + if rev["committer"].get("email") == "": + rev["committer"]["email"] = b"" + if rev["committer"].get("name") == "": + rev["committer"]["name"] = b"" + return rev + + +def _fix_revision_transplant_source(rev): + if rev.get("metadata") and rev["metadata"].get("extra_headers"): + rev = copy.deepcopy(rev) + rev["metadata"]["extra_headers"] = [ + [key, value.encode("ascii")] + if key == "transplant_source" and isinstance(value, str) + else [key, value] + for (key, value) in rev["metadata"]["extra_headers"] + ] + return rev + + +def _check_date(date): + """Returns whether the date can be represented in backends with sane + limits on timestamps and timezones (resp. signed 64-bits and + signed 16 bits), and that microseconds is valid (ie. between 0 and 10^6). + """ + if date is None: + return True + date = normalize_timestamp(date) + return ( + (-(2 ** 63) <= date["timestamp"]["seconds"] < 2 ** 63) + and (0 <= date["timestamp"]["microseconds"] < 10 ** 6) + and (-(2 ** 15) <= date["offset"] < 2 ** 15) + ) + + +def _check_revision_date(rev): + """Exclude revisions with invalid dates. + See https://forge.softwareheritage.org/T1339""" + return _check_date(rev["date"]) and _check_date(rev["committer_date"]) + + +def _fix_revision(revision: Dict[str, Any]) -> Optional[Dict]: + """Fix various legacy revision issues. + + Fix author/committer person: + + >>> from pprint import pprint + >>> date = { + ... 'timestamp': { + ... 'seconds': 1565096932, + ... 'microseconds': 0, + ... }, + ... 'offset': 0, + ... } + >>> rev0 = _fix_revision({ + ... 'id': b'rev-id', + ... 'author': {'fullname': b'', 'name': '', 'email': ''}, + ... 'committer': {'fullname': b'', 'name': '', 'email': ''}, + ... 'date': date, + ... 'committer_date': date, + ... 'type': 'git', + ... 'message': '', + ... 'directory': b'dir-id', + ... 'synthetic': False, + ... }) + >>> rev0['author'] + {'fullname': b'', 'name': b'', 'email': b''} + >>> rev0['committer'] + {'fullname': b'', 'name': b'', 'email': b''} + + Fix type of 'transplant_source' extra headers: + + >>> rev1 = _fix_revision({ + ... 'id': b'rev-id', + ... 'author': {'fullname': b'', 'name': '', 'email': ''}, + ... 'committer': {'fullname': b'', 'name': '', 'email': ''}, + ... 'date': date, + ... 'committer_date': date, + ... 'metadata': { + ... 'extra_headers': [ + ... ['time_offset_seconds', b'-3600'], + ... ['transplant_source', '29c154a012a70f49df983625090434587622b39e'] + ... ]}, + ... 'type': 'git', + ... 'message': '', + ... 'directory': b'dir-id', + ... 'synthetic': False, + ... }) + >>> pprint(rev1['metadata']['extra_headers']) + [['time_offset_seconds', b'-3600'], + ['transplant_source', b'29c154a012a70f49df983625090434587622b39e']] + + Revision with invalid date are filtered: + + >>> from copy import deepcopy + >>> invalid_date1 = deepcopy(date) + >>> invalid_date1['timestamp']['microseconds'] = 1000000000 # > 10^6 + >>> rev = _fix_revision({ + ... 'author': {'fullname': b'', 'name': '', 'email': ''}, + ... 'committer': {'fullname': b'', 'name': '', 'email': ''}, + ... 'date': invalid_date1, + ... 'committer_date': date, + ... }) + >>> rev is None + True + + >>> invalid_date2 = deepcopy(date) + >>> invalid_date2['timestamp']['seconds'] = 2**70 # > 10^63 + >>> rev = _fix_revision({ + ... 'author': {'fullname': b'', 'name': '', 'email': ''}, + ... 'committer': {'fullname': b'', 'name': '', 'email': ''}, + ... 'date': invalid_date2, + ... 'committer_date': date, + ... }) + >>> rev is None + True + + >>> invalid_date3 = deepcopy(date) + >>> invalid_date3['offset'] = 2**20 # > 10^15 + >>> rev = _fix_revision({ + ... 'author': {'fullname': b'', 'name': '', 'email': ''}, + ... 'committer': {'fullname': b'', 'name': '', 'email': ''}, + ... 'date': date, + ... 'committer_date': invalid_date3, + ... }) + >>> rev is None + True + + """ # noqa + rev = _fix_revision_pypi_empty_string(revision) + rev = _fix_revision_transplant_source(rev) + if not _check_revision_date(rev): + logger.warning( + "Invalid revision date detected: %(revision)s", {"revision": rev} + ) + return None + return rev + + +def _fix_origin(origin: Dict) -> Dict: + """Fix legacy origin with type which is no longer part of the model. + + >>> from pprint import pprint + >>> pprint(_fix_origin({ + ... 'url': 'http://foo', + ... })) + {'url': 'http://foo'} + >>> pprint(_fix_origin({ + ... 'url': 'http://bar', + ... 'type': 'foo', + ... })) + {'url': 'http://bar'} + + """ + o = origin.copy() + o.pop("type", None) + return o + + +def _fix_origin_visit(visit: Dict) -> Dict: + """Fix various legacy origin visit issues. + + `visit['origin']` is a dict instead of an URL: + + >>> from datetime import datetime, timezone + >>> from pprint import pprint + >>> date = datetime(2020, 2, 27, 14, 39, 19, tzinfo=timezone.utc) + >>> pprint(_fix_origin_visit({ + ... 'origin': {'url': 'http://foo'}, + ... 'date': date, + ... 'type': 'git', + ... 'status': 'ongoing', + ... 'snapshot': None, + ... })) + {'date': datetime.datetime(2020, 2, 27, 14, 39, 19, tzinfo=datetime.timezone.utc), + 'metadata': None, + 'origin': 'http://foo', + 'snapshot': None, + 'status': 'ongoing', + 'type': 'git'} + + `visit['type']` is missing , but `origin['visit']['type']` exists: + + >>> pprint(_fix_origin_visit( + ... {'origin': {'type': 'hg', 'url': 'http://foo'}, + ... 'date': date, + ... 'status': 'ongoing', + ... 'snapshot': None, + ... })) + {'date': datetime.datetime(2020, 2, 27, 14, 39, 19, tzinfo=datetime.timezone.utc), + 'metadata': None, + 'origin': 'http://foo', + 'snapshot': None, + 'status': 'ongoing', + 'type': 'hg'} + + Old visit format (origin_visit with no type) raises: + + >>> _fix_origin_visit({ + ... 'origin': {'url': 'http://foo'}, + ... 'date': date, + ... 'status': 'ongoing', + ... 'snapshot': None + ... }) + Traceback (most recent call last): + ... + ValueError: Old origin visit format detected... + + >>> _fix_origin_visit({ + ... 'origin': 'http://foo', + ... 'date': date, + ... 'status': 'ongoing', + ... 'snapshot': None + ... }) + Traceback (most recent call last): + ... + ValueError: Old origin visit format detected... + + """ # noqa + visit = visit.copy() + if "type" not in visit: + if isinstance(visit["origin"], dict) and "type" in visit["origin"]: + # Very old version of the schema: visits did not have a type, + # but their 'origin' field was a dict with a 'type' key. + visit["type"] = visit["origin"]["type"] + else: + # Very old schema version: 'type' is missing, stop early + + # We expect the journal's origin_visit topic to no longer reference + # such visits. If it does, the replayer must crash so we can fix + # the journal's topic. + raise ValueError(f"Old origin visit format detected: {visit}") + if isinstance(visit["origin"], dict): + # Old version of the schema: visit['origin'] was a dict. + visit["origin"] = visit["origin"]["url"] + if "metadata" not in visit: + visit["metadata"] = None + return visit + + +def fix_objects(object_type: str, objects: List[Dict]) -> List[Dict]: + """ + Fix legacy objects from the journal to bring them up to date with the + latest storage schema. + """ + if object_type == "content": + return [_fix_content(v) for v in objects] + elif object_type == "revision": + revisions = [_fix_revision(v) for v in objects] + return [rev for rev in revisions if rev is not None] + elif object_type == "origin": + return [_fix_origin(v) for v in objects] + elif object_type == "origin_visit": + return [_fix_origin_visit(v) for v in objects] + else: + return objects diff --git a/swh/storage/replay.py b/swh/storage/replay.py new file mode 100644 index 00000000..0a15d08d --- /dev/null +++ b/swh/storage/replay.py @@ -0,0 +1,128 @@ +# Copyright (C) 2019-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import logging +from typing import Any, Callable, Dict, Iterable, List + +try: + from systemd.daemon import notify +except ImportError: + notify = None + +from swh.core.statsd import statsd +from swh.storage.fixer import fix_objects + +from swh.model.model import ( + BaseContent, + BaseModel, + Content, + Directory, + Origin, + OriginVisit, + Revision, + SkippedContent, + Snapshot, + Release, +) +from swh.storage.exc import HashCollision + +logger = logging.getLogger(__name__) + +GRAPH_OPERATIONS_METRIC = "swh_graph_replayer_operations_total" +GRAPH_DURATION_METRIC = "swh_graph_replayer_duration_seconds" + + +object_converter_fn: Dict[str, Callable[[Dict], BaseModel]] = { + "origin": Origin.from_dict, + "origin_visit": OriginVisit.from_dict, + "snapshot": Snapshot.from_dict, + "revision": Revision.from_dict, + "release": Release.from_dict, + "directory": Directory.from_dict, + "content": Content.from_dict, + "skipped_content": SkippedContent.from_dict, +} + + +def process_replay_objects(all_objects, *, storage): + for (object_type, objects) in all_objects.items(): + logger.debug("Inserting %s %s objects", len(objects), object_type) + with statsd.timed(GRAPH_DURATION_METRIC, tags={"object_type": object_type}): + _insert_objects(object_type, objects, storage) + statsd.increment( + GRAPH_OPERATIONS_METRIC, len(objects), tags={"object_type": object_type} + ) + if notify: + notify("WATCHDOG=1") + + +def collision_aware_content_add( + content_add_fn: Callable[[Iterable[Any]], None], contents: List[BaseContent] +) -> None: + """Add contents to storage. If a hash collision is detected, an error is + logged. Then this adds the other non colliding contents to the storage. + + Args: + content_add_fn: Storage content callable + contents: List of contents or skipped contents to add to storage + + """ + if not contents: + return + colliding_content_hashes: List[Dict[str, Any]] = [] + while True: + try: + content_add_fn(contents) + except HashCollision as e: + colliding_content_hashes.append( + { + "algo": e.algo, + "hash": e.hash_id, # hex hash id + "objects": e.colliding_contents, # hex hashes + } + ) + colliding_hashes = e.colliding_content_hashes() + # Drop the colliding contents from the transaction + contents = [c for c in contents if c.hashes() not in colliding_hashes] + else: + # Successfully added contents, we are done + break + if colliding_content_hashes: + for collision in colliding_content_hashes: + logger.error("Collision detected: %(collision)s", {"collision": collision}) + + +def _insert_objects(object_type: str, objects: List[Dict], storage) -> None: + """Insert objects of type object_type in the storage. + + """ + objects = fix_objects(object_type, objects) + + if object_type == "content": + contents: List[BaseContent] = [] + skipped_contents: List[BaseContent] = [] + for content in objects: + c = BaseContent.from_dict(content) + if isinstance(c, SkippedContent): + skipped_contents.append(c) + else: + contents.append(c) + + collision_aware_content_add(storage.skipped_content_add, skipped_contents) + collision_aware_content_add(storage.content_add_metadata, contents) + elif object_type == "origin_visit": + visits: List[OriginVisit] = [] + origins: List[Origin] = [] + for obj in objects: + visit = OriginVisit.from_dict(obj) + visits.append(visit) + origins.append(Origin(url=visit.origin)) + storage.origin_add(origins) + storage.origin_visit_upsert(visits) + elif object_type in ("directory", "revision", "release", "snapshot", "origin"): + method = getattr(storage, object_type + "_add") + method(object_converter_fn[object_type](o) for o in objects) + else: + logger.warning("Received a series of %s, this should not happen", object_type) diff --git a/swh/storage/tests/conftest.py b/swh/storage/tests/conftest.py index 6cc476f0..6e095d13 100644 --- a/swh/storage/tests/conftest.py +++ b/swh/storage/tests/conftest.py @@ -1,254 +1,252 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import pytest from typing import Union from pytest_postgresql import factories from pytest_postgresql.janitor import DatabaseJanitor, psycopg2, Version from os import path, environ from hypothesis import settings from typing import Dict import swh.storage from swh.core.utils import numfile_sortkey as sortkey - from swh.model.tests.generate_testdata import gen_contents, gen_origins from swh.model.model import ( Content, Directory, Origin, OriginVisit, Release, Revision, SkippedContent, Snapshot, ) OBJECT_FACTORY = { "content": Content.from_dict, "directory": Directory.from_dict, "origin": Origin.from_dict, "origin_visit": OriginVisit.from_dict, "release": Release.from_dict, "revision": Revision.from_dict, "skipped_content": SkippedContent.from_dict, "snapshot": Snapshot.from_dict, } - SQL_DIR = path.join(path.dirname(swh.storage.__file__), "sql") environ["LC_ALL"] = "C.UTF-8" DUMP_FILES = path.join(SQL_DIR, "*.sql") # define tests profile. Full documentation is at: # https://hypothesis.readthedocs.io/en/latest/settings.html#settings-profiles settings.register_profile("fast", max_examples=5, deadline=5000) settings.register_profile("slow", max_examples=20, deadline=5000) @pytest.fixture def swh_storage_backend_config(postgresql_proc, swh_storage_postgresql): yield { "cls": "local", "db": "postgresql://{user}@{host}:{port}/{dbname}".format( host=postgresql_proc.host, port=postgresql_proc.port, user="postgres", dbname="tests", ), "objstorage": {"cls": "memory", "args": {}}, "journal_writer": {"cls": "memory",}, } @pytest.fixture def swh_storage(swh_storage_backend_config): return swh.storage.get_storage(cls="validate", storage=swh_storage_backend_config) @pytest.fixture def swh_contents(swh_storage): contents = gen_contents(n=20) swh_storage.content_add([c for c in contents if c["status"] != "absent"]) swh_storage.skipped_content_add([c for c in contents if c["status"] == "absent"]) return contents @pytest.fixture def swh_origins(swh_storage): origins = gen_origins(n=100) swh_storage.origin_add(origins) return origins # the postgres_fact factory fixture below is mostly a copy of the code # from pytest-postgresql. We need a custom version here to be able to # specify our version of the DBJanitor we use. def postgresql_fact(process_fixture_name, db_name=None, dump_files=DUMP_FILES): @pytest.fixture def postgresql_factory(request): """ Fixture factory for PostgreSQL. :param FixtureRequest request: fixture request object :rtype: psycopg2.connection :returns: postgresql client """ config = factories.get_config(request) if not psycopg2: raise ImportError("No module named psycopg2. Please install it.") proc_fixture = request.getfixturevalue(process_fixture_name) # _, config = try_import('psycopg2', request) pg_host = proc_fixture.host pg_port = proc_fixture.port pg_user = proc_fixture.user pg_options = proc_fixture.options pg_db = db_name or config["dbname"] with SwhDatabaseJanitor( pg_user, pg_host, pg_port, pg_db, proc_fixture.version, dump_files=dump_files, ): connection = psycopg2.connect( dbname=pg_db, user=pg_user, host=pg_host, port=pg_port, options=pg_options, ) yield connection connection.close() return postgresql_factory swh_storage_postgresql = postgresql_fact("postgresql_proc") # This version of the DatabaseJanitor implement a different setup/teardown # behavior than than the stock one: instead of dropping, creating and # initializing the database for each test, it create and initialize the db only # once, then it truncate the tables. This is needed to have acceptable test # performances. class SwhDatabaseJanitor(DatabaseJanitor): def __init__( self, user: str, host: str, port: str, db_name: str, version: Union[str, float, Version], dump_files: str = DUMP_FILES, ) -> None: super().__init__(user, host, port, db_name, version) self.dump_files = sorted(glob.glob(dump_files), key=sortkey) def db_setup(self): with psycopg2.connect( dbname=self.db_name, user=self.user, host=self.host, port=self.port, ) as cnx: with cnx.cursor() as cur: for fname in self.dump_files: with open(fname) as fobj: sql = fobj.read().replace("concurrently", "").strip() if sql: cur.execute(sql) cnx.commit() def db_reset(self): with psycopg2.connect( dbname=self.db_name, user=self.user, host=self.host, port=self.port, ) as cnx: with cnx.cursor() as cur: cur.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = %s", ("public",), ) tables = set(table for (table,) in cur.fetchall()) for table in tables: cur.execute("truncate table %s cascade" % table) cur.execute( "SELECT sequence_name FROM information_schema.sequences " "WHERE sequence_schema = %s", ("public",), ) seqs = set(seq for (seq,) in cur.fetchall()) for seq in seqs: cur.execute("ALTER SEQUENCE %s RESTART;" % seq) cnx.commit() def init(self): with self.cursor() as cur: cur.execute( "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,) ) db_exists = cur.fetchone()[0] == 1 if db_exists: cur.execute( "UPDATE pg_database SET datallowconn=true " "WHERE datname = %s;", (self.db_name,), ) if db_exists: self.db_reset() else: with self.cursor() as cur: cur.execute('CREATE DATABASE "{}";'.format(self.db_name)) self.db_setup() def drop(self): pid_column = "pid" with self.cursor() as cur: cur.execute( "UPDATE pg_database SET datallowconn=false " "WHERE datname = %s;", (self.db_name,), ) cur.execute( "SELECT pg_terminate_backend(pg_stat_activity.{})" "FROM pg_stat_activity " "WHERE pg_stat_activity.datname = %s;".format(pid_column), (self.db_name,), ) @pytest.fixture def sample_data() -> Dict: """Pre-defined sample storage object data to manipulate Returns: Dict of data (keys: content, directory, revision, release, person, origin) """ from .storage_data import data return { "content": [data.cont, data.cont2], "content_metadata": [data.cont3], "skipped_content": [data.skipped_cont, data.skipped_cont2], "person": [data.person], "directory": [data.dir2, data.dir], "revision": [data.revision, data.revision2, data.revision3], "release": [data.release, data.release2, data.release3], "snapshot": [data.snapshot], "origin": [data.origin, data.origin2], "tool": [data.metadata_tool], "provider": [data.provider], "origin_metadata": [data.origin_metadata, data.origin_metadata2], } diff --git a/swh/storage/tests/test_cli.py b/swh/storage/tests/test_cli.py index 138ab57c..046bb4bd 100644 --- a/swh/storage/tests/test_cli.py +++ b/swh/storage/tests/test_cli.py @@ -1,119 +1,192 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import copy import logging -import tempfile +import re import requests +import tempfile import threading import time +import yaml from contextlib import contextmanager -from aiohttp.test_utils import unused_port +from typing import Any, Dict +from unittest.mock import patch -import yaml +import pytest +from aiohttp.test_utils import unused_port from click.testing import CliRunner +from confluent_kafka import Producer +from swh.journal.serializers import key_to_kafka, value_to_kafka +from swh.storage import get_storage from swh.storage.cli import storage as cli logger = logging.getLogger(__name__) CLI_CONFIG = { "storage": {"cls": "memory",}, } -def invoke(*args, env=None): +@pytest.fixture +def storage(): + """An swh-storage object that gets injected into the CLI functions.""" + storage_config = {"cls": "pipeline", "steps": [{"cls": "memory"},]} + storage = get_storage(**storage_config) + with patch("swh.storage.cli.get_storage") as get_storage_mock: + get_storage_mock.return_value = storage + yield storage + + +@pytest.fixture +def monkeypatch_retry_sleep(monkeypatch): + from swh.journal.replay import copy_object, obj_in_objstorage + + monkeypatch.setattr(copy_object.retry, "sleep", lambda x: None) + monkeypatch.setattr(obj_in_objstorage.retry, "sleep", lambda x: None) + + +def invoke(*args, env=None, journal_config=None): config = copy.deepcopy(CLI_CONFIG) + if journal_config: + config["journal_client"] = journal_config.copy() + config["journal_client"]["cls"] = "kafka" + runner = CliRunner() with tempfile.NamedTemporaryFile("a", suffix=".yml") as config_fd: yaml.dump(config, config_fd) config_fd.seek(0) args = ["-C" + config_fd.name] + list(args) return runner.invoke(cli, args, obj={"log_level": logging.DEBUG}, env=env,) @contextmanager def check_rpc_serve(): # this context manager adds, if needed, a /quit route to the flask application that # uses the werzeuk tric to exit the test server started by app.run() # # The /testing/ gathering code is executed in a thread while the main thread runs # the test server. Results of the tests consist in a list of requests Response # objects stored in the `response` shared list. # # This convoluted execution code is needed because the flask app needs to run in the # main thread. # # The context manager will yield the port on which tests (GET queries) will be done, # so the test function should start the RPC server on this port in the body of the # context manager. from swh.storage.api.server import app from flask import request if "/quit" not in [r.rule for r in app.url_map.iter_rules()]: @app.route("/quit") def quit_app(): request.environ["werkzeug.server.shutdown"]() return "Bye" port = unused_port() responses = [] def run_tests(): # we do run the "test" part in the thread because flask does not like the # app.run() to be executed in a (non-main) thread def get(path): for i in range(5): try: resp = requests.get(f"http://127.0.0.1:{port}{path}") break except requests.exceptions.ConnectionError: time.sleep(0.2) responses.append(resp) get("/") # ensure the server starts and can reply the '/' path get("/quit") # ask the test server to quit gracefully t = threading.Thread(target=run_tests) t.start() yield port # this is where the caller should start the server listening on "port" # we expect to reach this point because the /quit endpoint should have been called, # thus the server, executed in the caller's context manager's body, should now # return t.join() # check the GET requests we made in the thread have expected results assert len(responses) == 2 assert responses[0].status_code == 200 assert "Software Heritage storage server" in responses[0].text assert responses[1].status_code == 200 assert responses[1].text == "Bye" def test_rpc_serve(): with check_rpc_serve() as port: invoke("rpc-serve", "--host", "127.0.0.1", "--port", port) def test_rpc_serve_bwcompat(): def invoke(*args, env=None): config = copy.deepcopy(CLI_CONFIG) runner = CliRunner() with tempfile.NamedTemporaryFile("a", suffix=".yml") as config_fd: yaml.dump(config, config_fd) config_fd.seek(0) args = list(args) + [config_fd.name] return runner.invoke(cli, args, obj={"log_level": logging.DEBUG}, env=env,) with check_rpc_serve() as port: invoke("rpc-serve", "--host", "127.0.0.1", "--port", port) + + +def test_replay( + storage, kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, +): + kafka_prefix += ".swh.journal.objects" + + producer = Producer( + { + "bootstrap.servers": kafka_server, + "client.id": "test-producer", + "acks": "all", + } + ) + + snapshot = { + "id": b"foo", + "branches": {b"HEAD": {"target_type": "revision", "target": b"\x01" * 20,}}, + } # type: Dict[str, Any] + producer.produce( + topic=kafka_prefix + ".snapshot", + key=key_to_kafka(snapshot["id"]), + value=value_to_kafka(snapshot), + ) + producer.flush() + + logger.debug("Flushed producer") + + result = invoke( + "replay", + "--stop-after-objects", + "1", + journal_config={ + "brokers": [kafka_server], + "group_id": kafka_consumer_group, + "prefix": kafka_prefix, + }, + ) + + expected = r"Done.\n" + assert result.exit_code == 0, result.output + assert re.fullmatch(expected, result.output, re.MULTILINE), result.output + + assert storage.snapshot_get(snapshot["id"]) == {**snapshot, "next_branch": None} diff --git a/swh/storage/tests/test_kafka_writer.py b/swh/storage/tests/test_kafka_writer.py new file mode 100644 index 00000000..e48ab32e --- /dev/null +++ b/swh/storage/tests/test_kafka_writer.py @@ -0,0 +1,60 @@ +# Copyright (C) 2018-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from confluent_kafka import Consumer + +from swh.storage import get_storage +from swh.model.model import Origin, OriginVisit +from swh.journal.pytest_plugin import consume_messages, assert_all_objects_consumed +from swh.journal.tests.journal_data import TEST_OBJECTS + + +def test_storage_direct_writer(kafka_prefix: str, kafka_server, consumer: Consumer): + kafka_prefix += ".swh.journal.objects" + + writer_config = { + "cls": "kafka", + "brokers": [kafka_server], + "client_id": "kafka_writer", + "prefix": kafka_prefix, + } + storage_config = { + "cls": "pipeline", + "steps": [{"cls": "memory", "journal_writer": writer_config},], + } + + storage = get_storage(**storage_config) + + expected_messages = 0 + + for object_type, objects in TEST_OBJECTS.items(): + method = getattr(storage, object_type + "_add") + if object_type in ( + "content", + "directory", + "revision", + "release", + "snapshot", + "origin", + ): + method(objects) + expected_messages += len(objects) + elif object_type in ("origin_visit",): + for obj in objects: + assert isinstance(obj, OriginVisit) + storage.origin_add_one(Origin(url=obj.origin)) + visit = method(obj.origin, date=obj.date, type=obj.type) + expected_messages += 1 + + obj_d = obj.to_dict() + for k in ("visit", "origin", "date", "type"): + del obj_d[k] + storage.origin_visit_update(obj.origin, visit.visit, **obj_d) + expected_messages += 1 + else: + assert False, object_type + + consumed_messages = consume_messages(consumer, kafka_prefix, expected_messages) + assert_all_objects_consumed(consumed_messages) diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py new file mode 100644 index 00000000..48be2397 --- /dev/null +++ b/swh/storage/tests/test_replay.py @@ -0,0 +1,388 @@ +# Copyright (C) 2019-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import datetime +import functools +import random +import logging +import dateutil + +from typing import Dict, List + +from confluent_kafka import Producer + +import pytest + +from swh.model.hashutil import hash_to_hex +from swh.model.model import Content + +from swh.storage import get_storage +from swh.storage.replay import process_replay_objects + +from swh.journal.serializers import key_to_kafka, value_to_kafka +from swh.journal.client import JournalClient + +from swh.journal.tests.utils import MockedJournalClient, MockedKafkaWriter +from swh.journal.tests.conftest import ( + TEST_OBJECT_DICTS, + DUPLICATE_CONTENTS, +) + + +storage_config = {"cls": "pipeline", "steps": [{"cls": "memory"},]} + + +def test_storage_play( + kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, caplog, +): + """Optimal replayer scenario. + + This: + - writes objects to the topic + - replayer consumes objects from the topic and replay them + + """ + kafka_prefix += ".swh.journal.objects" + + storage = get_storage(**storage_config) + + producer = Producer( + { + "bootstrap.servers": kafka_server, + "client.id": "test producer", + "acks": "all", + } + ) + + now = datetime.datetime.now(tz=datetime.timezone.utc) + + # Fill Kafka + nb_sent = 0 + nb_visits = 0 + for object_type, objects in TEST_OBJECT_DICTS.items(): + topic = f"{kafka_prefix}.{object_type}" + for object_ in objects: + key = bytes(random.randint(0, 255) for _ in range(40)) + object_ = object_.copy() + if object_type == "content": + object_["ctime"] = now + elif object_type == "origin_visit": + nb_visits += 1 + object_["visit"] = nb_visits + producer.produce( + topic=topic, key=key_to_kafka(key), value=value_to_kafka(object_), + ) + nb_sent += 1 + + producer.flush() + + caplog.set_level(logging.ERROR, "swh.journal.replay") + # Fill the storage from Kafka + replayer = JournalClient( + brokers=kafka_server, + group_id=kafka_consumer_group, + prefix=kafka_prefix, + stop_after_objects=nb_sent, + ) + worker_fn = functools.partial(process_replay_objects, storage=storage) + nb_inserted = 0 + while nb_inserted < nb_sent: + nb_inserted += replayer.process(worker_fn) + assert nb_sent == nb_inserted + + # Check the objects were actually inserted in the storage + assert TEST_OBJECT_DICTS["revision"] == list( + storage.revision_get([rev["id"] for rev in TEST_OBJECT_DICTS["revision"]]) + ) + assert TEST_OBJECT_DICTS["release"] == list( + storage.release_get([rel["id"] for rel in TEST_OBJECT_DICTS["release"]]) + ) + + origins = list(storage.origin_get([orig for orig in TEST_OBJECT_DICTS["origin"]])) + assert TEST_OBJECT_DICTS["origin"] == [{"url": orig["url"]} for orig in origins] + for origin in origins: + origin_url = origin["url"] + expected_visits = [ + { + **visit, + "origin": origin_url, + "date": dateutil.parser.parse(visit["date"]), + } + for visit in TEST_OBJECT_DICTS["origin_visit"] + if visit["origin"] == origin["url"] + ] + actual_visits = list(storage.origin_visit_get(origin_url)) + for visit in actual_visits: + del visit["visit"] # opaque identifier + assert expected_visits == actual_visits + + input_contents = TEST_OBJECT_DICTS["content"] + contents = storage.content_get_metadata([cont["sha1"] for cont in input_contents]) + assert len(contents) == len(input_contents) + assert contents == {cont["sha1"]: [cont] for cont in input_contents} + + collision = 0 + for record in caplog.records: + logtext = record.getMessage() + if "Colliding contents:" in logtext: + collision += 1 + + assert collision == 0, "No collision should be detected" + + +def test_storage_play_with_collision( + kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, caplog, +): + """Another replayer scenario with collisions. + + This: + - writes objects to the topic, including colliding contents + - replayer consumes objects from the topic and replay them + - This drops the colliding contents from the replay when detected + + """ + kafka_prefix += ".swh.journal.objects" + + storage = get_storage(**storage_config) + + producer = Producer( + { + "bootstrap.servers": kafka_server, + "client.id": "test producer", + "enable.idempotence": "true", + } + ) + + now = datetime.datetime.now(tz=datetime.timezone.utc) + + # Fill Kafka + nb_sent = 0 + nb_visits = 0 + for object_type, objects in TEST_OBJECT_DICTS.items(): + topic = f"{kafka_prefix}.{object_type}" + for object_ in objects: + key = bytes(random.randint(0, 255) for _ in range(40)) + object_ = object_.copy() + if object_type == "content": + object_["ctime"] = now + elif object_type == "origin_visit": + nb_visits += 1 + object_["visit"] = nb_visits + producer.produce( + topic=topic, key=key_to_kafka(key), value=value_to_kafka(object_), + ) + nb_sent += 1 + + # Create collision in input data + # They are not written in the destination + for content in DUPLICATE_CONTENTS: + topic = f"{kafka_prefix}.content" + producer.produce( + topic=topic, key=key_to_kafka(key), value=value_to_kafka(content), + ) + + nb_sent += 1 + + producer.flush() + + caplog.set_level(logging.ERROR, "swh.journal.replay") + # Fill the storage from Kafka + replayer = JournalClient( + brokers=kafka_server, + group_id=kafka_consumer_group, + prefix=kafka_prefix, + stop_after_objects=nb_sent, + ) + worker_fn = functools.partial(process_replay_objects, storage=storage) + nb_inserted = 0 + while nb_inserted < nb_sent: + nb_inserted += replayer.process(worker_fn) + assert nb_sent == nb_inserted + + # Check the objects were actually inserted in the storage + assert TEST_OBJECT_DICTS["revision"] == list( + storage.revision_get([rev["id"] for rev in TEST_OBJECT_DICTS["revision"]]) + ) + assert TEST_OBJECT_DICTS["release"] == list( + storage.release_get([rel["id"] for rel in TEST_OBJECT_DICTS["release"]]) + ) + + origins = list(storage.origin_get([orig for orig in TEST_OBJECT_DICTS["origin"]])) + assert TEST_OBJECT_DICTS["origin"] == [{"url": orig["url"]} for orig in origins] + for origin in origins: + origin_url = origin["url"] + expected_visits = [ + { + **visit, + "origin": origin_url, + "date": dateutil.parser.parse(visit["date"]), + } + for visit in TEST_OBJECT_DICTS["origin_visit"] + if visit["origin"] == origin["url"] + ] + actual_visits = list(storage.origin_visit_get(origin_url)) + for visit in actual_visits: + del visit["visit"] # opaque identifier + assert expected_visits == actual_visits + + input_contents = TEST_OBJECT_DICTS["content"] + contents = storage.content_get_metadata([cont["sha1"] for cont in input_contents]) + assert len(contents) == len(input_contents) + assert contents == {cont["sha1"]: [cont] for cont in input_contents} + + nb_collisions = 0 + + actual_collision: Dict + for record in caplog.records: + logtext = record.getMessage() + if "Collision detected:" in logtext: + nb_collisions += 1 + actual_collision = record.args["collision"] + + assert nb_collisions == 1, "1 collision should be detected" + + algo = "sha1" + assert actual_collision["algo"] == algo + expected_colliding_hash = hash_to_hex(DUPLICATE_CONTENTS[0][algo]) + assert actual_collision["hash"] == expected_colliding_hash + + actual_colliding_hashes = actual_collision["objects"] + assert len(actual_colliding_hashes) == len(DUPLICATE_CONTENTS) + for content in DUPLICATE_CONTENTS: + expected_content_hashes = { + k: hash_to_hex(v) for k, v in Content.from_dict(content).hashes().items() + } + assert expected_content_hashes in actual_colliding_hashes + + +def _test_write_replay_origin_visit(visits: List[Dict]): + """Helper function to write tests for origin_visit. + + Each visit (a dict) given in the 'visits' argument will be sent to + a (mocked) kafka queue, which a in-memory-storage backed replayer is + listening to. + + Check that corresponding origin visits entities are present in the storage + and have correct values if they are not skipped. + + """ + queue: List = [] + replayer = MockedJournalClient(queue) + writer = MockedKafkaWriter(queue) + + # Note that flipping the order of these two insertions will crash + # the test, because the legacy origin_format does not allow to create + # the origin when needed (type is missing) + writer.send( + "origin", + "foo", + { + "url": "http://example.com/", + "type": "git", # test the legacy origin format is accepted + }, + ) + for visit in visits: + writer.send("origin_visit", "foo", visit) + + queue_size = len(queue) + assert replayer.stop_after_objects is None + replayer.stop_after_objects = queue_size + + storage = get_storage(**storage_config) + worker_fn = functools.partial(process_replay_objects, storage=storage) + + replayer.process(worker_fn) + + actual_visits = list(storage.origin_visit_get("http://example.com/")) + + assert len(actual_visits) == len(visits), actual_visits + + for vin, vout in zip(visits, actual_visits): + vin = vin.copy() + vout = vout.copy() + assert vout.pop("origin") == "http://example.com/" + vin.pop("origin") + vin.setdefault("type", "git") + vin.setdefault("metadata", None) + assert vin == vout + + +def test_write_replay_origin_visit(): + """Test origin_visit when the 'origin' is just a string.""" + now = datetime.datetime.now() + visits = [ + { + "visit": 1, + "origin": "http://example.com/", + "date": now, + "type": "git", + "status": "partial", + "snapshot": None, + } + ] + _test_write_replay_origin_visit(visits) + + +def test_write_replay_legacy_origin_visit1(): + """Origin_visit with no types should make the replayer crash + + We expect the journal's origin_visit topic to no longer reference such + visits. If it does, the replayer must crash so we can fix the journal's + topic. + + """ + now = datetime.datetime.now() + visit = { + "visit": 1, + "origin": "http://example.com/", + "date": now, + "status": "partial", + "snapshot": None, + } + now2 = datetime.datetime.now() + visit2 = { + "visit": 2, + "origin": {"url": "http://example.com/"}, + "date": now2, + "status": "partial", + "snapshot": None, + } + + for origin_visit in [visit, visit2]: + with pytest.raises(ValueError, match="Old origin visit format"): + _test_write_replay_origin_visit([origin_visit]) + + +def test_write_replay_legacy_origin_visit2(): + """Test origin_visit when 'type' is missing from the visit, but not + from the origin.""" + now = datetime.datetime.now() + visits = [ + { + "visit": 1, + "origin": {"url": "http://example.com/", "type": "git",}, + "date": now, + "type": "git", + "status": "partial", + "snapshot": None, + } + ] + _test_write_replay_origin_visit(visits) + + +def test_write_replay_legacy_origin_visit3(): + """Test origin_visit when the origin is a dict""" + now = datetime.datetime.now() + visits = [ + { + "visit": 1, + "origin": {"url": "http://example.com/",}, + "date": now, + "type": "git", + "status": "partial", + "snapshot": None, + } + ] + _test_write_replay_origin_visit(visits) diff --git a/swh/storage/tests/test_write_replay.py b/swh/storage/tests/test_write_replay.py new file mode 100644 index 00000000..c78efe28 --- /dev/null +++ b/swh/storage/tests/test_write_replay.py @@ -0,0 +1,112 @@ +# Copyright (C) 2019-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import functools +from unittest.mock import patch + +import attr +from hypothesis import given, settings, HealthCheck +from hypothesis.strategies import lists + +from swh.model.hypothesis_strategies import objects +from swh.model.model import Origin +from swh.storage import get_storage +from swh.storage.exc import HashCollision + +from swh.storage.replay import process_replay_objects + +from swh.journal.tests.utils import MockedJournalClient, MockedKafkaWriter + + +storage_config = { + "cls": "memory", + "journal_writer": {"cls": "memory"}, +} + + +def empty_person_name_email(rev_or_rel): + """Empties the 'name' and 'email' fields of the author/committer fields + of a revision or release; leaving only the fullname.""" + if getattr(rev_or_rel, "author", None): + rev_or_rel = attr.evolve( + rev_or_rel, author=attr.evolve(rev_or_rel.author, name=b"", email=b"",) + ) + + if getattr(rev_or_rel, "committer", None): + rev_or_rel = attr.evolve( + rev_or_rel, + committer=attr.evolve(rev_or_rel.committer, name=b"", email=b"",), + ) + + return rev_or_rel + + +@given(lists(objects(blacklist_types=("origin_visit_status",)), min_size=1)) +@settings(suppress_health_check=[HealthCheck.too_slow]) +def test_write_replay_same_order_batches(objects): + queue = [] + replayer = MockedJournalClient(queue) + + with patch( + "swh.journal.writer.inmemory.InMemoryJournalWriter", + return_value=MockedKafkaWriter(queue), + ): + storage1 = get_storage(**storage_config) + + # Write objects to storage1 + for (obj_type, obj) in objects: + if obj_type == "content" and obj.status == "absent": + obj_type = "skipped_content" + + if obj_type == "origin_visit": + storage1.origin_add_one(Origin(url=obj.origin)) + storage1.origin_visit_upsert([obj]) + else: + method = getattr(storage1, obj_type + "_add") + try: + method([obj]) + except HashCollision: + pass + + # Bail out early if we didn't insert any relevant objects... + queue_size = len(queue) + assert queue_size != 0, "No test objects found; hypothesis strategy bug?" + + assert replayer.stop_after_objects is None + replayer.stop_after_objects = queue_size + + storage2 = get_storage(**storage_config) + worker_fn = functools.partial(process_replay_objects, storage=storage2) + + replayer.process(worker_fn) + + assert replayer.consumer.committed + + for attr_name in ( + "_contents", + "_directories", + "_snapshots", + "_origin_visits", + "_origins", + ): + assert getattr(storage1, attr_name) == getattr(storage2, attr_name), attr_name + + # When hypothesis generates a revision and a release with same + # author (or committer) fullname but different name or email, then + # the storage will use the first name/email it sees. + # This first one will be either the one from the revision or the release, + # and since there is no order guarantees, storage2 has 1/2 chance of + # not seeing the same order as storage1, therefore we need to strip + # them out before comparing. + for attr_name in ("_revisions", "_releases"): + items1 = { + k: empty_person_name_email(v) + for (k, v) in getattr(storage1, attr_name).items() + } + items2 = { + k: empty_person_name_email(v) + for (k, v) in getattr(storage2, attr_name).items() + } + assert items1 == items2, attr_name