Page MenuHomeSoftware Heritage

D5582.id20144.diff
No OneTemporary

D5582.id20144.diff

diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -89,10 +89,45 @@
class CassandraStorage:
- def __init__(self, hosts, keyspace, objstorage, port=9042, journal_writer=None):
- self._cql_runner: CqlRunner = CqlRunner(hosts, keyspace, port)
+ def __init__(
+ self,
+ hosts,
+ keyspace,
+ objstorage,
+ port=9042,
+ journal_writer=None,
+ allow_overwrite=False,
+ ):
+ """
+ A backend of swh-storage backed by Cassandra
+
+ Args:
+ hosts: Seed Cassandra nodes, to start connecting to the cluster
+ keyspace: Name of the Cassandra database to use
+ objstorage: Passed as argument to :class:`ObjStorage`
+ port: Cassandra port
+ journal_writer: Passed as argument to :class:`JournalWriter`
+ allow_overwrite: Whether ``*_add`` functions will check if an object
+ already exists in the database before sending it in an INSERT.
+ ``False`` is the default as it is more efficient when there is
+ a moderately high probability the object is already known,
+ but ``True`` can be useful to overwrite existing objects
+ (eg. when applying a schema update),
+ or when the database is known to be mostly empty.
+ Note that a ``False`` value does not guarantee there won't be
+ any overwrite.
+ """
+ self._hosts = hosts
+ self._keyspace = keyspace
+ self._port = port
+ self._set_cql_runner()
self.journal_writer: JournalWriter = JournalWriter(journal_writer)
self.objstorage: ObjStorage = ObjStorage(objstorage)
+ self._allow_overwrite = allow_overwrite
+
+ def _set_cql_runner(self):
+ """Used by tests when they need to reset the CqlRunner"""
+ self._cql_runner: CqlRunner = CqlRunner(self._hosts, self._keyspace, self._port)
def check_config(self, *, check_write: bool) -> bool:
self._cql_runner.check_read()
@@ -119,9 +154,12 @@
def _content_add(self, contents: List[Content], with_data: bool) -> Dict:
# Filter-out content already in the database.
- contents = [
- c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict())
- ]
+ if not self._allow_overwrite:
+ contents = [
+ c
+ for c in contents
+ if not self._cql_runner.content_get_from_pk(c.to_dict())
+ ]
if with_data:
# First insert to the objstorage, if the endpoint is
@@ -150,27 +188,28 @@
# The proper way to do it would probably be a BATCH, but this
# would be inefficient because of the number of partitions we
# need to affect (len(HASH_ALGORITHMS)+1, which is currently 5)
- for algo in {"sha1", "sha1_git"}:
- collisions = []
- # Get tokens of 'content' rows with the same value for
- # sha1/sha1_git
- rows = self._content_get_from_hash(algo, content.get_hash(algo))
- for row in rows:
- if getattr(row, algo) != content.get_hash(algo):
- # collision of token(partition key), ignore this
- # row
- continue
-
- for other_algo in HASH_ALGORITHMS:
- if getattr(row, other_algo) != content.get_hash(other_algo):
- # This hash didn't match; discard the row.
- collisions.append(
- {k: getattr(row, k) for k in HASH_ALGORITHMS}
- )
-
- if collisions:
- collisions.append(content.hashes())
- raise HashCollision(algo, content.get_hash(algo), collisions)
+ if not self._allow_overwrite:
+ for algo in {"sha1", "sha1_git"}:
+ collisions = []
+ # Get tokens of 'content' rows with the same value for
+ # sha1/sha1_git
+ rows = self._content_get_from_hash(algo, content.get_hash(algo))
+ for row in rows:
+ if getattr(row, algo) != content.get_hash(algo):
+ # collision of token(partition key), ignore this
+ # row
+ continue
+
+ for other_algo in HASH_ALGORITHMS:
+ if getattr(row, other_algo) != content.get_hash(other_algo):
+ # This hash didn't match; discard the row.
+ collisions.append(
+ {k: getattr(row, k) for k in HASH_ALGORITHMS}
+ )
+
+ if collisions:
+ collisions.append(content.hashes())
+ raise HashCollision(algo, content.get_hash(algo), collisions)
(token, insertion_finalizer) = self._cql_runner.content_add_prepare(
ContentRow(**remove_keys(content.to_dict(), ("data",)))
@@ -323,11 +362,12 @@
def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict:
# Filter-out content already in the database.
- contents = [
- c
- for c in contents
- if not self._cql_runner.skipped_content_get_from_pk(c.to_dict())
- ]
+ if not self._allow_overwrite:
+ contents = [
+ c
+ for c in contents
+ if not self._cql_runner.skipped_content_get_from_pk(c.to_dict())
+ ]
self.journal_writer.skipped_content_add(contents)
@@ -359,9 +399,10 @@
def directory_add(self, directories: List[Directory]) -> Dict:
to_add = {d.id: d for d in directories}.values()
- # Filter out directories that are already inserted.
- missing = self.directory_missing([dir_.id for dir_ in to_add])
- directories = [dir_ for dir_ in directories if dir_.id in missing]
+ if not self._allow_overwrite:
+ # Filter out directories that are already inserted.
+ missing = self.directory_missing([dir_.id for dir_ in to_add])
+ directories = [dir_ for dir_ in directories if dir_.id in missing]
self.journal_writer.directory_add(directories)
@@ -484,9 +525,10 @@
def revision_add(self, revisions: List[Revision]) -> Dict:
# Filter-out revisions already in the database
- to_add = {r.id: r for r in revisions}.values()
- missing = self.revision_missing([rev.id for rev in to_add])
- revisions = [rev for rev in revisions if rev.id in missing]
+ if not self._allow_overwrite:
+ to_add = {r.id: r for r in revisions}.values()
+ missing = self.revision_missing([rev.id for rev in to_add])
+ revisions = [rev for rev in revisions if rev.id in missing]
self.journal_writer.revision_add(revisions)
for revision in revisions:
@@ -595,9 +637,10 @@
return revision.id
def release_add(self, releases: List[Release]) -> Dict:
- to_add = {r.id: r for r in releases}.values()
- missing = set(self.release_missing([rel.id for rel in to_add]))
- releases = [rel for rel in to_add if rel.id in missing]
+ if not self._allow_overwrite:
+ to_add = {r.id: r for r in releases}.values()
+ missing = set(self.release_missing([rel.id for rel in to_add]))
+ releases = [rel for rel in to_add if rel.id in missing]
self.journal_writer.release_add(releases)
for release in releases:
@@ -624,9 +667,10 @@
return release.id
def snapshot_add(self, snapshots: List[Snapshot]) -> Dict:
- to_add = {s.id: s for s in snapshots}.values()
- missing = self._cql_runner.snapshot_missing([snp.id for snp in to_add])
- snapshots = [snp for snp in snapshots if snp.id in missing]
+ if not self._allow_overwrite:
+ to_add = {s.id: s for s in snapshots}.values()
+ missing = self._cql_runner.snapshot_missing([snp.id for snp in to_add])
+ snapshots = [snp for snp in snapshots if snp.id in missing]
for snapshot in snapshots:
self.journal_writer.snapshot_add([snapshot])
@@ -895,8 +939,9 @@
)
def origin_add(self, origins: List[Origin]) -> Dict[str, int]:
- to_add = {o.url: o for o in origins}.values()
- origins = [ori for ori in to_add if self.origin_get_one(ori.url) is None]
+ if not self._allow_overwrite:
+ to_add = {o.url: o for o in origins}.values()
+ origins = [ori for ori in to_add if self.origin_get_one(ori.url) is None]
self.journal_writer.origin_add(origins)
for origin in origins:
@@ -1343,13 +1388,16 @@
# ExtID tables
def extid_add(self, ids: List[ExtID]) -> Dict[str, int]:
- extids = [
- extid
- for extid in ids
- if not self._cql_runner.extid_get_from_pk(
- extid_type=extid.extid_type, extid=extid.extid, target=extid.target,
- )
- ]
+ if not self._allow_overwrite:
+ extids = [
+ extid
+ for extid in ids
+ if not self._cql_runner.extid_get_from_pk(
+ extid_type=extid.extid_type, extid=extid.extid, target=extid.target,
+ )
+ ]
+ else:
+ extids = list(ids)
self.journal_writer.extid_add(extids)
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -678,6 +678,7 @@
def __init__(self, journal_writer=None):
self.reset()
self.journal_writer = JournalWriter(journal_writer)
+ self._allow_overwrite = False
def reset(self):
self._cql_runner = InMemoryCqlRunner()
diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py
--- a/swh/storage/tests/test_cassandra.py
+++ b/swh/storage/tests/test_cassandra.py
@@ -4,26 +4,29 @@
# See top-level LICENSE file for more information
import datetime
+import itertools
import os
import signal
import socket
import subprocess
import time
-from typing import Dict
+from typing import Any, Dict
import attr
import pytest
from swh.core.api.classes import stream_results
+from swh.model.model import Directory, DirectoryEntry, Snapshot
from swh.storage import get_storage
from swh.storage.cassandra import create_keyspace
from swh.storage.cassandra.model import ContentRow, ExtIDRow
from swh.storage.cassandra.schema import HASH_ALGORITHMS, TABLES
+from swh.storage.tests.storage_data import StorageData
from swh.storage.tests.storage_tests import (
TestStorageGeneratedData as _TestStorageGeneratedData,
)
from swh.storage.tests.storage_tests import TestStorage as _TestStorage
-from swh.storage.utils import now
+from swh.storage.utils import now, remove_keys
CONFIG_TEMPLATE = """
data_file_directories:
@@ -448,3 +451,105 @@
@pytest.mark.skip("Not supported by Cassandra")
def test_origin_count_with_visit_with_visits_no_snapshot(self):
pass
+
+
+@pytest.mark.parametrize(
+ "allow_overwrite,object_type",
+ itertools.product(
+ [False, True],
+ # Note the absence of "content", it's tested above.
+ ["directory", "revision", "release", "snapshot", "origin", "extid"],
+ ),
+)
+def test_allow_overwrite(
+ allow_overwrite: bool, object_type: str, swh_storage_backend_config
+):
+ if object_type in ("origin", "extid"):
+ pytest.skip(
+ f"test_disallow_overwrite not implemented for {object_type} objects, "
+ f"because all their columns are in the primary key."
+ )
+ swh_storage = get_storage(
+ allow_overwrite=allow_overwrite, **swh_storage_backend_config
+ )
+
+ # directory_ls joins with content and directory table, and needs those to return
+ # non-None entries:
+ if object_type == "directory":
+ swh_storage.directory_add([StorageData.directory5])
+ swh_storage.content_add([StorageData.content, StorageData.content2])
+
+ obj1: Any
+ obj2: Any
+
+ # Get two test objects
+ if object_type == "directory":
+ (obj1, obj2, *_) = StorageData.directories
+ elif object_type == "snapshot":
+ # StorageData.snapshots[1] is the empty snapshot, which is the corner case
+ # that makes this test succeed for the wrong reasons
+ obj1 = StorageData.snapshot
+ obj2 = StorageData.complete_snapshot
+ else:
+ (obj1, obj2, *_) = getattr(StorageData, (object_type + "s"))
+
+ # Let's make both objects have the same hash, but different content
+ obj1 = attr.evolve(obj1, id=obj2.id)
+
+ # Get the methods used to add and get these objects
+ add = getattr(swh_storage, object_type + "_add")
+ if object_type == "directory":
+
+ def get(ids):
+ return [
+ Directory(
+ id=ids[0],
+ entries=tuple(
+ map(
+ lambda entry: DirectoryEntry(
+ name=entry["name"],
+ type=entry["type"],
+ target=entry["sha1_git"],
+ perms=entry["perms"],
+ ),
+ swh_storage.directory_ls(ids[0]),
+ )
+ ),
+ )
+ ]
+
+ elif object_type == "snapshot":
+
+ def get(ids):
+ return [
+ Snapshot.from_dict(
+ remove_keys(swh_storage.snapshot_get(ids[0]), ("next_branch",))
+ )
+ ]
+
+ else:
+ get = getattr(swh_storage, object_type + "_get")
+
+ # Add the first object
+ add([obj1])
+
+ # It should be returned as-is
+ assert get([obj1.id]) == [obj1]
+
+ # Add the second object
+ add([obj2])
+
+ if allow_overwrite:
+ # obj1 was overwritten by obj2
+ expected = obj2
+ else:
+ # obj2 was not written, because obj1 already exists and has the same hash
+ expected = obj1
+
+ if allow_overwrite and object_type in ("directory", "snapshot"):
+ # TODO
+ pytest.xfail(
+ "directory entries and snapshot branches are concatenated "
+ "instead of being replaced"
+ )
+ assert get([obj1.id]) == [expected]
diff --git a/swh/storage/tests/test_cassandra_migration.py b/swh/storage/tests/test_cassandra_migration.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/tests/test_cassandra_migration.py
@@ -0,0 +1,163 @@
+# Copyright (C) 2021 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+"""This module tests the migration capabilities of the Cassandra backend,
+by sending CQL commands (eg. 'ALTER TABLE'), and
+by monkey-patching large parts of the implementations to simulate code updates,."""
+
+import dataclasses
+import functools
+import operator
+from typing import Dict, Iterable, Optional
+
+import attr
+import pytest
+
+from swh.model.model import Content
+from swh.storage import get_storage
+from swh.storage.cassandra.cql import (
+ CqlRunner,
+ _prepared_insert_statement,
+ _prepared_select_statement,
+)
+from swh.storage.cassandra.model import ContentRow
+from swh.storage.cassandra.schema import CONTENT_INDEX_TEMPLATE, HASH_ALGORITHMS
+from swh.storage.cassandra.storage import CassandraStorage
+from swh.storage.exc import StorageArgumentException
+
+from .storage_data import StorageData
+from .test_cassandra import ( # noqa, needed for swh_storage fixture
+ cassandra_cluster,
+ keyspace,
+ swh_storage_backend_config,
+)
+
+
+def byte_xor_hash(data):
+ # Behold, a one-line hash function:
+ return bytes([functools.reduce(operator.xor, data)])
+
+
+@attr.s
+class ContentWithXor(Content):
+ """An hypothetical upgrade of Content with an extra "hash"."""
+
+ byte_xor = attr.ib(type=bytes, default=None)
+
+
+@dataclasses.dataclass
+class ContentRowWithXor(ContentRow):
+ """An hypothetical upgrade of ContentRow with an extra "hash"."""
+
+ byte_xor: bytes
+
+
+class CqlRunnerWithXor(CqlRunner):
+ @_prepared_select_statement(
+ ContentRowWithXor,
+ f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}",
+ )
+ def content_get_from_pk(
+ self, content_hashes: Dict[str, bytes], *, statement
+ ) -> Optional[ContentRow]:
+ rows = list(
+ self._execute_with_retries(
+ statement, [content_hashes[algo] for algo in HASH_ALGORITHMS]
+ )
+ )
+ assert len(rows) <= 1
+ if rows:
+ return ContentRowWithXor(**rows[0])
+ else:
+ return None
+
+ @_prepared_select_statement(
+ ContentRowWithXor, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) = ?"
+ )
+ def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]:
+ return map(
+ ContentRowWithXor.from_dict, self._execute_with_retries(statement, [token])
+ )
+
+ # Redecorate content_add_prepare with the new ContentRow class
+ content_add_prepare = _prepared_insert_statement(ContentRowWithXor)( # type: ignore
+ CqlRunner.content_add_prepare.__wrapped__ # type: ignore
+ )
+
+
+def test_add_content_column(
+ swh_storage: CassandraStorage, swh_storage_backend_config, mocker # noqa
+) -> None:
+ """Adds a column to the 'content' table and a new matching index.
+ This is a simple migration, as it does not require an update to the primary key.
+ """
+ content_xor_hash = byte_xor_hash(StorageData.content.data)
+
+ # First insert some existing data
+ swh_storage.content_add([StorageData.content, StorageData.content2])
+
+ # Then update the schema
+ swh_storage._cql_runner._session.execute("ALTER TABLE content ADD byte_xor blob")
+ for statement in CONTENT_INDEX_TEMPLATE.split("\n\n"):
+ swh_storage._cql_runner._session.execute(statement.format(main_algo="byte_xor"))
+
+ # Should not affect the running code at all:
+ assert swh_storage.content_get([StorageData.content.sha1]) == [
+ attr.evolve(StorageData.content, data=None)
+ ]
+ with pytest.raises(StorageArgumentException):
+ swh_storage.content_find({"byte_xor": content_xor_hash})
+
+ # Then update the running code:
+ new_hash_algos = HASH_ALGORITHMS + ["byte_xor"]
+ mocker.patch("swh.storage.cassandra.storage.HASH_ALGORITHMS", new_hash_algos)
+ mocker.patch("swh.storage.cassandra.cql.HASH_ALGORITHMS", new_hash_algos)
+ mocker.patch("swh.model.model.DEFAULT_ALGORITHMS", new_hash_algos)
+ mocker.patch("swh.storage.cassandra.storage.Content", ContentWithXor)
+ mocker.patch("swh.storage.cassandra.storage.ContentRow", ContentRowWithXor)
+ mocker.patch("swh.storage.cassandra.model.ContentRow", ContentRowWithXor)
+ mocker.patch("swh.storage.cassandra.storage.CqlRunner", CqlRunnerWithXor)
+
+ # Forge new objects with this extra hash:
+ new_content = ContentWithXor.from_dict(
+ {
+ "byte_xor": byte_xor_hash(StorageData.content.data),
+ **StorageData.content.to_dict(),
+ }
+ )
+ new_content2 = ContentWithXor.from_dict(
+ {
+ "byte_xor": byte_xor_hash(StorageData.content2.data),
+ **StorageData.content2.to_dict(),
+ }
+ )
+
+ # Simulates a restart:
+ swh_storage._set_cql_runner()
+
+ # Old algos still works, and return the new object type:
+ assert swh_storage.content_get([StorageData.content.sha1]) == [
+ attr.evolve(new_content, data=None, byte_xor=None)
+ ]
+
+ # The new algo does not work, we did not backfill it yet:
+ assert swh_storage.content_find({"byte_xor": content_xor_hash}) == []
+
+ # A normal storage would not overwrite, because the object already exists,
+ # as it is not aware it is missing a field:
+ swh_storage.content_add([new_content, new_content2])
+ assert swh_storage.content_find({"byte_xor": content_xor_hash}) == []
+
+ # Backfill (in production this would be done with a replayer reading from
+ # the journal):
+ overwriting_swh_storage = get_storage(
+ allow_overwrite=True, **swh_storage_backend_config
+ )
+ overwriting_swh_storage.content_add([new_content, new_content2])
+
+ # Now, the object can be found:
+ assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [
+ attr.evolve(new_content, data=None)
+ ]

File Metadata

Mime Type
text/plain
Expires
Dec 21 2024, 12:08 AM (11 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3234709

Event Timeline