diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -90,10 +90,45 @@ class CassandraStorage: - def __init__(self, hosts, keyspace, objstorage, port=9042, journal_writer=None): - self._cql_runner: CqlRunner = CqlRunner(hosts, keyspace, port) + def __init__( + self, + hosts, + keyspace, + objstorage, + port=9042, + journal_writer=None, + allow_overwrite=False, + ): + """ + A backend of swh-storage backed by Cassandra + + Args: + hosts: Seed Cassandra nodes, to start connecting to the cluster + keyspace: Name of the Cassandra database to use + objstorage: Passed as argument to :class:`ObjStorage` + port: Cassandra port + journal_writer: Passed as argument to :class:`JournalWriter` + allow_overwrite: Whether ``*_add`` functions will check if an object + already exists in the database before sending it in an INSERT. + ``False`` is the default as it is more efficient when there is + a moderately high probability the object is already known, + but ``True`` can be useful to overwrite existing objects + (eg. when applying a schema update), + or when the database is known to be mostly empty. + Note that a ``False`` value does not guarantee there won't be + any overwrite. + """ + self._hosts = hosts + self._keyspace = keyspace + self._port = port + self._set_cql_runner() self.journal_writer: JournalWriter = JournalWriter(journal_writer) self.objstorage: ObjStorage = ObjStorage(objstorage) + self._allow_overwrite = allow_overwrite + + def _set_cql_runner(self): + """Used by tests when they need to reset the CqlRunner""" + self._cql_runner: CqlRunner = CqlRunner(self._hosts, self._keyspace, self._port) def check_config(self, *, check_write: bool) -> bool: self._cql_runner.check_read() @@ -120,9 +155,12 @@ def _content_add(self, contents: List[Content], with_data: bool) -> Dict[str, int]: # Filter-out content already in the database. - contents = [ - c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict()) - ] + if not self._allow_overwrite: + contents = [ + c + for c in contents + if not self._cql_runner.content_get_from_pk(c.to_dict()) + ] if with_data: # First insert to the objstorage, if the endpoint is @@ -151,27 +189,28 @@ # The proper way to do it would probably be a BATCH, but this # would be inefficient because of the number of partitions we # need to affect (len(HASH_ALGORITHMS)+1, which is currently 5) - for algo in {"sha1", "sha1_git"}: - collisions = [] - # Get tokens of 'content' rows with the same value for - # sha1/sha1_git - rows = self._content_get_from_hash(algo, content.get_hash(algo)) - for row in rows: - if getattr(row, algo) != content.get_hash(algo): - # collision of token(partition key), ignore this - # row - continue - - for other_algo in HASH_ALGORITHMS: - if getattr(row, other_algo) != content.get_hash(other_algo): - # This hash didn't match; discard the row. - collisions.append( - {k: getattr(row, k) for k in HASH_ALGORITHMS} - ) - - if collisions: - collisions.append(content.hashes()) - raise HashCollision(algo, content.get_hash(algo), collisions) + if not self._allow_overwrite: + for algo in {"sha1", "sha1_git"}: + collisions = [] + # Get tokens of 'content' rows with the same value for + # sha1/sha1_git + rows = self._content_get_from_hash(algo, content.get_hash(algo)) + for row in rows: + if getattr(row, algo) != content.get_hash(algo): + # collision of token(partition key), ignore this + # row + continue + + for other_algo in HASH_ALGORITHMS: + if getattr(row, other_algo) != content.get_hash(other_algo): + # This hash didn't match; discard the row. + collisions.append( + {k: getattr(row, k) for k in HASH_ALGORITHMS} + ) + + if collisions: + collisions.append(content.hashes()) + raise HashCollision(algo, content.get_hash(algo), collisions) (token, insertion_finalizer) = self._cql_runner.content_add_prepare( ContentRow(**remove_keys(content.to_dict(), ("data",))) @@ -324,11 +363,12 @@ def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict[str, int]: # Filter-out content already in the database. - contents = [ - c - for c in contents - if not self._cql_runner.skipped_content_get_from_pk(c.to_dict()) - ] + if not self._allow_overwrite: + contents = [ + c + for c in contents + if not self._cql_runner.skipped_content_get_from_pk(c.to_dict()) + ] self.journal_writer.skipped_content_add(contents) @@ -360,9 +400,10 @@ def directory_add(self, directories: List[Directory]) -> Dict[str, int]: to_add = {d.id: d for d in directories}.values() - # Filter out directories that are already inserted. - missing = self.directory_missing([dir_.id for dir_ in to_add]) - directories = [dir_ for dir_ in directories if dir_.id in missing] + if not self._allow_overwrite: + # Filter out directories that are already inserted. + missing = self.directory_missing([dir_.id for dir_ in to_add]) + directories = [dir_ for dir_ in directories if dir_.id in missing] self.journal_writer.directory_add(directories) @@ -485,9 +526,10 @@ def revision_add(self, revisions: List[Revision]) -> Dict[str, int]: # Filter-out revisions already in the database - to_add = {r.id: r for r in revisions}.values() - missing = self.revision_missing([rev.id for rev in to_add]) - revisions = [rev for rev in revisions if rev.id in missing] + if not self._allow_overwrite: + to_add = {r.id: r for r in revisions}.values() + missing = self.revision_missing([rev.id for rev in to_add]) + revisions = [rev for rev in revisions if rev.id in missing] self.journal_writer.revision_add(revisions) for revision in revisions: @@ -596,9 +638,10 @@ return revision.id def release_add(self, releases: List[Release]) -> Dict[str, int]: - to_add = {r.id: r for r in releases}.values() - missing = set(self.release_missing([rel.id for rel in to_add])) - releases = [rel for rel in to_add if rel.id in missing] + if not self._allow_overwrite: + to_add = {r.id: r for r in releases}.values() + missing = set(self.release_missing([rel.id for rel in to_add])) + releases = [rel for rel in to_add if rel.id in missing] self.journal_writer.release_add(releases) for release in releases: @@ -625,9 +668,10 @@ return release.id def snapshot_add(self, snapshots: List[Snapshot]) -> Dict[str, int]: - to_add = {s.id: s for s in snapshots}.values() - missing = self._cql_runner.snapshot_missing([snp.id for snp in to_add]) - snapshots = [snp for snp in snapshots if snp.id in missing] + if not self._allow_overwrite: + to_add = {s.id: s for s in snapshots}.values() + missing = self._cql_runner.snapshot_missing([snp.id for snp in to_add]) + snapshots = [snp for snp in snapshots if snp.id in missing] for snapshot in snapshots: self.journal_writer.snapshot_add([snapshot]) @@ -896,8 +940,9 @@ ) def origin_add(self, origins: List[Origin]) -> Dict[str, int]: - to_add = {o.url: o for o in origins}.values() - origins = [ori for ori in to_add if self.origin_get_one(ori.url) is None] + if not self._allow_overwrite: + to_add = {o.url: o for o in origins}.values() + origins = [ori for ori in to_add if self.origin_get_one(ori.url) is None] self.journal_writer.origin_add(origins) for origin in origins: @@ -1358,13 +1403,16 @@ # ExtID tables def extid_add(self, ids: List[ExtID]) -> Dict[str, int]: - extids = [ - extid - for extid in ids - if not self._cql_runner.extid_get_from_pk( - extid_type=extid.extid_type, extid=extid.extid, target=extid.target, - ) - ] + if not self._allow_overwrite: + extids = [ + extid + for extid in ids + if not self._cql_runner.extid_get_from_pk( + extid_type=extid.extid_type, extid=extid.extid, target=extid.target, + ) + ] + else: + extids = list(ids) self.journal_writer.extid_add(extids) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -678,6 +678,7 @@ def __init__(self, journal_writer=None): self.reset() self.journal_writer = JournalWriter(journal_writer) + self._allow_overwrite = False def reset(self): self._cql_runner = InMemoryCqlRunner() diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -4,26 +4,29 @@ # See top-level LICENSE file for more information import datetime +import itertools import os import signal import socket import subprocess import time -from typing import Dict +from typing import Any, Dict import attr import pytest from swh.core.api.classes import stream_results +from swh.model.model import Directory, DirectoryEntry, Snapshot from swh.storage import get_storage from swh.storage.cassandra import create_keyspace from swh.storage.cassandra.model import ContentRow, ExtIDRow from swh.storage.cassandra.schema import HASH_ALGORITHMS, TABLES +from swh.storage.tests.storage_data import StorageData from swh.storage.tests.storage_tests import ( TestStorageGeneratedData as _TestStorageGeneratedData, ) from swh.storage.tests.storage_tests import TestStorage as _TestStorage -from swh.storage.utils import now +from swh.storage.utils import now, remove_keys CONFIG_TEMPLATE = """ data_file_directories: @@ -448,3 +451,105 @@ @pytest.mark.skip("Not supported by Cassandra") def test_origin_count_with_visit_with_visits_no_snapshot(self): pass + + +@pytest.mark.parametrize( + "allow_overwrite,object_type", + itertools.product( + [False, True], + # Note the absence of "content", it's tested above. + ["directory", "revision", "release", "snapshot", "origin", "extid"], + ), +) +def test_allow_overwrite( + allow_overwrite: bool, object_type: str, swh_storage_backend_config +): + if object_type in ("origin", "extid"): + pytest.skip( + f"test_disallow_overwrite not implemented for {object_type} objects, " + f"because all their columns are in the primary key." + ) + swh_storage = get_storage( + allow_overwrite=allow_overwrite, **swh_storage_backend_config + ) + + # directory_ls joins with content and directory table, and needs those to return + # non-None entries: + if object_type == "directory": + swh_storage.directory_add([StorageData.directory5]) + swh_storage.content_add([StorageData.content, StorageData.content2]) + + obj1: Any + obj2: Any + + # Get two test objects + if object_type == "directory": + (obj1, obj2, *_) = StorageData.directories + elif object_type == "snapshot": + # StorageData.snapshots[1] is the empty snapshot, which is the corner case + # that makes this test succeed for the wrong reasons + obj1 = StorageData.snapshot + obj2 = StorageData.complete_snapshot + else: + (obj1, obj2, *_) = getattr(StorageData, (object_type + "s")) + + # Let's make both objects have the same hash, but different content + obj1 = attr.evolve(obj1, id=obj2.id) + + # Get the methods used to add and get these objects + add = getattr(swh_storage, object_type + "_add") + if object_type == "directory": + + def get(ids): + return [ + Directory( + id=ids[0], + entries=tuple( + map( + lambda entry: DirectoryEntry( + name=entry["name"], + type=entry["type"], + target=entry["sha1_git"], + perms=entry["perms"], + ), + swh_storage.directory_ls(ids[0]), + ) + ), + ) + ] + + elif object_type == "snapshot": + + def get(ids): + return [ + Snapshot.from_dict( + remove_keys(swh_storage.snapshot_get(ids[0]), ("next_branch",)) + ) + ] + + else: + get = getattr(swh_storage, object_type + "_get") + + # Add the first object + add([obj1]) + + # It should be returned as-is + assert get([obj1.id]) == [obj1] + + # Add the second object + add([obj2]) + + if allow_overwrite: + # obj1 was overwritten by obj2 + expected = obj2 + else: + # obj2 was not written, because obj1 already exists and has the same hash + expected = obj1 + + if allow_overwrite and object_type in ("directory", "snapshot"): + # TODO + pytest.xfail( + "directory entries and snapshot branches are concatenated " + "instead of being replaced" + ) + assert get([obj1.id]) == [expected] diff --git a/swh/storage/tests/test_cassandra_migration.py b/swh/storage/tests/test_cassandra_migration.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_cassandra_migration.py @@ -0,0 +1,163 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +"""This module tests the migration capabilities of the Cassandra backend, +by sending CQL commands (eg. 'ALTER TABLE'), and +by monkey-patching large parts of the implementations to simulate code updates,.""" + +import dataclasses +import functools +import operator +from typing import Dict, Iterable, Optional + +import attr +import pytest + +from swh.model.model import Content +from swh.storage import get_storage +from swh.storage.cassandra.cql import ( + CqlRunner, + _prepared_insert_statement, + _prepared_select_statement, +) +from swh.storage.cassandra.model import ContentRow +from swh.storage.cassandra.schema import CONTENT_INDEX_TEMPLATE, HASH_ALGORITHMS +from swh.storage.cassandra.storage import CassandraStorage +from swh.storage.exc import StorageArgumentException + +from .storage_data import StorageData +from .test_cassandra import ( # noqa, needed for swh_storage fixture + cassandra_cluster, + keyspace, + swh_storage_backend_config, +) + + +def byte_xor_hash(data): + # Behold, a one-line hash function: + return bytes([functools.reduce(operator.xor, data)]) + + +@attr.s +class ContentWithXor(Content): + """An hypothetical upgrade of Content with an extra "hash".""" + + byte_xor = attr.ib(type=bytes, default=None) + + +@dataclasses.dataclass +class ContentRowWithXor(ContentRow): + """An hypothetical upgrade of ContentRow with an extra "hash".""" + + byte_xor: bytes + + +class CqlRunnerWithXor(CqlRunner): + @_prepared_select_statement( + ContentRowWithXor, + f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}", + ) + def content_get_from_pk( + self, content_hashes: Dict[str, bytes], *, statement + ) -> Optional[ContentRow]: + rows = list( + self._execute_with_retries( + statement, [content_hashes[algo] for algo in HASH_ALGORITHMS] + ) + ) + assert len(rows) <= 1 + if rows: + return ContentRowWithXor(**rows[0]) + else: + return None + + @_prepared_select_statement( + ContentRowWithXor, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) = ?" + ) + def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]: + return map( + ContentRowWithXor.from_dict, self._execute_with_retries(statement, [token]) + ) + + # Redecorate content_add_prepare with the new ContentRow class + content_add_prepare = _prepared_insert_statement(ContentRowWithXor)( # type: ignore + CqlRunner.content_add_prepare.__wrapped__ # type: ignore + ) + + +def test_add_content_column( + swh_storage: CassandraStorage, swh_storage_backend_config, mocker # noqa +) -> None: + """Adds a column to the 'content' table and a new matching index. + This is a simple migration, as it does not require an update to the primary key. + """ + content_xor_hash = byte_xor_hash(StorageData.content.data) + + # First insert some existing data + swh_storage.content_add([StorageData.content, StorageData.content2]) + + # Then update the schema + swh_storage._cql_runner._session.execute("ALTER TABLE content ADD byte_xor blob") + for statement in CONTENT_INDEX_TEMPLATE.split("\n\n"): + swh_storage._cql_runner._session.execute(statement.format(main_algo="byte_xor")) + + # Should not affect the running code at all: + assert swh_storage.content_get([StorageData.content.sha1]) == [ + attr.evolve(StorageData.content, data=None) + ] + with pytest.raises(StorageArgumentException): + swh_storage.content_find({"byte_xor": content_xor_hash}) + + # Then update the running code: + new_hash_algos = HASH_ALGORITHMS + ["byte_xor"] + mocker.patch("swh.storage.cassandra.storage.HASH_ALGORITHMS", new_hash_algos) + mocker.patch("swh.storage.cassandra.cql.HASH_ALGORITHMS", new_hash_algos) + mocker.patch("swh.model.model.DEFAULT_ALGORITHMS", new_hash_algos) + mocker.patch("swh.storage.cassandra.storage.Content", ContentWithXor) + mocker.patch("swh.storage.cassandra.storage.ContentRow", ContentRowWithXor) + mocker.patch("swh.storage.cassandra.model.ContentRow", ContentRowWithXor) + mocker.patch("swh.storage.cassandra.storage.CqlRunner", CqlRunnerWithXor) + + # Forge new objects with this extra hash: + new_content = ContentWithXor.from_dict( + { + "byte_xor": byte_xor_hash(StorageData.content.data), + **StorageData.content.to_dict(), + } + ) + new_content2 = ContentWithXor.from_dict( + { + "byte_xor": byte_xor_hash(StorageData.content2.data), + **StorageData.content2.to_dict(), + } + ) + + # Simulates a restart: + swh_storage._set_cql_runner() + + # Old algos still works, and return the new object type: + assert swh_storage.content_get([StorageData.content.sha1]) == [ + attr.evolve(new_content, data=None, byte_xor=None) + ] + + # The new algo does not work, we did not backfill it yet: + assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [] + + # A normal storage would not overwrite, because the object already exists, + # as it is not aware it is missing a field: + swh_storage.content_add([new_content, new_content2]) + assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [] + + # Backfill (in production this would be done with a replayer reading from + # the journal): + overwriting_swh_storage = get_storage( + allow_overwrite=True, **swh_storage_backend_config + ) + overwriting_swh_storage.content_add([new_content, new_content2]) + + # Now, the object can be found: + assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [ + attr.evolve(new_content, data=None) + ]