diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -14,6 +14,9 @@ [mypy-cassandra.*] ignore_missing_imports = True +[mypy-confluent_kafka.*] +ignore_missing_imports = True + # only shipped indirectly via hypothesis [mypy-django.*] ignore_missing_imports = True @@ -27,6 +30,9 @@ [mypy-pytest.*] ignore_missing_imports = True +[mypy-pytest_kafka.*] +ignore_missing_imports = True + [mypy-systemd.daemon.*] ignore_missing_imports = True diff --git a/swh/storage/cli.py b/swh/storage/cli.py --- a/swh/storage/cli.py +++ b/swh/storage/cli.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import functools import logging import os import warnings @@ -11,8 +12,15 @@ from swh.core import config from swh.core.cli import CONTEXT_SETTINGS +from swh.journal.cli import get_journal_client +from swh.storage import get_storage from swh.storage.api.server import load_and_check_config, app +try: + from systemd.daemon import notify +except ImportError: + notify = None + @click.group(name="storage", context_settings=CONTEXT_SETTINGS) @click.option( @@ -135,6 +143,47 @@ ctx.exit(0) +@storage.command() +@click.option( + "--stop-after-objects", + "-n", + default=None, + type=int, + help="Stop after processing this many objects. Default is to " "run forever.", +) +@click.pass_context +def replay(ctx, stop_after_objects): + """Fill a Storage by reading a Journal. + + There can be several 'replayers' filling a Storage as long as they use + the same `group-id`. + """ + from swh.storage.replay import process_replay_objects + + conf = ctx.obj["config"] + try: + storage = get_storage(**conf.pop("storage")) + except KeyError: + ctx.fail("You must have a storage configured in your config file.") + + client = get_journal_client(ctx, stop_after_objects=stop_after_objects) + worker_fn = functools.partial(process_replay_objects, storage=storage) + + if notify: + notify("READY=1") + + try: + client.process(worker_fn) + except KeyboardInterrupt: + ctx.exit(0) + else: + print("Done.") + finally: + if notify: + notify("STOPPING=1") + client.close() + + def main(): logging.basicConfig() return serve(auto_envvar_prefix="SWH_STORAGE") diff --git a/swh/storage/fixer.py b/swh/storage/fixer.py new file mode 100644 --- /dev/null +++ b/swh/storage/fixer.py @@ -0,0 +1,293 @@ +import copy +import logging +from typing import Any, Dict, List, Optional +from swh.model.identifiers import normalize_timestamp + +logger = logging.getLogger(__name__) + + +def _fix_content(content: Dict[str, Any]) -> Dict[str, Any]: + """Filters-out invalid 'perms' key that leaked from swh.model.from_disk + to the journal. + + >>> _fix_content({'perms': 0o100644, 'sha1_git': b'foo'}) + {'sha1_git': b'foo'} + + >>> _fix_content({'sha1_git': b'bar'}) + {'sha1_git': b'bar'} + + """ + content = content.copy() + content.pop("perms", None) + return content + + +def _fix_revision_pypi_empty_string(rev): + """PyPI loader failed to encode empty strings as bytes, see: + swh:1:rev:8f0095ee0664867055d03de9bcc8f95b91d8a2b9 + or https://forge.softwareheritage.org/D1772 + """ + rev = { + **rev, + "author": rev["author"].copy(), + "committer": rev["committer"].copy(), + } + if rev["author"].get("email") == "": + rev["author"]["email"] = b"" + if rev["author"].get("name") == "": + rev["author"]["name"] = b"" + if rev["committer"].get("email") == "": + rev["committer"]["email"] = b"" + if rev["committer"].get("name") == "": + rev["committer"]["name"] = b"" + return rev + + +def _fix_revision_transplant_source(rev): + if rev.get("metadata") and rev["metadata"].get("extra_headers"): + rev = copy.deepcopy(rev) + rev["metadata"]["extra_headers"] = [ + [key, value.encode("ascii")] + if key == "transplant_source" and isinstance(value, str) + else [key, value] + for (key, value) in rev["metadata"]["extra_headers"] + ] + return rev + + +def _check_date(date): + """Returns whether the date can be represented in backends with sane + limits on timestamps and timezones (resp. signed 64-bits and + signed 16 bits), and that microseconds is valid (ie. between 0 and 10^6). + """ + if date is None: + return True + date = normalize_timestamp(date) + return ( + (-(2 ** 63) <= date["timestamp"]["seconds"] < 2 ** 63) + and (0 <= date["timestamp"]["microseconds"] < 10 ** 6) + and (-(2 ** 15) <= date["offset"] < 2 ** 15) + ) + + +def _check_revision_date(rev): + """Exclude revisions with invalid dates. + See https://forge.softwareheritage.org/T1339""" + return _check_date(rev["date"]) and _check_date(rev["committer_date"]) + + +def _fix_revision(revision: Dict[str, Any]) -> Optional[Dict]: + """Fix various legacy revision issues. + + Fix author/committer person: + + >>> from pprint import pprint + >>> date = { + ... 'timestamp': { + ... 'seconds': 1565096932, + ... 'microseconds': 0, + ... }, + ... 'offset': 0, + ... } + >>> rev0 = _fix_revision({ + ... 'id': b'rev-id', + ... 'author': {'fullname': b'', 'name': '', 'email': ''}, + ... 'committer': {'fullname': b'', 'name': '', 'email': ''}, + ... 'date': date, + ... 'committer_date': date, + ... 'type': 'git', + ... 'message': '', + ... 'directory': b'dir-id', + ... 'synthetic': False, + ... }) + >>> rev0['author'] + {'fullname': b'', 'name': b'', 'email': b''} + >>> rev0['committer'] + {'fullname': b'', 'name': b'', 'email': b''} + + Fix type of 'transplant_source' extra headers: + + >>> rev1 = _fix_revision({ + ... 'id': b'rev-id', + ... 'author': {'fullname': b'', 'name': '', 'email': ''}, + ... 'committer': {'fullname': b'', 'name': '', 'email': ''}, + ... 'date': date, + ... 'committer_date': date, + ... 'metadata': { + ... 'extra_headers': [ + ... ['time_offset_seconds', b'-3600'], + ... ['transplant_source', '29c154a012a70f49df983625090434587622b39e'] + ... ]}, + ... 'type': 'git', + ... 'message': '', + ... 'directory': b'dir-id', + ... 'synthetic': False, + ... }) + >>> pprint(rev1['metadata']['extra_headers']) + [['time_offset_seconds', b'-3600'], + ['transplant_source', b'29c154a012a70f49df983625090434587622b39e']] + + Revision with invalid date are filtered: + + >>> from copy import deepcopy + >>> invalid_date1 = deepcopy(date) + >>> invalid_date1['timestamp']['microseconds'] = 1000000000 # > 10^6 + >>> rev = _fix_revision({ + ... 'author': {'fullname': b'', 'name': '', 'email': ''}, + ... 'committer': {'fullname': b'', 'name': '', 'email': ''}, + ... 'date': invalid_date1, + ... 'committer_date': date, + ... }) + >>> rev is None + True + + >>> invalid_date2 = deepcopy(date) + >>> invalid_date2['timestamp']['seconds'] = 2**70 # > 10^63 + >>> rev = _fix_revision({ + ... 'author': {'fullname': b'', 'name': '', 'email': ''}, + ... 'committer': {'fullname': b'', 'name': '', 'email': ''}, + ... 'date': invalid_date2, + ... 'committer_date': date, + ... }) + >>> rev is None + True + + >>> invalid_date3 = deepcopy(date) + >>> invalid_date3['offset'] = 2**20 # > 10^15 + >>> rev = _fix_revision({ + ... 'author': {'fullname': b'', 'name': '', 'email': ''}, + ... 'committer': {'fullname': b'', 'name': '', 'email': ''}, + ... 'date': date, + ... 'committer_date': invalid_date3, + ... }) + >>> rev is None + True + + """ # noqa + rev = _fix_revision_pypi_empty_string(revision) + rev = _fix_revision_transplant_source(rev) + if not _check_revision_date(rev): + logger.warning( + "Invalid revision date detected: %(revision)s", {"revision": rev} + ) + return None + return rev + + +def _fix_origin(origin: Dict) -> Dict: + """Fix legacy origin with type which is no longer part of the model. + + >>> from pprint import pprint + >>> pprint(_fix_origin({ + ... 'url': 'http://foo', + ... })) + {'url': 'http://foo'} + >>> pprint(_fix_origin({ + ... 'url': 'http://bar', + ... 'type': 'foo', + ... })) + {'url': 'http://bar'} + + """ + o = origin.copy() + o.pop("type", None) + return o + + +def _fix_origin_visit(visit: Dict) -> Dict: + """Fix various legacy origin visit issues. + + `visit['origin']` is a dict instead of an URL: + + >>> from datetime import datetime, timezone + >>> from pprint import pprint + >>> date = datetime(2020, 2, 27, 14, 39, 19, tzinfo=timezone.utc) + >>> pprint(_fix_origin_visit({ + ... 'origin': {'url': 'http://foo'}, + ... 'date': date, + ... 'type': 'git', + ... 'status': 'ongoing', + ... 'snapshot': None, + ... })) + {'date': datetime.datetime(2020, 2, 27, 14, 39, 19, tzinfo=datetime.timezone.utc), + 'metadata': None, + 'origin': 'http://foo', + 'snapshot': None, + 'status': 'ongoing', + 'type': 'git'} + + `visit['type']` is missing , but `origin['visit']['type']` exists: + + >>> pprint(_fix_origin_visit( + ... {'origin': {'type': 'hg', 'url': 'http://foo'}, + ... 'date': date, + ... 'status': 'ongoing', + ... 'snapshot': None, + ... })) + {'date': datetime.datetime(2020, 2, 27, 14, 39, 19, tzinfo=datetime.timezone.utc), + 'metadata': None, + 'origin': 'http://foo', + 'snapshot': None, + 'status': 'ongoing', + 'type': 'hg'} + + Old visit format (origin_visit with no type) raises: + + >>> _fix_origin_visit({ + ... 'origin': {'url': 'http://foo'}, + ... 'date': date, + ... 'status': 'ongoing', + ... 'snapshot': None + ... }) + Traceback (most recent call last): + ... + ValueError: Old origin visit format detected... + + >>> _fix_origin_visit({ + ... 'origin': 'http://foo', + ... 'date': date, + ... 'status': 'ongoing', + ... 'snapshot': None + ... }) + Traceback (most recent call last): + ... + ValueError: Old origin visit format detected... + + """ # noqa + visit = visit.copy() + if "type" not in visit: + if isinstance(visit["origin"], dict) and "type" in visit["origin"]: + # Very old version of the schema: visits did not have a type, + # but their 'origin' field was a dict with a 'type' key. + visit["type"] = visit["origin"]["type"] + else: + # Very old schema version: 'type' is missing, stop early + + # We expect the journal's origin_visit topic to no longer reference + # such visits. If it does, the replayer must crash so we can fix + # the journal's topic. + raise ValueError(f"Old origin visit format detected: {visit}") + if isinstance(visit["origin"], dict): + # Old version of the schema: visit['origin'] was a dict. + visit["origin"] = visit["origin"]["url"] + if "metadata" not in visit: + visit["metadata"] = None + return visit + + +def fix_objects(object_type: str, objects: List[Dict]) -> List[Dict]: + """ + Fix legacy objects from the journal to bring them up to date with the + latest storage schema. + """ + if object_type == "content": + return [_fix_content(v) for v in objects] + elif object_type == "revision": + revisions = [_fix_revision(v) for v in objects] + return [rev for rev in revisions if rev is not None] + elif object_type == "origin": + return [_fix_origin(v) for v in objects] + elif object_type == "origin_visit": + return [_fix_origin_visit(v) for v in objects] + else: + return objects diff --git a/swh/storage/replay.py b/swh/storage/replay.py new file mode 100644 --- /dev/null +++ b/swh/storage/replay.py @@ -0,0 +1,128 @@ +# Copyright (C) 2019-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import logging +from typing import Any, Callable, Dict, Iterable, List + +try: + from systemd.daemon import notify +except ImportError: + notify = None + +from swh.core.statsd import statsd +from swh.storage.fixer import fix_objects + +from swh.model.model import ( + BaseContent, + BaseModel, + Content, + Directory, + Origin, + OriginVisit, + Revision, + SkippedContent, + Snapshot, + Release, +) +from swh.storage.exc import HashCollision + +logger = logging.getLogger(__name__) + +GRAPH_OPERATIONS_METRIC = "swh_graph_replayer_operations_total" +GRAPH_DURATION_METRIC = "swh_graph_replayer_duration_seconds" + + +object_converter_fn: Dict[str, Callable[[Dict], BaseModel]] = { + "origin": Origin.from_dict, + "origin_visit": OriginVisit.from_dict, + "snapshot": Snapshot.from_dict, + "revision": Revision.from_dict, + "release": Release.from_dict, + "directory": Directory.from_dict, + "content": Content.from_dict, + "skipped_content": SkippedContent.from_dict, +} + + +def process_replay_objects(all_objects, *, storage): + for (object_type, objects) in all_objects.items(): + logger.debug("Inserting %s %s objects", len(objects), object_type) + with statsd.timed(GRAPH_DURATION_METRIC, tags={"object_type": object_type}): + _insert_objects(object_type, objects, storage) + statsd.increment( + GRAPH_OPERATIONS_METRIC, len(objects), tags={"object_type": object_type} + ) + if notify: + notify("WATCHDOG=1") + + +def collision_aware_content_add( + content_add_fn: Callable[[Iterable[Any]], None], contents: List[BaseContent] +) -> None: + """Add contents to storage. If a hash collision is detected, an error is + logged. Then this adds the other non colliding contents to the storage. + + Args: + content_add_fn: Storage content callable + contents: List of contents or skipped contents to add to storage + + """ + if not contents: + return + colliding_content_hashes: List[Dict[str, Any]] = [] + while True: + try: + content_add_fn(contents) + except HashCollision as e: + colliding_content_hashes.append( + { + "algo": e.algo, + "hash": e.hash_id, # hex hash id + "objects": e.colliding_contents, # hex hashes + } + ) + colliding_hashes = e.colliding_content_hashes() + # Drop the colliding contents from the transaction + contents = [c for c in contents if c.hashes() not in colliding_hashes] + else: + # Successfully added contents, we are done + break + if colliding_content_hashes: + for collision in colliding_content_hashes: + logger.error("Collision detected: %(collision)s", {"collision": collision}) + + +def _insert_objects(object_type: str, objects: List[Dict], storage) -> None: + """Insert objects of type object_type in the storage. + + """ + objects = fix_objects(object_type, objects) + + if object_type == "content": + contents: List[BaseContent] = [] + skipped_contents: List[BaseContent] = [] + for content in objects: + c = BaseContent.from_dict(content) + if isinstance(c, SkippedContent): + skipped_contents.append(c) + else: + contents.append(c) + + collision_aware_content_add(storage.skipped_content_add, skipped_contents) + collision_aware_content_add(storage.content_add_metadata, contents) + elif object_type == "origin_visit": + visits: List[OriginVisit] = [] + origins: List[Origin] = [] + for obj in objects: + visit = OriginVisit.from_dict(obj) + visits.append(visit) + origins.append(Origin(url=visit.origin)) + storage.origin_add(origins) + storage.origin_visit_upsert(visits) + elif object_type in ("directory", "revision", "release", "snapshot", "origin"): + method = getattr(storage, object_type + "_add") + method(object_converter_fn[object_type](o) for o in objects) + else: + logger.warning("Received a series of %s, this should not happen", object_type) diff --git a/swh/storage/tests/conftest.py b/swh/storage/tests/conftest.py --- a/swh/storage/tests/conftest.py +++ b/swh/storage/tests/conftest.py @@ -18,7 +18,6 @@ import swh.storage from swh.core.utils import numfile_sortkey as sortkey - from swh.model.tests.generate_testdata import gen_contents, gen_origins from swh.model.model import ( Content, @@ -43,7 +42,6 @@ "snapshot": Snapshot.from_dict, } - SQL_DIR = path.join(path.dirname(swh.storage.__file__), "sql") environ["LC_ALL"] = "C.UTF-8" diff --git a/swh/storage/tests/test_cli.py b/swh/storage/tests/test_cli.py --- a/swh/storage/tests/test_cli.py +++ b/swh/storage/tests/test_cli.py @@ -5,18 +5,27 @@ import copy import logging + + import tempfile +import re import requests import threading import time +import yaml from contextlib import contextmanager -from aiohttp.test_utils import unused_port +from typing import Any, Dict +from unittest.mock import patch -import yaml +import pytest +from aiohttp.test_utils import unused_port from click.testing import CliRunner +from confluent_kafka import Producer +from swh.journal.serializers import key_to_kafka, value_to_kafka +from swh.storage import get_storage from swh.storage.cli import storage as cli @@ -28,8 +37,30 @@ } -def invoke(*args, env=None): +@pytest.fixture +def storage(): + """An swh-storage object that gets injected into the CLI functions.""" + storage_config = {"cls": "pipeline", "steps": [{"cls": "memory"},]} + storage = get_storage(**storage_config) + with patch("swh.storage.cli.get_storage") as get_storage_mock: + get_storage_mock.return_value = storage + yield storage + + +@pytest.fixture +def monkeypatch_retry_sleep(monkeypatch): + from swh.journal.replay import copy_object, obj_in_objstorage + + monkeypatch.setattr(copy_object.retry, "sleep", lambda x: None) + monkeypatch.setattr(obj_in_objstorage.retry, "sleep", lambda x: None) + + +def invoke(*args, env=None, journal_config=None): config = copy.deepcopy(CLI_CONFIG) + if journal_config: + config["journal_client"] = journal_config.copy() + config["journal_client"]["cls"] = "kafka" + runner = CliRunner() with tempfile.NamedTemporaryFile("a", suffix=".yml") as config_fd: yaml.dump(config, config_fd) @@ -117,3 +148,47 @@ with check_rpc_serve() as port: invoke("rpc-serve", "--host", "127.0.0.1", "--port", port) + + +def test_replay( + storage, kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, +): + kafka_prefix += ".swh.journal.objects" + + producer = Producer( + { + "bootstrap.servers": kafka_server, + "client.id": "test-producer", + "acks": "all", + } + ) + + snapshot = { + "id": b"foo", + "branches": {b"HEAD": {"target_type": "revision", "target": b"\x01" * 20,}}, + } # type: Dict[str, Any] + producer.produce( + topic=kafka_prefix + ".snapshot", + key=key_to_kafka(snapshot["id"]), + value=value_to_kafka(snapshot), + ) + producer.flush() + + logger.debug("Flushed producer") + + result = invoke( + "replay", + "--stop-after-objects", + "1", + journal_config={ + "brokers": [kafka_server], + "group_id": kafka_consumer_group, + "prefix": kafka_prefix, + }, + ) + + expected = r"Done.\n" + assert result.exit_code == 0, result.output + assert re.fullmatch(expected, result.output, re.MULTILINE), result.output + + assert storage.snapshot_get(snapshot["id"]) == {**snapshot, "next_branch": None} diff --git a/swh/storage/tests/test_kafka_writer.py b/swh/storage/tests/test_kafka_writer.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_kafka_writer.py @@ -0,0 +1,60 @@ +# Copyright (C) 2018-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from confluent_kafka import Consumer + +from swh.storage import get_storage +from swh.model.model import Origin, OriginVisit +from swh.journal.pytest_plugin import consume_messages, assert_all_objects_consumed +from swh.journal.tests.journal_data import TEST_OBJECTS + + +def test_storage_direct_writer(kafka_prefix: str, kafka_server, consumer: Consumer): + kafka_prefix += ".swh.journal.objects" + + writer_config = { + "cls": "kafka", + "brokers": [kafka_server], + "client_id": "kafka_writer", + "prefix": kafka_prefix, + } + storage_config = { + "cls": "pipeline", + "steps": [{"cls": "memory", "journal_writer": writer_config},], + } + + storage = get_storage(**storage_config) + + expected_messages = 0 + + for object_type, objects in TEST_OBJECTS.items(): + method = getattr(storage, object_type + "_add") + if object_type in ( + "content", + "directory", + "revision", + "release", + "snapshot", + "origin", + ): + method(objects) + expected_messages += len(objects) + elif object_type in ("origin_visit",): + for obj in objects: + assert isinstance(obj, OriginVisit) + storage.origin_add_one(Origin(url=obj.origin)) + visit = method(obj.origin, date=obj.date, type=obj.type) + expected_messages += 1 + + obj_d = obj.to_dict() + for k in ("visit", "origin", "date", "type"): + del obj_d[k] + storage.origin_visit_update(obj.origin, visit.visit, **obj_d) + expected_messages += 1 + else: + assert False, object_type + + consumed_messages = consume_messages(consumer, kafka_prefix, expected_messages) + assert_all_objects_consumed(consumed_messages) diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_replay.py @@ -0,0 +1,388 @@ +# Copyright (C) 2019-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import datetime +import functools +import random +import logging +import dateutil + +from typing import Dict, List + +from confluent_kafka import Producer + +import pytest + +from swh.model.hashutil import hash_to_hex +from swh.model.model import Content + +from swh.storage import get_storage +from swh.storage.replay import process_replay_objects + +from swh.journal.serializers import key_to_kafka, value_to_kafka +from swh.journal.client import JournalClient + +from swh.journal.tests.utils import MockedJournalClient, MockedKafkaWriter +from swh.journal.tests.conftest import ( + TEST_OBJECT_DICTS, + DUPLICATE_CONTENTS, +) + + +storage_config = {"cls": "pipeline", "steps": [{"cls": "memory"},]} + + +def test_storage_play( + kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, caplog, +): + """Optimal replayer scenario. + + This: + - writes objects to the topic + - replayer consumes objects from the topic and replay them + + """ + kafka_prefix += ".swh.journal.objects" + + storage = get_storage(**storage_config) + + producer = Producer( + { + "bootstrap.servers": kafka_server, + "client.id": "test producer", + "acks": "all", + } + ) + + now = datetime.datetime.now(tz=datetime.timezone.utc) + + # Fill Kafka + nb_sent = 0 + nb_visits = 0 + for object_type, objects in TEST_OBJECT_DICTS.items(): + topic = f"{kafka_prefix}.{object_type}" + for object_ in objects: + key = bytes(random.randint(0, 255) for _ in range(40)) + object_ = object_.copy() + if object_type == "content": + object_["ctime"] = now + elif object_type == "origin_visit": + nb_visits += 1 + object_["visit"] = nb_visits + producer.produce( + topic=topic, key=key_to_kafka(key), value=value_to_kafka(object_), + ) + nb_sent += 1 + + producer.flush() + + caplog.set_level(logging.ERROR, "swh.journal.replay") + # Fill the storage from Kafka + replayer = JournalClient( + brokers=kafka_server, + group_id=kafka_consumer_group, + prefix=kafka_prefix, + stop_after_objects=nb_sent, + ) + worker_fn = functools.partial(process_replay_objects, storage=storage) + nb_inserted = 0 + while nb_inserted < nb_sent: + nb_inserted += replayer.process(worker_fn) + assert nb_sent == nb_inserted + + # Check the objects were actually inserted in the storage + assert TEST_OBJECT_DICTS["revision"] == list( + storage.revision_get([rev["id"] for rev in TEST_OBJECT_DICTS["revision"]]) + ) + assert TEST_OBJECT_DICTS["release"] == list( + storage.release_get([rel["id"] for rel in TEST_OBJECT_DICTS["release"]]) + ) + + origins = list(storage.origin_get([orig for orig in TEST_OBJECT_DICTS["origin"]])) + assert TEST_OBJECT_DICTS["origin"] == [{"url": orig["url"]} for orig in origins] + for origin in origins: + origin_url = origin["url"] + expected_visits = [ + { + **visit, + "origin": origin_url, + "date": dateutil.parser.parse(visit["date"]), + } + for visit in TEST_OBJECT_DICTS["origin_visit"] + if visit["origin"] == origin["url"] + ] + actual_visits = list(storage.origin_visit_get(origin_url)) + for visit in actual_visits: + del visit["visit"] # opaque identifier + assert expected_visits == actual_visits + + input_contents = TEST_OBJECT_DICTS["content"] + contents = storage.content_get_metadata([cont["sha1"] for cont in input_contents]) + assert len(contents) == len(input_contents) + assert contents == {cont["sha1"]: [cont] for cont in input_contents} + + collision = 0 + for record in caplog.records: + logtext = record.getMessage() + if "Colliding contents:" in logtext: + collision += 1 + + assert collision == 0, "No collision should be detected" + + +def test_storage_play_with_collision( + kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, caplog, +): + """Another replayer scenario with collisions. + + This: + - writes objects to the topic, including colliding contents + - replayer consumes objects from the topic and replay them + - This drops the colliding contents from the replay when detected + + """ + kafka_prefix += ".swh.journal.objects" + + storage = get_storage(**storage_config) + + producer = Producer( + { + "bootstrap.servers": kafka_server, + "client.id": "test producer", + "enable.idempotence": "true", + } + ) + + now = datetime.datetime.now(tz=datetime.timezone.utc) + + # Fill Kafka + nb_sent = 0 + nb_visits = 0 + for object_type, objects in TEST_OBJECT_DICTS.items(): + topic = f"{kafka_prefix}.{object_type}" + for object_ in objects: + key = bytes(random.randint(0, 255) for _ in range(40)) + object_ = object_.copy() + if object_type == "content": + object_["ctime"] = now + elif object_type == "origin_visit": + nb_visits += 1 + object_["visit"] = nb_visits + producer.produce( + topic=topic, key=key_to_kafka(key), value=value_to_kafka(object_), + ) + nb_sent += 1 + + # Create collision in input data + # They are not written in the destination + for content in DUPLICATE_CONTENTS: + topic = f"{kafka_prefix}.content" + producer.produce( + topic=topic, key=key_to_kafka(key), value=value_to_kafka(content), + ) + + nb_sent += 1 + + producer.flush() + + caplog.set_level(logging.ERROR, "swh.journal.replay") + # Fill the storage from Kafka + replayer = JournalClient( + brokers=kafka_server, + group_id=kafka_consumer_group, + prefix=kafka_prefix, + stop_after_objects=nb_sent, + ) + worker_fn = functools.partial(process_replay_objects, storage=storage) + nb_inserted = 0 + while nb_inserted < nb_sent: + nb_inserted += replayer.process(worker_fn) + assert nb_sent == nb_inserted + + # Check the objects were actually inserted in the storage + assert TEST_OBJECT_DICTS["revision"] == list( + storage.revision_get([rev["id"] for rev in TEST_OBJECT_DICTS["revision"]]) + ) + assert TEST_OBJECT_DICTS["release"] == list( + storage.release_get([rel["id"] for rel in TEST_OBJECT_DICTS["release"]]) + ) + + origins = list(storage.origin_get([orig for orig in TEST_OBJECT_DICTS["origin"]])) + assert TEST_OBJECT_DICTS["origin"] == [{"url": orig["url"]} for orig in origins] + for origin in origins: + origin_url = origin["url"] + expected_visits = [ + { + **visit, + "origin": origin_url, + "date": dateutil.parser.parse(visit["date"]), + } + for visit in TEST_OBJECT_DICTS["origin_visit"] + if visit["origin"] == origin["url"] + ] + actual_visits = list(storage.origin_visit_get(origin_url)) + for visit in actual_visits: + del visit["visit"] # opaque identifier + assert expected_visits == actual_visits + + input_contents = TEST_OBJECT_DICTS["content"] + contents = storage.content_get_metadata([cont["sha1"] for cont in input_contents]) + assert len(contents) == len(input_contents) + assert contents == {cont["sha1"]: [cont] for cont in input_contents} + + nb_collisions = 0 + + actual_collision: Dict + for record in caplog.records: + logtext = record.getMessage() + if "Collision detected:" in logtext: + nb_collisions += 1 + actual_collision = record.args["collision"] + + assert nb_collisions == 1, "1 collision should be detected" + + algo = "sha1" + assert actual_collision["algo"] == algo + expected_colliding_hash = hash_to_hex(DUPLICATE_CONTENTS[0][algo]) + assert actual_collision["hash"] == expected_colliding_hash + + actual_colliding_hashes = actual_collision["objects"] + assert len(actual_colliding_hashes) == len(DUPLICATE_CONTENTS) + for content in DUPLICATE_CONTENTS: + expected_content_hashes = { + k: hash_to_hex(v) for k, v in Content.from_dict(content).hashes().items() + } + assert expected_content_hashes in actual_colliding_hashes + + +def _test_write_replay_origin_visit(visits: List[Dict]): + """Helper function to write tests for origin_visit. + + Each visit (a dict) given in the 'visits' argument will be sent to + a (mocked) kafka queue, which a in-memory-storage backed replayer is + listening to. + + Check that corresponding origin visits entities are present in the storage + and have correct values if they are not skipped. + + """ + queue: List = [] + replayer = MockedJournalClient(queue) + writer = MockedKafkaWriter(queue) + + # Note that flipping the order of these two insertions will crash + # the test, because the legacy origin_format does not allow to create + # the origin when needed (type is missing) + writer.send( + "origin", + "foo", + { + "url": "http://example.com/", + "type": "git", # test the legacy origin format is accepted + }, + ) + for visit in visits: + writer.send("origin_visit", "foo", visit) + + queue_size = len(queue) + assert replayer.stop_after_objects is None + replayer.stop_after_objects = queue_size + + storage = get_storage(**storage_config) + worker_fn = functools.partial(process_replay_objects, storage=storage) + + replayer.process(worker_fn) + + actual_visits = list(storage.origin_visit_get("http://example.com/")) + + assert len(actual_visits) == len(visits), actual_visits + + for vin, vout in zip(visits, actual_visits): + vin = vin.copy() + vout = vout.copy() + assert vout.pop("origin") == "http://example.com/" + vin.pop("origin") + vin.setdefault("type", "git") + vin.setdefault("metadata", None) + assert vin == vout + + +def test_write_replay_origin_visit(): + """Test origin_visit when the 'origin' is just a string.""" + now = datetime.datetime.now() + visits = [ + { + "visit": 1, + "origin": "http://example.com/", + "date": now, + "type": "git", + "status": "partial", + "snapshot": None, + } + ] + _test_write_replay_origin_visit(visits) + + +def test_write_replay_legacy_origin_visit1(): + """Origin_visit with no types should make the replayer crash + + We expect the journal's origin_visit topic to no longer reference such + visits. If it does, the replayer must crash so we can fix the journal's + topic. + + """ + now = datetime.datetime.now() + visit = { + "visit": 1, + "origin": "http://example.com/", + "date": now, + "status": "partial", + "snapshot": None, + } + now2 = datetime.datetime.now() + visit2 = { + "visit": 2, + "origin": {"url": "http://example.com/"}, + "date": now2, + "status": "partial", + "snapshot": None, + } + + for origin_visit in [visit, visit2]: + with pytest.raises(ValueError, match="Old origin visit format"): + _test_write_replay_origin_visit([origin_visit]) + + +def test_write_replay_legacy_origin_visit2(): + """Test origin_visit when 'type' is missing from the visit, but not + from the origin.""" + now = datetime.datetime.now() + visits = [ + { + "visit": 1, + "origin": {"url": "http://example.com/", "type": "git",}, + "date": now, + "type": "git", + "status": "partial", + "snapshot": None, + } + ] + _test_write_replay_origin_visit(visits) + + +def test_write_replay_legacy_origin_visit3(): + """Test origin_visit when the origin is a dict""" + now = datetime.datetime.now() + visits = [ + { + "visit": 1, + "origin": {"url": "http://example.com/",}, + "date": now, + "type": "git", + "status": "partial", + "snapshot": None, + } + ] + _test_write_replay_origin_visit(visits) diff --git a/swh/storage/tests/test_write_replay.py b/swh/storage/tests/test_write_replay.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_write_replay.py @@ -0,0 +1,112 @@ +# Copyright (C) 2019-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import functools +from unittest.mock import patch + +import attr +from hypothesis import given, settings, HealthCheck +from hypothesis.strategies import lists + +from swh.model.hypothesis_strategies import objects +from swh.model.model import Origin +from swh.storage import get_storage +from swh.storage.exc import HashCollision + +from swh.storage.replay import process_replay_objects + +from swh.journal.tests.utils import MockedJournalClient, MockedKafkaWriter + + +storage_config = { + "cls": "memory", + "journal_writer": {"cls": "memory"}, +} + + +def empty_person_name_email(rev_or_rel): + """Empties the 'name' and 'email' fields of the author/committer fields + of a revision or release; leaving only the fullname.""" + if getattr(rev_or_rel, "author", None): + rev_or_rel = attr.evolve( + rev_or_rel, author=attr.evolve(rev_or_rel.author, name=b"", email=b"",) + ) + + if getattr(rev_or_rel, "committer", None): + rev_or_rel = attr.evolve( + rev_or_rel, + committer=attr.evolve(rev_or_rel.committer, name=b"", email=b"",), + ) + + return rev_or_rel + + +@given(lists(objects(blacklist_types=("origin_visit_status",)), min_size=1)) +@settings(suppress_health_check=[HealthCheck.too_slow]) +def test_write_replay_same_order_batches(objects): + queue = [] + replayer = MockedJournalClient(queue) + + with patch( + "swh.journal.writer.inmemory.InMemoryJournalWriter", + return_value=MockedKafkaWriter(queue), + ): + storage1 = get_storage(**storage_config) + + # Write objects to storage1 + for (obj_type, obj) in objects: + if obj_type == "content" and obj.status == "absent": + obj_type = "skipped_content" + + if obj_type == "origin_visit": + storage1.origin_add_one(Origin(url=obj.origin)) + storage1.origin_visit_upsert([obj]) + else: + method = getattr(storage1, obj_type + "_add") + try: + method([obj]) + except HashCollision: + pass + + # Bail out early if we didn't insert any relevant objects... + queue_size = len(queue) + assert queue_size != 0, "No test objects found; hypothesis strategy bug?" + + assert replayer.stop_after_objects is None + replayer.stop_after_objects = queue_size + + storage2 = get_storage(**storage_config) + worker_fn = functools.partial(process_replay_objects, storage=storage2) + + replayer.process(worker_fn) + + assert replayer.consumer.committed + + for attr_name in ( + "_contents", + "_directories", + "_snapshots", + "_origin_visits", + "_origins", + ): + assert getattr(storage1, attr_name) == getattr(storage2, attr_name), attr_name + + # When hypothesis generates a revision and a release with same + # author (or committer) fullname but different name or email, then + # the storage will use the first name/email it sees. + # This first one will be either the one from the revision or the release, + # and since there is no order guarantees, storage2 has 1/2 chance of + # not seeing the same order as storage1, therefore we need to strip + # them out before comparing. + for attr_name in ("_revisions", "_releases"): + items1 = { + k: empty_person_name_email(v) + for (k, v) in getattr(storage1, attr_name).items() + } + items2 = { + k: empty_person_name_email(v) + for (k, v) in getattr(storage2, attr_name).items() + } + assert items1 == items2, attr_name