Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7124201
D5582.id20144.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
20 KB
Subscribers
None
D5582.id20144.diff
View Options
diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -89,10 +89,45 @@
class CassandraStorage:
- def __init__(self, hosts, keyspace, objstorage, port=9042, journal_writer=None):
- self._cql_runner: CqlRunner = CqlRunner(hosts, keyspace, port)
+ def __init__(
+ self,
+ hosts,
+ keyspace,
+ objstorage,
+ port=9042,
+ journal_writer=None,
+ allow_overwrite=False,
+ ):
+ """
+ A backend of swh-storage backed by Cassandra
+
+ Args:
+ hosts: Seed Cassandra nodes, to start connecting to the cluster
+ keyspace: Name of the Cassandra database to use
+ objstorage: Passed as argument to :class:`ObjStorage`
+ port: Cassandra port
+ journal_writer: Passed as argument to :class:`JournalWriter`
+ allow_overwrite: Whether ``*_add`` functions will check if an object
+ already exists in the database before sending it in an INSERT.
+ ``False`` is the default as it is more efficient when there is
+ a moderately high probability the object is already known,
+ but ``True`` can be useful to overwrite existing objects
+ (eg. when applying a schema update),
+ or when the database is known to be mostly empty.
+ Note that a ``False`` value does not guarantee there won't be
+ any overwrite.
+ """
+ self._hosts = hosts
+ self._keyspace = keyspace
+ self._port = port
+ self._set_cql_runner()
self.journal_writer: JournalWriter = JournalWriter(journal_writer)
self.objstorage: ObjStorage = ObjStorage(objstorage)
+ self._allow_overwrite = allow_overwrite
+
+ def _set_cql_runner(self):
+ """Used by tests when they need to reset the CqlRunner"""
+ self._cql_runner: CqlRunner = CqlRunner(self._hosts, self._keyspace, self._port)
def check_config(self, *, check_write: bool) -> bool:
self._cql_runner.check_read()
@@ -119,9 +154,12 @@
def _content_add(self, contents: List[Content], with_data: bool) -> Dict:
# Filter-out content already in the database.
- contents = [
- c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict())
- ]
+ if not self._allow_overwrite:
+ contents = [
+ c
+ for c in contents
+ if not self._cql_runner.content_get_from_pk(c.to_dict())
+ ]
if with_data:
# First insert to the objstorage, if the endpoint is
@@ -150,27 +188,28 @@
# The proper way to do it would probably be a BATCH, but this
# would be inefficient because of the number of partitions we
# need to affect (len(HASH_ALGORITHMS)+1, which is currently 5)
- for algo in {"sha1", "sha1_git"}:
- collisions = []
- # Get tokens of 'content' rows with the same value for
- # sha1/sha1_git
- rows = self._content_get_from_hash(algo, content.get_hash(algo))
- for row in rows:
- if getattr(row, algo) != content.get_hash(algo):
- # collision of token(partition key), ignore this
- # row
- continue
-
- for other_algo in HASH_ALGORITHMS:
- if getattr(row, other_algo) != content.get_hash(other_algo):
- # This hash didn't match; discard the row.
- collisions.append(
- {k: getattr(row, k) for k in HASH_ALGORITHMS}
- )
-
- if collisions:
- collisions.append(content.hashes())
- raise HashCollision(algo, content.get_hash(algo), collisions)
+ if not self._allow_overwrite:
+ for algo in {"sha1", "sha1_git"}:
+ collisions = []
+ # Get tokens of 'content' rows with the same value for
+ # sha1/sha1_git
+ rows = self._content_get_from_hash(algo, content.get_hash(algo))
+ for row in rows:
+ if getattr(row, algo) != content.get_hash(algo):
+ # collision of token(partition key), ignore this
+ # row
+ continue
+
+ for other_algo in HASH_ALGORITHMS:
+ if getattr(row, other_algo) != content.get_hash(other_algo):
+ # This hash didn't match; discard the row.
+ collisions.append(
+ {k: getattr(row, k) for k in HASH_ALGORITHMS}
+ )
+
+ if collisions:
+ collisions.append(content.hashes())
+ raise HashCollision(algo, content.get_hash(algo), collisions)
(token, insertion_finalizer) = self._cql_runner.content_add_prepare(
ContentRow(**remove_keys(content.to_dict(), ("data",)))
@@ -323,11 +362,12 @@
def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict:
# Filter-out content already in the database.
- contents = [
- c
- for c in contents
- if not self._cql_runner.skipped_content_get_from_pk(c.to_dict())
- ]
+ if not self._allow_overwrite:
+ contents = [
+ c
+ for c in contents
+ if not self._cql_runner.skipped_content_get_from_pk(c.to_dict())
+ ]
self.journal_writer.skipped_content_add(contents)
@@ -359,9 +399,10 @@
def directory_add(self, directories: List[Directory]) -> Dict:
to_add = {d.id: d for d in directories}.values()
- # Filter out directories that are already inserted.
- missing = self.directory_missing([dir_.id for dir_ in to_add])
- directories = [dir_ for dir_ in directories if dir_.id in missing]
+ if not self._allow_overwrite:
+ # Filter out directories that are already inserted.
+ missing = self.directory_missing([dir_.id for dir_ in to_add])
+ directories = [dir_ for dir_ in directories if dir_.id in missing]
self.journal_writer.directory_add(directories)
@@ -484,9 +525,10 @@
def revision_add(self, revisions: List[Revision]) -> Dict:
# Filter-out revisions already in the database
- to_add = {r.id: r for r in revisions}.values()
- missing = self.revision_missing([rev.id for rev in to_add])
- revisions = [rev for rev in revisions if rev.id in missing]
+ if not self._allow_overwrite:
+ to_add = {r.id: r for r in revisions}.values()
+ missing = self.revision_missing([rev.id for rev in to_add])
+ revisions = [rev for rev in revisions if rev.id in missing]
self.journal_writer.revision_add(revisions)
for revision in revisions:
@@ -595,9 +637,10 @@
return revision.id
def release_add(self, releases: List[Release]) -> Dict:
- to_add = {r.id: r for r in releases}.values()
- missing = set(self.release_missing([rel.id for rel in to_add]))
- releases = [rel for rel in to_add if rel.id in missing]
+ if not self._allow_overwrite:
+ to_add = {r.id: r for r in releases}.values()
+ missing = set(self.release_missing([rel.id for rel in to_add]))
+ releases = [rel for rel in to_add if rel.id in missing]
self.journal_writer.release_add(releases)
for release in releases:
@@ -624,9 +667,10 @@
return release.id
def snapshot_add(self, snapshots: List[Snapshot]) -> Dict:
- to_add = {s.id: s for s in snapshots}.values()
- missing = self._cql_runner.snapshot_missing([snp.id for snp in to_add])
- snapshots = [snp for snp in snapshots if snp.id in missing]
+ if not self._allow_overwrite:
+ to_add = {s.id: s for s in snapshots}.values()
+ missing = self._cql_runner.snapshot_missing([snp.id for snp in to_add])
+ snapshots = [snp for snp in snapshots if snp.id in missing]
for snapshot in snapshots:
self.journal_writer.snapshot_add([snapshot])
@@ -895,8 +939,9 @@
)
def origin_add(self, origins: List[Origin]) -> Dict[str, int]:
- to_add = {o.url: o for o in origins}.values()
- origins = [ori for ori in to_add if self.origin_get_one(ori.url) is None]
+ if not self._allow_overwrite:
+ to_add = {o.url: o for o in origins}.values()
+ origins = [ori for ori in to_add if self.origin_get_one(ori.url) is None]
self.journal_writer.origin_add(origins)
for origin in origins:
@@ -1343,13 +1388,16 @@
# ExtID tables
def extid_add(self, ids: List[ExtID]) -> Dict[str, int]:
- extids = [
- extid
- for extid in ids
- if not self._cql_runner.extid_get_from_pk(
- extid_type=extid.extid_type, extid=extid.extid, target=extid.target,
- )
- ]
+ if not self._allow_overwrite:
+ extids = [
+ extid
+ for extid in ids
+ if not self._cql_runner.extid_get_from_pk(
+ extid_type=extid.extid_type, extid=extid.extid, target=extid.target,
+ )
+ ]
+ else:
+ extids = list(ids)
self.journal_writer.extid_add(extids)
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -678,6 +678,7 @@
def __init__(self, journal_writer=None):
self.reset()
self.journal_writer = JournalWriter(journal_writer)
+ self._allow_overwrite = False
def reset(self):
self._cql_runner = InMemoryCqlRunner()
diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py
--- a/swh/storage/tests/test_cassandra.py
+++ b/swh/storage/tests/test_cassandra.py
@@ -4,26 +4,29 @@
# See top-level LICENSE file for more information
import datetime
+import itertools
import os
import signal
import socket
import subprocess
import time
-from typing import Dict
+from typing import Any, Dict
import attr
import pytest
from swh.core.api.classes import stream_results
+from swh.model.model import Directory, DirectoryEntry, Snapshot
from swh.storage import get_storage
from swh.storage.cassandra import create_keyspace
from swh.storage.cassandra.model import ContentRow, ExtIDRow
from swh.storage.cassandra.schema import HASH_ALGORITHMS, TABLES
+from swh.storage.tests.storage_data import StorageData
from swh.storage.tests.storage_tests import (
TestStorageGeneratedData as _TestStorageGeneratedData,
)
from swh.storage.tests.storage_tests import TestStorage as _TestStorage
-from swh.storage.utils import now
+from swh.storage.utils import now, remove_keys
CONFIG_TEMPLATE = """
data_file_directories:
@@ -448,3 +451,105 @@
@pytest.mark.skip("Not supported by Cassandra")
def test_origin_count_with_visit_with_visits_no_snapshot(self):
pass
+
+
+@pytest.mark.parametrize(
+ "allow_overwrite,object_type",
+ itertools.product(
+ [False, True],
+ # Note the absence of "content", it's tested above.
+ ["directory", "revision", "release", "snapshot", "origin", "extid"],
+ ),
+)
+def test_allow_overwrite(
+ allow_overwrite: bool, object_type: str, swh_storage_backend_config
+):
+ if object_type in ("origin", "extid"):
+ pytest.skip(
+ f"test_disallow_overwrite not implemented for {object_type} objects, "
+ f"because all their columns are in the primary key."
+ )
+ swh_storage = get_storage(
+ allow_overwrite=allow_overwrite, **swh_storage_backend_config
+ )
+
+ # directory_ls joins with content and directory table, and needs those to return
+ # non-None entries:
+ if object_type == "directory":
+ swh_storage.directory_add([StorageData.directory5])
+ swh_storage.content_add([StorageData.content, StorageData.content2])
+
+ obj1: Any
+ obj2: Any
+
+ # Get two test objects
+ if object_type == "directory":
+ (obj1, obj2, *_) = StorageData.directories
+ elif object_type == "snapshot":
+ # StorageData.snapshots[1] is the empty snapshot, which is the corner case
+ # that makes this test succeed for the wrong reasons
+ obj1 = StorageData.snapshot
+ obj2 = StorageData.complete_snapshot
+ else:
+ (obj1, obj2, *_) = getattr(StorageData, (object_type + "s"))
+
+ # Let's make both objects have the same hash, but different content
+ obj1 = attr.evolve(obj1, id=obj2.id)
+
+ # Get the methods used to add and get these objects
+ add = getattr(swh_storage, object_type + "_add")
+ if object_type == "directory":
+
+ def get(ids):
+ return [
+ Directory(
+ id=ids[0],
+ entries=tuple(
+ map(
+ lambda entry: DirectoryEntry(
+ name=entry["name"],
+ type=entry["type"],
+ target=entry["sha1_git"],
+ perms=entry["perms"],
+ ),
+ swh_storage.directory_ls(ids[0]),
+ )
+ ),
+ )
+ ]
+
+ elif object_type == "snapshot":
+
+ def get(ids):
+ return [
+ Snapshot.from_dict(
+ remove_keys(swh_storage.snapshot_get(ids[0]), ("next_branch",))
+ )
+ ]
+
+ else:
+ get = getattr(swh_storage, object_type + "_get")
+
+ # Add the first object
+ add([obj1])
+
+ # It should be returned as-is
+ assert get([obj1.id]) == [obj1]
+
+ # Add the second object
+ add([obj2])
+
+ if allow_overwrite:
+ # obj1 was overwritten by obj2
+ expected = obj2
+ else:
+ # obj2 was not written, because obj1 already exists and has the same hash
+ expected = obj1
+
+ if allow_overwrite and object_type in ("directory", "snapshot"):
+ # TODO
+ pytest.xfail(
+ "directory entries and snapshot branches are concatenated "
+ "instead of being replaced"
+ )
+ assert get([obj1.id]) == [expected]
diff --git a/swh/storage/tests/test_cassandra_migration.py b/swh/storage/tests/test_cassandra_migration.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/tests/test_cassandra_migration.py
@@ -0,0 +1,163 @@
+# Copyright (C) 2021 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+"""This module tests the migration capabilities of the Cassandra backend,
+by sending CQL commands (eg. 'ALTER TABLE'), and
+by monkey-patching large parts of the implementations to simulate code updates,."""
+
+import dataclasses
+import functools
+import operator
+from typing import Dict, Iterable, Optional
+
+import attr
+import pytest
+
+from swh.model.model import Content
+from swh.storage import get_storage
+from swh.storage.cassandra.cql import (
+ CqlRunner,
+ _prepared_insert_statement,
+ _prepared_select_statement,
+)
+from swh.storage.cassandra.model import ContentRow
+from swh.storage.cassandra.schema import CONTENT_INDEX_TEMPLATE, HASH_ALGORITHMS
+from swh.storage.cassandra.storage import CassandraStorage
+from swh.storage.exc import StorageArgumentException
+
+from .storage_data import StorageData
+from .test_cassandra import ( # noqa, needed for swh_storage fixture
+ cassandra_cluster,
+ keyspace,
+ swh_storage_backend_config,
+)
+
+
+def byte_xor_hash(data):
+ # Behold, a one-line hash function:
+ return bytes([functools.reduce(operator.xor, data)])
+
+
+@attr.s
+class ContentWithXor(Content):
+ """An hypothetical upgrade of Content with an extra "hash"."""
+
+ byte_xor = attr.ib(type=bytes, default=None)
+
+
+@dataclasses.dataclass
+class ContentRowWithXor(ContentRow):
+ """An hypothetical upgrade of ContentRow with an extra "hash"."""
+
+ byte_xor: bytes
+
+
+class CqlRunnerWithXor(CqlRunner):
+ @_prepared_select_statement(
+ ContentRowWithXor,
+ f"WHERE {' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS))}",
+ )
+ def content_get_from_pk(
+ self, content_hashes: Dict[str, bytes], *, statement
+ ) -> Optional[ContentRow]:
+ rows = list(
+ self._execute_with_retries(
+ statement, [content_hashes[algo] for algo in HASH_ALGORITHMS]
+ )
+ )
+ assert len(rows) <= 1
+ if rows:
+ return ContentRowWithXor(**rows[0])
+ else:
+ return None
+
+ @_prepared_select_statement(
+ ContentRowWithXor, f"WHERE token({', '.join(ContentRow.PARTITION_KEY)}) = ?"
+ )
+ def content_get_from_token(self, token, *, statement) -> Iterable[ContentRow]:
+ return map(
+ ContentRowWithXor.from_dict, self._execute_with_retries(statement, [token])
+ )
+
+ # Redecorate content_add_prepare with the new ContentRow class
+ content_add_prepare = _prepared_insert_statement(ContentRowWithXor)( # type: ignore
+ CqlRunner.content_add_prepare.__wrapped__ # type: ignore
+ )
+
+
+def test_add_content_column(
+ swh_storage: CassandraStorage, swh_storage_backend_config, mocker # noqa
+) -> None:
+ """Adds a column to the 'content' table and a new matching index.
+ This is a simple migration, as it does not require an update to the primary key.
+ """
+ content_xor_hash = byte_xor_hash(StorageData.content.data)
+
+ # First insert some existing data
+ swh_storage.content_add([StorageData.content, StorageData.content2])
+
+ # Then update the schema
+ swh_storage._cql_runner._session.execute("ALTER TABLE content ADD byte_xor blob")
+ for statement in CONTENT_INDEX_TEMPLATE.split("\n\n"):
+ swh_storage._cql_runner._session.execute(statement.format(main_algo="byte_xor"))
+
+ # Should not affect the running code at all:
+ assert swh_storage.content_get([StorageData.content.sha1]) == [
+ attr.evolve(StorageData.content, data=None)
+ ]
+ with pytest.raises(StorageArgumentException):
+ swh_storage.content_find({"byte_xor": content_xor_hash})
+
+ # Then update the running code:
+ new_hash_algos = HASH_ALGORITHMS + ["byte_xor"]
+ mocker.patch("swh.storage.cassandra.storage.HASH_ALGORITHMS", new_hash_algos)
+ mocker.patch("swh.storage.cassandra.cql.HASH_ALGORITHMS", new_hash_algos)
+ mocker.patch("swh.model.model.DEFAULT_ALGORITHMS", new_hash_algos)
+ mocker.patch("swh.storage.cassandra.storage.Content", ContentWithXor)
+ mocker.patch("swh.storage.cassandra.storage.ContentRow", ContentRowWithXor)
+ mocker.patch("swh.storage.cassandra.model.ContentRow", ContentRowWithXor)
+ mocker.patch("swh.storage.cassandra.storage.CqlRunner", CqlRunnerWithXor)
+
+ # Forge new objects with this extra hash:
+ new_content = ContentWithXor.from_dict(
+ {
+ "byte_xor": byte_xor_hash(StorageData.content.data),
+ **StorageData.content.to_dict(),
+ }
+ )
+ new_content2 = ContentWithXor.from_dict(
+ {
+ "byte_xor": byte_xor_hash(StorageData.content2.data),
+ **StorageData.content2.to_dict(),
+ }
+ )
+
+ # Simulates a restart:
+ swh_storage._set_cql_runner()
+
+ # Old algos still works, and return the new object type:
+ assert swh_storage.content_get([StorageData.content.sha1]) == [
+ attr.evolve(new_content, data=None, byte_xor=None)
+ ]
+
+ # The new algo does not work, we did not backfill it yet:
+ assert swh_storage.content_find({"byte_xor": content_xor_hash}) == []
+
+ # A normal storage would not overwrite, because the object already exists,
+ # as it is not aware it is missing a field:
+ swh_storage.content_add([new_content, new_content2])
+ assert swh_storage.content_find({"byte_xor": content_xor_hash}) == []
+
+ # Backfill (in production this would be done with a replayer reading from
+ # the journal):
+ overwriting_swh_storage = get_storage(
+ allow_overwrite=True, **swh_storage_backend_config
+ )
+ overwriting_swh_storage.content_add([new_content, new_content2])
+
+ # Now, the object can be found:
+ assert swh_storage.content_find({"byte_xor": content_xor_hash}) == [
+ attr.evolve(new_content, data=None)
+ ]
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Dec 21 2024, 12:08 AM (11 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3234709
Attached To
D5582: cassandra: Add 'allow_overwrite' option, to allow updating objects
Event Timeline
Log In to Comment