diff --git a/requirements-swh-journal.txt b/requirements-swh-journal.txt --- a/requirements-swh-journal.txt +++ b/requirements-swh-journal.txt @@ -1 +1 @@ -swh.journal >= 0.6.2 +swh.journal >= 0.9 diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -6,8 +6,10 @@ # adding the [testing] extra. swh.model[testing] >= 0.0.50 pytz +pytest-redis pytest-xdist types-python-dateutil types-pytz types-pyyaml +types-redis types-requests diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,11 @@ +aiohttp +cassandra-driver >= 3.19.0, != 3.21.0 click +deprecated flask +iso8601 +mypy_extensions psycopg2 -aiohttp +redis tenacity -cassandra-driver >= 3.19.0, != 3.21.0 -deprecated typing-extensions -mypy_extensions -iso8601 diff --git a/swh/storage/backfill.py b/swh/storage/backfill.py --- a/swh/storage/backfill.py +++ b/swh/storage/backfill.py @@ -38,7 +38,7 @@ db_to_release, db_to_revision, ) -from swh.storage.replay import object_converter_fn +from swh.storage.replay import OBJECT_CONVERTERS from swh.storage.writer import JournalWriter logger = logging.getLogger(__name__) @@ -541,7 +541,7 @@ if converter: record = converter(db, record) else: - record = object_converter_fn[obj_type](record) + record = OBJECT_CONVERTERS[obj_type](record) logger.debug("record: %s", record) yield record diff --git a/swh/storage/cli.py b/swh/storage/cli.py --- a/swh/storage/cli.py +++ b/swh/storage/cli.py @@ -13,6 +13,8 @@ from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group +from swh.storage.replay import ModelObjectDeserializer + try: from systemd.daemon import notify @@ -204,11 +206,29 @@ conf = ctx.obj["config"] storage = get_storage(**conf.pop("storage")) + if "error_reporter" in conf: + from redis import Redis + + reporter = Redis(**conf["error_reporter"]).set + else: + reporter = None + validate = conf.get("privileged", False) + + if not validate and reporter: + ctx.fail( + "Invalid configuration: you cannot have 'error_reporter' set if " + "'privileged' is False; we cannot validate anonymized objects." + ) + + deserializer = ModelObjectDeserializer(reporter=reporter, validate=validate) + client_cfg = conf.pop("journal_client") + client_cfg["value_deserializer"] = deserializer.convert if object_types: client_cfg["object_types"] = object_types if stop_after_objects: client_cfg["stop_after_objects"] = stop_after_objects + try: client = get_journal_client(**client_cfg) except ValueError as exc: diff --git a/swh/storage/replay.py b/swh/storage/replay.py --- a/swh/storage/replay.py +++ b/swh/storage/replay.py @@ -3,8 +3,12 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from collections import Counter +from functools import partial import logging -from typing import Any, Callable, Container, Dict, List +from typing import Any, Callable +from typing import Counter as CounterT +from typing import Dict, List, Optional, TypeVar, Union, cast try: from systemd.daemon import notify @@ -12,12 +16,15 @@ notify = None from swh.core.statsd import statsd +from swh.journal.serializers import kafka_to_value +from swh.model.hashutil import hash_to_hex from swh.model.model import ( BaseContent, BaseModel, Content, Directory, ExtID, + HashableObject, MetadataAuthority, MetadataFetcher, Origin, @@ -29,9 +36,9 @@ SkippedContent, Snapshot, ) -from swh.storage.exc import HashCollision -from swh.storage.fixer import fix_objects +from swh.storage.exc import HashCollision, StorageArgumentException from swh.storage.interface import StorageInterface +from swh.storage.utils import remove_keys logger = logging.getLogger(__name__) @@ -39,7 +46,7 @@ GRAPH_DURATION_METRIC = "swh_graph_replayer_duration_seconds" -object_converter_fn: Dict[str, Callable[[Dict], BaseModel]] = { +OBJECT_CONVERTERS: Dict[str, Callable[[Dict], BaseModel]] = { "origin": Origin.from_dict, "origin_visit": OriginVisit.from_dict, "origin_visit_status": OriginVisitStatus.from_dict, @@ -54,10 +61,91 @@ "raw_extrinsic_metadata": RawExtrinsicMetadata.from_dict, "extid": ExtID.from_dict, } +# Deprecated, for BW compat only. +object_converter_fn = OBJECT_CONVERTERS + + +OBJECT_FIXERS = { + "revision": partial(remove_keys, keys=("metadata",)), +} + + +class ModelObjectDeserializer: + + """A swh.journal object deserializer that checks object validity and reports + invalid objects + + The deserializer will directly produce BaseModel objects from journal + objects representations. + + If validation is activated and the object is hashable, it will check if the + computed hash matches the identifier of the object. + + If the object is invalid and a 'reporter' function is given, it will be + called with 2 arguments:: + + reporter(object_id, journal_msg) + + Where 'object_id' is a string representation of the object identifier (from + the journal message), and 'journal_msg' is the row message (bytes) + retrieved from the journal. + + If 'raise_on_error' is True, a 'StorageArgumentException' exception is + raised. + + Typical usage:: + + deserializer = ModelObjectDeserializer(validate=True, reporter=reporter_cb) + client = get_journal_client( + cls="kafka", value_deserializer=deserializer, **cfg) + + """ + + def __init__( + self, + validate: bool = True, + raise_on_error: bool = False, + reporter: Optional[Callable[[str, bytes], None]] = None, + ): + self.validate = validate + self.reporter = reporter + self.raise_on_error = raise_on_error + + def convert(self, object_type: str, msg: bytes) -> Optional[BaseModel]: + dict_repr = kafka_to_value(msg) + if object_type in OBJECT_FIXERS: + dict_repr = OBJECT_FIXERS[object_type](dict_repr) + obj = OBJECT_CONVERTERS[object_type](dict_repr) + if self.validate: + if isinstance(obj, HashableObject): + cid = obj.compute_hash() + if obj.id != cid: + error_msg = ( + f"Object has id {hash_to_hex(obj.id)}, " + f"but it should be {hash_to_hex(cid)}: {obj}" + ) + logger.error(error_msg) + self.report_failure(msg, obj) + if self.raise_on_error: + raise StorageArgumentException(error_msg) + return None + return obj + + def report_failure(self, msg: bytes, obj: BaseModel): + if self.reporter: + oid: str = "" + if hasattr(obj, "swhid"): + swhid = obj.swhid() # type: ignore[attr-defined] + oid = str(swhid) + elif isinstance(obj, HashableObject): + uid = obj.compute_hash() + oid = f"{obj.object_type}:{uid.hex()}" # type: ignore[attr-defined] + if oid: + self.reporter(oid, msg) def process_replay_objects( - all_objects: Dict[str, List[Dict[str, Any]]], *, storage: StorageInterface + all_objects: Dict[str, List[BaseModel]], *, storage: StorageInterface ) -> None: for (object_type, objects) in all_objects.items(): logger.debug("Inserting %s %s objects", len(objects), object_type) @@ -70,9 +158,13 @@ notify("WATCHDOG=1") +ContentType = TypeVar("ContentType", bound=BaseContent) + + def collision_aware_content_add( - content_add_fn: Callable[[List[Any]], Dict[str, int]], contents: List[BaseContent], -) -> None: + contents: List[ContentType], + content_add_fn: Callable[[List[ContentType]], Dict[str, int]], +) -> Dict[str, int]: """Add contents to storage. If a hash collision is detected, an error is logged. Then this adds the other non colliding contents to the storage. @@ -82,11 +174,12 @@ """ if not contents: - return + return {} colliding_content_hashes: List[Dict[str, Any]] = [] + results: CounterT[str] = Counter() while True: try: - content_add_fn(contents) + results.update(content_add_fn(contents)) except HashCollision as e: colliding_content_hashes.append( { @@ -104,81 +197,35 @@ if colliding_content_hashes: for collision in colliding_content_hashes: logger.error("Collision detected: %(collision)s", {"collision": collision}) - - -def dict_key_dropper(d: Dict, keys_to_drop: Container) -> Dict: - """Returns a copy of the dict d without any key listed in keys_to_drop""" - return {k: v for (k, v) in d.items() if k not in keys_to_drop} + return dict(results) def _insert_objects( - object_type: str, objects: List[Dict], storage: StorageInterface + object_type: str, objects: List[BaseModel], storage: StorageInterface ) -> None: """Insert objects of type object_type in the storage. """ - objects = fix_objects(object_type, objects) - - if object_type == "content": - # for bw compat, skipped content should now be delivered in the skipped_content - # topic - contents: List[BaseContent] = [] - skipped_contents: List[BaseContent] = [] - for content in objects: - c = BaseContent.from_dict(content) - if isinstance(c, SkippedContent): - logger.warning( - "Received a series of skipped_content in the " - "content topic, this should not happen anymore" - ) - skipped_contents.append(c) - else: - contents.append(c) - collision_aware_content_add(storage.skipped_content_add, skipped_contents) - collision_aware_content_add(storage.content_add_metadata, contents) - elif object_type == "skipped_content": - skipped_contents = [SkippedContent.from_dict(obj) for obj in objects] - collision_aware_content_add(storage.skipped_content_add, skipped_contents) + if object_type not in OBJECT_CONVERTERS: + logger.warning("Received a series of %s, this should not happen", object_type) + return + + method = getattr(storage, f"{object_type}_add") + if object_type == "skipped_content": + method = partial(collision_aware_content_add, content_add_fn=method) + elif object_type == "content": + method = partial( + collision_aware_content_add, content_add_fn=storage.content_add_metadata + ) elif object_type in ("origin_visit", "origin_visit_status"): origins: List[Origin] = [] - converter_fn = object_converter_fn[object_type] - model_objs = [] - for obj in objects: - origins.append(Origin(url=obj["origin"])) - model_objs.append(converter_fn(obj)) + for obj in cast(List[Union[OriginVisit, OriginVisitStatus]], objects): + origins.append(Origin(url=obj.origin)) storage.origin_add(origins) - method = getattr(storage, f"{object_type}_add") - method(model_objs) elif object_type == "raw_extrinsic_metadata": - converted = [RawExtrinsicMetadata.from_dict(o) for o in objects] - authorities = {emd.authority for emd in converted} - fetchers = {emd.fetcher for emd in converted} + emds = cast(List[RawExtrinsicMetadata], objects) + authorities = {emd.authority for emd in emds} + fetchers = {emd.fetcher for emd in emds} storage.metadata_authority_add(list(authorities)) storage.metadata_fetcher_add(list(fetchers)) - storage.raw_extrinsic_metadata_add(converted) - elif object_type == "revision": - # drop the metadata field from the revision (is any); this field is - # about to be dropped from the data model (in favor of - # raw_extrinsic_metadata) and there can be bogus values in the existing - # journal (metadata with \0000 in it) - method = getattr(storage, object_type + "_add") - method( - [ - object_converter_fn[object_type](dict_key_dropper(o, ("metadata",))) - for o in objects - ] - ) - elif object_type in ( - "directory", - "extid", - "revision", - "release", - "snapshot", - "origin", - "metadata_fetcher", - "metadata_authority", - ): - method = getattr(storage, object_type + "_add") - method([object_converter_fn[object_type](o) for o in objects]) - else: - logger.warning("Received a series of %s, this should not happen", object_type) + method(objects) diff --git a/swh/storage/tests/test_backfill.py b/swh/storage/tests/test_backfill.py --- a/swh/storage/tests/test_backfill.py +++ b/swh/storage/tests/test_backfill.py @@ -20,7 +20,7 @@ raw_extrinsic_metadata_target_ranges, ) from swh.storage.in_memory import InMemoryStorage -from swh.storage.replay import process_replay_objects +from swh.storage.replay import ModelObjectDeserializer, process_replay_objects from swh.storage.tests.test_replay import check_replayed TEST_CONFIG = { @@ -239,7 +239,6 @@ } swh_storage_backend_config["journal_writer"] = journal1 storage = get_storage(**swh_storage_backend_config) - # fill the storage and the journal (under prefix1) for object_type, objects in TEST_OBJECTS.items(): method = getattr(storage, object_type + "_add") @@ -266,13 +265,16 @@ # now check journal content are the same under both topics # use the replayer scaffolding to fill storages to make is a bit easier # Replaying #1 + deserializer = ModelObjectDeserializer() sto1 = get_storage(cls="memory") replayer1 = JournalClient( brokers=kafka_server, group_id=f"{kafka_consumer_group}-1", prefix=prefix1, stop_on_eof=True, + value_deserializer=deserializer.convert, ) + worker_fn1 = functools.partial(process_replay_objects, storage=sto1) replayer1.process(worker_fn1) @@ -283,6 +285,7 @@ group_id=f"{kafka_consumer_group}-2", prefix=prefix2, stop_on_eof=True, + value_deserializer=deserializer.convert, ) worker_fn2 = functools.partial(process_replay_objects, storage=sto2) replayer2.process(worker_fn2) diff --git a/swh/storage/tests/test_cli.py b/swh/storage/tests/test_cli.py --- a/swh/storage/tests/test_cli.py +++ b/swh/storage/tests/test_cli.py @@ -18,7 +18,7 @@ from swh.model.model import Snapshot, SnapshotBranch, TargetType from swh.storage import get_storage from swh.storage.cli import storage as cli -from swh.storage.replay import object_converter_fn +from swh.storage.replay import OBJECT_CONVERTERS logger = logging.getLogger(__name__) @@ -119,7 +119,7 @@ assert len(types_in_help) == 1 types = types_in_help[0].split("|") - assert sorted(types) == sorted(list(object_converter_fn.keys())), ( + assert sorted(types) == sorted(list(OBJECT_CONVERTERS.keys())), ( "Make sure the list of accepted types in cli.py " "matches implementation in replay.py" ) diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py --- a/swh/storage/tests/test_replay.py +++ b/swh/storage/tests/test_replay.py @@ -7,13 +7,14 @@ import datetime import functools import logging -from typing import Any, Container, Dict, Optional +import re +from typing import Any, Container, Dict, Optional, cast import attr import pytest from swh.journal.client import JournalClient -from swh.journal.serializers import key_to_kafka, value_to_kafka +from swh.journal.serializers import kafka_to_value, key_to_kafka, value_to_kafka from swh.model.hashutil import DEFAULT_ALGORITHMS, MultiHash, hash_to_bytes, hash_to_hex from swh.model.model import Revision, RevisionType from swh.model.tests.swh_model_data import ( @@ -25,15 +26,17 @@ from swh.model.tests.swh_model_data import TEST_OBJECTS as _TEST_OBJECTS from swh.storage import get_storage from swh.storage.cassandra.model import ContentRow, SkippedContentRow +from swh.storage.exc import StorageArgumentException from swh.storage.in_memory import InMemoryStorage -from swh.storage.replay import process_replay_objects +from swh.storage.replay import ModelObjectDeserializer, process_replay_objects UTC = datetime.timezone.utc TEST_OBJECTS = _TEST_OBJECTS.copy() +# add a revision with metadata to check this later is dropped while being replayed TEST_OBJECTS["revision"] = list(_TEST_OBJECTS["revision"]) + [ Revision( - id=hash_to_bytes("a569b03ebe6e5f9f2f6077355c40d89bd6986d0c"), + id=hash_to_bytes("51d9d94ab08d3f75512e3a9fd15132e0a7ca7928"), message=b"hello again", date=DATES[1], committer=COMMITTERS[1], @@ -46,6 +49,9 @@ parents=(REVISIONS[0].id,), ), ] +WRONG_ID_REG = re.compile( + "Object has id [0-9a-f]{40}, but it should be [0-9a-f]{40}: .*" +) def nullify_ctime(obj): @@ -70,11 +76,13 @@ "journal_writer": journal_writer_config, } storage = get_storage(**storage_config) + deserializer = ModelObjectDeserializer() replayer = JournalClient( brokers=kafka_server, group_id=kafka_consumer_group, prefix=kafka_prefix, stop_on_eof=True, + value_deserializer=deserializer.convert, ) yield storage, replayer @@ -122,7 +130,7 @@ assert collision == 0, "No collision should be detected" -def test_storage_play_with_collision(replayer_storage_and_client, caplog): +def test_storage_replay_with_collision(replayer_storage_and_client, caplog): """Another replayer scenario with collisions. This: @@ -207,12 +215,6 @@ _check_replay_skipped_content(src, replayer, "skipped_content") -def test_replay_skipped_content_bwcompat(replayer_storage_and_client): - """Test the 'content' topic can be used to replay SkippedContent objects.""" - src, replayer = replayer_storage_and_client - _check_replay_skipped_content(src, replayer, "content") - - # utility functions @@ -321,7 +323,7 @@ @pytest.mark.parametrize("privileged", [True, False]) -def test_storage_play_anonymized( +def test_storage_replay_anonymized( kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, privileged: bool, ): """Optimal replayer scenario. @@ -355,12 +357,16 @@ # Fill a destination storage from Kafka, potentially using privileged topics dst_storage = get_storage(cls="memory") + deserializer = ModelObjectDeserializer( + validate=False + ) # we cannot validate an anonymized replay replayer = JournalClient( brokers=kafka_server, group_id=kafka_consumer_group, prefix=kafka_prefix, stop_after_objects=nb_sent, privileged=privileged, + value_deserializer=deserializer.convert, ) worker_fn = functools.partial(process_replay_objects, storage=dst_storage) @@ -373,3 +379,166 @@ assert isinstance(storage, InMemoryStorage) # needed to help mypy assert isinstance(dst_storage, InMemoryStorage) check_replayed(storage, dst_storage, expected_anonymized=not privileged) + + +def test_storage_replayer_with_validation_ok( + replayer_storage_and_client, caplog, redisdb +): + """Optimal replayer scenario + + with validation activated and reporter set to a redis db. + + - writes objects to a source storage + - replayer consumes objects from the topic and replays them + - a destination storage is filled from this + - nothing has been reported in the redis db + - both storages should have the same content + """ + src, replayer = replayer_storage_and_client + replayer.deserializer = ModelObjectDeserializer(validate=True, reporter=redisdb.set) + + # Fill Kafka using a source storage + nb_sent = 0 + for object_type, objects in TEST_OBJECTS.items(): + method = getattr(src, object_type + "_add") + method(objects) + if object_type == "origin_visit": + nb_sent += len(objects) # origin-visit-add adds origin-visit-status as well + nb_sent += len(objects) + + # Fill the destination storage from Kafka + dst = get_storage(cls="memory") + worker_fn = functools.partial(process_replay_objects, storage=dst) + nb_inserted = replayer.process(worker_fn) + assert nb_sent == nb_inserted + + # check we do not have invalid objects reported + invalid = 0 + for record in caplog.records: + logtext = record.getMessage() + if WRONG_ID_REG.match(logtext): + invalid += 1 + assert invalid == 0, "Invalid objects should not be detected" + assert not redisdb.keys() + # so the dst should be the same as src storage + check_replayed(cast(InMemoryStorage, src), cast(InMemoryStorage, dst)) + + +def test_storage_replayer_with_validation_nok( + replayer_storage_and_client, caplog, redisdb +): + """Replayer scenario with invalid objects + + with validation and reporter set to a redis db. + + - writes objects to a source storage + - replayer consumes objects from the topic and replays them + - the destination storage is filled with only valid objects + - the redis db contains the invalid (raw kafka mesg) objects + """ + src, replayer = replayer_storage_and_client + replayer.value_deserializer = ModelObjectDeserializer( + validate=True, reporter=redisdb.set + ).convert + + caplog.set_level(logging.ERROR, "swh.journal.replay") + + # Fill Kafka using a source storage + nb_sent = 0 + for object_type, objects in TEST_OBJECTS.items(): + method = getattr(src, object_type + "_add") + method(objects) + if object_type == "origin_visit": + nb_sent += len(objects) # origin-visit-add adds origin-visit-status as well + nb_sent += len(objects) + + # insert invalid objects + for object_type in ("revision", "directory", "release", "snapshot"): + method = getattr(src, object_type + "_add") + method([attr.evolve(TEST_OBJECTS[object_type][0], id=b"\x00" * 20)]) + nb_sent += 1 + + # Fill the destination storage from Kafka + dst = get_storage(cls="memory") + worker_fn = functools.partial(process_replay_objects, storage=dst) + nb_inserted = replayer.process(worker_fn) + assert nb_sent == nb_inserted + + # check we do have invalid objects reported + invalid = 0 + for record in caplog.records: + logtext = record.getMessage() + if WRONG_ID_REG.match(logtext): + invalid += 1 + assert invalid == 4, "Invalid objects should be detected" + assert set(redisdb.keys()) == { + f"swh:1:{typ}:{'0'*40}".encode() for typ in ("rel", "rev", "snp", "dir") + } + + for key in redisdb.keys(): + # check the stored value looks right + rawvalue = redisdb.get(key) + value = kafka_to_value(rawvalue) + assert isinstance(value, dict) + assert "id" in value + assert value["id"] == b"\x00" * 20 + + # check that invalid objects did not reach the dst storage + for attr_ in ( + "directories", + "revisions", + "releases", + "snapshots", + ): + for id, obj in sorted(getattr(dst._cql_runner, f"_{attr_}").iter_all()): + assert id != b"\x00" * 20 + + +def test_storage_replayer_with_validation_nok_raises( + replayer_storage_and_client, caplog, redisdb +): + """Replayer scenario with invalid objects + + with raise_on_error set to True + + This: + - writes both valid & invalid objects to a source storage + - a StorageArgumentException should be raised while replayer consumes + objects from the topic and replays them + """ + src, replayer = replayer_storage_and_client + replayer.value_deserializer = ModelObjectDeserializer( + validate=True, reporter=redisdb.set, raise_on_error=True + ).convert + + caplog.set_level(logging.ERROR, "swh.journal.replay") + + # Fill Kafka using a source storage + nb_sent = 0 + for object_type, objects in TEST_OBJECTS.items(): + method = getattr(src, object_type + "_add") + method(objects) + if object_type == "origin_visit": + nb_sent += len(objects) # origin-visit-add adds origin-visit-status as well + nb_sent += len(objects) + + # insert invalid objects + for object_type in ("revision", "directory", "release", "snapshot"): + method = getattr(src, object_type + "_add") + method([attr.evolve(TEST_OBJECTS[object_type][0], id=b"\x00" * 20)]) + nb_sent += 1 + + # Fill the destination storage from Kafka + dst = get_storage(cls="memory") + worker_fn = functools.partial(process_replay_objects, storage=dst) + with pytest.raises(StorageArgumentException): + replayer.process(worker_fn) + + # check we do have invalid objects reported + invalid = 0 + for record in caplog.records: + logtext = record.getMessage() + if WRONG_ID_REG.match(logtext): + invalid += 1 + assert invalid == 1, "One invalid objects should be detected" + assert len(redisdb.keys()) == 1