diff --git a/mypy.ini b/mypy.ini
--- a/mypy.ini
+++ b/mypy.ini
@@ -11,6 +11,9 @@
# 3rd party libraries without stubs (yet)
+[mypy-cassandra.*]
+ignore_missing_imports = True
+
# only shipped indirectly via hypothesis
[mypy-django.*]
ignore_missing_imports = True
diff --git a/pytest.ini b/pytest.ini
--- a/pytest.ini
+++ b/pytest.ini
@@ -4,5 +4,6 @@
ignore:.*the imp module.*:PendingDeprecationWarning
markers =
db: execute tests using a postgresql database
+ cassandra: execute tests using a cassandra database (which are slow)
property_based: execute tests generating data with hypothesis (potentially long run time)
network: execute tests using a socket between two threads
diff --git a/requirements-test.txt b/requirements-test.txt
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -8,3 +8,4 @@
# adding the [testing] extra.
swh.model[testing] >= 0.0.50
pytz
+pytest-xdist
diff --git a/requirements.txt b/requirements.txt
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,3 +5,4 @@
vcversioner
aiohttp
tenacity
+cassandra-driver >= 3.19.0, < 3.21.0
diff --git a/swh/storage/__init__.py b/swh/storage/__init__.py
--- a/swh/storage/__init__.py
+++ b/swh/storage/__init__.py
@@ -5,10 +5,6 @@
import warnings
-from . import storage
-
-Storage = storage.Storage
-
class HashCollision(Exception):
pass
@@ -16,6 +12,7 @@
STORAGE_IMPLEMENTATION = {
'pipeline', 'local', 'remote', 'memory', 'filter', 'buffer', 'retry',
+ 'cassandra',
}
@@ -53,6 +50,8 @@
from .api.client import RemoteStorage as Storage
elif cls == 'local':
from .storage import Storage
+ elif cls == 'cassandra':
+ from .cassandra import CassandraStorage as Storage
elif cls == 'memory':
from .in_memory import InMemoryStorage as Storage
elif cls == 'filter':
diff --git a/swh/storage/cassandra/__init__.py b/swh/storage/cassandra/__init__.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/cassandra/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (C) 2019-2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+
+from .cql import create_keyspace
+from .storage import CassandraStorage
+
+
+__all__ = ['create_keyspace', 'CassandraStorage']
diff --git a/swh/storage/cassandra/common.py b/swh/storage/cassandra/common.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/cassandra/common.py
@@ -0,0 +1,20 @@
+# Copyright (C) 2019-2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+
+import hashlib
+
+
+Row = tuple
+
+
+TOKEN_BEGIN = -(2**63)
+'''Minimum value returned by the CQL function token()'''
+TOKEN_END = 2**63-1
+'''Maximum value returned by the CQL function token()'''
+
+
+def hash_url(url: str) -> bytes:
+ return hashlib.sha1(url.encode('ascii')).digest()
diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/cassandra/converters.py
@@ -0,0 +1,72 @@
+# Copyright (C) 2019-2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+
+import json
+from typing import Any, Dict
+
+import attr
+
+from swh.model.model import (
+ RevisionType, ObjectType, Revision, Release,
+)
+
+
+from ..converters import git_headers_to_db, db_to_git_headers
+
+
+def revision_to_db(revision: Dict[str, Any]) -> Revision:
+ metadata = revision.get('metadata')
+ if metadata and 'extra_headers' in metadata:
+ extra_headers = git_headers_to_db(
+ metadata['extra_headers'])
+ revision = {
+ **revision,
+ 'metadata': {
+ **metadata,
+ 'extra_headers': extra_headers
+ }
+ }
+
+ rev = Revision.from_dict(revision)
+ rev = attr.evolve(
+ rev,
+ type=rev.type.value,
+ metadata=json.dumps(rev.metadata),
+ )
+
+ return rev
+
+
+def revision_from_db(revision) -> Revision:
+ metadata = json.loads(revision.metadata)
+ if metadata and 'extra_headers' in metadata:
+ extra_headers = db_to_git_headers(
+ metadata['extra_headers'])
+ metadata['extra_headers'] = extra_headers
+ rev = attr.evolve(
+ revision,
+ type=RevisionType(revision.type),
+ metadata=metadata,
+ )
+
+ return rev
+
+
+def release_to_db(release: Dict[str, Any]) -> Release:
+ rel = Release.from_dict(release)
+ rel = attr.evolve(
+ rel,
+ target_type=rel.target_type.value,
+ )
+ return rel
+
+
+def release_from_db(release: Release) -> Release:
+ release = attr.evolve(
+ release,
+ target_type=ObjectType(release.target_type),
+ )
+ return release
diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/cassandra/cql.py
@@ -0,0 +1,619 @@
+# Copyright (C) 2019-2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+import functools
+import json
+import logging
+import random
+from typing import (
+ Any, Callable, Dict, Generator, Iterable, List, Optional, TypeVar
+)
+
+from cassandra import WriteFailure, WriteTimeout, ReadFailure, ReadTimeout
+from cassandra.cluster import (
+ Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile, ResultSet)
+from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy
+from cassandra.query import PreparedStatement
+
+from swh.model.model import (
+ Sha1Git, TimestampWithTimezone, Timestamp, Person, Content,
+)
+
+from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url
+from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS
+
+
+logger = logging.getLogger(__name__)
+
+
+_execution_profiles = {
+ EXEC_PROFILE_DEFAULT: ExecutionProfile(
+ load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy())),
+}
+
+
+def create_keyspace(hosts: List[str], keyspace: str, port: int = 9042):
+ cluster = Cluster(
+ hosts, port=port, execution_profiles=_execution_profiles)
+ session = cluster.connect()
+ session.execute('''CREATE KEYSPACE IF NOT EXISTS "%s"
+ WITH REPLICATION = {
+ 'class' : 'SimpleStrategy',
+ 'replication_factor' : 1
+ };
+ ''' % keyspace)
+ session.execute('USE "%s"' % keyspace)
+ for query in CREATE_TABLES_QUERIES:
+ session.execute(query)
+
+
+T = TypeVar('T')
+
+
+def _prepared_statement(
+ query: str) -> Callable[[Callable[..., T]], Callable[..., T]]:
+ """Returns a decorator usable on methods of CqlRunner, to
+ inject them with a 'statement' argument, that is a prepared
+ statement corresponding to the query.
+
+ This only works on methods of CqlRunner, as preparing a
+ statement requires a connection to a Cassandra server."""
+ def decorator(f):
+ @functools.wraps(f)
+ def newf(self, *args, **kwargs) -> T:
+ if f.__name__ not in self._prepared_statements:
+ statement: PreparedStatement = self._session.prepare(query)
+ self._prepared_statements[f.__name__] = statement
+ return f(self, *args, **kwargs,
+ statement=self._prepared_statements[f.__name__])
+ return newf
+ return decorator
+
+
+def _prepared_insert_statement(table_name: str, columns: List[str]):
+ """Shorthand for using `_prepared_statement` for `INSERT INTO`
+ statements."""
+ return _prepared_statement(
+ 'INSERT INTO %s (%s) VALUES (%s)' % (
+ table_name,
+ ', '.join(columns), ', '.join('?' for _ in columns),
+ )
+ )
+
+
+def _prepared_exists_statement(table_name: str):
+ """Shorthand for using `_prepared_statement` for queries that only
+ check which ids in a list exist in the table."""
+ return _prepared_statement(f'SELECT id FROM {table_name} WHERE id IN ?')
+
+
+class CqlRunner:
+ def __init__(self, hosts: List[str], keyspace: str, port: int):
+ self._cluster = Cluster(
+ hosts, port=port, execution_profiles=_execution_profiles)
+ self._session = self._cluster.connect(keyspace)
+ self._cluster.register_user_type(
+ keyspace, 'microtimestamp_with_timezone', TimestampWithTimezone)
+ self._cluster.register_user_type(
+ keyspace, 'microtimestamp', Timestamp)
+ self._cluster.register_user_type(
+ keyspace, 'person', Person)
+
+ self._prepared_statements: Dict[str, PreparedStatement] = {}
+
+ ##########################
+ # Common utility functions
+ ##########################
+
+ MAX_RETRIES = 3
+
+ def _execute_and_retry(self, statement, args) -> ResultSet:
+ for nb_retries in range(self.MAX_RETRIES):
+ try:
+ return self._session.execute(statement, args, timeout=100.)
+ except (WriteFailure, WriteTimeout) as e:
+ logger.error('Failed to write object to cassandra: %r', e)
+ if nb_retries == self.MAX_RETRIES-1:
+ raise e
+ except (ReadFailure, ReadTimeout) as e:
+ logger.error('Failed to read object(s) to cassandra: %r', e)
+ if nb_retries == self.MAX_RETRIES-1:
+ raise e
+
+ @_prepared_statement('UPDATE object_count SET count = count + ? '
+ 'WHERE partition_key = 0 AND object_type = ?')
+ def _increment_counter(
+ self, object_type: str, nb: int, *, statement: PreparedStatement
+ ) -> None:
+ self._execute_and_retry(statement, [nb, object_type])
+
+ def _add_one(
+ self, statement, object_type: str, obj, keys: List[str]
+ ) -> None:
+ self._increment_counter(object_type, 1)
+ self._execute_and_retry(
+ statement, [getattr(obj, key) for key in keys])
+
+ def _get_random_row(self, statement) -> Optional[Row]:
+ '''Takes a prepared stement of the form
+ "SELECT * FROM
WHERE token() > ? LIMIT 1"
+ and uses it to return a random row'''
+ token = random.randint(TOKEN_BEGIN, TOKEN_END)
+ rows = self._execute_and_retry(statement, [token])
+ if not rows:
+ # There are no row with a greater token; wrap around to get
+ # the row with the smallest token
+ rows = self._execute_and_retry(statement, [TOKEN_BEGIN])
+ if rows:
+ return rows.one()
+ else:
+ return None
+
+ def _missing(self, statement, ids):
+ res = self._execute_and_retry(statement, [ids])
+ found_ids = {id_ for (id_,) in res}
+ return [id_ for id_ in ids if id_ not in found_ids]
+
+ ##########################
+ # 'content' table
+ ##########################
+
+ _content_pk = ['sha1', 'sha1_git', 'sha256', 'blake2s256']
+ _content_keys = [
+ 'sha1', 'sha1_git', 'sha256', 'blake2s256', 'length',
+ 'ctime', 'status']
+
+ @_prepared_statement('SELECT * FROM content WHERE ' +
+ ' AND '.join(map('%s = ?'.__mod__, HASH_ALGORITHMS)))
+ def content_get_from_pk(
+ self, content_hashes: Dict[str, bytes], *, statement
+ ) -> Optional[Row]:
+ rows = list(self._execute_and_retry(
+ statement, [content_hashes[algo] for algo in HASH_ALGORITHMS]))
+ assert len(rows) <= 1
+ if rows:
+ return rows[0]
+ else:
+ return None
+
+ @_prepared_statement('SELECT * FROM content WHERE token(%s) > ? LIMIT 1'
+ % ', '.join(_content_pk))
+ def content_get_random(self, *, statement) -> Optional[Row]:
+ return self._get_random_row(statement)
+
+ @_prepared_statement(('SELECT token({0}) AS tok, {1} FROM content '
+ 'WHERE token({0}) >= ? AND token({0}) <= ? LIMIT ?')
+ .format(', '.join(_content_pk),
+ ', '.join(_content_keys)))
+ def content_get_token_range(
+ self, start: int, end: int, limit: int, *, statement) -> Row:
+ return self._execute_and_retry(statement, [start, end, limit])
+
+ ##########################
+ # 'content_by_*' tables
+ ##########################
+
+ @_prepared_statement('SELECT sha1_git FROM content_by_sha1_git '
+ 'WHERE sha1_git IN ?')
+ def content_missing_by_sha1_git(
+ self, ids: List[bytes], *, statement) -> List[bytes]:
+ return self._missing(statement, ids)
+
+ @_prepared_insert_statement('content', _content_keys)
+ def content_add_one(self, content, *, statement) -> None:
+ self._add_one(statement, 'content', content, self._content_keys)
+
+ def content_index_add_one(self, main_algo: str, content: Content) -> None:
+ query = 'INSERT INTO content_by_{algo} ({cols}) VALUES ({values})' \
+ .format(algo=main_algo, cols=', '.join(self._content_pk),
+ values=', '.join('%s' for _ in self._content_pk))
+ self._execute_and_retry(
+ query, [content.get_hash(algo) for algo in self._content_pk])
+
+ def content_get_pks_from_single_hash(
+ self, algo: str, hash_: bytes) -> List[Row]:
+ assert algo in HASH_ALGORITHMS
+ query = 'SELECT * FROM content_by_{algo} WHERE {algo} = %s'.format(
+ algo=algo)
+ return list(self._execute_and_retry(query, [hash_]))
+
+ ##########################
+ # 'revision' table
+ ##########################
+
+ _revision_keys = [
+ 'id', 'date', 'committer_date', 'type', 'directory', 'message',
+ 'author', 'committer',
+ 'synthetic', 'metadata']
+
+ @_prepared_exists_statement('revision')
+ def revision_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
+ return self._missing(statement, ids)
+
+ @_prepared_insert_statement('revision', _revision_keys)
+ def revision_add_one(self, revision: Dict[str, Any], *, statement) -> None:
+ self._add_one(statement, 'revision', revision, self._revision_keys)
+
+ @_prepared_statement('SELECT id FROM revision WHERE id IN ?')
+ def revision_get_ids(self, revision_ids, *, statement) -> ResultSet:
+ return self._execute_and_retry(
+ statement, [revision_ids])
+
+ @_prepared_statement('SELECT * FROM revision WHERE id IN ?')
+ def revision_get(self, revision_ids, *, statement) -> ResultSet:
+ return self._execute_and_retry(
+ statement, [revision_ids])
+
+ @_prepared_statement('SELECT * FROM revision WHERE token(id) > ? LIMIT 1')
+ def revision_get_random(self, *, statement) -> Optional[Row]:
+ return self._get_random_row(statement)
+
+ ##########################
+ # 'revision_parent' table
+ ##########################
+
+ _revision_parent_keys = ['id', 'parent_rank', 'parent_id']
+
+ @_prepared_insert_statement('revision_parent', _revision_parent_keys)
+ def revision_parent_add_one(
+ self, id_: Sha1Git, parent_rank: int, parent_id: Sha1Git, *,
+ statement) -> None:
+ self._execute_and_retry(
+ statement, [id_, parent_rank, parent_id])
+
+ @_prepared_statement('SELECT parent_id FROM revision_parent WHERE id = ?')
+ def revision_parent_get(
+ self, revision_id: Sha1Git, *, statement) -> ResultSet:
+ return self._execute_and_retry(
+ statement, [revision_id])
+
+ ##########################
+ # 'release' table
+ ##########################
+
+ _release_keys = [
+ 'id', 'target', 'target_type', 'date', 'name', 'message', 'author',
+ 'synthetic']
+
+ @_prepared_exists_statement('release')
+ def release_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
+ return self._missing(statement, ids)
+
+ @_prepared_insert_statement('release', _release_keys)
+ def release_add_one(self, release: Dict[str, Any], *, statement) -> None:
+ self._add_one(statement, 'release', release, self._release_keys)
+
+ @_prepared_statement('SELECT * FROM release WHERE id in ?')
+ def release_get(self, release_ids: List[str], *, statement) -> None:
+ return self._execute_and_retry(statement, [release_ids])
+
+ @_prepared_statement('SELECT * FROM release WHERE token(id) > ? LIMIT 1')
+ def release_get_random(self, *, statement) -> Optional[Row]:
+ return self._get_random_row(statement)
+
+ ##########################
+ # 'directory' table
+ ##########################
+
+ _directory_keys = ['id']
+
+ @_prepared_exists_statement('directory')
+ def directory_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
+ return self._missing(statement, ids)
+
+ @_prepared_insert_statement('directory', _directory_keys)
+ def directory_add_one(self, directory_id: Sha1Git, *, statement) -> None:
+ """Called after all calls to directory_entry_add_one, to
+ commit/finalize the directory."""
+ self._execute_and_retry(statement, [directory_id])
+ self._increment_counter('directory', 1)
+
+ @_prepared_statement('SELECT * FROM directory WHERE token(id) > ? LIMIT 1')
+ def directory_get_random(self, *, statement) -> Optional[Row]:
+ return self._get_random_row(statement)
+
+ ##########################
+ # 'directory_entry' table
+ ##########################
+
+ _directory_entry_keys = ['directory_id', 'name', 'type', 'target', 'perms']
+
+ @_prepared_insert_statement('directory_entry', _directory_entry_keys)
+ def directory_entry_add_one(
+ self, entry: Dict[str, Any], *, statement) -> None:
+ self._execute_and_retry(
+ statement, [entry[key] for key in self._directory_entry_keys])
+
+ @_prepared_statement('SELECT * FROM directory_entry '
+ 'WHERE directory_id IN ?')
+ def directory_entry_get(self, directory_ids, *, statement) -> ResultSet:
+ return self._execute_and_retry(
+ statement, [directory_ids])
+
+ ##########################
+ # 'snapshot' table
+ ##########################
+
+ _snapshot_keys = ['id']
+
+ @_prepared_exists_statement('snapshot')
+ def snapshot_missing(self, ids: List[bytes], *, statement) -> List[bytes]:
+ return self._missing(statement, ids)
+
+ @_prepared_insert_statement('snapshot', _snapshot_keys)
+ def snapshot_add_one(self, snapshot_id: Sha1Git, *, statement) -> None:
+ self._execute_and_retry(statement, [snapshot_id])
+ self._increment_counter('snapshot', 1)
+
+ @_prepared_statement('SELECT * FROM snapshot '
+ 'WHERE id = ?')
+ def snapshot_get(self, snapshot_id: Sha1Git, *, statement) -> ResultSet:
+ return self._execute_and_retry(statement, [snapshot_id])
+
+ @_prepared_statement('SELECT * FROM snapshot WHERE token(id) > ? LIMIT 1')
+ def snapshot_get_random(self, *, statement) -> Optional[Row]:
+ return self._get_random_row(statement)
+
+ ##########################
+ # 'snapshot_branch' table
+ ##########################
+
+ _snapshot_branch_keys = ['snapshot_id', 'name', 'target_type', 'target']
+
+ @_prepared_insert_statement('snapshot_branch', _snapshot_branch_keys)
+ def snapshot_branch_add_one(
+ self, branch: Dict[str, Any], *, statement) -> None:
+ self._execute_and_retry(
+ statement, [branch[key] for key in self._snapshot_branch_keys])
+
+ @_prepared_statement('SELECT ascii_bins_count(target_type) AS counts '
+ 'FROM snapshot_branch '
+ 'WHERE snapshot_id = ? ')
+ def snapshot_count_branches(
+ self, snapshot_id: Sha1Git, *, statement) -> ResultSet:
+ return self._execute_and_retry(statement, [snapshot_id])
+
+ @_prepared_statement('SELECT * FROM snapshot_branch '
+ 'WHERE snapshot_id = ? AND name >= ?'
+ 'LIMIT ?')
+ def snapshot_branch_get(
+ self, snapshot_id: Sha1Git, from_: bytes, limit: int, *,
+ statement) -> None:
+ return self._execute_and_retry(statement, [snapshot_id, from_, limit])
+
+ ##########################
+ # 'origin' table
+ ##########################
+
+ origin_keys = ['sha1', 'url', 'type', 'next_visit_id']
+
+ @_prepared_statement('INSERT INTO origin (sha1, url, next_visit_id) '
+ 'VALUES (?, ?, 1) IF NOT EXISTS')
+ def origin_add_one(self, origin: Dict[str, Any], *, statement) -> None:
+ self._execute_and_retry(
+ statement, [hash_url(origin['url']), origin['url']])
+ self._increment_counter('origin', 1)
+
+ @_prepared_statement('SELECT * FROM origin WHERE sha1 = ?')
+ def origin_get_by_sha1(self, sha1: bytes, *, statement) -> ResultSet:
+ return self._execute_and_retry(statement, [sha1])
+
+ def origin_get_by_url(self, url: str) -> ResultSet:
+ return self.origin_get_by_sha1(hash_url(url))
+
+ @_prepared_statement(
+ f'SELECT token(sha1) AS tok, {", ".join(origin_keys)} '
+ f'FROM origin WHERE token(sha1) >= ? LIMIT ?')
+ def origin_list(
+ self, start_token: int, limit: int, *, statement) -> ResultSet:
+ return self._execute_and_retry(
+ statement, [start_token, limit])
+
+ @_prepared_statement('SELECT * FROM origin')
+ def origin_iter_all(self, *, statement) -> ResultSet:
+ return self._execute_and_retry(statement, [])
+
+ @_prepared_statement('SELECT next_visit_id FROM origin WHERE sha1 = ?')
+ def _origin_get_next_visit_id(
+ self, origin_sha1: bytes, *, statement) -> int:
+ rows = list(self._execute_and_retry(statement, [origin_sha1]))
+ assert len(rows) == 1 # TODO: error handling
+ return rows[0].next_visit_id
+
+ @_prepared_statement('UPDATE origin SET next_visit_id=? '
+ 'WHERE sha1 = ? IF next_visit_id=?')
+ def origin_generate_unique_visit_id(
+ self, origin_url: str, *, statement) -> int:
+ origin_sha1 = hash_url(origin_url)
+ next_id = self._origin_get_next_visit_id(origin_sha1)
+ while True:
+ res = list(self._execute_and_retry(
+ statement, [next_id+1, origin_sha1, next_id]))
+ assert len(res) == 1
+ if res[0].applied:
+ # No data race
+ return next_id
+ else:
+ # Someone else updated it before we did, let's try again
+ next_id = res[0].next_visit_id
+ # TODO: abort after too many attempts
+
+ return next_id
+
+ ##########################
+ # 'origin_visit' table
+ ##########################
+
+ _origin_visit_keys = [
+ 'origin', 'visit', 'type', 'date', 'status', 'metadata', 'snapshot']
+ _origin_visit_update_keys = [
+ 'type', 'date', 'status', 'metadata', 'snapshot']
+
+ @_prepared_statement('SELECT * FROM origin_visit '
+ 'WHERE origin = ? AND visit > ?')
+ def _origin_visit_get_no_limit(
+ self, origin_url: str, last_visit: int, *, statement) -> ResultSet:
+ return self._execute_and_retry(statement, [origin_url, last_visit])
+
+ @_prepared_statement('SELECT * FROM origin_visit '
+ 'WHERE origin = ? AND visit > ? LIMIT ?')
+ def _origin_visit_get_limit(
+ self, origin_url: str, last_visit: int, limit: int, *, statement
+ ) -> ResultSet:
+ return self._execute_and_retry(
+ statement, [origin_url, last_visit, limit])
+
+ def origin_visit_get(
+ self, origin_url: str, last_visit: Optional[int],
+ limit: Optional[int]) -> ResultSet:
+ if last_visit is None:
+ last_visit = -1
+
+ if limit is None:
+ return self._origin_visit_get_no_limit(origin_url, last_visit)
+ else:
+ return self._origin_visit_get_limit(origin_url, last_visit, limit)
+
+ def origin_visit_update(
+ self, origin_url: str, visit_id: int, updates: Dict[str, Any]
+ ) -> None:
+ set_parts = []
+ args: List[Any] = []
+ for (column, value) in updates.items():
+ set_parts.append(f'{column} = %s')
+ if column == 'metadata':
+ args.append(json.dumps(value))
+ else:
+ args.append(value)
+
+ if not set_parts:
+ return
+
+ query = ('UPDATE origin_visit SET ' + ', '.join(set_parts) +
+ ' WHERE origin = %s AND visit = %s')
+ self._execute_and_retry(
+ query, args + [origin_url, visit_id])
+
+ @_prepared_insert_statement('origin_visit', _origin_visit_keys)
+ def origin_visit_add_one(
+ self, visit: Dict[str, Any], *, statement) -> None:
+ self._execute_and_retry(
+ statement, [visit[key] for key in self._origin_visit_keys])
+ self._increment_counter('origin_visit', 1)
+
+ @_prepared_statement(
+ 'UPDATE origin_visit SET ' +
+ ', '.join('%s = ?' % key for key in _origin_visit_update_keys) +
+ ' WHERE origin = ? AND visit = ?')
+ def origin_visit_upsert(
+ self, visit: Dict[str, Any], *, statement) -> None:
+ self._execute_and_retry(
+ statement,
+ [visit.get(key) for key in self._origin_visit_update_keys]
+ + [visit['origin'], visit['visit']])
+ # TODO: check if there is already one
+ self._increment_counter('origin_visit', 1)
+
+ @_prepared_statement('SELECT * FROM origin_visit '
+ 'WHERE origin = ? AND visit = ?')
+ def origin_visit_get_one(
+ self, origin_url: str, visit_id: int, *,
+ statement) -> Optional[Row]:
+ # TODO: error handling
+ rows = list(self._execute_and_retry(statement, [origin_url, visit_id]))
+ if rows:
+ return rows[0]
+ else:
+ return None
+
+ @_prepared_statement('SELECT * FROM origin_visit '
+ 'WHERE origin = ?')
+ def origin_visit_get_all(self, origin_url: str, *, statement) -> ResultSet:
+ return self._execute_and_retry(
+ statement, [origin_url])
+
+ @_prepared_statement('SELECT * FROM origin_visit WHERE origin = ?')
+ def origin_visit_get_latest(
+ self, origin: str, allowed_statuses: Optional[Iterable[str]],
+ require_snapshot: bool, *, statement) -> Optional[Row]:
+ # TODO: do the ordering and filtering in Cassandra
+ rows = list(self._execute_and_retry(statement, [origin]))
+
+ rows.sort(key=lambda row: (row.date, row.visit), reverse=True)
+
+ for row in rows:
+ if require_snapshot and row.snapshot is None:
+ continue
+ if allowed_statuses is not None \
+ and row.status not in allowed_statuses:
+ continue
+ if row.snapshot is not None and \
+ self.snapshot_missing([row.snapshot]):
+ raise ValueError('visit references unknown snapshot')
+ return row
+ else:
+ return None
+
+ @_prepared_statement('SELECT * FROM origin_visit WHERE token(origin) >= ?')
+ def _origin_visit_iter_from(
+ self, min_token: int, *, statement) -> Generator[Row, None, None]:
+ yield from self._execute_and_retry(statement, [min_token])
+
+ @_prepared_statement('SELECT * FROM origin_visit WHERE token(origin) < ?')
+ def _origin_visit_iter_to(
+ self, max_token: int, *, statement) -> Generator[Row, None, None]:
+ yield from self._execute_and_retry(statement, [max_token])
+
+ def origin_visit_iter(
+ self, start_token: int) -> Generator[Row, None, None]:
+ """Returns all origin visits in order from this token,
+ and wraps around the token space."""
+ yield from self._origin_visit_iter_from(start_token)
+ yield from self._origin_visit_iter_to(start_token)
+
+ ##########################
+ # 'tool' table
+ ##########################
+
+ _tool_keys = ['id', 'name', 'version', 'configuration']
+
+ @_prepared_insert_statement('tool_by_uuid', _tool_keys)
+ def tool_by_uuid_add_one(self, tool: Dict[str, Any], *, statement) -> None:
+ self._execute_and_retry(
+ statement, [tool[key] for key in self._tool_keys])
+
+ @_prepared_insert_statement('tool', _tool_keys)
+ def tool_add_one(self, tool: Dict[str, Any], *, statement) -> None:
+ self._execute_and_retry(
+ statement, [tool[key] for key in self._tool_keys])
+ self._increment_counter('tool', 1)
+
+ @_prepared_statement('SELECT id FROM tool '
+ 'WHERE name = ? AND version = ? '
+ 'AND configuration = ?')
+ def tool_get_one_uuid(
+ self, name: str, version: str, configuration: Dict[str, Any], *,
+ statement) -> Optional[str]:
+ rows = list(self._execute_and_retry(
+ statement, [name, version, configuration]))
+ if rows:
+ assert len(rows) == 1
+ return rows[0].id
+ else:
+ return None
+
+ ##########################
+ # Miscellaneous
+ ##########################
+
+ @_prepared_statement('SELECT uuid() FROM revision LIMIT 1;')
+ def check_read(self, *, statement):
+ self._execute_and_retry(statement, [])
+
+ @_prepared_statement('SELECT object_type, count FROM object_count '
+ 'WHERE partition_key=0')
+ def stat_counters(self, *, statement) -> ResultSet:
+ return self._execute_and_retry(
+ statement, [])
diff --git a/swh/storage/cassandra/schema.py b/swh/storage/cassandra/schema.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/cassandra/schema.py
@@ -0,0 +1,196 @@
+# Copyright (C) 2019-2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+
+CREATE_TABLES_QUERIES = '''
+CREATE OR REPLACE FUNCTION ascii_bins_count_sfunc (
+ state tuple>, -- (nb_none, map)
+ bin_name ascii
+)
+CALLED ON NULL INPUT
+RETURNS tuple>
+LANGUAGE java AS
+$$
+ if (bin_name == null) {
+ state.setInt(0, state.getInt(0) + 1);
+ }
+ else {
+ Map counters = state.getMap(
+ 1, String.class, Integer.class);
+ Integer nb = counters.get(bin_name);
+ if (nb == null) {
+ nb = 0;
+ }
+ counters.put(bin_name, nb + 1);
+ state.setMap(1, counters, String.class, Integer.class);
+ }
+ return state;
+$$
+;
+
+CREATE OR REPLACE AGGREGATE ascii_bins_count ( ascii )
+SFUNC ascii_bins_count_sfunc
+STYPE tuple>
+INITCOND (0, {})
+;
+
+CREATE TYPE IF NOT EXISTS microtimestamp (
+ seconds bigint,
+ microseconds int
+);
+
+CREATE TYPE IF NOT EXISTS microtimestamp_with_timezone (
+ timestamp frozen,
+ offset smallint,
+ negative_utc boolean
+);
+
+CREATE TYPE IF NOT EXISTS person (
+ fullname blob,
+ name blob,
+ email blob
+);
+
+CREATE TABLE IF NOT EXISTS content (
+ sha1 blob,
+ sha1_git blob,
+ sha256 blob,
+ blake2s256 blob,
+ length bigint,
+ ctime timestamp,
+ -- creation time, i.e. time of (first) injection into the storage
+ status ascii,
+ PRIMARY KEY ((sha1, sha1_git, sha256, blake2s256))
+);
+
+CREATE TABLE IF NOT EXISTS revision (
+ id blob PRIMARY KEY,
+ date microtimestamp_with_timezone,
+ committer_date microtimestamp_with_timezone,
+ type ascii,
+ directory blob, -- source code "root" directory
+ message blob,
+ author person,
+ committer person,
+ synthetic boolean,
+ -- true iff revision has been created by Software Heritage
+ metadata text
+ -- extra metadata as JSON(tarball checksums,
+ -- extra commit information, etc...)
+);
+
+CREATE TABLE IF NOT EXISTS revision_parent (
+ id blob,
+ parent_rank int,
+ -- parent position in merge commits, 0-based
+ parent_id blob,
+ PRIMARY KEY ((id), parent_rank)
+);
+
+CREATE TABLE IF NOT EXISTS release
+(
+ id blob PRIMARY KEY,
+ target_type ascii,
+ target blob,
+ date microtimestamp_with_timezone,
+ name blob,
+ message blob,
+ author person,
+ synthetic boolean,
+ -- true iff release has been created by Software Heritage
+);
+
+CREATE TABLE IF NOT EXISTS directory (
+ id blob PRIMARY KEY,
+);
+
+CREATE TABLE IF NOT EXISTS directory_entry (
+ directory_id blob,
+ name blob, -- path name, relative to containing dir
+ target blob,
+ perms int, -- unix-like permissions
+ type ascii, -- target type
+ PRIMARY KEY ((directory_id), name)
+);
+
+CREATE TABLE IF NOT EXISTS snapshot (
+ id blob PRIMARY KEY,
+);
+
+-- For a given snapshot_id, branches are sorted by their name,
+-- allowing easy pagination.
+CREATE TABLE IF NOT EXISTS snapshot_branch (
+ snapshot_id blob,
+ name blob,
+ target_type ascii,
+ target blob,
+ PRIMARY KEY ((snapshot_id), name)
+);
+
+CREATE TABLE IF NOT EXISTS origin_visit (
+ origin text,
+ visit bigint,
+ date timestamp,
+ type text,
+ status ascii,
+ metadata text,
+ snapshot blob,
+ PRIMARY KEY ((origin), visit)
+);
+
+
+CREATE TABLE IF NOT EXISTS origin (
+ sha1 blob PRIMARY KEY,
+ url text,
+ type text,
+ next_visit_id int,
+ -- We need integer visit ids for compatibility with the pgsql
+ -- storage, so we're using lightweight transactions with this trick:
+ -- https://stackoverflow.com/a/29391877/539465
+);
+
+
+CREATE TABLE IF NOT EXISTS tool_by_uuid (
+ id timeuuid PRIMARY KEY,
+ name ascii,
+ version ascii,
+ configuration blob,
+);
+
+
+CREATE TABLE IF NOT EXISTS tool (
+ id timeuuid,
+ name ascii,
+ version ascii,
+ configuration blob,
+ PRIMARY KEY ((name, version, configuration))
+)
+
+
+CREATE TABLE IF NOT EXISTS object_count (
+ partition_key smallint, -- Constant, must always be 0
+ object_type ascii,
+ count counter,
+ PRIMARY KEY ((partition_key), object_type)
+);
+'''.split('\n\n')
+
+CONTENT_INDEX_TEMPLATE = '''
+CREATE TABLE IF NOT EXISTS content_by_{main_algo} (
+ sha1 blob,
+ sha1_git blob,
+ sha256 blob,
+ blake2s256 blob,
+ PRIMARY KEY (({main_algo}), {other_algos})
+);'''
+
+HASH_ALGORITHMS = ['sha1', 'sha1_git', 'sha256', 'blake2s256']
+
+for main_algo in HASH_ALGORITHMS:
+ CREATE_TABLES_QUERIES.append(CONTENT_INDEX_TEMPLATE.format(
+ main_algo=main_algo,
+ other_algos=', '.join(
+ [algo for algo in HASH_ALGORITHMS if algo != main_algo])
+ ))
diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/cassandra/storage.py
@@ -0,0 +1,972 @@
+# Copyright (C) 2019-2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+import datetime
+import json
+import random
+import re
+from typing import Any, Dict, List, Optional
+import uuid
+
+import attr
+import dateutil
+
+from swh.model.model import (
+ Revision, Release, Directory, DirectoryEntry, Content, OriginVisit,
+)
+from swh.objstorage import get_objstorage
+from swh.objstorage.exc import ObjNotFoundError
+try:
+ from swh.journal.writer import get_journal_writer
+except ImportError:
+ get_journal_writer = None # type: ignore
+ # mypy limitation, see https://github.com/python/mypy/issues/1153
+
+
+from .common import TOKEN_BEGIN, TOKEN_END
+from .converters import (
+ revision_to_db, revision_from_db, release_to_db, release_from_db,
+)
+from .cql import CqlRunner
+from .schema import HASH_ALGORITHMS
+
+
+# Max block size of contents to return
+BULK_BLOCK_CONTENT_LEN_MAX = 10000
+
+
+def now():
+ return datetime.datetime.now(tz=datetime.timezone.utc)
+
+
+class CassandraStorage:
+ def __init__(self, hosts, keyspace, objstorage,
+ port=9042, journal_writer=None):
+ self._cql_runner = CqlRunner(hosts, keyspace, port)
+
+ self.objstorage = get_objstorage(**objstorage)
+
+ if journal_writer:
+ self.journal_writer = get_journal_writer(**journal_writer)
+ else:
+ self.journal_writer = None
+
+ def check_config(self, *, check_write):
+ self._cql_runner.check_read()
+
+ return True
+
+ def _content_add(self, contents, with_data):
+ contents = [Content.from_dict(c) for c in contents]
+
+ # Filter-out content already in the database.
+ contents = [c for c in contents
+ if not self._cql_runner.content_get_from_pk(c.to_dict())]
+
+ if self.journal_writer:
+ for content in contents:
+ content = content.to_dict()
+ if 'data' in content:
+ del content['data']
+ self.journal_writer.write_addition('content', content)
+
+ count_contents = 0
+ count_content_added = 0
+ count_content_bytes_added = 0
+
+ for content in contents:
+ # First insert to the objstorage, if the endpoint is
+ # `content_add` (as opposed to `content_add_metadata`).
+ # TODO: this should probably be done in concurrently to inserting
+ # in index tables (but still before the main table; so an entry is
+ # only added to the main table after everything else was
+ # successfully inserted.
+ count_contents += 1
+ if content.status != 'absent':
+ count_content_added += 1
+ if with_data:
+ content_data = content.data
+ count_content_bytes_added += len(content_data)
+ self.objstorage.add(content_data, content.sha1)
+
+ # Then add to index tables
+ for algo in HASH_ALGORITHMS:
+ self._cql_runner.content_index_add_one(algo, content)
+
+ # Then to the main table
+ self._cql_runner.content_add_one(content)
+
+ # Note that we check for collisions *after* inserting. This
+ # differs significantly from the pgsql storage, but checking
+ # before insertion does not provide any guarantee in case
+ # another thread inserts the colliding hash at the same time.
+ #
+ # The proper way to do it would probably be a BATCH, but this
+ # would be inefficient because of the number of partitions we
+ # need to affect (len(HASH_ALGORITHMS)+1, which is currently 5)
+ for algo in {'sha1', 'sha1_git'}:
+ pks = self._cql_runner.content_get_pks_from_single_hash(
+ algo, content.get_hash(algo))
+ if len(pks) > 1:
+ # There are more than the one we just inserted.
+ from .. import HashCollision
+ raise HashCollision(algo, content.get_hash(algo), pks)
+
+ summary = {
+ 'content:add': count_content_added,
+ 'skipped_content:add': count_contents - count_content_added,
+ }
+
+ if with_data:
+ summary['content:add:bytes'] = count_content_bytes_added
+
+ return summary
+
+ def content_add(self, content):
+ content = [c.copy() for c in content] # semi-shallow copy
+ for item in content:
+ item['ctime'] = now()
+ return self._content_add(content, with_data=True)
+
+ def content_update(self, content, keys=[]):
+ raise NotImplementedError(
+ 'content_update is not supported by the Cassandra backend')
+
+ def content_add_metadata(self, content):
+ return self._content_add(content, with_data=False)
+
+ def content_get(self, content):
+ if len(content) > BULK_BLOCK_CONTENT_LEN_MAX:
+ raise ValueError(
+ "Sending at most %s contents." % BULK_BLOCK_CONTENT_LEN_MAX)
+ for obj_id in content:
+ try:
+ data = self.objstorage.get(obj_id)
+ except ObjNotFoundError:
+ yield None
+ continue
+
+ yield {'sha1': obj_id, 'data': data}
+
+ def content_get_partition(
+ self, partition_id: int, nb_partitions: int, limit: int = 1000,
+ page_token: str = None):
+ if limit is None:
+ raise ValueError('Development error: limit should not be None')
+
+ # Compute start and end of the range of tokens covered by the
+ # requested partition
+ partition_size = (TOKEN_END-TOKEN_BEGIN)//nb_partitions
+ range_start = TOKEN_BEGIN + partition_id*partition_size
+ range_end = TOKEN_BEGIN + (partition_id+1)*partition_size
+
+ # offset the range start according to the `page_token`.
+ if page_token is not None:
+ if not (range_start <= int(page_token) <= range_end):
+ raise ValueError('Invalid page_token.')
+ range_start = int(page_token)
+
+ # Get the first rows of the range
+ rows = self._cql_runner.content_get_token_range(
+ range_start, range_end, limit)
+ rows = list(rows)
+
+ if len(rows) == limit:
+ next_page_token: Optional[str] = str(rows[-1].tok+1)
+ else:
+ next_page_token = None
+
+ return {
+ 'contents': [row._asdict() for row in rows
+ if row.status != 'absent'],
+ 'next_page_token': next_page_token,
+ }
+
+ def content_get_metadata(
+ self, contents: List[bytes]) -> Dict[bytes, List[Dict]]:
+ result: Dict[bytes, List[Dict]] = {sha1: [] for sha1 in contents}
+ for sha1 in contents:
+ # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1
+ # matches the argument, from the index table ('content_by_sha1')
+ pks = self._cql_runner.content_get_pks_from_single_hash(
+ 'sha1', sha1)
+
+ if pks:
+ # TODO: what to do if there are more than one?
+ pk = pks[0]
+
+ # Query the main table ('content')
+ res = self._cql_runner.content_get_from_pk(pk._asdict())
+
+ # Rows in 'content' are inserted after corresponding
+ # rows in 'content_by_*', so we might be missing it
+ if res:
+ content_metadata = res._asdict()
+ content_metadata.pop('ctime')
+ result[content_metadata['sha1']].append(content_metadata)
+ return result
+
+ def skipped_content_missing(self, contents):
+ # TODO
+ raise NotImplementedError('not yet supported for Cassandra')
+
+ def content_find(self, content):
+ # Find an algorithm that is common to all the requested contents.
+ # It will be used to do an initial filtering efficiently.
+ filter_algos = list(set(content).intersection(HASH_ALGORITHMS))
+ if not filter_algos:
+ raise ValueError('content keys must contain at least one of: '
+ '%s' % ', '.join(sorted(HASH_ALGORITHMS)))
+ common_algo = filter_algos[0]
+
+ # Find all contents whose common_algo matches one at least one
+ # of the requests.
+ found_pks = self._cql_runner.content_get_pks_from_single_hash(
+ common_algo, content[common_algo])
+ found_pks = [pk._asdict() for pk in found_pks]
+
+ # Filter with the other hash algorithms.
+ for algo in filter_algos[1:]:
+ found_pks = [pk for pk in found_pks if pk[algo] == content[algo]]
+
+ results = []
+ for pk in found_pks:
+ # Query the main table ('content').
+ res = self._cql_runner.content_get_from_pk(pk)
+
+ # Rows in 'content' are inserted after corresponding
+ # rows in 'content_by_*', so we might be missing it
+ if res:
+ results.append({
+ **res._asdict(),
+ 'ctime': res.ctime.replace(tzinfo=datetime.timezone.utc)
+ })
+ return results
+
+ def content_missing(self, content, key_hash='sha1'):
+ for cont in content:
+ res = self.content_find(cont)
+ if not res:
+ yield cont[key_hash]
+ if any(c['status'] == 'missing' for c in res):
+ yield cont[key_hash]
+
+ def content_missing_per_sha1(self, contents):
+ return self.content_missing([{'sha1': c for c in contents}])
+
+ def content_missing_per_sha1_git(self, contents):
+ return self.content_missing([{'sha1_git': c for c in contents}],
+ key_hash='sha1_git')
+
+ def content_get_random(self):
+ return self._cql_runner.content_get_random().sha1_git
+
+ def directory_add(self, directories):
+ directories = list(directories)
+
+ # Filter out directories that are already inserted.
+ missing = self.directory_missing([dir_['id'] for dir_ in directories])
+ directories = [dir_ for dir_ in directories if dir_['id'] in missing]
+
+ if self.journal_writer:
+ self.journal_writer.write_additions('directory', directories)
+
+ for directory in directories:
+ directory = Directory.from_dict(directory)
+
+ # Add directory entries to the 'directory_entry' table
+ for entry in directory.entries:
+ entry = entry.to_dict()
+ entry['directory_id'] = directory.id
+ self._cql_runner.directory_entry_add_one(entry)
+
+ # Add the directory *after* adding all the entries, so someone
+ # calling snapshot_get_branch in the meantime won't end up
+ # with half the entries.
+ self._cql_runner.directory_add_one(directory.id)
+
+ return {'directory:add': len(missing)}
+
+ def directory_missing(self, directories):
+ return self._cql_runner.directory_missing(directories)
+
+ def _join_dentry_to_content(self, dentry):
+ keys = (
+ 'status',
+ 'sha1',
+ 'sha1_git',
+ 'sha256',
+ 'length',
+ )
+ ret = dict.fromkeys(keys)
+ ret.update(dentry.to_dict())
+ if ret['type'] == 'file':
+ content = self.content_find({'sha1_git': ret['target']})
+ if content:
+ content = content[0]
+ for key in keys:
+ ret[key] = content[key]
+ return ret
+
+ def _directory_ls(self, directory_id, recursive, prefix=b''):
+ if self.directory_missing([directory_id]):
+ return
+ rows = list(self._cql_runner.directory_entry_get([directory_id]))
+
+ for row in rows:
+ # Build and yield the directory entry dict
+ entry = row._asdict()
+ del entry['directory_id']
+ entry = DirectoryEntry.from_dict(entry)
+ ret = self._join_dentry_to_content(entry)
+ ret['name'] = prefix + ret['name']
+ ret['dir_id'] = directory_id
+ yield ret
+
+ if recursive and ret['type'] == 'dir':
+ yield from self._directory_ls(
+ ret['target'], True, prefix + ret['name'] + b'/')
+
+ def directory_entry_get_by_path(self, directory, paths):
+ return self._directory_entry_get_by_path(directory, paths, b'')
+
+ def _directory_entry_get_by_path(self, directory, paths, prefix):
+ if not paths:
+ return
+
+ contents = list(self.directory_ls(directory))
+
+ if not contents:
+ return
+
+ def _get_entry(entries, name):
+ """Finds the entry with the requested name, prepends the
+ prefix (to get its full path), and returns it.
+
+ If no entry has that name, returns None."""
+ for entry in entries:
+ if entry['name'] == name:
+ entry = entry.copy()
+ entry['name'] = prefix + entry['name']
+ return entry
+
+ first_item = _get_entry(contents, paths[0])
+
+ if len(paths) == 1:
+ return first_item
+
+ if not first_item or first_item['type'] != 'dir':
+ return
+
+ return self._directory_entry_get_by_path(
+ first_item['target'], paths[1:], prefix + paths[0] + b'/')
+
+ def directory_ls(self, directory, recursive=False):
+ yield from self._directory_ls(directory, recursive)
+
+ def directory_get_random(self):
+ return self._cql_runner.directory_get_random().id
+
+ def revision_add(self, revisions):
+ revisions = list(revisions)
+
+ # Filter-out revisions already in the database
+ missing = self.revision_missing([rev['id'] for rev in revisions])
+ revisions = [rev for rev in revisions if rev['id'] in missing]
+
+ if self.journal_writer:
+ self.journal_writer.write_additions('revision', revisions)
+
+ for revision in revisions:
+ revision = revision_to_db(revision)
+
+ if revision:
+ # Add parents first
+ for (rank, parent) in enumerate(revision.parents):
+ self._cql_runner.revision_parent_add_one(
+ revision.id, rank, parent)
+
+ # Then write the main revision row.
+ # Writing this after all parents were written ensures that
+ # read endpoints don't return a partial view while writing
+ # the parents
+ self._cql_runner.revision_add_one(revision)
+
+ return {'revision:add': len(revisions)}
+
+ def revision_missing(self, revisions):
+ return self._cql_runner.revision_missing(revisions)
+
+ def revision_get(self, revisions):
+ rows = self._cql_runner.revision_get(revisions)
+ revs = {}
+ for row in rows:
+ # TODO: use a single query to get all parents?
+ # (it might have lower latency, but requires more code and more
+ # bandwidth, because revision id would be part of each returned
+ # row)
+ parent_rows = self._cql_runner.revision_parent_get(row.id)
+ # parent_rank is the clustering key, so results are already
+ # sorted by rank.
+ parents = [row.parent_id for row in parent_rows]
+
+ rev = Revision(**row._asdict(), parents=parents)
+
+ rev = revision_from_db(rev)
+ revs[rev.id] = rev.to_dict()
+
+ for rev_id in revisions:
+ yield revs.get(rev_id)
+
+ def _get_parent_revs(self, rev_ids, seen, limit, short):
+ if limit and len(seen) >= limit:
+ return
+ rev_ids = [id_ for id_ in rev_ids if id_ not in seen]
+ if not rev_ids:
+ return
+ seen |= set(rev_ids)
+
+ # We need this query, even if short=True, to return consistent
+ # results (ie. not return only a subset of a revision's parents
+ # if it is being written)
+ if short:
+ rows = self._cql_runner.revision_get_ids(rev_ids)
+ else:
+ rows = self._cql_runner.revision_get(rev_ids)
+
+ for row in rows:
+ # TODO: use a single query to get all parents?
+ # (it might have less latency, but requires less code and more
+ # bandwidth (because revision id would be part of each returned
+ # row)
+ parent_rows = self._cql_runner.revision_parent_get(row.id)
+
+ # parent_rank is the clustering key, so results are already
+ # sorted by rank.
+ parents = [row.parent_id for row in parent_rows]
+
+ if short:
+ yield (row.id, parents)
+ else:
+ rev = revision_from_db(Revision(
+ **row._asdict(), parents=parents))
+ yield rev.to_dict()
+ yield from self._get_parent_revs(parents, seen, limit, short)
+
+ def revision_log(self, revisions, limit=None):
+ """Fetch revision entry from the given root revisions.
+
+ Args:
+ revisions: array of root revision to lookup
+ limit: limitation on the output result. Default to None.
+
+ Yields:
+ List of revision log from such revisions root.
+
+ """
+ seen = set()
+ yield from self._get_parent_revs(revisions, seen, limit, False)
+
+ def revision_shortlog(self, revisions, limit=None):
+ """Fetch the shortlog for the given revisions
+
+ Args:
+ revisions: list of root revisions to lookup
+ limit: depth limitation for the output
+
+ Yields:
+ a list of (id, parents) tuples.
+
+ """
+ seen = set()
+ yield from self._get_parent_revs(revisions, seen, limit, True)
+
+ def revision_get_random(self):
+ return self._cql_runner.revision_get_random().id
+
+ def release_add(self, releases):
+ releases = list(releases)
+ missing = self.release_missing([rel['id'] for rel in releases])
+ releases = [rel for rel in releases if rel['id'] in missing]
+
+ if self.journal_writer:
+ self.journal_writer.write_additions('release', releases)
+
+ for release in releases:
+ release = release_to_db(release)
+
+ if release:
+ self._cql_runner.release_add_one(release)
+
+ return {'release:add': len(missing)}
+
+ def release_missing(self, releases):
+ return self._cql_runner.release_missing(releases)
+
+ def release_get(self, releases):
+ rows = self._cql_runner.release_get(releases)
+ rels = {}
+ for row in rows:
+ release = Release(**row._asdict())
+ release = release_from_db(release)
+ rels[row.id] = release.to_dict()
+
+ for rel_id in releases:
+ yield rels.get(rel_id)
+
+ def release_get_random(self):
+ return self._cql_runner.release_get_random().id
+
+ def snapshot_add(self, snapshots):
+ snapshots = list(snapshots)
+ missing = self._cql_runner.snapshot_missing(
+ [snp['id'] for snp in snapshots])
+ snapshots = [snp for snp in snapshots if snp['id'] in missing]
+
+ for snapshot in snapshots:
+ if self.journal_writer:
+ self.journal_writer.write_addition('snapshot', snapshot)
+
+ # Add branches
+ for (branch_name, branch) in snapshot['branches'].items():
+ if branch is None:
+ branch = {'target_type': None, 'target': None}
+ self._cql_runner.snapshot_branch_add_one({
+ 'snapshot_id': snapshot['id'],
+ 'name': branch_name,
+ 'target_type': branch['target_type'],
+ 'target': branch['target'],
+ })
+
+ # Add the snapshot *after* adding all the branches, so someone
+ # calling snapshot_get_branch in the meantime won't end up
+ # with half the branches.
+ self._cql_runner.snapshot_add_one(snapshot['id'])
+
+ return {'snapshot:add': len(snapshots)}
+
+ def snapshot_missing(self, snapshots):
+ return self._cql_runner.snapshot_missing(snapshots)
+
+ def snapshot_get(self, snapshot_id):
+ return self.snapshot_get_branches(snapshot_id)
+
+ def snapshot_get_by_origin_visit(self, origin, visit):
+ try:
+ visit = self._cql_runner.origin_visit_get_one(origin, visit)
+ except IndexError:
+ return None
+
+ return self.snapshot_get(visit.snapshot)
+
+ def snapshot_get_latest(self, origin, allowed_statuses=None):
+ visit = self.origin_visit_get_latest(
+ origin,
+ allowed_statuses=allowed_statuses,
+ require_snapshot=True)
+
+ if visit:
+ assert visit['snapshot']
+ if self._cql_runner.snapshot_missing([visit['snapshot']]):
+ raise ValueError('Visit references unknown snapshot')
+ return self.snapshot_get_branches(visit['snapshot'])
+
+ def snapshot_count_branches(self, snapshot_id):
+ if self._cql_runner.snapshot_missing([snapshot_id]):
+ # Makes sure we don't fetch branches for a snapshot that is
+ # being added.
+ return None
+ rows = list(self._cql_runner.snapshot_count_branches(snapshot_id))
+ assert len(rows) == 1
+ (nb_none, counts) = rows[0].counts
+ counts = dict(counts)
+ if nb_none:
+ counts[None] = nb_none
+ return counts
+
+ def snapshot_get_branches(self, snapshot_id, branches_from=b'',
+ branches_count=1000, target_types=None):
+ if self._cql_runner.snapshot_missing([snapshot_id]):
+ # Makes sure we don't fetch branches for a snapshot that is
+ # being added.
+ return None
+
+ branches = []
+ while len(branches) < branches_count+1:
+ new_branches = list(self._cql_runner.snapshot_branch_get(
+ snapshot_id, branches_from, branches_count+1))
+
+ if not new_branches:
+ break
+
+ branches_from = new_branches[-1].name
+
+ new_branches_filtered = new_branches
+
+ # Filter by target_type
+ if target_types:
+ new_branches_filtered = [
+ branch for branch in new_branches_filtered
+ if branch.target is not None
+ and branch.target_type in target_types]
+
+ branches.extend(new_branches_filtered)
+
+ if len(new_branches) < branches_count+1:
+ break
+
+ if len(branches) > branches_count:
+ last_branch = branches.pop(-1).name
+ else:
+ last_branch = None
+
+ branches = {
+ branch.name: {
+ 'target': branch.target,
+ 'target_type': branch.target_type,
+ } if branch.target else None
+ for branch in branches
+ }
+
+ return {
+ 'id': snapshot_id,
+ 'branches': branches,
+ 'next_branch': last_branch,
+ }
+
+ def snapshot_get_random(self):
+ return self._cql_runner.snapshot_get_random().id
+
+ def object_find_by_sha1_git(self, ids):
+ results = {id_: [] for id_ in ids}
+ missing_ids = set(ids)
+
+ # Mind the order, revision is the most likely one for a given ID,
+ # so we check revisions first.
+ queries = [
+ ('revision', self._cql_runner.revision_missing),
+ ('release', self._cql_runner.release_missing),
+ ('content', self._cql_runner.content_missing_by_sha1_git),
+ ('directory', self._cql_runner.directory_missing),
+ ]
+
+ for (object_type, query_fn) in queries:
+ found_ids = missing_ids - set(query_fn(missing_ids))
+ for sha1_git in found_ids:
+ results[sha1_git].append({
+ 'sha1_git': sha1_git,
+ 'type': object_type,
+ })
+ missing_ids.remove(sha1_git)
+
+ if not missing_ids:
+ # We found everything, skipping the next queries.
+ break
+
+ return results
+
+ def origin_get(self, origins):
+ if isinstance(origins, dict):
+ # Old API
+ return_single = True
+ origins = [origins]
+ else:
+ return_single = False
+
+ if any('id' in origin for origin in origins):
+ raise ValueError('Origin ids are not supported.')
+
+ results = [self.origin_get_one(origin) for origin in origins]
+
+ if return_single:
+ assert len(results) == 1
+ return results[0]
+ else:
+ return results
+
+ def origin_get_one(self, origin):
+ if 'id' in origin:
+ raise ValueError('Origin ids are not supported.')
+ rows = self._cql_runner.origin_get_by_url(origin['url'])
+
+ rows = list(rows)
+ if rows:
+ assert len(rows) == 1
+ result = rows[0]._asdict()
+ return {
+ 'url': result['url'],
+ }
+ else:
+ return None
+
+ def origin_get_by_sha1(self, sha1s):
+ results = []
+ for sha1 in sha1s:
+ rows = self._cql_runner.origin_get_by_sha1(sha1)
+ if rows:
+ results.append({'url': rows.one().url})
+ else:
+ results.append(None)
+ return results
+
+ def origin_list(self, page_token: Optional[str] = None, limit: int = 100
+ ) -> dict:
+ # Compute what token to begin the listing from
+ start_token = TOKEN_BEGIN
+ if page_token:
+ start_token = int(page_token)
+ if not (TOKEN_BEGIN <= start_token <= TOKEN_END):
+ raise ValueError('Invalid page_token.')
+
+ rows = self._cql_runner.origin_list(start_token, limit)
+ rows = list(rows)
+
+ if len(rows) == limit:
+ next_page_token: Optional[str] = str(rows[-1].tok+1)
+ else:
+ next_page_token = None
+
+ return {
+ 'origins': [{'url': row.url} for row in rows],
+ 'next_page_token': next_page_token,
+ }
+
+ def origin_search(self, url_pattern, offset=0, limit=50,
+ regexp=False, with_visit=False):
+ # TODO: remove this endpoint, swh-search should be used instead.
+ origins = self._cql_runner.origin_iter_all()
+ if regexp:
+ pat = re.compile(url_pattern)
+ origins = [orig for orig in origins if pat.search(orig.url)]
+ else:
+ origins = [orig for orig in origins if url_pattern in orig.url]
+
+ if with_visit:
+ origins = [orig for orig in origins
+ if orig.next_visit_id > 1]
+
+ return [
+ {
+ 'url': orig.url,
+ }
+ for orig in origins[offset:offset+limit]]
+
+ def origin_add(self, origins):
+ origins = list(origins)
+ if any('id' in origin for origin in origins):
+ raise ValueError('Origins must not already have an id.')
+ results = []
+ for origin in origins:
+ self.origin_add_one(origin)
+ results.append(origin)
+ return results
+
+ def origin_add_one(self, origin):
+ known_origin = self.origin_get_one(origin)
+
+ if known_origin:
+ origin_url = known_origin['url']
+ else:
+ if self.journal_writer:
+ self.journal_writer.write_addition('origin', origin)
+
+ self._cql_runner.origin_add_one(origin)
+ origin_url = origin['url']
+
+ return origin_url
+
+ def origin_visit_add(self, origin, date, type):
+ origin_url = origin # TODO: rename the argument
+
+ if isinstance(date, str):
+ date = dateutil.parser.parse(date)
+
+ origin = self.origin_get_one({'url': origin_url})
+
+ if not origin:
+ return None
+
+ visit_id = self._cql_runner.origin_generate_unique_visit_id(origin_url)
+
+ visit = {
+ 'origin': origin_url,
+ 'date': date,
+ 'type': type,
+ 'status': 'ongoing',
+ 'snapshot': None,
+ 'metadata': None,
+ 'visit': visit_id
+ }
+
+ if self.journal_writer:
+ self.journal_writer.write_addition('origin_visit', visit)
+
+ self._cql_runner.origin_visit_add_one(visit)
+
+ return {
+ 'origin': origin_url,
+ 'visit': visit_id,
+ }
+
+ def origin_visit_update(self, origin, visit_id, status=None,
+ metadata=None, snapshot=None):
+ origin_url = origin # TODO: rename the argument
+
+ # Get the existing data of the visit
+ row = self._cql_runner.origin_visit_get_one(origin_url, visit_id)
+ if not row:
+ raise ValueError('This origin visit does not exist.')
+ visit = OriginVisit.from_dict(self._format_origin_visit_row(row))
+
+ updates = {}
+ if status:
+ updates['status'] = status
+ if metadata:
+ updates['metadata'] = metadata
+ if snapshot:
+ updates['snapshot'] = snapshot
+
+ visit = attr.evolve(visit, **updates)
+
+ if self.journal_writer:
+ self.journal_writer.write_update('origin_visit', visit)
+
+ self._cql_runner.origin_visit_update(origin_url, visit_id, updates)
+
+ def origin_visit_upsert(self, visits):
+ visits = [visit.copy() for visit in visits]
+ for visit in visits:
+ if isinstance(visit['date'], str):
+ visit['date'] = dateutil.parser.parse(visit['date'])
+
+ if self.journal_writer:
+ for visit in visits:
+ self.journal_writer.write_addition('origin_visit', visit)
+
+ for visit in visits:
+ visit = visit.copy()
+ if visit.get('metadata'):
+ visit['metadata'] = json.dumps(visit['metadata'])
+ self._cql_runner.origin_visit_upsert(visit)
+
+ @staticmethod
+ def _format_origin_visit_row(visit):
+ return {
+ **visit._asdict(),
+ 'origin': visit.origin,
+ 'date': visit.date.replace(tzinfo=datetime.timezone.utc),
+ 'metadata': (json.loads(visit.metadata)
+ if visit.metadata else None),
+ }
+
+ def origin_visit_get(self, origin, last_visit=None, limit=None):
+ rows = self._cql_runner.origin_visit_get(origin, last_visit, limit)
+
+ yield from map(self._format_origin_visit_row, rows)
+
+ def origin_visit_find_by_date(self, origin, visit_date):
+ # Iterator over all the visits of the origin
+ # This should be ok for now, as there aren't too many visits
+ # per origin.
+ visits = list(self._cql_runner.origin_visit_get_all(origin))
+
+ def key(visit):
+ dt = visit.date.replace(tzinfo=datetime.timezone.utc) - visit_date
+ return (abs(dt), -visit.visit)
+
+ if visits:
+ visit = min(visits, key=key)
+ return visit._asdict()
+
+ def origin_visit_get_by(self, origin, visit):
+ visit = self._cql_runner.origin_visit_get_one(origin, visit)
+ if visit:
+ return self._format_origin_visit_row(visit)
+
+ def origin_visit_get_latest(
+ self, origin, allowed_statuses=None, require_snapshot=False):
+ visit = self._cql_runner.origin_visit_get_latest(
+ origin,
+ allowed_statuses=allowed_statuses,
+ require_snapshot=require_snapshot)
+ if visit:
+ return self._format_origin_visit_row(visit)
+
+ def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]:
+ back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back
+
+ # Random position to start iteration at
+ start_token = random.randint(TOKEN_BEGIN, TOKEN_END)
+
+ # Iterator over all visits, ordered by token(origins) then visit_id
+ rows = self._cql_runner.origin_visit_iter(start_token)
+ for row in rows:
+ visit = self._format_origin_visit_row(row)
+ if visit['date'] > back_in_the_day \
+ and visit['status'] == 'full':
+ return visit
+ else:
+ return None
+
+ def tool_add(self, tools):
+ inserted = []
+ for tool in tools:
+ tool = tool.copy()
+ tool_json = tool.copy()
+ tool_json['configuration'] = json.dumps(
+ tool['configuration'], sort_keys=True).encode()
+ id_ = self._cql_runner.tool_get_one_uuid(**tool_json)
+ if not id_:
+ id_ = uuid.uuid1()
+ tool_json['id'] = id_
+ self._cql_runner.tool_by_uuid_add_one(tool_json)
+ self._cql_runner.tool_add_one(tool_json)
+ tool['id'] = id_
+ inserted.append(tool)
+ return inserted
+
+ def tool_get(self, tool):
+ id_ = self._cql_runner.tool_get_one_uuid(
+ tool['name'], tool['version'],
+ json.dumps(tool['configuration'], sort_keys=True).encode())
+ if id_:
+ tool = tool.copy()
+ tool['id'] = id_
+ return tool
+ else:
+ return None
+
+ def stat_counters(self):
+ rows = self._cql_runner.stat_counters()
+ keys = (
+ 'content', 'directory', 'origin', 'origin_visit',
+ 'release', 'revision', 'skipped_content', 'snapshot')
+ stats = {key: 0 for key in keys}
+ stats.update({row.object_type: row.count for row in rows})
+ return stats
+
+ def refresh_stat_counters(self):
+ pass
+
+ def origin_metadata_add(self, origin_url, ts, provider, tool, metadata):
+ # TODO
+ raise NotImplementedError('not yet supported for Cassandra')
+
+ def origin_metadata_get_by(self, origin_url, provider_type=None):
+ # TODO
+ raise NotImplementedError('not yet supported for Cassandra')
+
+ def metadata_provider_add(self, provider_name, provider_type, provider_url,
+ metadata):
+ # TODO
+ raise NotImplementedError('not yet supported for Cassandra')
+
+ def metadata_provider_get(self, provider_id):
+ # TODO
+ raise NotImplementedError('not yet supported for Cassandra')
+
+ def metadata_provider_get_by(self, provider):
+ # TODO
+ raise NotImplementedError('not yet supported for Cassandra')
diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/tests/test_cassandra.py
@@ -0,0 +1,272 @@
+# Copyright (C) 2018-2019 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+import os
+import signal
+import socket
+import subprocess
+import time
+
+import pytest
+
+from swh.storage import get_storage
+from swh.storage.cassandra import create_keyspace
+
+from swh.storage.tests.test_storage import TestStorage as _TestStorage
+from swh.storage.tests.test_storage import TestStorageGeneratedData \
+ as _TestStorageGeneratedData
+
+
+CONFIG_TEMPLATE = '''
+data_file_directories:
+ - {data_dir}/data
+commitlog_directory: {data_dir}/commitlog
+hints_directory: {data_dir}/hints
+saved_caches_directory: {data_dir}/saved_caches
+
+commitlog_sync: periodic
+commitlog_sync_period_in_ms: 1000000
+partitioner: org.apache.cassandra.dht.Murmur3Partitioner
+endpoint_snitch: SimpleSnitch
+seed_provider:
+ - class_name: org.apache.cassandra.locator.SimpleSeedProvider
+ parameters:
+ - seeds: "127.0.0.1"
+
+storage_port: {storage_port}
+native_transport_port: {native_transport_port}
+start_native_transport: true
+listen_address: 127.0.0.1
+
+enable_user_defined_functions: true
+'''
+
+
+def free_port():
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(('127.0.0.1', 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ return port
+
+
+def wait_for_peer(addr, port):
+ wait_until = time.time() + 20
+ while time.time() < wait_until:
+ try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.connect((addr, port))
+ except ConnectionRefusedError:
+ time.sleep(0.1)
+ else:
+ sock.close()
+ return True
+ return False
+
+
+@pytest.fixture(scope='session')
+def cassandra_cluster(tmpdir_factory):
+ cassandra_conf = tmpdir_factory.mktemp('cassandra_conf')
+ cassandra_data = tmpdir_factory.mktemp('cassandra_data')
+ cassandra_log = tmpdir_factory.mktemp('cassandra_log')
+ native_transport_port = free_port()
+ storage_port = free_port()
+ jmx_port = free_port()
+
+ with open(str(cassandra_conf.join('cassandra.yaml')), 'w') as fd:
+ fd.write(CONFIG_TEMPLATE.format(
+ data_dir=str(cassandra_data),
+ storage_port=storage_port,
+ native_transport_port=native_transport_port,
+ ))
+
+ if os.environ.get('LOG_CASSANDRA'):
+ stdout = stderr = None
+ else:
+ stdout = stderr = subprocess.DEVNULL
+ proc = subprocess.Popen(
+ [
+ '/usr/sbin/cassandra',
+ '-Dcassandra.config=file://%s/cassandra.yaml' % cassandra_conf,
+ '-Dcassandra.logdir=%s' % cassandra_log,
+ '-Dcassandra.jmx.local.port=%d' % jmx_port,
+ '-Dcassandra-foreground=yes',
+ ],
+ start_new_session=True,
+ env={
+ 'MAX_HEAP_SIZE': '300M',
+ 'HEAP_NEWSIZE': '50M',
+ 'JVM_OPTS': '-Xlog:gc=error:file=%s/gc.log' % cassandra_log,
+ },
+ stdout=stdout,
+ stderr=stderr,
+ )
+
+ running = wait_for_peer('127.0.0.1', native_transport_port)
+
+ if running:
+ yield (['127.0.0.1'], native_transport_port)
+
+ if not running or os.environ.get('LOG_CASSANDRA'):
+ with open(str(cassandra_log.join('debug.log'))) as fd:
+ print(fd.read())
+
+ if not running:
+ raise Exception('cassandra process stopped unexpectedly.')
+
+ pgrp = os.getpgid(proc.pid)
+ os.killpg(pgrp, signal.SIGKILL)
+
+
+class RequestHandler:
+ def on_request(self, rf):
+ if hasattr(rf.message, 'query'):
+ print()
+ print(rf.message.query)
+
+
+# tests are executed using imported classes (TestStorage and
+# TestStorageGeneratedData) using overloaded swh_storage fixture
+# below
+
+@pytest.fixture
+def swh_storage(cassandra_cluster):
+ (hosts, port) = cassandra_cluster
+ keyspace = os.urandom(10).hex()
+
+ create_keyspace(hosts, keyspace, port)
+
+ storage = get_storage(
+ 'cassandra',
+ hosts=hosts, port=port,
+ keyspace=keyspace,
+ journal_writer={
+ 'cls': 'memory',
+ },
+ objstorage={
+ 'cls': 'memory',
+ 'args': {},
+ },
+ )
+
+ yield storage
+
+ storage._cql_runner._session.execute(
+ 'DROP KEYSPACE "%s"' % keyspace)
+
+
+@pytest.mark.cassandra
+class TestCassandraStorage(_TestStorage):
+ @pytest.mark.skip('postgresql-specific test')
+ def test_content_add_db(self):
+ pass
+
+ @pytest.mark.skip('content_update is not yet implemented for Cassandra')
+ def test_content_update(self):
+ pass
+
+ @pytest.mark.skip('postgresql-specific test')
+ def test_skipped_content_add_db(self):
+ pass
+
+ @pytest.mark.skip('postgresql-specific test')
+ def test_content_add_metadata_db(self):
+ pass
+
+ @pytest.mark.skip(
+ 'not implemented, see https://forge.softwareheritage.org/T1633')
+ def test_skipped_content_add(self):
+ pass
+
+ @pytest.mark.skip(
+ 'The "person" table of the pgsql is a legacy thing, and not '
+ 'supported by the cassandra backend.')
+ def test_person_fullname_unicity(self):
+ pass
+
+ @pytest.mark.skip(
+ 'The "person" table of the pgsql is a legacy thing, and not '
+ 'supported by the cassandra backend.')
+ def test_person_get(self):
+ pass
+
+ @pytest.mark.skip('Not yet implemented')
+ def test_metadata_provider_add(self):
+ pass
+
+ @pytest.mark.skip('Not yet implemented')
+ def test_metadata_provider_get(self):
+ pass
+
+ @pytest.mark.skip('Not yet implemented')
+ def test_metadata_provider_get_by(self):
+ pass
+
+ @pytest.mark.skip('Not yet implemented')
+ def test_origin_metadata_add(self):
+ pass
+
+ @pytest.mark.skip('Not yet implemented')
+ def test_origin_metadata_get(self):
+ pass
+
+ @pytest.mark.skip('Not yet implemented')
+ def test_origin_metadata_get_by_provider_type(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_origin_count(self):
+ pass
+
+
+@pytest.mark.cassandra
+class TestCassandraStorageGeneratedData(_TestStorageGeneratedData):
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_origin_count(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_origin_get_range(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_origin_get_range_from_zero(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_generate_content_get_range_limit(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_generate_content_get_range_no_limit(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_generate_content_get_range(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_generate_content_get_range_empty(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_generate_content_get_range_limit_none(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_generate_content_get_range_full(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_origin_count_with_visit_no_visits(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_origin_count_with_visit_with_visits_and_snapshot(self):
+ pass
+
+ @pytest.mark.skip('Not supported by Cassandra')
+ def test_origin_count_with_visit_with_visits_no_snapshot(self):
+ pass
diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py
--- a/swh/storage/tests/test_storage.py
+++ b/swh/storage/tests/test_storage.py
@@ -1385,9 +1385,14 @@
swh_storage.origin_add_one(data.origin2)
origin_url = data.origin2['url']
+ date_visit = datetime.datetime.now(datetime.timezone.utc)
+
+ # Round to milliseconds before insertion, so equality doesn't fail
+ # after a round-trip through a DB (eg. Cassandra)
+ date_visit = date_visit.replace(
+ microsecond=round(date_visit.microsecond, -3))
# when
- date_visit = datetime.datetime.now(datetime.timezone.utc)
origin_visit1 = swh_storage.origin_visit_add(
origin_url,
type=data.type_visit1,
@@ -1429,6 +1434,14 @@
# when
date_visit = datetime.datetime.now(datetime.timezone.utc)
date_visit2 = date_visit + datetime.timedelta(minutes=1)
+
+ # Round to milliseconds before insertion, so equality doesn't fail
+ # after a round-trip through a DB (eg. Cassandra)
+ date_visit = date_visit.replace(
+ microsecond=round(date_visit.microsecond, -3))
+ date_visit2 = date_visit2.replace(
+ microsecond=round(date_visit2.microsecond, -3))
+
origin_visit1 = swh_storage.origin_visit_add(
origin_url,
date=date_visit,
@@ -1491,13 +1504,21 @@
origin_url = data.origin['url']
date_visit = datetime.datetime.now(datetime.timezone.utc)
+ date_visit2 = date_visit + datetime.timedelta(minutes=1)
+
+ # Round to milliseconds before insertion, so equality doesn't fail
+ # after a round-trip through a DB (eg. Cassandra)
+ date_visit = date_visit.replace(
+ microsecond=round(date_visit.microsecond, -3))
+ date_visit2 = date_visit2.replace(
+ microsecond=round(date_visit2.microsecond, -3))
+
origin_visit1 = swh_storage.origin_visit_add(
origin_url,
date=date_visit,
type=data.type_visit1,
)
- date_visit2 = date_visit + datetime.timedelta(minutes=1)
origin_visit2 = swh_storage.origin_visit_add(
origin_url,
date=date_visit2,
diff --git a/tox.ini b/tox.ini
--- a/tox.ini
+++ b/tox.ini
@@ -7,8 +7,12 @@
deps =
pytest-cov
dev: ipdb
+passenv =
+ LOG_CASSANDRA
commands =
pytest \
+ !cassandra: -m "not cassandra" \
+ cassandra: -m "cassandra" \
!slow: --hypothesis-profile=fast \
slow: --hypothesis-profile=slow \
--cov={envsitepackagesdir}/swh/storage \