diff --git a/PKG-INFO b/PKG-INFO index 11e02d9..79d9098 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,112 +1,114 @@ Metadata-Version: 1.1 Name: cassandra-driver -Version: 3.16.0 +Version: 3.20.2 Summary: Python driver for Cassandra Home-page: http://github.com/datastax/python-driver -Author:: Tyler Hobbs -Author-email:: tyler@datastax.com +Author: Tyler Hobbs +Author-email: tyler@datastax.com License: UNKNOWN Description: DataStax Python Driver for Apache Cassandra =========================================== .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master :target: https://travis-ci.org/datastax/python-driver A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. - The driver supports Python 2.7, 3.4, 3.5, and 3.6. + The driver supports Python 2.7, 3.4, 3.5, 3.6 and 3.7. If you require compatibility with DataStax Enterprise, use the `DataStax Enterprise Python Driver `_. **Note:** DataStax products do not support big-endian systems. Feedback Requested ------------------ **Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short). Features -------- * `Synchronous `_ and `Asynchronous `_ APIs * `Simple, Prepared, and Batch statements `_ * Asynchronous IO, parallel execution, request pipelining * `Connection pooling `_ * Automatic node discovery * `Automatic reconnection `_ * Configurable `load balancing `_ and `retry policies `_ * `Concurrent execution utilities `_ * `Object mapper `_ + * `Connecting to DataStax Apollo database (cloud) `_ Installation ------------ Installation through pip is recommended:: $ pip install cassandra-driver For more complete installation instructions, see the `installation guide `_. Documentation ------------- The documentation can be found online `here `_. A couple of links for getting up to speed: * `Installation `_ * `Getting started guide `_ * `API docs `_ * `Performance tips `_ Object Mapper ------------- cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the community) is now maintained as an integral part of this package. Refer to `documentation here `_. Contributing ------------ See `CONTRIBUTING.md `_. Reporting Problems ------------------ Please report any bugs and make any feature requests on the `JIRA `_ issue tracker. If you would like to contribute, please feel free to open a pull request. Getting Help ------------ Your best options for getting help with the driver are the `mailing list `_ and the ``#datastax-drivers`` channel in the `DataStax Academy Slack `_. License ------- - Copyright 2013-2017 DataStax + Copyright DataStax, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Keywords: cassandra,cql,orm Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Natural Language :: English Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2.7 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Software Development :: Libraries :: Python Modules diff --git a/README.rst b/README.rst index f14cc77..b98463c 100644 --- a/README.rst +++ b/README.rst @@ -1,88 +1,89 @@ DataStax Python Driver for Apache Cassandra =========================================== .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master :target: https://travis-ci.org/datastax/python-driver A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. -The driver supports Python 2.7, 3.4, 3.5, and 3.6. +The driver supports Python 2.7, 3.4, 3.5, 3.6 and 3.7. If you require compatibility with DataStax Enterprise, use the `DataStax Enterprise Python Driver `_. **Note:** DataStax products do not support big-endian systems. Feedback Requested ------------------ **Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short). Features -------- * `Synchronous `_ and `Asynchronous `_ APIs * `Simple, Prepared, and Batch statements `_ * Asynchronous IO, parallel execution, request pipelining * `Connection pooling `_ * Automatic node discovery * `Automatic reconnection `_ * Configurable `load balancing `_ and `retry policies `_ * `Concurrent execution utilities `_ * `Object mapper `_ +* `Connecting to DataStax Apollo database (cloud) `_ Installation ------------ Installation through pip is recommended:: $ pip install cassandra-driver For more complete installation instructions, see the `installation guide `_. Documentation ------------- The documentation can be found online `here `_. A couple of links for getting up to speed: * `Installation `_ * `Getting started guide `_ * `API docs `_ * `Performance tips `_ Object Mapper ------------- cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the community) is now maintained as an integral part of this package. Refer to `documentation here `_. Contributing ------------ See `CONTRIBUTING.md `_. Reporting Problems ------------------ Please report any bugs and make any feature requests on the `JIRA `_ issue tracker. If you would like to contribute, please feel free to open a pull request. Getting Help ------------ Your best options for getting help with the driver are the `mailing list `_ and the ``#datastax-drivers`` channel in the `DataStax Academy Slack `_. License ------- -Copyright 2013-2017 DataStax +Copyright DataStax, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 94a2bc9..38aef2f 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -1,698 +1,703 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging class NullHandler(logging.Handler): def emit(self, record): pass logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 16, 0) +__version_info__ = (3, 20, 2) __version__ = '.'.join(map(str, __version_info__)) class ConsistencyLevel(object): """ Spcifies how many replicas must respond for an operation to be considered a success. By default, ``ONE`` is used for all operations. """ ANY = 0 """ Only requires that one replica receives the write *or* the coordinator stores a hint to replay later. Valid only for writes. """ ONE = 1 """ Only one replica needs to respond to consider the operation a success """ TWO = 2 """ Two replicas must respond to consider the operation a success """ THREE = 3 """ Three replicas must respond to consider the operation a success """ QUORUM = 4 """ ``ceil(RF/2)`` replicas must respond to consider the operation a success """ ALL = 5 """ All replicas must respond to consider the operation a success """ LOCAL_QUORUM = 6 """ Requires a quorum of replicas in the local datacenter """ EACH_QUORUM = 7 """ Requires a quorum of replicas in each datacenter """ SERIAL = 8 """ For conditional inserts/updates that utilize Cassandra's lightweight transactions, this requires consensus among all replicas for the modified data. """ LOCAL_SERIAL = 9 """ Like :attr:`~ConsistencyLevel.SERIAL`, but only requires consensus among replicas in the local datacenter. """ LOCAL_ONE = 10 """ Sends a request only to replicas in the local datacenter and waits for one response. """ + @staticmethod + def is_serial(cl): + return cl == ConsistencyLevel.SERIAL or cl == ConsistencyLevel.LOCAL_SERIAL + + ConsistencyLevel.value_to_name = { ConsistencyLevel.ANY: 'ANY', ConsistencyLevel.ONE: 'ONE', ConsistencyLevel.TWO: 'TWO', ConsistencyLevel.THREE: 'THREE', ConsistencyLevel.QUORUM: 'QUORUM', ConsistencyLevel.ALL: 'ALL', ConsistencyLevel.LOCAL_QUORUM: 'LOCAL_QUORUM', ConsistencyLevel.EACH_QUORUM: 'EACH_QUORUM', ConsistencyLevel.SERIAL: 'SERIAL', ConsistencyLevel.LOCAL_SERIAL: 'LOCAL_SERIAL', ConsistencyLevel.LOCAL_ONE: 'LOCAL_ONE' } ConsistencyLevel.name_to_value = { 'ANY': ConsistencyLevel.ANY, 'ONE': ConsistencyLevel.ONE, 'TWO': ConsistencyLevel.TWO, 'THREE': ConsistencyLevel.THREE, 'QUORUM': ConsistencyLevel.QUORUM, 'ALL': ConsistencyLevel.ALL, 'LOCAL_QUORUM': ConsistencyLevel.LOCAL_QUORUM, 'EACH_QUORUM': ConsistencyLevel.EACH_QUORUM, 'SERIAL': ConsistencyLevel.SERIAL, 'LOCAL_SERIAL': ConsistencyLevel.LOCAL_SERIAL, 'LOCAL_ONE': ConsistencyLevel.LOCAL_ONE } def consistency_value_to_name(value): return ConsistencyLevel.value_to_name[value] if value is not None else "Not Set" class ProtocolVersion(object): """ Defines native protocol versions supported by this driver. """ V1 = 1 """ v1, supported in Cassandra 1.2-->2.2 """ V2 = 2 """ v2, supported in Cassandra 2.0-->2.2; added support for lightweight transactions, batch operations, and automatic query paging. """ V3 = 3 """ v3, supported in Cassandra 2.1-->3.x+; added support for protocol-level client-side timestamps (see :attr:`.Session.use_client_timestamp`), serial consistency levels for :class:`~.BatchStatement`, and an improved connection pool. """ V4 = 4 """ v4, supported in Cassandra 2.2-->3.x+; added a number of new types, server warnings, new failure messages, and custom payloads. Details in the `project docs `_ """ V5 = 5 """ v5, in beta from 3.x+ """ SUPPORTED_VERSIONS = (V5, V4, V3, V2, V1) """ A tuple of all supported protocol versions """ BETA_VERSIONS = (V5,) """ A tuple of all beta protocol versions """ MIN_SUPPORTED = min(SUPPORTED_VERSIONS) """ Minimum protocol version supported by this driver. """ MAX_SUPPORTED = max(SUPPORTED_VERSIONS) """ Maximum protocol versioni supported by this driver. """ @classmethod def get_lower_supported(cls, previous_version): """ Return the lower supported protocol version. Beta versions are omitted. """ try: version = next(v for v in sorted(ProtocolVersion.SUPPORTED_VERSIONS, reverse=True) if v not in ProtocolVersion.BETA_VERSIONS and v < previous_version) except StopIteration: version = 0 return version @classmethod def uses_int_query_flags(cls, version): return version >= cls.V5 @classmethod def uses_prepare_flags(cls, version): return version >= cls.V5 @classmethod def uses_prepared_metadata(cls, version): return version >= cls.V5 @classmethod def uses_error_code_map(cls, version): return version >= cls.V5 @classmethod def uses_keyspace_flag(cls, version): return version >= cls.V5 class WriteType(object): """ For usage with :class:`.RetryPolicy`, this describe a type of write operation. """ SIMPLE = 0 """ A write to a single partition key. Such writes are guaranteed to be atomic and isolated. """ BATCH = 1 """ A write to multiple partition keys that used the distributed batch log to ensure atomicity. """ UNLOGGED_BATCH = 2 """ A write to multiple partition keys that did not use the distributed batch log. Atomicity for such writes is not guaranteed. """ COUNTER = 3 """ A counter write (for one or multiple partition keys). Such writes should not be replayed in order to avoid overcount. """ BATCH_LOG = 4 """ The initial write to the distributed batch log that Cassandra performs internally before a BATCH write. """ CAS = 5 """ A lighweight-transaction write, such as "DELETE ... IF EXISTS". """ VIEW = 6 """ This WriteType is only seen in results for requests that were unable to complete MV operations. """ CDC = 7 """ This WriteType is only seen in results for requests that were unable to complete CDC operations. """ WriteType.name_to_value = { 'SIMPLE': WriteType.SIMPLE, 'BATCH': WriteType.BATCH, 'UNLOGGED_BATCH': WriteType.UNLOGGED_BATCH, 'COUNTER': WriteType.COUNTER, 'BATCH_LOG': WriteType.BATCH_LOG, 'CAS': WriteType.CAS, 'VIEW': WriteType.VIEW, 'CDC': WriteType.CDC } WriteType.value_to_name = {v: k for k, v in WriteType.name_to_value.items()} class SchemaChangeType(object): DROPPED = 'DROPPED' CREATED = 'CREATED' UPDATED = 'UPDATED' class SchemaTargetType(object): KEYSPACE = 'KEYSPACE' TABLE = 'TABLE' TYPE = 'TYPE' FUNCTION = 'FUNCTION' AGGREGATE = 'AGGREGATE' class SignatureDescriptor(object): def __init__(self, name, argument_types): self.name = name self.argument_types = argument_types @property def signature(self): """ function signature string in the form 'name([type0[,type1[...]]])' can be used to uniquely identify overloaded function names within a keyspace """ return self.format_signature(self.name, self.argument_types) @staticmethod def format_signature(name, argument_types): return "%s(%s)" % (name, ','.join(t for t in argument_types)) def __repr__(self): return "%s(%s, %s)" % (self.__class__.__name__, self.name, self.argument_types) class UserFunctionDescriptor(SignatureDescriptor): """ Describes a User function by name and argument signature """ name = None """ name of the function """ argument_types = None """ Ordered list of CQL argument type names comprising the type signature """ class UserAggregateDescriptor(SignatureDescriptor): """ Describes a User aggregate function by name and argument signature """ name = None """ name of the aggregate """ argument_types = None """ Ordered list of CQL argument type names comprising the type signature """ class DriverException(Exception): """ Base for all exceptions explicitly raised by the driver. """ pass class RequestExecutionException(DriverException): """ Base for request execution exceptions returned from the server. """ pass class Unavailable(RequestExecutionException): """ There were not enough live replicas to satisfy the requested consistency level, so the coordinator node immediately failed the request without forwarding it to any replicas. """ consistency = None """ The requested :class:`ConsistencyLevel` """ required_replicas = None """ The number of replicas that needed to be live to complete the operation """ alive_replicas = None """ The number of replicas that were actually alive """ def __init__(self, summary_message, consistency=None, required_replicas=None, alive_replicas=None): self.consistency = consistency self.required_replicas = required_replicas self.alive_replicas = alive_replicas Exception.__init__(self, summary_message + ' info=' + repr({'consistency': consistency_value_to_name(consistency), 'required_replicas': required_replicas, 'alive_replicas': alive_replicas})) class Timeout(RequestExecutionException): """ Replicas failed to respond to the coordinator node before timing out. """ consistency = None """ The requested :class:`ConsistencyLevel` """ required_responses = None """ The number of required replica responses """ received_responses = None """ The number of replicas that responded before the coordinator timed out the operation """ def __init__(self, summary_message, consistency=None, required_responses=None, received_responses=None, **kwargs): self.consistency = consistency self.required_responses = required_responses self.received_responses = received_responses if "write_type" in kwargs: kwargs["write_type"] = WriteType.value_to_name[kwargs["write_type"]] info = {'consistency': consistency_value_to_name(consistency), 'required_responses': required_responses, 'received_responses': received_responses} info.update(kwargs) Exception.__init__(self, summary_message + ' info=' + repr(info)) class ReadTimeout(Timeout): """ A subclass of :exc:`Timeout` for read operations. This indicates that the replicas failed to respond to the coordinator node before the configured timeout. This timeout is configured in ``cassandra.yaml`` with the ``read_request_timeout_in_ms`` and ``range_request_timeout_in_ms`` options. """ data_retrieved = None """ A boolean indicating whether the requested data was retrieved by the coordinator from any replicas before it timed out the operation """ def __init__(self, message, data_retrieved=None, **kwargs): Timeout.__init__(self, message, **kwargs) self.data_retrieved = data_retrieved class WriteTimeout(Timeout): """ A subclass of :exc:`Timeout` for write operations. This indicates that the replicas failed to respond to the coordinator node before the configured timeout. This timeout is configured in ``cassandra.yaml`` with the ``write_request_timeout_in_ms`` option. """ write_type = None """ The type of write operation, enum on :class:`~cassandra.policies.WriteType` """ def __init__(self, message, write_type=None, **kwargs): kwargs["write_type"] = write_type Timeout.__init__(self, message, **kwargs) self.write_type = write_type class CDCWriteFailure(RequestExecutionException): """ Hit limit on data in CDC folder, writes are rejected """ def __init__(self, message): Exception.__init__(self, message) class CoordinationFailure(RequestExecutionException): """ Replicas sent a failure to the coordinator. """ consistency = None """ The requested :class:`ConsistencyLevel` """ required_responses = None """ The number of required replica responses """ received_responses = None """ The number of replicas that responded before the coordinator timed out the operation """ failures = None """ The number of replicas that sent a failure message """ error_code_map = None """ A map of inet addresses to error codes representing replicas that sent a failure message. Only set when `protocol_version` is 5 or higher. """ def __init__(self, summary_message, consistency=None, required_responses=None, received_responses=None, failures=None, error_code_map=None): self.consistency = consistency self.required_responses = required_responses self.received_responses = received_responses self.failures = failures self.error_code_map = error_code_map info_dict = { 'consistency': consistency_value_to_name(consistency), 'required_responses': required_responses, 'received_responses': received_responses, 'failures': failures } if error_code_map is not None: # make error codes look like "0x002a" formatted_map = dict((addr, '0x%04x' % err_code) for (addr, err_code) in error_code_map.items()) info_dict['error_code_map'] = formatted_map Exception.__init__(self, summary_message + ' info=' + repr(info_dict)) class ReadFailure(CoordinationFailure): """ A subclass of :exc:`CoordinationFailure` for read operations. This indicates that the replicas sent a failure message to the coordinator. """ data_retrieved = None """ A boolean indicating whether the requested data was retrieved by the coordinator from any replicas before it timed out the operation """ def __init__(self, message, data_retrieved=None, **kwargs): CoordinationFailure.__init__(self, message, **kwargs) self.data_retrieved = data_retrieved class WriteFailure(CoordinationFailure): """ A subclass of :exc:`CoordinationFailure` for write operations. This indicates that the replicas sent a failure message to the coordinator. """ write_type = None """ The type of write operation, enum on :class:`~cassandra.policies.WriteType` """ def __init__(self, message, write_type=None, **kwargs): CoordinationFailure.__init__(self, message, **kwargs) self.write_type = write_type class FunctionFailure(RequestExecutionException): """ User Defined Function failed during execution """ keyspace = None """ Keyspace of the function """ function = None """ Name of the function """ arg_types = None """ List of argument type names of the function """ def __init__(self, summary_message, keyspace, function, arg_types): self.keyspace = keyspace self.function = function self.arg_types = arg_types Exception.__init__(self, summary_message) class RequestValidationException(DriverException): """ Server request validation failed """ pass class ConfigurationException(RequestValidationException): """ Server indicated request errro due to current configuration """ pass class AlreadyExists(ConfigurationException): """ An attempt was made to create a keyspace or table that already exists. """ keyspace = None """ The name of the keyspace that already exists, or, if an attempt was made to create a new table, the keyspace that the table is in. """ table = None """ The name of the table that already exists, or, if an attempt was make to create a keyspace, :const:`None`. """ def __init__(self, keyspace=None, table=None): if table: message = "Table '%s.%s' already exists" % (keyspace, table) else: message = "Keyspace '%s' already exists" % (keyspace,) Exception.__init__(self, message) self.keyspace = keyspace self.table = table class InvalidRequest(RequestValidationException): """ A query was made that was invalid for some reason, such as trying to set the keyspace for a connection to a nonexistent keyspace. """ pass class Unauthorized(RequestValidationException): """ The current user is not authorized to perform the requested operation. """ pass class AuthenticationFailed(DriverException): """ Failed to authenticate. """ pass class OperationTimedOut(DriverException): """ The operation took longer than the specified (client-side) timeout to complete. This is not an error generated by Cassandra, only the driver. """ errors = None """ A dict of errors keyed by the :class:`~.Host` against which they occurred. """ last_host = None """ The last :class:`~.Host` this operation was attempted against. """ def __init__(self, errors=None, last_host=None): self.errors = errors self.last_host = last_host message = "errors=%s, last_host=%s" % (self.errors, self.last_host) Exception.__init__(self, message) class UnsupportedOperation(DriverException): """ An attempt was made to use a feature that is not supported by the selected protocol version. See :attr:`Cluster.protocol_version` for more details. """ pass class UnresolvableContactPoints(DriverException): """ The driver was unable to resolve any provided hostnames. Note that this is *not* raised when a :class:`.Cluster` is created with no contact points, only when lookup fails for all hosts """ pass diff --git a/cassandra/cluster.py b/cassandra/cluster.py index e119605..8fcbe33 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1,4409 +1,4651 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module houses the main classes you will interact with, :class:`.Cluster` and :class:`.Session`. """ from __future__ import absolute_import import atexit -from collections import defaultdict, Mapping +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures from copy import copy from functools import partial, wraps from itertools import groupby, count import logging from warnings import warn from random import random import six from six.moves import filter, range, queue as Queue import socket import sys import time from threading import Lock, RLock, Thread, Event import weakref from weakref import WeakValueDictionary try: from weakref import WeakSet except ImportError: from cassandra.util import WeakSet # NOQA from cassandra import (ConsistencyLevel, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, SchemaTargetType, DriverException, ProtocolVersion, UnresolvableContactPoints) +from cassandra.auth import PlainTextAuthProvider from cassandra.connection import (ConnectionException, ConnectionShutdown, - ConnectionHeartbeat, ProtocolVersionUnsupported) + ConnectionHeartbeat, ProtocolVersionUnsupported, + EndPoint, DefaultEndPoint, DefaultEndPointFactory, + SniEndPointFactory) from cassandra.cqltypes import UserType from cassandra.encoder import Encoder from cassandra.protocol import (QueryMessage, ResultMessage, ErrorMessage, ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, OverloadedErrorMessage, PrepareMessage, ExecuteMessage, PreparedQueryNotFound, IsBootstrappingErrorMessage, + TruncateError, ServerError, BatchMessage, RESULT_KIND_PREPARED, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler) from cassandra.metadata import Metadata, protect_name, murmur3 from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, ExponentialReconnectionPolicy, HostDistance, RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan, NoSpeculativeExecutionPolicy) from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, HostConnectionPool, HostConnection, NoConnectionsAvailable) from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, BatchStatement, bind_params, QueryTrace, TraceUnavailable, named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET) from cassandra.timestamps import MonotonicTimestampGenerator +from cassandra.compat import Mapping +from cassandra.datastax import cloud as dscloud def _is_eventlet_monkey_patched(): if 'eventlet.patcher' not in sys.modules: return False import eventlet.patcher return eventlet.patcher.is_monkey_patched('socket') def _is_gevent_monkey_patched(): if 'gevent.monkey' not in sys.modules: return False import gevent.socket return socket.socket is gevent.socket.socket # default to gevent when we are monkey patched with gevent, eventlet when # monkey patched with eventlet, otherwise if libev is available, use that as # the default because it's fastest. Otherwise, use asyncore. if _is_gevent_monkey_patched(): from cassandra.io.geventreactor import GeventConnection as DefaultConnection elif _is_eventlet_monkey_patched(): from cassandra.io.eventletreactor import EventletConnection as DefaultConnection else: try: from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA except ImportError: from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA # Forces load of utf8 encoding module to avoid deadlock that occurs # if code that is being imported tries to import the module in a seperate # thread. # See http://bugs.python.org/issue10923 "".encode('utf8') log = logging.getLogger(__name__) DEFAULT_MIN_REQUESTS = 5 DEFAULT_MAX_REQUESTS = 100 DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST = 2 DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST = 8 DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1 DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2 _NOT_SET = object() class NoHostAvailable(Exception): """ Raised when an operation is attempted but all connections are busy, defunct, closed, or resulted in errors when used. """ errors = None """ A map of the form ``{ip: exception}`` which details the particular Exception that was caught for each host the operation was attempted against. """ def __init__(self, message, errors): Exception.__init__(self, message, errors) self.errors = errors def _future_completed(future): """ Helper for run_in_executor() """ exc = future.exception() if exc: log.debug("Failed to run task on executor", exc_info=exc) def run_in_executor(f): """ A decorator to run the given method in the ThreadPoolExecutor. """ @wraps(f) def new_f(self, *args, **kwargs): if self.is_shutdown: return try: future = self.executor.submit(f, self, *args, **kwargs) future.add_done_callback(_future_completed) except Exception: log.exception("Failed to submit task to executor") return new_f _clusters_for_shutdown = set() def _register_cluster_shutdown(cluster): _clusters_for_shutdown.add(cluster) def _discard_cluster_shutdown(cluster): _clusters_for_shutdown.discard(cluster) def _shutdown_clusters(): clusters = _clusters_for_shutdown.copy() # copy because shutdown modifies the global set "discard" for cluster in clusters: cluster.shutdown() atexit.register(_shutdown_clusters) def default_lbp_factory(): if murmur3 is not None: return TokenAwarePolicy(DCAwareRoundRobinPolicy()) return DCAwareRoundRobinPolicy() def _addrinfo_or_none(contact_point, port): """ A helper function that wraps socket.getaddrinfo and returns None when it fails to, e.g. resolve one of the hostnames. Used to address PYTHON-895. """ try: return socket.getaddrinfo(contact_point, port, socket.AF_UNSPEC, socket.SOCK_STREAM) except socket.gaierror: log.debug('Could not resolve hostname "{}" ' 'with port {}'.format(contact_point, port)) return None def _resolve_contact_points(contact_points, port): resolved = tuple(_addrinfo_or_none(p, port) for p in contact_points) if resolved and all((x is None for x in resolved)): raise UnresolvableContactPoints(contact_points, port) resolved = tuple(r for r in resolved if r is not None) return [endpoint[4][0] for addrinfo in resolved for endpoint in addrinfo] +def _execution_profile_to_string(name): + if name is EXEC_PROFILE_DEFAULT: + return 'EXEC_PROFILE_DEFAULT' + return '"%s"' % (name,) + + class ExecutionProfile(object): load_balancing_policy = None """ An instance of :class:`.policies.LoadBalancingPolicy` or one of its subclasses. Used in determining host distance for establishing connections, and routing requests. Defaults to ``TokenAwarePolicy(DCAwareRoundRobinPolicy())`` if not specified """ retry_policy = None """ An instance of :class:`.policies.RetryPolicy` instance used when :class:`.Statement` objects do not have a :attr:`~.Statement.retry_policy` explicitly set. Defaults to :class:`.RetryPolicy` if not specified """ consistency_level = ConsistencyLevel.LOCAL_ONE """ :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement`. """ serial_consistency_level = None """ Serial :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement` (for LWT conditional statements). """ request_timeout = 10.0 """ Request timeout used when not overridden in :meth:`.Session.execute` """ - row_factory = staticmethod(tuple_factory) + row_factory = staticmethod(named_tuple_factory) """ A callable to format results, accepting ``(colnames, rows)`` where ``colnames`` is a list of column names, and ``rows`` is a list of tuples, with each tuple representing a row of parsed values. Some example implementations: - :func:`cassandra.query.tuple_factory` - return a result row as a tuple - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple - :func:`cassandra.query.dict_factory` - return a result row as a dict - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict """ speculative_execution_policy = None """ An instance of :class:`.policies.SpeculativeExecutionPolicy` Defaults to :class:`.NoSpeculativeExecutionPolicy` if not specified """ - # indicates if lbp was set explicitly or uses default values + # indicates if set explicitly or uses default values _load_balancing_policy_explicit = False + _consistency_level_explicit = False def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None, - consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, + consistency_level=_NOT_SET, serial_consistency_level=None, request_timeout=10.0, row_factory=named_tuple_factory, speculative_execution_policy=None): if load_balancing_policy is _NOT_SET: self._load_balancing_policy_explicit = False self.load_balancing_policy = default_lbp_factory() else: self._load_balancing_policy_explicit = True self.load_balancing_policy = load_balancing_policy + + if consistency_level is _NOT_SET: + self._consistency_level_explicit = False + self.consistency_level = ConsistencyLevel.LOCAL_ONE + else: + self._consistency_level_explicit = True + self.consistency_level = consistency_level + self.retry_policy = retry_policy or RetryPolicy() - self.consistency_level = consistency_level + + if (serial_consistency_level is not None and + not ConsistencyLevel.is_serial(serial_consistency_level)): + raise ValueError("serial_consistency_level must be either " + "ConsistencyLevel.SERIAL " + "or ConsistencyLevel.LOCAL_SERIAL.") self.serial_consistency_level = serial_consistency_level + self.request_timeout = request_timeout self.row_factory = row_factory self.speculative_execution_policy = speculative_execution_policy or NoSpeculativeExecutionPolicy() class ProfileManager(object): def __init__(self): self.profiles = dict() def _profiles_without_explicit_lbps(self): names = (profile_name for profile_name, profile in self.profiles.items() if not profile._load_balancing_policy_explicit) return tuple( 'EXEC_PROFILE_DEFAULT' if n is EXEC_PROFILE_DEFAULT else n for n in names ) def distance(self, host): distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) return HostDistance.LOCAL if HostDistance.LOCAL in distances else \ HostDistance.REMOTE if HostDistance.REMOTE in distances else \ HostDistance.IGNORED def populate(self, cluster, hosts): for p in self.profiles.values(): p.load_balancing_policy.populate(cluster, hosts) def check_supported(self): for p in self.profiles.values(): p.load_balancing_policy.check_supported() def on_up(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_up(host) def on_down(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_down(host) def on_add(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_add(host) def on_remove(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_remove(host) @property def default(self): """ internal-only; no checks are done because this entry is populated on cluster init """ return self.profiles[EXEC_PROFILE_DEFAULT] EXEC_PROFILE_DEFAULT = object() """ Key for the ``Cluster`` default execution profile, used when no other profile is selected in ``Session.execute(execution_profile)``. Use this as the key in ``Cluster(execution_profiles)`` to override the default profile. """ class _ConfigMode(object): UNCOMMITTED = 0 LEGACY = 1 PROFILES = 2 class Cluster(object): """ The main class to use when interacting with a Cassandra cluster. Typically, one instance of this class will be created for each separate Cassandra cluster that your application interacts with. Example usage:: >>> from cassandra.cluster import Cluster >>> cluster = Cluster(['192.168.1.1', '192.168.1.2']) >>> session = cluster.connect() >>> session.execute("CREATE KEYSPACE ...") >>> ... >>> cluster.shutdown() ``Cluster`` and ``Session`` also provide context management functions which implicitly handle shutdown when leaving scope. """ contact_points = ['127.0.0.1'] """ - The list of contact points to try connecting for cluster discovery. + The list of contact points to try connecting for cluster discovery. A + contact point can be a string (ip, hostname) or a + :class:`.connection.EndPoint` instance. Defaults to loopback interface. Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit local_dc set (as is the default), the DC is chosen from an arbitrary host in contact_points. In this case, contact_points should contain only nodes from a single, local DC. Note: In the next major version, if you specify contact points, you will also be required to also explicitly specify a load-balancing policy. This change will help prevent cases where users had hard-to-debug issues surrounding unintuitive default load-balancing policy behavior. """ # tracks if contact_points was set explicitly or with default values _contact_points_explicit = None port = 9042 """ The server-side port to open connections to. Defaults to 9042. """ cql_version = None """ If a specific version of CQL should be used, this may be set to that string version. Otherwise, the highest CQL version supported by the server will be automatically used. """ protocol_version = ProtocolVersion.V4 """ The maximum version of the native protocol to use. See :class:`.ProtocolVersion` for more information about versions. If not set in the constructor, the driver will automatically downgrade version based on a negotiation with the server, but it is most efficient to set this to the maximum supported by your version of Cassandra. Setting this will also prevent conflicting versions negotiated if your cluster is upgraded. """ allow_beta_protocol_version = False no_compact = False """ Setting true injects a flag in all messages that makes the server accept and use "beta" protocol version. Used for testing new protocol features incrementally before the new version is complete. """ compression = True """ Controls compression for communications between the driver and Cassandra. If left as the default of :const:`True`, either lz4 or snappy compression may be used, depending on what is supported by both the driver and Cassandra. If both are fully supported, lz4 will be preferred. You may also set this to 'snappy' or 'lz4' to request that specific compression type. Setting this to :const:`False` disables compression. """ _auth_provider = None _auth_provider_callable = None @property def auth_provider(self): """ When :attr:`~.Cluster.protocol_version` is 2 or higher, this should be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`, such as :class:`~.PlainTextAuthProvider`. When :attr:`~.Cluster.protocol_version` is 1, this should be a function that accepts one argument, the IP address of a node, and returns a dict of credentials for that node. When not using authentication, this should be left as :const:`None`. """ return self._auth_provider @auth_provider.setter # noqa def auth_provider(self, value): if not value: self._auth_provider = value return try: self._auth_provider_callable = value.new_authenticator except AttributeError: if self.protocol_version > 1: raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider " "interface when protocol_version >= 2") elif not callable(value): raise TypeError("auth_provider must be callable when protocol_version == 1") self._auth_provider_callable = value self._auth_provider = value _load_balancing_policy = None @property def load_balancing_policy(self): """ An instance of :class:`.policies.LoadBalancingPolicy` or one of its subclasses. .. versionchanged:: 2.6.0 Defaults to :class:`~.TokenAwarePolicy` (:class:`~.DCAwareRoundRobinPolicy`). when using CPython (where the murmur3 extension is available). :class:`~.DCAwareRoundRobinPolicy` otherwise. Default local DC will be chosen from contact points. **Please see** :class:`~.DCAwareRoundRobinPolicy` **for a discussion on default behavior with respect to DC locality and remote nodes.** """ return self._load_balancing_policy @load_balancing_policy.setter def load_balancing_policy(self, lbp): if self._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Cluster.load_balancing_policy while using Configuration Profiles. Set this in a profile instead.") self._load_balancing_policy = lbp self._config_mode = _ConfigMode.LEGACY @property def _default_load_balancing_policy(self): return self.profile_manager.default.load_balancing_policy reconnection_policy = ExponentialReconnectionPolicy(1.0, 600.0) """ An instance of :class:`.policies.ReconnectionPolicy`. Defaults to an instance of :class:`.ExponentialReconnectionPolicy` with a base delay of one second and a max delay of ten minutes. """ _default_retry_policy = RetryPolicy() @property def default_retry_policy(self): """ A default :class:`.policies.RetryPolicy` instance to use for all :class:`.Statement` objects which do not have a :attr:`~.Statement.retry_policy` explicitly set. """ return self._default_retry_policy @default_retry_policy.setter def default_retry_policy(self, policy): if self._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Cluster.default_retry_policy while using Configuration Profiles. Set this in a profile instead.") self._default_retry_policy = policy self._config_mode = _ConfigMode.LEGACY conviction_policy_factory = SimpleConvictionPolicy """ A factory function which creates instances of :class:`.policies.ConvictionPolicy`. Defaults to :class:`.policies.SimpleConvictionPolicy`. """ address_translator = IdentityTranslator() """ :class:`.policies.AddressTranslator` instance to be used in translating server node addresses to driver connection addresses. """ connect_to_remote_hosts = True """ If left as :const:`True`, hosts that are considered :attr:`~.HostDistance.REMOTE` by the :attr:`~.Cluster.load_balancing_policy` will have a connection opened to them. Otherwise, they will not have a connection opened to them. Note that the default load balancing policy ignores remote hosts by default. .. versionadded:: 2.1.0 """ metrics_enabled = False """ Whether or not metric collection is enabled. If enabled, :attr:`.metrics` will be an instance of :class:`~cassandra.metrics.Metrics`. """ metrics = None """ An instance of :class:`cassandra.metrics.Metrics` if :attr:`.metrics_enabled` is :const:`True`, else :const:`None`. """ ssl_options = None """ - A optional dict which will be used as kwargs for ``ssl.wrap_socket()`` - when new sockets are created. This should be used when client encryption - is enabled in Cassandra. + Using ssl_options without ssl_context is deprecated and will be removed in the + next major release. + + An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket`` (or + ``ssl.wrap_socket()`` if used without ssl_context) when new sockets are created. + This should be used when client encryption is enabled in Cassandra. + + The following documentation only applies when ssl_options is used without ssl_context. By default, a ``ca_certs`` value should be supplied (the value should be a string pointing to the location of the CA certs file), and you probably want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLSv1`` to match Cassandra's default protocol. .. versionchanged:: 3.3.0 In addition to ``wrap_socket`` kwargs, clients may also specify ``'check_hostname': True`` to verify the cert hostname as outlined in RFC 2818 and RFC 6125. Note that this requires the certificate to be transferred, so should almost always require the option ``'cert_reqs': ssl.CERT_REQUIRED``. Note also that this functionality was not built into Python standard library until (2.7.9, 3.2). To enable this mechanism in earlier versions, patch ``ssl.match_hostname`` with a custom or `back-ported function `_. """ + ssl_context = None + """ + An optional ``ssl.SSLContext`` instance which will be used when new sockets are created. + This should be used when client encryption is enabled in Cassandra. + + ``wrap_socket`` options can be set using :attr:`~Cluster.ssl_options`. ssl_options will + be used as kwargs for ``ssl.SSLContext.wrap_socket``. + + .. versionadded:: 3.17.0 + """ + sockopts = None """ An optional list of tuples which will be used as arguments to ``socket.setsockopt()`` for all created sockets. Note: some drivers find setting TCPNODELAY beneficial in the context of their execution model. It was not found generally beneficial for this driver. To try with your own workload, set ``sockopts = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]`` """ max_schema_agreement_wait = 10 """ The maximum duration (in seconds) that the driver will wait for schema agreement across the cluster. Defaults to ten seconds. If set <= 0, the driver will bypass schema agreement waits altogether. """ metadata = None """ An instance of :class:`cassandra.metadata.Metadata`. """ connection_class = DefaultConnection """ This determines what event loop system will be used for managing I/O with Cassandra. These are the current options: * :class:`cassandra.io.asyncorereactor.AsyncoreConnection` * :class:`cassandra.io.libevreactor.LibevConnection` * :class:`cassandra.io.eventletreactor.EventletConnection` (requires monkey-patching - see doc for details) * :class:`cassandra.io.geventreactor.GeventConnection` (requires monkey-patching - see doc for details) * :class:`cassandra.io.twistedreactor.TwistedConnection` * EXPERIMENTAL: :class:`cassandra.io.asyncioreactor.AsyncioConnection` By default, ``AsyncoreConnection`` will be used, which uses the ``asyncore`` module in the Python standard library. If ``libev`` is installed, ``LibevConnection`` will be used instead. If ``gevent`` or ``eventlet`` monkey-patching is detected, the corresponding connection class will be used automatically. ``AsyncioConnection``, which uses the ``asyncio`` module in the Python standard library, is also available, but currently experimental. Note that it requires ``asyncio`` features that were only introduced in the 3.4 line in 3.4.6, and in the 3.5 line in 3.5.1. """ control_connection_timeout = 2.0 """ A timeout, in seconds, for queries made by the control connection, such as querying the current schema and information about nodes in the cluster. If set to :const:`None`, there will be no timeout for these queries. """ idle_heartbeat_interval = 30 """ Interval, in seconds, on which to heartbeat idle connections. This helps keep connections open through network devices that expire idle connections. It also helps discover bad connections early in low-traffic scenarios. Setting to zero disables heartbeats. """ idle_heartbeat_timeout = 30 """ Timeout, in seconds, on which the heartbeat wait for idle connection responses. Lowering this value can help to discover bad connections earlier. """ schema_event_refresh_window = 2 """ Window, in seconds, within which a schema component will be refreshed after receiving a schema_change event. The driver delays a random amount of time in the range [0.0, window) before executing the refresh. This serves two purposes: 1.) Spread the refresh for deployments with large fanout from C* to client tier, preventing a 'thundering herd' problem with many clients refreshing simultaneously. 2.) Remove redundant refreshes. Redundant events arriving within the delay period are discarded, and only one refresh is executed. Setting this to zero will execute refreshes immediately. Setting this negative will disable schema refreshes in response to push events (refreshes will still occur in response to schema change responses to DDL statements executed by Sessions of this Cluster). """ topology_event_refresh_window = 10 """ Window, in seconds, within which the node and token list will be refreshed after receiving a topology_change event. Setting this to zero will execute refreshes immediately. Setting this negative will disable node refreshes in response to push events. See :attr:`.schema_event_refresh_window` for discussion of rationale """ status_event_refresh_window = 2 """ Window, in seconds, within which the driver will start the reconnect after receiving a status_change event. Setting this to zero will connect immediately. This is primarily used to avoid 'thundering herd' in deployments with large fanout from cluster to clients. When nodes come up, clients attempt to reprepare prepared statements (depending on :attr:`.reprepare_on_up`), and establish connection pools. This can cause a rush of connections and queries if not mitigated with this factor. """ prepare_on_all_hosts = True """ Specifies whether statements should be prepared on all hosts, or just one. This can reasonably be disabled on long-running applications with numerous clients preparing statements on startup, where a randomized initial condition of the load balancing policy can be expected to distribute prepares from different clients across the cluster. """ reprepare_on_up = True """ Specifies whether all known prepared statements should be prepared on a node when it comes up. May be used to avoid overwhelming a node on return, or if it is supposed that the node was only marked down due to network. If statements are not reprepared, they are prepared on the first execution, causing an extra roundtrip for one or more client requests. """ connect_timeout = 5 """ Timeout, in seconds, for creating new connections. This timeout covers the entire connection negotiation, including TCP establishment, options passing, and authentication. """ timestamp_generator = None """ An object, shared between all sessions created by this cluster instance, that generates timestamps when client-side timestamp generation is enabled. By default, each :class:`Cluster` uses a new :class:`~.MonotonicTimestampGenerator`. Applications can set this value for custom timestamp behavior. See the documentation for :meth:`Session.timestamp_generator`. """ + cloud = None + """ + A dict of the cloud configuration. Example:: + + { + # path to the secure connect bundle + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip' + } + + The zip file will be temporarily extracted in the same directory to + load the configuration and certificates. + """ + @property def schema_metadata_enabled(self): """ Flag indicating whether internal schema metadata is updated. When disabled, the driver does not populate Cluster.metadata.keyspaces on connect, or on schema change events. This can be used to speed initial connection, and reduce load on client and server during operation. Turning this off gives away token aware request routing, and programmatic inspection of the metadata model. """ return self.control_connection._schema_meta_enabled @schema_metadata_enabled.setter def schema_metadata_enabled(self, enabled): self.control_connection._schema_meta_enabled = bool(enabled) @property def token_metadata_enabled(self): """ Flag indicating whether internal token metadata is updated. When disabled, the driver does not query node token information on connect, or on topology change events. This can be used to speed initial connection, and reduce load on client and server during operation. It is most useful in large clusters using vnodes, where the token map can be expensive to compute. Turning this off gives away token aware request routing, and programmatic inspection of the token ring. """ return self.control_connection._token_meta_enabled @token_metadata_enabled.setter def token_metadata_enabled(self, enabled): self.control_connection._token_meta_enabled = bool(enabled) + endpoint_factory = None + """ + An :class:`~.connection.EndPointFactory` instance to use internally when creating + a socket connection to a node. You can ignore this unless you need a special + connection mechanism. + """ + profile_manager = None _config_mode = _ConfigMode.UNCOMMITTED sessions = None control_connection = None scheduler = None executor = None is_shutdown = False _is_setup = False _prepared_statements = None _prepared_statement_lock = None _idle_heartbeat = None _protocol_version_explicit = False _discount_down_events = True _user_types = None """ A map of {keyspace: {type_name: UserType}} """ _listeners = None _listener_lock = None def __init__(self, contact_points=_NOT_SET, port=9042, compression=True, auth_provider=None, load_balancing_policy=None, reconnection_policy=None, default_retry_policy=None, conviction_policy_factory=None, metrics_enabled=False, connection_class=None, ssl_options=None, sockopts=None, cql_version=None, protocol_version=_NOT_SET, executor_threads=2, max_schema_agreement_wait=10, control_connection_timeout=2.0, idle_heartbeat_interval=30, schema_event_refresh_window=2, topology_event_refresh_window=10, connect_timeout=5, schema_metadata_enabled=True, token_metadata_enabled=True, address_translator=None, status_event_refresh_window=2, prepare_on_all_hosts=True, reprepare_on_up=True, execution_profiles=None, allow_beta_protocol_version=False, timestamp_generator=None, idle_heartbeat_timeout=30, - no_compact=False): + no_compact=False, + ssl_context=None, + endpoint_factory=None, + cloud=None): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as extablishing connection pools or refreshing metadata. Any of the mutable Cluster attributes may be set as keyword arguments to the constructor. """ + + if cloud is not None: + if contact_points is not _NOT_SET or endpoint_factory or ssl_context or ssl_options: + raise ValueError("contact_points, endpoint_factory, ssl_context, and ssl_options " + "cannot be specified with a cloud configuration") + + cloud_config = dscloud.get_cloud_config(cloud) + + ssl_context = cloud_config.ssl_context + ssl_options = {'check_hostname': True} + if (auth_provider is None and cloud_config.username + and cloud_config.password): + auth_provider = PlainTextAuthProvider(cloud_config.username, cloud_config.password) + + endpoint_factory = SniEndPointFactory(cloud_config.sni_host, cloud_config.sni_port) + contact_points = [ + endpoint_factory.create_from_sni(host_id) + for host_id in cloud_config.host_ids + ] + if contact_points is not None: if contact_points is _NOT_SET: self._contact_points_explicit = False contact_points = ['127.0.0.1'] else: self._contact_points_explicit = True if isinstance(contact_points, six.string_types): raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings") if None in contact_points: raise ValueError("contact_points should not contain None (it can resolve to localhost)") self.contact_points = contact_points self.port = port - self.contact_points_resolved = _resolve_contact_points(self.contact_points, - self.port) + self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port) + self.endpoint_factory.configure(self) + + raw_contact_points = [cp for cp in self.contact_points if not isinstance(cp, EndPoint)] + self.endpoints_resolved = [cp for cp in self.contact_points if isinstance(cp, EndPoint)] + + try: + self.endpoints_resolved += [DefaultEndPoint(address, self.port) + for address in _resolve_contact_points(raw_contact_points, self.port)] + except UnresolvableContactPoints: + # rethrow if no EndPoint was provided + if not self.endpoints_resolved: + raise self.compression = compression if protocol_version is not _NOT_SET: self.protocol_version = protocol_version self._protocol_version_explicit = True self.allow_beta_protocol_version = allow_beta_protocol_version self.no_compact = no_compact self.auth_provider = auth_provider if load_balancing_policy is not None: if isinstance(load_balancing_policy, type): raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class") self.load_balancing_policy = load_balancing_policy else: self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode if reconnection_policy is not None: if isinstance(reconnection_policy, type): raise TypeError("reconnection_policy should not be a class, it should be an instance of that class") self.reconnection_policy = reconnection_policy if default_retry_policy is not None: if isinstance(default_retry_policy, type): raise TypeError("default_retry_policy should not be a class, it should be an instance of that class") self.default_retry_policy = default_retry_policy if conviction_policy_factory is not None: if not callable(conviction_policy_factory): raise ValueError("conviction_policy_factory must be callable") self.conviction_policy_factory = conviction_policy_factory if address_translator is not None: if isinstance(address_translator, type): raise TypeError("address_translator should not be a class, it should be an instance of that class") self.address_translator = address_translator if connection_class is not None: self.connection_class = connection_class if timestamp_generator is not None: if not callable(timestamp_generator): raise ValueError("timestamp_generator must be callable") self.timestamp_generator = timestamp_generator else: self.timestamp_generator = MonotonicTimestampGenerator() self.profile_manager = ProfileManager() - self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(self.load_balancing_policy, - self.default_retry_policy, - Session._default_consistency_level, - Session._default_serial_consistency_level, - Session._default_timeout, - Session._row_factory) + self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile( + self.load_balancing_policy, + self.default_retry_policy, + request_timeout=Session._default_timeout, + row_factory=Session._row_factory + ) # legacy mode if either of these is not default if load_balancing_policy or default_retry_policy: if execution_profiles: raise ValueError("Clusters constructed with execution_profiles should not specify legacy parameters " "load_balancing_policy or default_retry_policy. Configure this in a profile instead.") self._config_mode = _ConfigMode.LEGACY warn("Legacy execution parameters will be removed in 4.0. Consider using " "execution profiles.", DeprecationWarning) else: if execution_profiles: self.profile_manager.profiles.update(execution_profiles) self._config_mode = _ConfigMode.PROFILES if self._contact_points_explicit: if self._config_mode is _ConfigMode.PROFILES: default_lbp_profiles = self.profile_manager._profiles_without_explicit_lbps() if default_lbp_profiles: log.warning( 'Cluster.__init__ called with contact_points ' 'specified, but load-balancing policies are not ' 'specified in some ExecutionProfiles. In the next ' 'major version, this will raise an error; please ' 'specify a load-balancing policy. ' '(contact_points = {cp}, ' 'EPs without explicit LBPs = {eps})' ''.format(cp=contact_points, eps=default_lbp_profiles)) else: if load_balancing_policy is None: log.warning( 'Cluster.__init__ called with contact_points ' 'specified, but no load_balancing_policy. In the next ' 'major version, this will raise an error; please ' 'specify a load-balancing policy. ' '(contact_points = {cp}, lbp = {lbp})' ''.format(cp=contact_points, lbp=load_balancing_policy)) self.metrics_enabled = metrics_enabled + + if ssl_options and not ssl_context: + warn('Using ssl_options without ssl_context is ' + 'deprecated and will result in an error in ' + 'the next major release. Please use ssl_context ' + 'to prepare for that release.', + DeprecationWarning) + self.ssl_options = ssl_options + self.ssl_context = ssl_context self.sockopts = sockopts self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait self.control_connection_timeout = control_connection_timeout self.idle_heartbeat_interval = idle_heartbeat_interval self.idle_heartbeat_timeout = idle_heartbeat_timeout self.schema_event_refresh_window = schema_event_refresh_window self.topology_event_refresh_window = topology_event_refresh_window self.status_event_refresh_window = status_event_refresh_window self.connect_timeout = connect_timeout self.prepare_on_all_hosts = prepare_on_all_hosts self.reprepare_on_up = reprepare_on_up self._listeners = set() self._listener_lock = Lock() # let Session objects be GC'ed (and shutdown) when the user no longer # holds a reference. self.sessions = WeakSet() self.metadata = Metadata() self.control_connection = None self._prepared_statements = WeakValueDictionary() self._prepared_statement_lock = Lock() self._user_types = defaultdict(dict) self._min_requests_per_connection = { HostDistance.LOCAL: DEFAULT_MIN_REQUESTS, HostDistance.REMOTE: DEFAULT_MIN_REQUESTS } self._max_requests_per_connection = { HostDistance.LOCAL: DEFAULT_MAX_REQUESTS, HostDistance.REMOTE: DEFAULT_MAX_REQUESTS } self._core_connections_per_host = { HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST } self._max_connections_per_host = { HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST } - self.executor = ThreadPoolExecutor(max_workers=executor_threads) + self.executor = self._create_thread_pool_executor(max_workers=executor_threads) self.scheduler = _Scheduler(self.executor) self._lock = RLock() if self.metrics_enabled: from cassandra.metrics import Metrics self.metrics = Metrics(weakref.proxy(self)) self.control_connection = ControlConnection( self, self.control_connection_timeout, self.schema_event_refresh_window, self.topology_event_refresh_window, self.status_event_refresh_window, schema_metadata_enabled, token_metadata_enabled) + def _create_thread_pool_executor(self, **kwargs): + """ + Create a ThreadPoolExecutor for the cluster. In most cases, the built-in + `concurrent.futures.ThreadPoolExecutor` is used. + + Python 3.7 and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` + to hang indefinitely. In that case, the user needs to have the `futurist` + package so we can use the `futurist.GreenThreadPoolExecutor` class instead. + + :param kwargs: All keyword args are passed to the ThreadPoolExecutor constructor. + :return: A ThreadPoolExecutor instance. + """ + tpe_class = ThreadPoolExecutor + if sys.version_info[0] >= 3 and sys.version_info[1] >= 7: + try: + from cassandra.io.eventletreactor import EventletConnection + is_eventlet = issubclass(self.connection_class, EventletConnection) + except: + # Eventlet is not available or can't be detected + return tpe_class(**kwargs) + + if is_eventlet: + try: + from futurist import GreenThreadPoolExecutor + tpe_class = GreenThreadPoolExecutor + except ImportError: + # futurist is not available + raise ImportError( + ("Python 3.7 and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` " + "to hang indefinitely. If you want to use the Eventlet reactor, you " + "need to install the `futurist` package to allow the driver to use " + "the GreenThreadPoolExecutor. See https://github.com/eventlet/eventlet/issues/508 " + "for more details.")) + + return tpe_class(**kwargs) + def register_user_type(self, keyspace, user_type, klass): """ Registers a class to use to represent a particular user-defined type. Query parameters for this user-defined type will be assumed to be instances of `klass`. Result sets for this user-defined type will be instances of `klass`. If no class is registered for a user-defined type, a namedtuple will be used for result sets, and non-prepared statements may not encode parameters for this type correctly. `keyspace` is the name of the keyspace that the UDT is defined in. `user_type` is the string name of the UDT to register the mapping for. `klass` should be a class with attributes whose names match the fields of the user-defined type. The constructor must accepts kwargs for each of the fields in the UDT. This method should only be called after the type has been created within Cassandra. Example:: cluster = Cluster(protocol_version=3) session = cluster.connect() session.set_keyspace('mykeyspace') session.execute("CREATE TYPE address (street text, zipcode int)") session.execute("CREATE TABLE users (id int PRIMARY KEY, location address)") # create a class to map to the "address" UDT class Address(object): def __init__(self, street, zipcode): self.street = street self.zipcode = zipcode cluster.register_user_type('mykeyspace', 'address', Address) # insert a row using an instance of Address session.execute("INSERT INTO users (id, location) VALUES (%s, %s)", (0, Address("123 Main St.", 78723))) # results will include Address instances results = session.execute("SELECT * FROM users") row = results[0] print row.id, row.location.street, row.location.zipcode """ if self.protocol_version < 3: log.warning("User Type serialization is only supported in native protocol version 3+ (%d in use). " "CQL encoding for simple statements will still work, but named tuples will " "be returned when reading type %s.%s.", self.protocol_version, keyspace, user_type) self._user_types[keyspace][user_type] = klass for session in tuple(self.sessions): session.user_type_registered(keyspace, user_type, klass) UserType.evict_udt_class(keyspace, user_type) def add_execution_profile(self, name, profile, pool_wait_timeout=5): """ Adds an :class:`.ExecutionProfile` to the cluster. This makes it available for use by ``name`` in :meth:`.Session.execute` and :meth:`.Session.execute_async`. This method will raise if the profile already exists. Normally profiles will be injected at cluster initialization via ``Cluster(execution_profiles)``. This method provides a way of adding them dynamically. Adding a new profile updates the connection pools according to the specified ``load_balancing_policy``. By default, this method will wait up to five seconds for the pool creation to complete, so the profile can be used immediately upon return. This behavior can be controlled using ``pool_wait_timeout`` (see `concurrent.futures.wait `_ for timeout semantics). """ if not isinstance(profile, ExecutionProfile): raise TypeError("profile must be an instance of ExecutionProfile") if self._config_mode == _ConfigMode.LEGACY: raise ValueError("Cannot add execution profiles when legacy parameters are set explicitly.") if name in self.profile_manager.profiles: raise ValueError("Profile {} already exists".format(name)) contact_points_but_no_lbp = ( self._contact_points_explicit and not profile._load_balancing_policy_explicit) if contact_points_but_no_lbp: log.warning( 'Tried to add an ExecutionProfile with name {name}. ' '{self} was explicitly configured with contact_points, but ' '{ep} was not explicitly configured with a ' 'load_balancing_policy. In the next major version, trying to ' 'add an ExecutionProfile without an explicitly configured LBP ' 'to a cluster with explicitly configured contact_points will ' 'raise an exception; please specify a load-balancing policy ' 'in the ExecutionProfile.' - ''.format(name=repr(name), self=self, ep=profile)) + ''.format(name=_execution_profile_to_string(name), self=self, ep=profile)) self.profile_manager.profiles[name] = profile profile.load_balancing_policy.populate(self, self.metadata.all_hosts()) # on_up after populate allows things like DCA LBP to choose default local dc for host in filter(lambda h: h.is_up, self.metadata.all_hosts()): profile.load_balancing_policy.on_up(host) futures = set() for session in tuple(self.sessions): + self._set_default_dbaas_consistency(session) futures.update(session.update_created_pools()) _, not_done = wait_futures(futures, pool_wait_timeout) if not_done: raise OperationTimedOut("Failed to create all new connection pools in the %ss timeout.") def get_min_requests_per_connection(self, host_distance): return self._min_requests_per_connection[host_distance] def set_min_requests_per_connection(self, host_distance, min_requests): """ Sets a threshold for concurrent requests per connection, below which connections will be considered for disposal (down to core connections; see :meth:`~Cluster.set_core_connections_per_host`). Pertains to connection pool management in protocol versions {1,2}. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_min_requests_per_connection() only has an effect " "when using protocol_version 1 or 2.") if min_requests < 0 or min_requests > 126 or \ min_requests >= self._max_requests_per_connection[host_distance]: raise ValueError("min_requests must be 0-126 and less than the max_requests for this host_distance (%d)" % (self._min_requests_per_connection[host_distance],)) self._min_requests_per_connection[host_distance] = min_requests def get_max_requests_per_connection(self, host_distance): return self._max_requests_per_connection[host_distance] def set_max_requests_per_connection(self, host_distance, max_requests): """ Sets a threshold for concurrent requests per connection, above which new connections will be created to a host (up to max connections; see :meth:`~Cluster.set_max_connections_per_host`). Pertains to connection pool management in protocol versions {1,2}. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_max_requests_per_connection() only has an effect " "when using protocol_version 1 or 2.") if max_requests < 1 or max_requests > 127 or \ max_requests <= self._min_requests_per_connection[host_distance]: raise ValueError("max_requests must be 1-127 and greater than the min_requests for this host_distance (%d)" % (self._min_requests_per_connection[host_distance],)) self._max_requests_per_connection[host_distance] = max_requests def get_core_connections_per_host(self, host_distance): """ Gets the minimum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for :attr:`~HostDistance.REMOTE`. This property is ignored if :attr:`~.Cluster.protocol_version` is 3 or higher. """ return self._core_connections_per_host[host_distance] def set_core_connections_per_host(self, host_distance, core_connections): """ Sets the minimum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for :attr:`~HostDistance.REMOTE`. Protocol version 1 and 2 are limited in the number of concurrent requests they can send per connection. The driver implements connection pooling to support higher levels of concurrency. If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) and using this will result in an :exc:`~.UnsupporteOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_core_connections_per_host() only has an effect " "when using protocol_version 1 or 2.") old = self._core_connections_per_host[host_distance] self._core_connections_per_host[host_distance] = core_connections if old < core_connections: self._ensure_core_connections() def get_max_connections_per_host(self, host_distance): """ Gets the maximum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 8 for :attr:`~HostDistance.LOCAL` and 2 for :attr:`~HostDistance.REMOTE`. This property is ignored if :attr:`~.Cluster.protocol_version` is 3 or higher. """ return self._max_connections_per_host[host_distance] def set_max_connections_per_host(self, host_distance, max_connections): """ Sets the maximum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for :attr:`~HostDistance.REMOTE`. If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) and using this will result in an :exc:`~.UnsupporteOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_max_connections_per_host() only has an effect " "when using protocol_version 1 or 2.") self._max_connections_per_host[host_distance] = max_connections - def connection_factory(self, address, *args, **kwargs): + def connection_factory(self, endpoint, *args, **kwargs): """ Called to create a new connection with proper configuration. Intended for internal use only. """ - kwargs = self._make_connection_kwargs(address, kwargs) - return self.connection_class.factory(address, self.connect_timeout, *args, **kwargs) + kwargs = self._make_connection_kwargs(endpoint, kwargs) + return self.connection_class.factory(endpoint, self.connect_timeout, *args, **kwargs) def _make_connection_factory(self, host, *args, **kwargs): - kwargs = self._make_connection_kwargs(host.address, kwargs) - return partial(self.connection_class.factory, host.address, self.connect_timeout, *args, **kwargs) + kwargs = self._make_connection_kwargs(host.endpoint, kwargs) + return partial(self.connection_class.factory, host.endpoint, self.connect_timeout, *args, **kwargs) - def _make_connection_kwargs(self, address, kwargs_dict): + def _make_connection_kwargs(self, endpoint, kwargs_dict): if self._auth_provider_callable: - kwargs_dict.setdefault('authenticator', self._auth_provider_callable(address)) + kwargs_dict.setdefault('authenticator', self._auth_provider_callable(endpoint.address)) kwargs_dict.setdefault('port', self.port) kwargs_dict.setdefault('compression', self.compression) kwargs_dict.setdefault('sockopts', self.sockopts) kwargs_dict.setdefault('ssl_options', self.ssl_options) + kwargs_dict.setdefault('ssl_context', self.ssl_context) kwargs_dict.setdefault('cql_version', self.cql_version) kwargs_dict.setdefault('protocol_version', self.protocol_version) kwargs_dict.setdefault('user_type_map', self._user_types) kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version) kwargs_dict.setdefault('no_compact', self.no_compact) return kwargs_dict - def protocol_downgrade(self, host_addr, previous_version): + def protocol_downgrade(self, host_endpoint, previous_version): if self._protocol_version_explicit: raise DriverException("ProtocolError returned from server while using explicitly set client protocol_version %d" % (previous_version,)) new_version = ProtocolVersion.get_lower_supported(previous_version) if new_version < ProtocolVersion.MIN_SUPPORTED: raise DriverException( "Cannot downgrade protocol version below minimum supported version: %d" % (ProtocolVersion.MIN_SUPPORTED,)) log.warning("Downgrading core protocol version from %d to %d for %s. " "To avoid this, it is best practice to explicitly set Cluster(protocol_version) to the version supported by your cluster. " - "http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_addr) + "http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint) self.protocol_version = new_version def connect(self, keyspace=None, wait_for_all_pools=False): """ - Creates and returns a new :class:`~.Session` object. If `keyspace` - is specified, that keyspace will be the default keyspace for + Creates and returns a new :class:`~.Session` object. + + If `keyspace` is specified, that keyspace will be the default keyspace for operations on the ``Session``. + + `wait_for_all_pools` specifies whether this call should wait for all connection pools to be + established or attempted. Default is `False`, which means it will return when the first + successful connection is established. Remaining pools are added asynchronously. """ with self._lock: if self.is_shutdown: raise DriverException("Cluster is already shut down") if not self._is_setup: log.debug("Connecting to cluster, contact points: %s; protocol version: %s", self.contact_points, self.protocol_version) self.connection_class.initialize_reactor() _register_cluster_shutdown(self) - for address in self.contact_points_resolved: - host, new = self.add_host(address, signal=False) + for endpoint in self.endpoints_resolved: + host, new = self.add_host(endpoint, signal=False) if new: host.set_up() for listener in self.listeners: listener.on_add(host) self.profile_manager.populate( weakref.proxy(self), self.metadata.all_hosts()) self.load_balancing_policy.populate( weakref.proxy(self), self.metadata.all_hosts() ) try: self.control_connection.connect() # we set all contact points up for connecting, but we won't infer state after this - for address in self.contact_points_resolved: - h = self.metadata.get_host(address) + for endpoint in self.endpoints_resolved: + h = self.metadata.get_host(endpoint) if h and self.profile_manager.distance(h) == HostDistance.IGNORED: h.is_up = None log.debug("Control connection created") except Exception: log.exception("Control connection failed to connect, " "shutting down Cluster:") self.shutdown() raise self.profile_manager.check_supported() # todo: rename this method if self.idle_heartbeat_interval: self._idle_heartbeat = ConnectionHeartbeat( self.idle_heartbeat_interval, self.get_connection_holders, timeout=self.idle_heartbeat_timeout ) self._is_setup = True session = self._new_session(keyspace) if wait_for_all_pools: wait_futures(session._initial_connect_futures) + + self._set_default_dbaas_consistency(session) + return session + def _set_default_dbaas_consistency(self, session): + if session.cluster.metadata.dbaas: + for profile in self.profile_manager.profiles.values(): + if not profile._consistency_level_explicit: + profile.consistency_level = ConsistencyLevel.LOCAL_QUORUM + session._default_consistency_level = ConsistencyLevel.LOCAL_QUORUM + def get_connection_holders(self): holders = [] for s in tuple(self.sessions): holders.extend(s.get_pools()) holders.append(self.control_connection) return holders def shutdown(self): """ Closes all sessions and connection associated with this Cluster. To ensure all connections are properly closed, **you should always call shutdown() on a Cluster instance when you are done with it**. Once shutdown, a Cluster should not be used for any purpose. """ with self._lock: if self.is_shutdown: return else: self.is_shutdown = True if self._idle_heartbeat: self._idle_heartbeat.stop() self.scheduler.shutdown() self.control_connection.shutdown() for session in tuple(self.sessions): session.shutdown() self.executor.shutdown() _discard_cluster_shutdown(self) def __enter__(self): return self def __exit__(self, *args): self.shutdown() def _new_session(self, keyspace): session = Session(self, self.metadata.all_hosts(), keyspace) self._session_register_user_types(session) self.sessions.add(session) return session def _session_register_user_types(self, session): for keyspace, type_map in six.iteritems(self._user_types): for udt_name, klass in six.iteritems(type_map): session.user_type_registered(keyspace, udt_name, klass) def _cleanup_failed_on_up_handling(self, host): self.profile_manager.on_down(host) self.control_connection.on_down(host) for session in tuple(self.sessions): session.remove_pool(host) self._start_reconnector(host, is_host_addition=False) def _on_up_future_completed(self, host, futures, results, lock, finished_future): with lock: futures.discard(finished_future) try: results.append(finished_future.result()) except Exception as exc: results.append(exc) if futures: return try: # all futures have completed at this point for exc in [f for f in results if isinstance(f, Exception)]: log.error("Unexpected failure while marking node %s up:", host, exc_info=exc) self._cleanup_failed_on_up_handling(host) return if not all(results): log.debug("Connection pool could not be created, not marking node %s up", host) self._cleanup_failed_on_up_handling(host) return log.info("Connection pools established for node %s", host) # mark the host as up and notify all listeners host.set_up() for listener in self.listeners: listener.on_up(host) finally: with host.lock: host._currently_handling_node_up = False # see if there are any pools to add or remove now that the host is marked up for session in tuple(self.sessions): session.update_created_pools() def on_up(self, host): """ Intended for internal use only. """ if self.is_shutdown: return log.debug("Waiting to acquire lock for handling up status of node %s", host) with host.lock: if host._currently_handling_node_up: log.debug("Another thread is already handling up status of node %s", host) return if host.is_up: log.debug("Host %s was already marked up", host) return host._currently_handling_node_up = True log.debug("Starting to handle up status of node %s", host) have_future = False futures = set() try: log.info("Host %s may be up; will prepare queries and open connection pool", host) reconnector = host.get_and_set_reconnection_handler(None) if reconnector: log.debug("Now that host %s is up, cancelling the reconnection handler", host) reconnector.cancel() if self.profile_manager.distance(host) != HostDistance.IGNORED: self._prepare_all_queries(host) log.debug("Done preparing all queries for host %s, ", host) for session in tuple(self.sessions): session.remove_pool(host) log.debug("Signalling to load balancing policies that host %s is up", host) self.profile_manager.on_up(host) log.debug("Signalling to control connection that host %s is up", host) self.control_connection.on_up(host) log.debug("Attempting to open new connection pools for host %s", host) futures_lock = Lock() futures_results = [] callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: have_future = True future.add_done_callback(callback) futures.add(future) except Exception: log.exception("Unexpected failure handling node %s being marked up:", host) for future in futures: future.cancel() self._cleanup_failed_on_up_handling(host) with host.lock: host._currently_handling_node_up = False raise else: if not have_future: with host.lock: host.set_up() host._currently_handling_node_up = False # for testing purposes return futures def _start_reconnector(self, host, is_host_addition): if self.profile_manager.distance(host) == HostDistance.IGNORED: return schedule = self.reconnection_policy.new_schedule() # in order to not hold references to this Cluster open and prevent # proper shutdown when the program ends, we'll just make a closure # of the current Cluster attributes to create new Connections with conn_factory = self._make_connection_factory(host) reconnector = _HostReconnectionHandler( host, conn_factory, is_host_addition, self.on_add, self.on_up, self.scheduler, schedule, host.get_and_set_reconnection_handler, new_handler=None) old_reconnector = host.get_and_set_reconnection_handler(reconnector) if old_reconnector: log.debug("Old host reconnector found for %s, cancelling", host) old_reconnector.cancel() log.debug("Starting reconnector for host %s", host) reconnector.start() @run_in_executor def on_down(self, host, is_host_addition, expect_host_to_be_down=False): """ Intended for internal use only. """ if self.is_shutdown: return with host.lock: was_up = host.is_up # ignore down signals if we have open pools to the host # this is to avoid closing pools when a control connection host became isolated if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: connected = False for session in tuple(self.sessions): pool_states = session.get_pool_state() pool_state = pool_states.get(host) if pool_state: connected |= pool_state['open_count'] > 0 if connected: return host.set_down() if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): return log.warning("Host %s has been marked down", host) self.profile_manager.on_down(host) self.control_connection.on_down(host) for session in tuple(self.sessions): session.on_down(host) for listener in self.listeners: listener.on_down(host) self._start_reconnector(host, is_host_addition) def on_add(self, host, refresh_nodes=True): if self.is_shutdown: return log.debug("Handling new host %r and notifying listeners", host) distance = self.profile_manager.distance(host) if distance != HostDistance.IGNORED: self._prepare_all_queries(host) log.debug("Done preparing queries for new host %r", host) self.profile_manager.on_add(host) self.control_connection.on_add(host, refresh_nodes) if distance == HostDistance.IGNORED: log.debug("Not adding connection pool for new host %r because the " "load balancing policy has marked it as IGNORED", host) self._finalize_add(host, set_up=False) return futures_lock = Lock() futures_results = [] futures = set() def future_completed(future): with futures_lock: futures.discard(future) try: futures_results.append(future.result()) except Exception as exc: futures_results.append(exc) if futures: return log.debug('All futures have completed for added host %s', host) for exc in [f for f in futures_results if isinstance(f, Exception)]: log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) return if not all(futures_results): log.warning("Connection pool could not be created, not marking node %s up", host) return self._finalize_add(host) have_future = False for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=True) if future is not None: have_future = True futures.add(future) future.add_done_callback(future_completed) if not have_future: self._finalize_add(host) def _finalize_add(self, host, set_up=True): if set_up: host.set_up() for listener in self.listeners: listener.on_add(host) # see if there are any pools to add or remove now that the host is marked up for session in tuple(self.sessions): session.update_created_pools() def on_remove(self, host): if self.is_shutdown: return log.debug("Removing host %s", host) host.set_down() self.profile_manager.on_remove(host) for session in tuple(self.sessions): session.on_remove(host) for listener in self.listeners: listener.on_remove(host) self.control_connection.on_remove(host) def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): is_down = host.signal_connection_failure(connection_exc) if is_down: self.on_down(host, is_host_addition, expect_host_to_be_down) return is_down - def add_host(self, address, datacenter=None, rack=None, signal=True, refresh_nodes=True): + def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True): """ Called when adding initial contact points and when the control connection subsequently discovers a new node. Returns a Host instance, and a flag indicating whether it was new in the metadata. Intended for internal use only. """ - host, new = self.metadata.add_or_return_host(Host(address, self.conviction_policy_factory, datacenter, rack)) + host, new = self.metadata.add_or_return_host(Host(endpoint, self.conviction_policy_factory, datacenter, rack)) if new and signal: log.info("New Cassandra host %r discovered", host) self.on_add(host, refresh_nodes) return host, new def remove_host(self, host): """ Called when the control connection observes that a node has left the ring. Intended for internal use only. """ if host and self.metadata.remove_host(host): log.info("Cassandra host %s removed", host) self.on_remove(host) def register_listener(self, listener): """ Adds a :class:`cassandra.policies.HostStateListener` subclass instance to the list of listeners to be notified when a host is added, removed, marked up, or marked down. """ with self._listener_lock: self._listeners.add(listener) def unregister_listener(self, listener): """ Removes a registered listener. """ with self._listener_lock: self._listeners.remove(listener) @property def listeners(self): with self._listener_lock: return self._listeners.copy() def _ensure_core_connections(self): """ If any host has fewer than the configured number of core connections open, attempt to open connections until that number is met. """ for session in tuple(self.sessions): for pool in tuple(session._pools.values()): pool.ensure_core_connections() @staticmethod def _validate_refresh_schema(keyspace, table, usertype, function, aggregate): if any((table, usertype, function, aggregate)): if not keyspace: raise ValueError("keyspace is required to refresh specific sub-entity {table, usertype, function, aggregate}") if sum(1 for e in (table, usertype, function) if e) > 1: raise ValueError("{table, usertype, function, aggregate} are mutually exclusive") @staticmethod def _target_type_from_refresh_args(keyspace, table, usertype, function, aggregate): if aggregate: return SchemaTargetType.AGGREGATE elif function: return SchemaTargetType.FUNCTION elif usertype: return SchemaTargetType.TYPE elif table: return SchemaTargetType.TABLE elif keyspace: return SchemaTargetType.KEYSPACE return None def get_control_connection_host(self): """ Returns the control connection host metadata. """ connection = self.control_connection._connection - host = connection.host if connection else None - return self.metadata.get_host(host) if host else None + endpoint = connection.endpoint if connection else None + return self.metadata.get_host(endpoint) if endpoint else None def refresh_schema_metadata(self, max_schema_agreement_wait=None): """ Synchronously refresh all schema metadata. By default, the timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait` and :attr:`~.Cluster.control_connection_timeout`. Passing max_schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`. Setting max_schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately. An Exception is raised if schema refresh fails for any reason. """ if not self.control_connection.refresh_schema(schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("Schema metadata was not refreshed. See log for details.") def refresh_keyspace_metadata(self, keyspace, max_schema_agreement_wait=None): """ Synchronously refresh keyspace metadata. This applies to keyspace-level information such as replication and durability settings. It does not refresh tables, types, etc. contained in the keyspace. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.KEYSPACE, keyspace=keyspace, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("Keyspace metadata was not refreshed. See log for details.") def refresh_table_metadata(self, keyspace, table, max_schema_agreement_wait=None): """ Synchronously refresh table metadata. This applies to a table, and any triggers or indexes attached to the table. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=table, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("Table metadata was not refreshed. See log for details.") def refresh_materialized_view_metadata(self, keyspace, view, max_schema_agreement_wait=None): """ Synchronously refresh materialized view metadata. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=view, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("View metadata was not refreshed. See log for details.") def refresh_user_type_metadata(self, keyspace, user_type, max_schema_agreement_wait=None): """ Synchronously refresh user defined type metadata. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TYPE, keyspace=keyspace, type=user_type, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("User Type metadata was not refreshed. See log for details.") def refresh_user_function_metadata(self, keyspace, function, max_schema_agreement_wait=None): """ Synchronously refresh user defined function metadata. ``function`` is a :class:`cassandra.UserFunctionDescriptor`. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.FUNCTION, keyspace=keyspace, function=function, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("User Function metadata was not refreshed. See log for details.") def refresh_user_aggregate_metadata(self, keyspace, aggregate, max_schema_agreement_wait=None): """ Synchronously refresh user defined aggregate metadata. ``aggregate`` is a :class:`cassandra.UserAggregateDescriptor`. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.AGGREGATE, keyspace=keyspace, aggregate=aggregate, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("User Aggregate metadata was not refreshed. See log for details.") def refresh_nodes(self, force_token_rebuild=False): """ Synchronously refresh the node list and token metadata `force_token_rebuild` can be used to rebuild the token map metadata, even if no new nodes are discovered. An Exception is raised if node refresh fails for any reason. """ if not self.control_connection.refresh_node_list_and_token_map(force_token_rebuild): raise DriverException("Node list was not refreshed. See log for details.") def set_meta_refresh_enabled(self, enabled): """ *Deprecated:* set :attr:`~.Cluster.schema_metadata_enabled` :attr:`~.Cluster.token_metadata_enabled` instead Sets a flag to enable (True) or disable (False) all metadata refresh queries. This applies to both schema and node topology. Disabling this is useful to minimize refreshes during multiple changes. Meta refresh must be enabled for the driver to become aware of any cluster topology changes or schema updates. """ warn("Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0. Set " "Cluster.schema_metadata_enabled and Cluster.token_metadata_enabled instead.", DeprecationWarning) self.schema_metadata_enabled = enabled self.token_metadata_enabled = enabled @classmethod def _send_chunks(cls, connection, host, chunks, set_keyspace=False): for ks_chunk in chunks: messages = [PrepareMessage(query=s.query_string, keyspace=s.keyspace if set_keyspace else None) for s in ks_chunk] # TODO: make this timeout configurable somehow? responses = connection.wait_for_responses(*messages, timeout=5.0, fail_on_error=False) for success, response in responses: if not success: log.debug("Got unexpected response when preparing " "statement on host %s: %r", host, response) def _prepare_all_queries(self, host): if not self._prepared_statements or not self.reprepare_on_up: return log.debug("Preparing all known prepared statements against host %s", host) connection = None try: - connection = self.connection_factory(host.address) - statements = self._prepared_statements.values() + connection = self.connection_factory(host.endpoint) + statements = list(self._prepared_statements.values()) if ProtocolVersion.uses_keyspace_flag(self.protocol_version): # V5 protocol and higher, no need to set the keyspace chunks = [] for i in range(0, len(statements), 10): chunks.append(statements[i:i + 10]) self._send_chunks(connection, host, chunks, True) else: for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): if keyspace is not None: connection.set_keyspace_blocking(keyspace) # prepare 10 statements at a time ks_statements = list(ks_statements) chunks = [] for i in range(0, len(ks_statements), 10): chunks.append(ks_statements[i:i + 10]) self._send_chunks(connection, host, chunks) log.debug("Done preparing all known prepared statements against host %s", host) except OperationTimedOut as timeout: log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout) except (ConnectionException, socket.error) as exc: log.warning("Error trying to prepare all statements on host %s: %r", host, exc) except Exception: log.exception("Error trying to prepare all statements on host %s", host) finally: if connection: connection.close() def add_prepared(self, query_id, prepared_statement): with self._prepared_statement_lock: self._prepared_statements[query_id] = prepared_statement class Session(object): """ A collection of connection pools for each host in the cluster. Instances of this class should not be created directly, only using :meth:`.Cluster.connect()`. Queries and statements can be executed through ``Session`` instances using the :meth:`~.Session.execute()` and :meth:`~.Session.execute_async()` methods. Example usage:: >>> session = cluster.connect() >>> session.set_keyspace("mykeyspace") >>> session.execute("SELECT * FROM mycf") """ cluster = None hosts = None keyspace = None is_shutdown = False _row_factory = staticmethod(named_tuple_factory) @property def row_factory(self): """ The format to return row results in. By default, each returned row will be a named tuple. You can alternatively use any of the following: - :func:`cassandra.query.tuple_factory` - return a result row as a tuple - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple - :func:`cassandra.query.dict_factory` - return a result row as a dict - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict """ return self._row_factory @row_factory.setter def row_factory(self, rf): self._validate_set_legacy_config('row_factory', rf) _default_timeout = 10.0 @property def default_timeout(self): """ A default timeout, measured in seconds, for queries executed through :meth:`.execute()` or :meth:`.execute_async()`. This default may be overridden with the `timeout` parameter for either of those methods. Setting this to :const:`None` will cause no timeouts to be set by default. Please see :meth:`.ResponseFuture.result` for details on the scope and effect of this timeout. .. versionadded:: 2.0.0 """ return self._default_timeout @default_timeout.setter def default_timeout(self, timeout): self._validate_set_legacy_config('default_timeout', timeout) _default_consistency_level = ConsistencyLevel.LOCAL_ONE @property def default_consistency_level(self): """ *Deprecated:* use execution profiles instead The default :class:`~ConsistencyLevel` for operations executed through this session. This default may be overridden by setting the :attr:`~.Statement.consistency_level` on individual statements. .. versionadded:: 1.2.0 .. versionchanged:: 3.0.0 default changed from ONE to LOCAL_ONE """ return self._default_consistency_level @default_consistency_level.setter def default_consistency_level(self, cl): """ *Deprecated:* use execution profiles instead """ warn("Setting the consistency level at the session level will be removed in 4.0. Consider using " "execution profiles and setting the desired consitency level to the EXEC_PROFILE_DEFAULT profile." , DeprecationWarning) self._validate_set_legacy_config('default_consistency_level', cl) _default_serial_consistency_level = None @property def default_serial_consistency_level(self): """ The default :class:`~ConsistencyLevel` for serial phase of conditional updates executed through this session. This default may be overridden by setting the :attr:`~.Statement.serial_consistency_level` on individual statements. Only valid for ``protocol_version >= 2``. """ return self._default_serial_consistency_level @default_serial_consistency_level.setter def default_serial_consistency_level(self, cl): + if (cl is not None and + not ConsistencyLevel.is_serial(cl)): + raise ValueError("default_serial_consistency_level must be either " + "ConsistencyLevel.SERIAL " + "or ConsistencyLevel.LOCAL_SERIAL.") + self._validate_set_legacy_config('default_serial_consistency_level', cl) max_trace_wait = 2.0 """ The maximum amount of time (in seconds) the driver will wait for trace details to be populated server-side for a query before giving up. If the `trace` parameter for :meth:`~.execute()` or :meth:`~.execute_async()` is :const:`True`, the driver will repeatedly attempt to fetch trace details for the query (using exponential backoff) until this limit is hit. If the limit is passed, an error will be logged and the :attr:`.Statement.trace` will be left as :const:`None`. """ default_fetch_size = 5000 """ By default, this many rows will be fetched at a time. Setting this to :const:`None` will disable automatic paging for large query results. The fetch size can be also specified per-query through :attr:`.Statement.fetch_size`. This only takes effect when protocol version 2 or higher is used. See :attr:`.Cluster.protocol_version` for details. .. versionadded:: 2.0.0 """ use_client_timestamp = True """ When using protocol version 3 or higher, write timestamps may be supplied client-side at the protocol level. (Normally they are generated server-side by the coordinator node.) Note that timestamps specified within a CQL query will override this timestamp. .. versionadded:: 2.1.0 """ timestamp_generator = None """ When :attr:`use_client_timestamp` is set, sessions call this object and use the result as the timestamp. (Note that timestamps specified within a CQL query will override this timestamp.) By default, a new :class:`~.MonotonicTimestampGenerator` is created for each :class:`Cluster` instance. Applications can set this value for custom timestamp behavior. For example, an application could share a timestamp generator across :class:`Cluster` objects to guarantee that the application will use unique, increasing timestamps across clusters, or set it to to ``lambda: int(time.time() * 1e6)`` if losing records over clock inconsistencies is acceptable for the application. Custom :attr:`timestamp_generator` s should be callable, and calling them should return an integer representing microseconds since some point in time, typically UNIX epoch. .. versionadded:: 3.8.0 """ encoder = None """ A :class:`~cassandra.encoder.Encoder` instance that will be used when formatting query parameters for non-prepared statements. This is not used for prepared statements (because prepared statements give the driver more information about what CQL types are expected, allowing it to accept a wider range of python types). The encoder uses a mapping from python types to encoder methods (for specific CQL types). This mapping can be be modified by users as they see fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping values if possible, because they take precautions to avoid injections and properly sanitize data. Example:: cluster = Cluster() session = cluster.connect("mykeyspace") session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple)") session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')]) .. versionadded:: 2.1.0 """ client_protocol_handler = ProtocolHandler """ Specifies a protocol handler that will be used for client-initiated requests (i.e. no internal driver requests). This can be used to override or extend features such as message or type ser/des. The default pure python implementation is :class:`cassandra.protocol.ProtocolHandler`. When compiled with Cython, there are also built-in faster alternatives. See :ref:`faster_deser` """ _lock = None _pools = None _profile_manager = None _metrics = None _request_init_callbacks = None def __init__(self, cluster, hosts, keyspace=None): self.cluster = cluster self.hosts = hosts self.keyspace = keyspace self._lock = RLock() self._pools = {} self._profile_manager = cluster.profile_manager self._metrics = cluster.metrics self._request_init_callbacks = [] self._protocol_version = self.cluster.protocol_version self.encoder = Encoder() # create connection pools in parallel self._initial_connect_futures = set() for host in hosts: future = self.add_or_renew_pool(host, is_host_addition=False) if future: self._initial_connect_futures.add(future) futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) while futures.not_done and not any(f.result() for f in futures.done): futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) if not any(f.result() for f in self._initial_connect_futures): msg = "Unable to connect to any servers" if self.keyspace: msg += " using keyspace '%s'" % self.keyspace raise NoHostAvailable(msg, [h.address for h in hosts]) - def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None): + def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, + custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None): """ Execute the given query and synchronously wait for the response. If an error is encountered while executing the query, an Exception will be raised. `query` may be a query string or an instance of :class:`cassandra.query.Statement`. `parameters` may be a sequence or dict of parameters to bind. If a sequence is used, ``%s`` should be used the placeholder for each argument. If a dict is used, ``%(name)s`` style placeholders must be used. `timeout` should specify a floating-point timeout (in seconds) after which an :exc:`.OperationTimedOut` exception will be raised if the query has not completed. If not set, the timeout defaults to :attr:`~.Session.default_timeout`. If set to :const:`None`, there is no timeout. Please see :meth:`.ResponseFuture.result` for details on the scope and effect of this timeout. If `trace` is set to :const:`True`, the query will be sent with tracing enabled. The trace details can be obtained using the returned :class:`.ResultSet` object. `custom_payload` is a :ref:`custom_payload` dict to be passed to the server. If `query` is a Statement with its own custom_payload. The message payload will be a union of the two, with the values specified here taking precedence. `execution_profile` is the execution profile to use for this request. It can be a key to a profile configured via :meth:`Cluster.add_execution_profile` or an instance (from :meth:`Session.execution_profile_clone_update`, for example `paging_state` is an optional paging state, reused from a previous :class:`ResultSet`. + + `host` is the :class:`cassandra.pool.Host` that should handle the query. If the host specified is down or + not yet connected, the query will fail with :class:`NoHostAvailable`. Using this is + discouraged except in a few cases, e.g., querying node-local tables and applying schema changes. """ - return self.execute_async(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state).result() + return self.execute_async(query, parameters, trace, custom_payload, + timeout, execution_profile, paging_state, host).result() - def execute_async(self, query, parameters=None, trace=False, custom_payload=None, timeout=_NOT_SET, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None): + def execute_async(self, query, parameters=None, trace=False, custom_payload=None, + timeout=_NOT_SET, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None): """ Execute the given query and return a :class:`~.ResponseFuture` object which callbacks may be attached to for asynchronous response delivery. You may also call :meth:`~.ResponseFuture.result()` on the :class:`.ResponseFuture` to synchronously block for results at any time. See :meth:`Session.execute` for parameter definitions. Example usage:: >>> session = cluster.connect() >>> future = session.execute_async("SELECT * FROM mycf") >>> def log_results(results): ... for row in results: ... log.info("Results: %s", row) >>> def log_error(exc): >>> log.error("Operation failed: %s", exc) >>> future.add_callbacks(log_results, log_error) Async execution with blocking wait for results:: >>> future = session.execute_async("SELECT * FROM mycf") >>> # do other stuff... >>> try: ... results = future.result() ... except Exception: ... log.exception("Operation failed:") """ - future = self._create_response_future(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state) + future = self._create_response_future( + query, parameters, trace, custom_payload, timeout, + execution_profile, paging_state, host) future._protocol_handler = self.client_protocol_handler self._on_request(future) future.send_request() return future - def _create_response_future(self, query, parameters, trace, custom_payload, timeout, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None): + def _create_response_future(self, query, parameters, trace, custom_payload, + timeout, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None): """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None if isinstance(query, six.string_types): query = SimpleStatement(query) elif isinstance(query, PreparedStatement): query = query.bind(parameters) if self.cluster._config_mode == _ConfigMode.LEGACY: if execution_profile is not EXEC_PROFILE_DEFAULT: raise ValueError("Cannot specify execution_profile while using legacy parameters.") if timeout is _NOT_SET: timeout = self.default_timeout cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else self.default_serial_consistency_level retry_policy = query.retry_policy or self.cluster.default_retry_policy row_factory = self.row_factory load_balancing_policy = self.cluster.load_balancing_policy spec_exec_policy = None else: - execution_profile = self._get_execution_profile(execution_profile) + execution_profile = self._maybe_get_execution_profile(execution_profile) if timeout is _NOT_SET: timeout = execution_profile.request_timeout cl = query.consistency_level if query.consistency_level is not None else execution_profile.consistency_level serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else execution_profile.serial_consistency_level retry_policy = query.retry_policy or execution_profile.retry_policy row_factory = execution_profile.row_factory load_balancing_policy = execution_profile.load_balancing_policy spec_exec_policy = execution_profile.speculative_execution_policy fetch_size = query.fetch_size if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2: fetch_size = self.default_fetch_size elif self._protocol_version == 1: fetch_size = None start_time = time.time() if self._protocol_version >= 3 and self.use_client_timestamp: timestamp = self.cluster.timestamp_generator() else: timestamp = None if isinstance(query, SimpleStatement): query_string = query.query_string statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None if parameters: query_string = bind_params(query_string, parameters, self.encoder) message = QueryMessage( query_string, cl, serial_cl, fetch_size, timestamp=timestamp, keyspace=statement_keyspace) elif isinstance(query, BoundStatement): prepared_statement = query.prepared_statement message = ExecuteMessage( prepared_statement.query_id, query.values, cl, serial_cl, fetch_size, timestamp=timestamp, skip_meta=bool(prepared_statement.result_metadata), result_metadata_id=prepared_statement.result_metadata_id) elif isinstance(query, BatchStatement): if self._protocol_version < 2: raise UnsupportedOperation( "BatchStatement execution is only supported with protocol version " "2 or higher (supported in Cassandra 2.0 and higher). Consider " "setting Cluster.protocol_version to 2 to support this operation.") statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None message = BatchMessage( query.batch_type, query._statements_and_parameters, cl, serial_cl, timestamp, statement_keyspace) message.tracing = trace message.update_custom_payload(query.custom_payload) message.update_custom_payload(custom_payload) message.allow_beta_protocol_version = self.cluster.allow_beta_protocol_version message.paging_state = paging_state spec_exec_plan = spec_exec_policy.new_plan(query.keyspace or self.keyspace, query) if query.is_idempotent and spec_exec_policy else None return ResponseFuture( self, message, query, timeout, metrics=self._metrics, prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory, - load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan) + load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan, + host=host) - def _get_execution_profile(self, ep): + def get_execution_profile(self, name): + """ + Returns the execution profile associated with the provided ``name``. + + :param name: The name (or key) of the execution profile. + """ profiles = self.cluster.profile_manager.profiles try: - return ep if isinstance(ep, ExecutionProfile) else profiles[ep] + return profiles[name] except KeyError: - raise ValueError("Invalid execution_profile: '%s'; valid profiles are %s" % (ep, profiles.keys())) + eps = [_execution_profile_to_string(ep) for ep in profiles.keys()] + raise ValueError("Invalid execution_profile: %s; valid profiles are: %s." % ( + _execution_profile_to_string(name), ', '.join(eps))) + + def _maybe_get_execution_profile(self, ep): + return ep if isinstance(ep, ExecutionProfile) else self.get_execution_profile(ep) def execution_profile_clone_update(self, ep, **kwargs): """ Returns a clone of the ``ep`` profile. ``kwargs`` can be specified to update attributes of the returned profile. This is a shallow clone, so any objects referenced by the profile are shared. This means Load Balancing Policy is maintained by inclusion in the active profiles. It also means updating any other rich objects will be seen by the active profile. In cases where this is not desirable, be sure to replace the instance instead of manipulating the shared object. """ - clone = copy(self._get_execution_profile(ep)) + clone = copy(self._maybe_get_execution_profile(ep)) for attr, value in kwargs.items(): setattr(clone, attr, value) return clone def add_request_init_listener(self, fn, *args, **kwargs): """ Adds a callback with arguments to be called when any request is created. It will be invoked as `fn(response_future, *args, **kwargs)` after each client request is created, - and before the request is sent\*. This can be used to create extensions by adding result callbacks to the + and before the request is sent. This can be used to create extensions by adding result callbacks to the response future. - \* where `response_future` is the :class:`.ResponseFuture` for the request. + `response_future` is the :class:`.ResponseFuture` for the request. Note that the init callback is done on the client thread creating the request, so you may need to consider synchronization if you have multiple threads. Any callbacks added to the response future will be executed on the event loop thread, so the normal advice about minimizing cycles and avoiding blocking apply (see Note in :meth:`.ResponseFuture.add_callbacks`. See `this example `_ in the source tree for an example. """ self._request_init_callbacks.append((fn, args, kwargs)) def remove_request_init_listener(self, fn, *args, **kwargs): """ Removes a callback and arguments from the list. See :meth:`.Session.add_request_init_listener`. """ self._request_init_callbacks.remove((fn, args, kwargs)) def _on_request(self, response_future): for fn, args, kwargs in self._request_init_callbacks: fn(response_future, *args, **kwargs) def prepare(self, query, custom_payload=None, keyspace=None): """ Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement` instance which can be used as follows:: >>> session = cluster.connect("mykeyspace") >>> query = "INSERT INTO users (id, name, age) VALUES (?, ?, ?)" >>> prepared = session.prepare(query) >>> session.execute(prepared, (user.id, user.name, user.age)) Or you may bind values to the prepared statement ahead of time:: >>> prepared = session.prepare(query) >>> bound_stmt = prepared.bind((user.id, user.name, user.age)) >>> session.execute(bound_stmt) Of course, prepared statements may (and should) be reused:: >>> prepared = session.prepare(query) >>> for user in users: ... bound = prepared.bind((user.id, user.name, user.age)) ... session.execute(bound) Alternatively, if :attr:`~.Cluster.protocol_version` is 5 or higher (requires Cassandra 4.0+), the keyspace can be specified as a parameter. This will allow you to avoid specifying the keyspace in the query without specifying a keyspace in :meth:`~.Cluster.connect`. It even will let you prepare and use statements against a keyspace other than the one originally specified on connection: >>> analyticskeyspace_prepared = session.prepare( ... "INSERT INTO user_activity id, last_activity VALUES (?, ?)", ... keyspace="analyticskeyspace") # note the different keyspace **Important**: PreparedStatements should be prepared only once. Preparing the same query more than once will likely affect performance. `custom_payload` is a key value map to be passed along with the prepare message. See :ref:`custom_payload`. """ message = PrepareMessage(query=query, keyspace=keyspace) future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) try: future.send_request() query_id, bind_metadata, pk_indexes, result_metadata, result_metadata_id = future.result() except Exception: log.exception("Error preparing query:") raise prepared_keyspace = keyspace if keyspace else None prepared_statement = PreparedStatement.from_message( query_id, bind_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace, self._protocol_version, result_metadata, result_metadata_id) prepared_statement.custom_payload = future.custom_payload self.cluster.add_prepared(query_id, prepared_statement) if self.cluster.prepare_on_all_hosts: host = future._current_host try: self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace) except Exception: log.exception("Error preparing query on all hosts:") return prepared_statement def prepare_on_all_hosts(self, query, excluded_host, keyspace=None): """ Prepare the given query on all hosts, excluding ``excluded_host``. Intended for internal use only. """ futures = [] for host in tuple(self._pools.keys()): if host != excluded_host and host.is_up: future = ResponseFuture(self, PrepareMessage(query=query, keyspace=keyspace), None, self.default_timeout) # we don't care about errors preparing against specific hosts, # since we can always prepare them as needed when the prepared # statement is used. Just log errors and continue on. try: request_id = future._query(host) except Exception: log.exception("Error preparing query for host %s:", host) continue if request_id is None: # the error has already been logged by ResponsFuture log.debug("Failed to prepare query for host %s: %r", host, future._errors.get(host)) continue futures.append((host, future)) for host, future in futures: try: future.result() except Exception: log.exception("Error preparing query for host %s:", host) def shutdown(self): """ Close all connections. ``Session`` instances should not be used for any purpose after being shutdown. """ with self._lock: if self.is_shutdown: return else: self.is_shutdown = True # PYTHON-673. If shutdown was called shortly after session init, avoid # a race by cancelling any initial connection attempts haven't started, # then blocking on any that have. for future in self._initial_connect_futures: future.cancel() wait_futures(self._initial_connect_futures) for pool in tuple(self._pools.values()): pool.shutdown() def __enter__(self): return self def __exit__(self, *args): self.shutdown() def __del__(self): try: # Ensure all connections are closed, in case the Session object is deleted by the GC self.shutdown() except: # Ignore all errors. Shutdown errors can be caught by the user # when cluster.shutdown() is called explicitly. pass def add_or_renew_pool(self, host, is_host_addition): """ For internal use only. """ distance = self._profile_manager.distance(host) if distance == HostDistance.IGNORED: return None def run_add_or_renew_pool(): try: if self._protocol_version >= 3: new_pool = HostConnection(host, distance, self) else: + # TODO remove host pool again ??? new_pool = HostConnectionPool(host, distance, self) except AuthenticationFailed as auth_exc: - conn_exc = ConnectionException(str(auth_exc), host=host) + conn_exc = ConnectionException(str(auth_exc), endpoint=host) self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) return False except Exception as conn_exc: log.warning("Failed to create connection pool for new host %s:", host, exc_info=conn_exc) # the host itself will still be marked down, so we need to pass # a special flag to make sure the reconnector is created self.cluster.signal_connection_failure( host, conn_exc, is_host_addition, expect_host_to_be_down=True) return False previous = self._pools.get(host) with self._lock: while new_pool._keyspace != self.keyspace: self._lock.release() set_keyspace_event = Event() errors_returned = [] def callback(pool, errors): errors_returned.extend(errors) set_keyspace_event.set() new_pool._set_keyspace_for_all_conns(self.keyspace, callback) set_keyspace_event.wait(self.cluster.connect_timeout) if not set_keyspace_event.is_set() or errors_returned: log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) self.cluster.on_down(host, is_host_addition) new_pool.shutdown() self._lock.acquire() return False self._lock.acquire() self._pools[host] = new_pool log.debug("Added pool for host %s to session", host) if previous: previous.shutdown() return True return self.submit(run_add_or_renew_pool) def remove_pool(self, host): pool = self._pools.pop(host, None) if pool: log.debug("Removed connection pool for %r", host) return self.submit(pool.shutdown) else: return None def update_created_pools(self): """ When the set of live nodes change, the loadbalancer will change its mind on host distances. It might change it on the node that came/left but also on other nodes (for instance, if a node dies, another previously ignored node may be now considered). This method ensures that all hosts for which a pool should exist have one, and hosts that shouldn't don't. For internal use only. """ futures = set() for host in self.cluster.metadata.all_hosts(): distance = self._profile_manager.distance(host) pool = self._pools.get(host) future = None if not pool or pool.is_shutdown: # we don't eagerly set is_up on previously ignored hosts. None is included here # to allow us to attempt connections to hosts that have gone from ignored to something # else. if distance != HostDistance.IGNORED and host.is_up in (True, None): future = self.add_or_renew_pool(host, False) elif distance != pool.host_distance: # the distance has changed if distance == HostDistance.IGNORED: future = self.remove_pool(host) else: pool.host_distance = distance if future: futures.add(future) return futures def on_down(self, host): """ Called by the parent Cluster instance when a node is marked down. Only intended for internal use. """ future = self.remove_pool(host) if future: future.add_done_callback(lambda f: self.update_created_pools()) def on_remove(self, host): """ Internal """ self.on_down(host) def set_keyspace(self, keyspace): """ Set the default keyspace for all queries made through this Session. This operation blocks until complete. """ self.execute('USE %s' % (protect_name(keyspace),)) def _set_keyspace_for_all_pools(self, keyspace, callback): """ Asynchronously sets the keyspace on all pools. When all pools have set all of their connections, `callback` will be called with a dictionary of all errors that occurred, keyed by the `Host` that they occurred against. """ with self._lock: self.keyspace = keyspace remaining_callbacks = set(self._pools.values()) errors = {} if not remaining_callbacks: callback(errors) return def pool_finished_setting_keyspace(pool, host_errors): remaining_callbacks.remove(pool) if host_errors: errors[pool.host] = host_errors if not remaining_callbacks: callback(host_errors) for pool in tuple(self._pools.values()): pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace) def user_type_registered(self, keyspace, user_type, klass): """ Called by the parent Cluster instance when the user registers a new mapping from a user-defined type to a class. Intended for internal use only. """ try: ks_meta = self.cluster.metadata.keyspaces[keyspace] except KeyError: raise UserTypeDoesNotExist( 'Keyspace %s does not exist or has not been discovered by the driver' % (keyspace,)) try: type_meta = ks_meta.user_types[user_type] except KeyError: raise UserTypeDoesNotExist( 'User type %s does not exist in keyspace %s' % (user_type, keyspace)) field_names = type_meta.field_names if six.PY2: # go from unicode to string to avoid decode errors from implicit # decode when formatting non-ascii values field_names = [fn.encode('utf-8') for fn in field_names] def encode(val): return '{ %s }' % ' , '.join('%s : %s' % ( field_name, self.encoder.cql_encode_all_types(getattr(val, field_name, None)) ) for field_name in field_names) self.encoder.mapping[klass] = encode def submit(self, fn, *args, **kwargs): """ Internal """ if not self.is_shutdown: return self.cluster.executor.submit(fn, *args, **kwargs) def get_pool_state(self): return dict((host, pool.get_state()) for host, pool in tuple(self._pools.items())) def get_pools(self): return self._pools.values() def _validate_set_legacy_config(self, attr_name, value): if self.cluster._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." % (attr_name,)) setattr(self, '_' + attr_name, value) self.cluster._config_mode = _ConfigMode.LEGACY class UserTypeDoesNotExist(Exception): """ An attempt was made to use a user-defined type that does not exist. .. versionadded:: 2.1.0 """ pass class _ControlReconnectionHandler(_ReconnectionHandler): """ Internal """ def __init__(self, control_connection, *args, **kwargs): _ReconnectionHandler.__init__(self, *args, **kwargs) self.control_connection = weakref.proxy(control_connection) def try_reconnect(self): return self.control_connection._reconnect_internal() def on_reconnection(self, connection): self.control_connection._set_new_connection(connection) def on_exception(self, exc, next_delay): # TODO only overridden to add logging, so add logging if isinstance(exc, AuthenticationFailed): return False else: log.debug("Error trying to reconnect control connection: %r", exc) return True def _watch_callback(obj_weakref, method_name, *args, **kwargs): """ A callback handler for the ControlConnection that tolerates weak references. """ obj = obj_weakref() if obj is None: return getattr(obj, method_name)(*args, **kwargs) def _clear_watcher(conn, expiring_weakref): """ Called when the ControlConnection object is about to be finalized. This clears watchers on the underlying Connection object. """ try: conn.control_conn_disposed() except ReferenceError: pass class ControlConnection(object): """ Internal """ _SELECT_PEERS = "SELECT * FROM system.peers" - _SELECT_PEERS_NO_TOKENS = "SELECT peer, data_center, rack, rpc_address, release_version, schema_version FROM system.peers" + _SELECT_PEERS_NO_TOKENS = "SELECT host_id, peer, data_center, rack, rpc_address, release_version, schema_version FROM system.peers" _SELECT_LOCAL = "SELECT * FROM system.local WHERE key='local'" - _SELECT_LOCAL_NO_TOKENS = "SELECT cluster_name, data_center, rack, partitioner, release_version, schema_version FROM system.local WHERE key='local'" + _SELECT_LOCAL_NO_TOKENS = "SELECT host_id, cluster_name, data_center, rack, partitioner, release_version, schema_version FROM system.local WHERE key='local'" + # Used only when token_metadata_enabled is set to False + _SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS = "SELECT rpc_address FROM system.local WHERE key='local'" - _SELECT_SCHEMA_PEERS = "SELECT peer, rpc_address, schema_version FROM system.peers" + _SELECT_SCHEMA_PEERS = "SELECT peer, host_id, rpc_address, schema_version FROM system.peers" _SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'" _is_shutdown = False _timeout = None _protocol_version = None _schema_event_refresh_window = None _topology_event_refresh_window = None _status_event_refresh_window = None _schema_meta_enabled = True _token_meta_enabled = True # for testing purposes _time = time def __init__(self, cluster, timeout, schema_event_refresh_window, topology_event_refresh_window, status_event_refresh_window, schema_meta_enabled=True, token_meta_enabled=True): # use a weak reference to allow the Cluster instance to be GC'ed (and # shutdown) since implementing __del__ disables the cycle detector self._cluster = weakref.proxy(cluster) self._connection = None self._timeout = timeout self._schema_event_refresh_window = schema_event_refresh_window self._topology_event_refresh_window = topology_event_refresh_window self._status_event_refresh_window = status_event_refresh_window self._schema_meta_enabled = schema_meta_enabled self._token_meta_enabled = token_meta_enabled self._lock = RLock() self._schema_agreement_lock = Lock() self._reconnection_handler = None self._reconnection_lock = RLock() self._event_schedule_times = {} def connect(self): if self._is_shutdown: return self._protocol_version = self._cluster.protocol_version self._set_new_connection(self._reconnect_internal()) + self._cluster.metadata.dbaas = self._connection._product_type == dscloud.PRODUCT_APOLLO + def _set_new_connection(self, conn): """ Replace existing connection (if there is one) and close it. """ with self._lock: old = self._connection self._connection = conn if old: log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() def _reconnect_internal(self): """ Tries to connect to each host in the query plan until one succeeds or every attempt fails. If successful, a new Connection will be returned. Otherwise, :exc:`NoHostAvailable` will be raised with an "errors" arg that is a dict mapping host addresses to the exception that was raised when an attempt was made to open a connection to that host. """ errors = {} lbp = ( self._cluster.load_balancing_policy if self._cluster._config_mode == _ConfigMode.LEGACY else self._cluster._default_load_balancing_policy ) for host in lbp.make_query_plan(): try: return self._try_connect(host) except ConnectionException as exc: - errors[host.address] = exc + errors[str(host.endpoint)] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) self._cluster.signal_connection_failure(host, exc, is_host_addition=False) except Exception as exc: - errors[host.address] = exc + errors[str(host.endpoint)] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) if self._is_shutdown: raise DriverException("[control connection] Reconnection in progress during shutdown") raise NoHostAvailable("Unable to connect to any servers", errors) def _try_connect(self, host): """ Creates a new Connection, registers for pushed events, and refreshes node/token and schema metadata. """ log.debug("[control connection] Opening new connection to %s", host) while True: try: - connection = self._cluster.connection_factory(host.address, is_control_connection=True) + connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True) if self._is_shutdown: connection.close() raise DriverException("Reconnecting during shutdown") break except ProtocolVersionUnsupported as e: - self._cluster.protocol_downgrade(host.address, e.startup_version) + self._cluster.protocol_downgrade(host.endpoint, e.startup_version) log.debug("[control connection] Established new connection %r, " "registering watchers and refreshing schema and topology", connection) # use weak references in both directions # _clear_watcher will be called when this ControlConnection is about to be finalized # _watch_callback will get the actual callback from the Connection and relay it to # this object (after a dereferencing a weakref) self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) try: connection.register_watchers({ "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') }, register_timeout=self._timeout) sel_peers = self._SELECT_PEERS if self._token_meta_enabled else self._SELECT_PEERS_NO_TOKENS sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE) local_query = QueryMessage(query=sel_local, consistency_level=ConsistencyLevel.ONE) shared_results = connection.wait_for_responses( peers_query, local_query, timeout=self._timeout) self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results) self._refresh_schema(connection, preloaded_results=shared_results, schema_agreement_wait=-1) except Exception: connection.close() raise return connection def reconnect(self): if self._is_shutdown: return self._submit(self._reconnect) def _reconnect(self): log.debug("[control connection] Attempting to reconnect") try: self._set_new_connection(self._reconnect_internal()) except NoHostAvailable: # make a retry schedule (which includes backoff) schedule = self._cluster.reconnection_policy.new_schedule() with self._reconnection_lock: # cancel existing reconnection attempts if self._reconnection_handler: self._reconnection_handler.cancel() # when a connection is successfully made, _set_new_connection # will be called with the new connection and then our # _reconnection_handler will be cleared out self._reconnection_handler = _ControlReconnectionHandler( self, self._cluster.scheduler, schedule, self._get_and_set_reconnection_handler, new_handler=None) self._reconnection_handler.start() except Exception: log.debug("[control connection] error reconnecting", exc_info=True) raise def _get_and_set_reconnection_handler(self, new_handler): """ Called by the _ControlReconnectionHandler when a new connection is successfully created. Clears out the _reconnection_handler on this ControlConnection. """ with self._reconnection_lock: old = self._reconnection_handler self._reconnection_handler = new_handler return old def _submit(self, *args, **kwargs): try: if not self._cluster.is_shutdown: return self._cluster.executor.submit(*args, **kwargs) except ReferenceError: pass return None def shutdown(self): # stop trying to reconnect (if we are) with self._reconnection_lock: if self._reconnection_handler: self._reconnection_handler.cancel() with self._lock: if self._is_shutdown: return else: self._is_shutdown = True log.debug("Shutting down control connection") if self._connection: self._connection.close() self._connection = None def refresh_schema(self, force=False, **kwargs): try: if self._connection: return self._refresh_schema(self._connection, force=force, **kwargs) except ReferenceError: pass # our weak reference to the Cluster is no good except Exception: log.debug("[control connection] Error refreshing schema", exc_info=True) self._signal_error() return False def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_wait=None, force=False, **kwargs): if self._cluster.is_shutdown: return False agreed = self.wait_for_schema_agreement(connection, preloaded_results=preloaded_results, wait_time=schema_agreement_wait) if not self._schema_meta_enabled and not force: log.debug("[control connection] Skipping schema refresh because schema metadata is disabled") return False if not agreed: log.debug("Skipping schema refresh due to lack of schema agreement") return False self._cluster.metadata.refresh(connection, self._timeout, **kwargs) return True def refresh_node_list_and_token_map(self, force_token_rebuild=False): try: if self._connection: self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild) return True except ReferenceError: pass # our weak reference to the Cluster is no good except Exception: log.debug("[control connection] Error refreshing node list and token map", exc_info=True) self._signal_error() return False def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, force_token_rebuild=False): if preloaded_results: log.debug("[control connection] Refreshing node list and token map using preloaded results") peers_result = preloaded_results[0] local_result = preloaded_results[1] else: cl = ConsistencyLevel.ONE if not self._token_meta_enabled: log.debug("[control connection] Refreshing node list without token map") sel_peers = self._SELECT_PEERS_NO_TOKENS sel_local = self._SELECT_LOCAL_NO_TOKENS else: log.debug("[control connection] Refreshing node list and token map") sel_peers = self._SELECT_PEERS sel_local = self._SELECT_LOCAL peers_query = QueryMessage(query=sel_peers, consistency_level=cl) local_query = QueryMessage(query=sel_local, consistency_level=cl) peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=self._timeout) peers_result = dict_factory(*peers_result.results) partitioner = None token_map = {} found_hosts = set() if local_result.results: - found_hosts.add(connection.host) + found_hosts.add(connection.endpoint) local_rows = dict_factory(*(local_result.results)) local_row = local_rows[0] cluster_name = local_row["cluster_name"] self._cluster.metadata.cluster_name = cluster_name partitioner = local_row.get("partitioner") tokens = local_row.get("tokens") - host = self._cluster.metadata.get_host(connection.host) + host = self._cluster.metadata.get_host(connection.endpoint) if host: datacenter = local_row.get("data_center") rack = local_row.get("rack") self._update_location_info(host, datacenter, rack) + host.host_id = local_row.get("host_id") host.listen_address = local_row.get("listen_address") host.broadcast_address = local_row.get("broadcast_address") + + host.broadcast_rpc_address = self._address_from_row(local_row) + if host.broadcast_rpc_address is None: + if self._token_meta_enabled: + # local rpc_address is not available, use the connection endpoint + host.broadcast_rpc_address = connection.endpoint.address + else: + # local rpc_address has not been queried yet, try to fetch it + # separately, which might fail because C* < 2.1.6 doesn't have rpc_address + # in system.local. See CASSANDRA-9436. + local_rpc_address_query = QueryMessage(query=self._SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS, + consistency_level=ConsistencyLevel.ONE) + success, local_rpc_address_result = connection.wait_for_response( + local_rpc_address_query, timeout=self._timeout, fail_on_error=False) + if success: + row = dict_factory(*local_rpc_address_result.results) + host.broadcast_rpc_address = row[0]['rpc_address'] + else: + host.broadcast_rpc_address = connection.endpoint.address + host.release_version = local_row.get("release_version") host.dse_version = local_row.get("dse_version") host.dse_workload = local_row.get("workload") if partitioner and tokens: token_map[host] = tokens # Check metadata.partitioner to see if we haven't built anything yet. If # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None for row in peers_result: - addr = self._rpc_from_peer_row(row) + endpoint = self._cluster.endpoint_factory.create(row) tokens = row.get("tokens", None) if 'tokens' in row and not tokens: # it was selected, but empty - log.warning("Excluding host (%s) with no tokens in system.peers table of %s." % (addr, connection.host)) + log.warning("Excluding host (%s) with no tokens in system.peers table of %s." % (endpoint, connection.endpoint)) continue - if addr in found_hosts: - log.warning("Found multiple hosts with the same rpc_address (%s). Excluding peer %s", addr, row.get("peer")) + if endpoint in found_hosts: + log.warning("Found multiple hosts with the same endpoint (%s). Excluding peer %s", endpoint, row.get("peer")) continue - found_hosts.add(addr) + found_hosts.add(endpoint) - host = self._cluster.metadata.get_host(addr) + host = self._cluster.metadata.get_host(endpoint) datacenter = row.get("data_center") rack = row.get("rack") if host is None: - log.debug("[control connection] Found new host to connect to: %s", addr) - host, _ = self._cluster.add_host(addr, datacenter, rack, signal=True, refresh_nodes=False) + log.debug("[control connection] Found new host to connect to: %s", endpoint) + host, _ = self._cluster.add_host(endpoint, datacenter, rack, signal=True, refresh_nodes=False) should_rebuild_token_map = True else: should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) + host.host_id = row.get("host_id") host.broadcast_address = row.get("peer") + host.broadcast_rpc_address = self._address_from_row(row) host.release_version = row.get("release_version") host.dse_version = row.get("dse_version") host.dse_workload = row.get("workload") if partitioner and tokens: token_map[host] = tokens for old_host in self._cluster.metadata.all_hosts(): - if old_host.address != connection.host and old_host.address not in found_hosts: + if old_host.endpoint.address != connection.endpoint and old_host.endpoint not in found_hosts: should_rebuild_token_map = True log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) self._cluster.remove_host(old_host) log.debug("[control connection] Finished fetching ring info") if partitioner and should_rebuild_token_map: log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) def _update_location_info(self, host, datacenter, rack): if host.datacenter == datacenter and host.rack == rack: return False # If the dc/rack information changes, we need to update the load balancing policy. # For that, we remove and re-add the node against the policy. Not the most elegant, and assumes # that the policy will update correctly, but in practice this should work. self._cluster.profile_manager.on_down(host) host.set_location_info(datacenter, rack) self._cluster.profile_manager.on_up(host) return True def _delay_for_event_type(self, event_type, delay_window): # this serves to order processing correlated events (received within the window) # the window and randomization still have the desired effect of skew across client instances next_time = self._event_schedule_times.get(event_type, 0) now = self._time.time() if now <= next_time: this_time = next_time + 0.01 delay = this_time - now else: delay = random() * delay_window this_time = now + delay self._event_schedule_times[event_type] = this_time return delay - def _refresh_nodes_if_not_up(self, addr): + def _refresh_nodes_if_not_up(self, host): """ Used to mitigate refreshes for nodes that are already known. Some versions of the server send superfluous NEW_NODE messages in addition to UP events. """ - host = self._cluster.metadata.get_host(addr) if not host or not host.is_up: self.refresh_node_list_and_token_map() def _handle_topology_change(self, event): change_type = event["change_type"] - addr = self._translate_address(event["address"][0]) + host = self._cluster.metadata.get_host(event["address"][0]) if change_type == "NEW_NODE" or change_type == "MOVED_NODE": if self._topology_event_refresh_window >= 0: delay = self._delay_for_event_type('topology_change', self._topology_event_refresh_window) - self._cluster.scheduler.schedule_unique(delay, self._refresh_nodes_if_not_up, addr) + self._cluster.scheduler.schedule_unique(delay, self._refresh_nodes_if_not_up, host) elif change_type == "REMOVED_NODE": - host = self._cluster.metadata.get_host(addr) self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host) def _handle_status_change(self, event): change_type = event["change_type"] - addr = self._translate_address(event["address"][0]) - host = self._cluster.metadata.get_host(addr) + host = self._cluster.metadata.get_host(event["address"][0]) if change_type == "UP": delay = self._delay_for_event_type('status_change', self._status_event_refresh_window) if host is None: # this is the first time we've seen the node self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) else: self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host) elif change_type == "DOWN": # Note that there is a slight risk we can receive the event late and thus # mark the host down even though we already had reconnected successfully. # But it is unlikely, and don't have too much consequence since we'll try reconnecting # right away, so we favor the detection to make the Host.is_up more accurate. if host is not None: # this will be run by the scheduler self._cluster.on_down(host, is_host_addition=False) - def _translate_address(self, addr): - return self._cluster.address_translator.translate(addr) - def _handle_schema_change(self, event): if self._schema_event_refresh_window < 0: return delay = self._delay_for_event_type('schema_change', self._schema_event_refresh_window) self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event) def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait if total_timeout <= 0: return True # Each schema change typically generates two schema refreshes, one # from the response type and one from the pushed notification. Holding # a lock is just a simple way to cut down on the number of schema queries # we'll make. with self._schema_agreement_lock: if self._is_shutdown: return if not connection: connection = self._connection if preloaded_results: log.debug("[control connection] Attempting to use preloaded results for schema agreement") peers_result = preloaded_results[0] local_result = preloaded_results[1] - schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host) + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) if schema_mismatches is None: return True log.debug("[control connection] Waiting for schema agreement") start = self._time.time() elapsed = 0 cl = ConsistencyLevel.ONE schema_mismatches = None while elapsed < total_timeout: peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl) local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl) try: timeout = min(self._timeout, total_timeout - elapsed) peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=timeout) except OperationTimedOut as timeout: log.debug("[control connection] Timed out waiting for " "response during schema agreement check: %s", timeout) elapsed = self._time.time() - start continue except ConnectionShutdown: if self._is_shutdown: log.debug("[control connection] Aborting wait for schema match due to shutdown") return None else: raise - schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host) + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) if schema_mismatches is None: return True log.debug("[control connection] Schemas mismatched, trying again") self._time.sleep(0.2) elapsed = self._time.time() - start log.warning("Node %s is reporting a schema disagreement: %s", - connection.host, schema_mismatches) + connection.endpoint, schema_mismatches) return False def _get_schema_mismatches(self, peers_result, local_result, local_address): peers_result = dict_factory(*peers_result.results) versions = defaultdict(set) if local_result.results: local_row = dict_factory(*local_result.results)[0] if local_row.get("schema_version"): versions[local_row.get("schema_version")].add(local_address) for row in peers_result: schema_ver = row.get('schema_version') if not schema_ver: continue - addr = self._rpc_from_peer_row(row) - peer = self._cluster.metadata.get_host(addr) + endpoint = self._cluster.endpoint_factory.create(row) + peer = self._cluster.metadata.get_host(endpoint) if peer and peer.is_up is not False: - versions[schema_ver].add(addr) + versions[schema_ver].add(endpoint) if len(versions) == 1: log.debug("[control connection] Schemas match") return None return dict((version, list(nodes)) for version, nodes in six.iteritems(versions)) - def _rpc_from_peer_row(self, row): - addr = row.get("rpc_address") + def _address_from_row(self, row): + """ + Parse the broadcast rpc address from a row and return it untranslated. + """ + addr = None + if "rpc_address" in row: + addr = row.get("rpc_address") # peers and local + if "native_transport_address" in row: + addr = row.get("native_transport_address") if not addr or addr in ["0.0.0.0", "::"]: addr = row.get("peer") - return self._translate_address(addr) + return addr + def _signal_error(self): with self._lock: if self._is_shutdown: return # try just signaling the cluster, as this will trigger a reconnect # as part of marking the host down if self._connection and self._connection.is_defunct: - host = self._cluster.metadata.get_host(self._connection.host) + host = self._cluster.metadata.get_host(self._connection.endpoint) # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: self._cluster.signal_connection_failure( host, self._connection.last_error, is_host_addition=False) return # if the connection is not defunct or the host already left, reconnect # manually self.reconnect() def on_up(self, host): pass def on_down(self, host): conn = self._connection - if conn and conn.host == host.address and \ + if conn and conn.endpoint == host.endpoint and \ self._reconnection_handler is None: log.debug("[control connection] Control connection host (%s) is " "considered down, starting reconnection", host) # this will result in a task being submitted to the executor to reconnect self.reconnect() def on_add(self, host, refresh_nodes=True): if refresh_nodes: self.refresh_node_list_and_token_map(force_token_rebuild=True) def on_remove(self, host): c = self._connection - if c and c.host == host.address: + if c and c.endpoint == host.endpoint: log.debug("[control connection] Control connection host (%s) is being removed. Reconnecting", host) # refresh will be done on reconnect self.reconnect() else: self.refresh_node_list_and_token_map(force_token_rebuild=True) def get_connections(self): c = getattr(self, '_connection', None) return [c] if c else [] def return_connection(self, connection): if connection is self._connection and (connection.is_defunct or connection.is_closed): self.reconnect() def _stop_scheduler(scheduler, thread): try: if not scheduler.is_shutdown: scheduler.shutdown() except ReferenceError: pass thread.join() class _Scheduler(Thread): _queue = None _scheduled_tasks = None _executor = None is_shutdown = False def __init__(self, executor): self._queue = Queue.PriorityQueue() self._scheduled_tasks = set() self._count = count() self._executor = executor Thread.__init__(self, name="Task Scheduler") self.daemon = True self.start() def shutdown(self): try: log.debug("Shutting down Cluster Scheduler") except AttributeError: # this can happen on interpreter shutdown pass self.is_shutdown = True self._queue.put_nowait((0, 0, None)) self.join() def schedule(self, delay, fn, *args, **kwargs): self._insert_task(delay, (fn, args, tuple(kwargs.items()))) def schedule_unique(self, delay, fn, *args, **kwargs): task = (fn, args, tuple(kwargs.items())) if task not in self._scheduled_tasks: self._insert_task(delay, task) else: log.debug("Ignoring schedule_unique for already-scheduled task: %r", task) def _insert_task(self, delay, task): if not self.is_shutdown: run_at = time.time() + delay self._scheduled_tasks.add(task) self._queue.put_nowait((run_at, next(self._count), task)) else: log.debug("Ignoring scheduled task after shutdown: %r", task) def run(self): while True: if self.is_shutdown: return try: while True: run_at, i, task = self._queue.get(block=True, timeout=None) if self.is_shutdown: if task: log.debug("Not executing scheduled task due to Scheduler shutdown") return if run_at <= time.time(): self._scheduled_tasks.discard(task) fn, args, kwargs = task kwargs = dict(kwargs) future = self._executor.submit(fn, *args, **kwargs) future.add_done_callback(self._log_if_failed) else: self._queue.put_nowait((run_at, i, task)) break except Queue.Empty: pass time.sleep(0.1) def _log_if_failed(self, future): exc = future.exception() if exc: log.warning( "An internally scheduled tasked failed with an unhandled exception:", exc_info=exc) def refresh_schema_and_set_result(control_conn, response_future, connection, **kwargs): try: log.debug("Refreshing schema in response to schema change. " "%s", kwargs) response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs) except Exception: log.exception("Exception refreshing schema in response to schema change:") response_future.session.submit(control_conn.refresh_schema, **kwargs) finally: response_future._set_final_result(None) class ResponseFuture(object): """ An asynchronous response delivery mechanism that is returned from calls to :meth:`.Session.execute_async()`. There are two ways for results to be delivered: - Synchronously, by calling :meth:`.result()` - Asynchronously, by attaching callback and errback functions via :meth:`.add_callback()`, :meth:`.add_errback()`, and :meth:`.add_callbacks()`. """ query = None """ The :class:`~.Statement` instance that is being executed through this :class:`.ResponseFuture`. """ is_schema_agreed = True """ For DDL requests, this may be set ``False`` if the schema agreement poll after the response fails. Always ``True`` for non-DDL requests. """ request_encoded_size = None """ Size of the request message sent """ coordinator_host = None """ The host from which we recieved a response """ attempted_hosts = None """ A list of hosts tried, including all speculative executions, retries, and pages """ session = None row_factory = None message = None default_timeout = None _retry_policy = None _profile_manager = None _req_id = None _final_result = _NOT_SET _col_names = None _col_types = None _final_exception = None _query_traces = None _callbacks = None _errbacks = None _current_host = None _connection = None _query_retries = 0 _start_time = None _metrics = None _paging_state = None _custom_payload = None _warnings = None _timer = None _protocol_handler = ProtocolHandler _spec_execution_plan = NoSpeculativeExecutionPlan() + _host = None _warned_timeout = False def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None, - retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, speculative_execution_plan=None): + retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, + speculative_execution_plan=None, host=None): self.session = session # TODO: normalize handling of retry policy and row factory self.row_factory = row_factory or session.row_factory self._load_balancer = load_balancer or session.cluster._default_load_balancing_policy self.message = message self.query = query self.timeout = timeout self._retry_policy = retry_policy self._metrics = metrics self.prepared_statement = prepared_statement self._callback_lock = Lock() self._start_time = start_time or time.time() + self._host = host self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan self._make_query_plan() self._event = Event() self._errors = {} self._callbacks = [] self._errbacks = [] self.attempted_hosts = [] self._start_timer() @property def _time_remaining(self): if self.timeout is None: return None return (self._start_time + self.timeout) - time.time() def _start_timer(self): if self._timer is None: spec_delay = self._spec_execution_plan.next_execution(self._current_host) if spec_delay >= 0: if self._time_remaining is None or self._time_remaining > spec_delay: self._timer = self.session.cluster.connection_class.create_timer(spec_delay, self._on_speculative_execute) return if self._time_remaining is not None: self._timer = self.session.cluster.connection_class.create_timer(self._time_remaining, self._on_timeout) def _cancel_timer(self): if self._timer: self._timer.cancel() def _on_timeout(self, _attempts=0): """ Called when the request associated with this ResponseFuture times out. This function may reschedule itself. The ``_attempts`` parameter tracks the number of times this has happened. This parameter should only be set in those cases, where ``_on_timeout`` reschedules itself. """ # PYTHON-853: for short timeouts, we sometimes race with our __init__ if self._connection is None and _attempts < 3: self._timer = self.session.cluster.connection_class.create_timer( 0.01, partial(self._on_timeout, _attempts=_attempts + 1) ) return if self._connection is not None: try: self._connection._requests.pop(self._req_id) - # This prevents the race condition of the - # event loop thread just receiving the waited message - # If it arrives after this, it will be ignored + # PYTHON-1044 + # This request might have been removed from the connection after the latter was defunct by heartbeat. + # We should still raise OperationTimedOut to reject the future so that the main event thread will not + # wait for it endlessly except KeyError: + key = "Connection defunct by heartbeat" + errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} + self._set_final_exception(OperationTimedOut(errors, self._current_host)) return pool = self.session._pools.get(self._current_host) if pool and not pool.is_shutdown: with self._connection.lock: self._connection.request_ids.append(self._req_id) pool.return_connection(self._connection) errors = self._errors if not errors: if self.is_schema_agreed: - key = self._current_host.address if self._current_host else 'no host queried before timeout' + key = str(self._current_host.endpoint) if self._current_host else 'no host queried before timeout' errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} else: connection = self.session.cluster.control_connection._connection - host = connection.host if connection else 'unknown' + host = str(connection.endpoint) if connection else 'unknown' errors = {host: "Request timed out while waiting for schema agreement. See Session.execute[_async](timeout) and Cluster.max_schema_agreement_wait."} self._set_final_exception(OperationTimedOut(errors, self._current_host)) def _on_speculative_execute(self): self._timer = None if not self._event.is_set(): # PYTHON-836, the speculative queries must be after # the query is sent from the main thread, otherwise the # query from the main thread may raise NoHostAvailable # if the _query_plan has been exhausted by the specualtive queries. # This also prevents a race condition accessing the iterator. # We reschedule this call until the main thread has succeeded # making a query if not self.attempted_hosts: self._timer = self.session.cluster.connection_class.create_timer(0.01, self._on_speculative_execute) return if self._time_remaining is not None: if self._time_remaining <= 0: self._on_timeout() return self.send_request(error_no_hosts=False) self._start_timer() - def _make_query_plan(self): - # convert the list/generator/etc to an iterator so that subsequent - # calls to send_request (which retries may do) will resume where - # they last left off - self.query_plan = iter(self._load_balancer.make_query_plan(self.session.keyspace, self.query)) + # set the query_plan according to the load balancing policy, + # or to the explicit host target if set + if self._host: + # returning a single value effectively disables retries + self.query_plan = [self._host] + else: + # convert the list/generator/etc to an iterator so that subsequent + # calls to send_request (which retries may do) will resume where + # they last left off + self.query_plan = iter(self._load_balancer.make_query_plan(self.session.keyspace, self.query)) def send_request(self, error_no_hosts=True): """ Internal """ # query_plan is an iterator, so this will resume where we last left # off if send_request() is called multiple times for host in self.query_plan: req_id = self._query(host) if req_id is not None: self._req_id = req_id return True if self.timeout is not None and time.time() - self._start_time > self.timeout: self._on_timeout() return True if error_no_hosts: self._set_final_exception(NoHostAvailable( "Unable to complete the operation against any hosts", self._errors)) return False def _query(self, host, message=None, cb=None): if message is None: message = self.message pool = self.session._pools.get(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") return None elif pool.is_shutdown: self._errors[host] = ConnectionException("Pool is shutdown") return None self._current_host = host connection = None try: # TODO get connectTimeout from cluster settings connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] if cb is None: cb = partial(self._set_result, host, connection, pool) self.request_encoded_size = connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message, result_metadata=result_meta) self.attempted_hosts.append(host) return request_id except NoConnectionsAvailable as exc: log.debug("All connections for host %s are at capacity, moving to the next host", host) self._errors[host] = exc return None except Exception as exc: log.debug("Error querying host %s", host, exc_info=True) self._errors[host] = exc if self._metrics is not None: self._metrics.on_connection_error() if connection: pool.return_connection(connection) return None @property def has_more_pages(self): """ Returns :const:`True` if there are more pages left in the query results, :const:`False` otherwise. This should only be checked after the first page has been returned. .. versionadded:: 2.0.0 """ return self._paging_state is not None @property def warnings(self): """ Warnings returned from the server, if any. This will only be set for protocol_version 4+. Warnings may be returned for such things as oversized batches, or too many tombstones in slice queries. Ensure the future is complete before trying to access this property (call :meth:`.result()`, or after callback is invoked). Otherwise it may throw if the response has not been received. """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): raise DriverException("warnings cannot be retrieved before ResponseFuture is finalized") return self._warnings @property def custom_payload(self): """ The custom payload returned from the server, if any. This will only be set by Cassandra servers implementing a custom QueryHandler, and only for protocol_version 4+. Ensure the future is complete before trying to access this property (call :meth:`.result()`, or after callback is invoked). Otherwise it may throw if the response has not been received. :return: :ref:`custom_payload`. """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): raise DriverException("custom_payload cannot be retrieved before ResponseFuture is finalized") return self._custom_payload def start_fetching_next_page(self): """ If there are more pages left in the query result, this asynchronously starts fetching the next page. If there are no pages left, :exc:`.QueryExhausted` is raised. Also see :attr:`.has_more_pages`. This should only be called after the first page has been returned. .. versionadded:: 2.0.0 """ if not self._paging_state: raise QueryExhausted() self._make_query_plan() self.message.paging_state = self._paging_state self._event.clear() self._final_result = _NOT_SET self._final_exception = None self._start_timer() self.send_request() def _reprepare(self, prepare_message, host, connection, pool): cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool) request_id = self._query(host, prepare_message, cb=cb) if request_id is None: # try to submit the original prepared statement on some other host self.send_request() def _set_result(self, host, connection, pool, response): try: self.coordinator_host = host if pool: pool.return_connection(connection) trace_id = getattr(response, 'trace_id', None) if trace_id: if not self._query_traces: self._query_traces = [] self._query_traces.append(QueryTrace(trace_id, self.session)) self._warnings = getattr(response, 'warnings', None) self._custom_payload = getattr(response, 'custom_payload', None) if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_SET_KEYSPACE: session = getattr(self, 'session', None) # since we're running on the event loop thread, we need to # use a non-blocking method for setting the keyspace on # all connections in this session, otherwise the event # loop thread will deadlock waiting for keyspaces to be # set. This uses a callback chain which ends with # self._set_keyspace_completed() being called in the # event loop thread. if session: session._set_keyspace_for_all_pools( response.results, self._set_keyspace_completed) elif response.kind == RESULT_KIND_SCHEMA_CHANGE: # refresh the schema before responding, but do it in another # thread instead of the event loop thread self.is_schema_agreed = False self.session.submit( refresh_schema_and_set_result, self.session.cluster.control_connection, self, connection, **response.results) else: results = getattr(response, 'results', None) if results is not None and response.kind == RESULT_KIND_ROWS: self._paging_state = response.paging_state self._col_types = response.col_types self._col_names = results[0] results = self.row_factory(*results) self._set_final_result(results) elif isinstance(response, ErrorMessage): retry_policy = self._retry_policy if isinstance(response, ReadTimeoutErrorMessage): if self._metrics is not None: self._metrics.on_read_timeout() retry = retry_policy.on_read_timeout( self.query, retry_num=self._query_retries, **response.info) elif isinstance(response, WriteTimeoutErrorMessage): if self._metrics is not None: self._metrics.on_write_timeout() retry = retry_policy.on_write_timeout( self.query, retry_num=self._query_retries, **response.info) elif isinstance(response, UnavailableErrorMessage): if self._metrics is not None: self._metrics.on_unavailable() retry = retry_policy.on_unavailable( self.query, retry_num=self._query_retries, **response.info) - elif isinstance(response, OverloadedErrorMessage): + elif isinstance(response, (OverloadedErrorMessage, + IsBootstrappingErrorMessage, + TruncateError, ServerError)): + log.warning("Host %s error: %s.", host, response.summary) if self._metrics is not None: self._metrics.on_other_error() - # need to retry against a different host here - log.warning("Host %s is overloaded, retrying against a different " - "host", host) - self._retry(reuse_connection=False, consistency_level=None, host=host) - return - elif isinstance(response, IsBootstrappingErrorMessage): - if self._metrics is not None: - self._metrics.on_other_error() - # need to retry against a different host here - self._retry(reuse_connection=False, consistency_level=None, host=host) - return + retry = retry_policy.on_request_error( + self.query, self.message.consistency_level, error=response, + retry_num=self._query_retries) elif isinstance(response, PreparedQueryNotFound): if self.prepared_statement: query_id = self.prepared_statement.query_id assert query_id == response.info, \ "Got different query ID in server response (%s) than we " \ "had before (%s)" % (response.info, query_id) else: query_id = response.info try: prepared_statement = self.session.cluster._prepared_statements[query_id] except KeyError: if not self.prepared_statement: log.error("Tried to execute unknown prepared statement: id=%s", query_id.encode('hex')) self._set_final_exception(response) return else: prepared_statement = self.prepared_statement self.session.cluster._prepared_statements[query_id] = prepared_statement current_keyspace = self._connection.keyspace prepared_keyspace = prepared_statement.keyspace if not ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) \ and prepared_keyspace and current_keyspace != prepared_keyspace: self._set_final_exception( ValueError("The Session's current keyspace (%s) does " "not match the keyspace the statement was " "prepared with (%s)" % (current_keyspace, prepared_keyspace))) return log.debug("Re-preparing unrecognized prepared statement against host %s: %s", host, prepared_statement.query_string) prepared_keyspace = prepared_statement.keyspace \ if ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) else None prepare_message = PrepareMessage(query=prepared_statement.query_string, keyspace=prepared_keyspace) # since this might block, run on the executor to avoid hanging # the event loop thread self.session.submit(self._reprepare, prepare_message, host, connection, pool) return else: if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) return - retry_type, consistency = retry - if retry_type in (RetryPolicy.RETRY, RetryPolicy.RETRY_NEXT_HOST): - self._query_retries += 1 - reuse = retry_type == RetryPolicy.RETRY - self._retry(reuse, consistency, host) - elif retry_type is RetryPolicy.RETHROW: - self._set_final_exception(response.to_exception()) - else: # IGNORE - if self._metrics is not None: - self._metrics.on_ignore() - self._set_final_result(None) - self._errors[host] = response.to_exception() + self._handle_retry_decision(retry, response, host) elif isinstance(response, ConnectionException): if self._metrics is not None: self._metrics.on_connection_error() if not isinstance(response, ConnectionShutdown): self._connection.defunct(response) - self._retry(reuse_connection=False, consistency_level=None, host=host) + retry = self._retry_policy.on_request_error( + self.query, self.message.consistency_level, error=response, + retry_num=self._query_retries) + self._handle_retry_decision(retry, response, host) elif isinstance(response, Exception): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) else: # we got some other kind of response message msg = "Got unexpected message: %r" % (response,) exc = ConnectionException(msg, host) self._cancel_timer() self._connection.defunct(exc) self._set_final_exception(exc) except Exception as exc: # almost certainly caused by a bug, but we need to set something here log.exception("Unexpected exception while handling result in ResponseFuture:") self._set_final_exception(exc) def _set_keyspace_completed(self, errors): if not errors: self._set_final_result(None) else: self._set_final_exception(ConnectionException( "Failed to set keyspace on all hosts: %s" % (errors,))) def _execute_after_prepare(self, host, connection, pool, response): """ Handle the response to our attempt to prepare a statement. If it succeeded, run the original query again against the same host. """ if pool: pool.return_connection(connection) if self._final_exception: return if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_PREPARED: if self.prepared_statement: # result metadata is the only thing that could have # changed from an alter (_, _, _, self.prepared_statement.result_metadata, new_metadata_id) = response.results if new_metadata_id is not None: self.prepared_statement.result_metadata_id = new_metadata_id # use self._query to re-use the same host and # at the same time properly borrow the connection request_id = self._query(host) if request_id is None: # this host errored out, move on to the next self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response when preparing statement " "on host %s: %s" % (host, response))) elif isinstance(response, ErrorMessage): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) elif isinstance(response, ConnectionException): log.debug("Connection error when preparing statement on host %s: %s", host, response) # try again on a different host, preparing again if necessary self._errors[host] = response self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response type when preparing " "statement on host %s: %s" % (host, response))) def _set_final_result(self, response): self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) with self._callback_lock: self._final_result = response # save off current callbacks inside lock for execution outside it # -- prevents case where _final_result is set, then a callback is # added and executed on the spot, then executed again as a # registered callback to_call = tuple( partial(fn, response, *args, **kwargs) for (fn, args, kwargs) in self._callbacks ) self._event.set() # apply each callback for callback_partial in to_call: callback_partial() def _set_final_exception(self, response): self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) with self._callback_lock: self._final_exception = response # save off current errbacks inside lock for execution outside it -- # prevents case where _final_exception is set, then an errback is # added and executed on the spot, then executed again as a # registered errback to_call = tuple( partial(fn, response, *args, **kwargs) for (fn, args, kwargs) in self._errbacks ) self._event.set() # apply each callback for callback_partial in to_call: callback_partial() + def _handle_retry_decision(self, retry_decision, response, host): + + def exception_from_response(response): + if hasattr(response, 'to_exception'): + return response.to_exception() + else: + return response + + retry_type, consistency = retry_decision + if retry_type in (RetryPolicy.RETRY, RetryPolicy.RETRY_NEXT_HOST): + self._query_retries += 1 + reuse = retry_type == RetryPolicy.RETRY + self._retry(reuse, consistency, host) + elif retry_type is RetryPolicy.RETHROW: + self._set_final_exception(exception_from_response(response)) + else: # IGNORE + if self._metrics is not None: + self._metrics.on_ignore() + self._set_final_result(None) + + self._errors[host] = exception_from_response(response) + def _retry(self, reuse_connection, consistency_level, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation return if self._metrics is not None: self._metrics.on_retry() if consistency_level is not None: self.message.consistency_level = consistency_level # don't retry on the event loop thread self.session.submit(self._retry_task, reuse_connection, host) def _retry_task(self, reuse_connection, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation return if reuse_connection and self._query(host) is not None: return # otherwise, move onto another host self.send_request() def result(self): """ Return the final result or raise an Exception if errors were encountered. If the final result or error has not been set yet, this method will block until it is set, or the timeout set for the request expires. Timeout is specified in the Session request execution functions. If the timeout is exceeded, an :exc:`cassandra.OperationTimedOut` will be raised. This is a client-side timeout. For more information about server-side coordinator timeouts, see :class:`.policies.RetryPolicy`. Example usage:: >>> future = session.execute_async("SELECT * FROM mycf") >>> # do other stuff... >>> try: ... rows = future.result() ... for row in rows: ... ... # process results ... except Exception: ... log.exception("Operation failed:") """ self._event.wait() if self._final_result is not _NOT_SET: return ResultSet(self, self._final_result) else: raise self._final_exception def get_query_trace_ids(self): """ Returns the trace session ids for this future, if tracing was enabled (does not fetch trace data). """ return [trace.trace_id for trace in self._query_traces] def get_query_trace(self, max_wait=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ Fetches and returns the query trace of the last response, or `None` if tracing was not enabled. Note that this may raise an exception if there are problems retrieving the trace details from Cassandra. If the trace is not available after `max_wait`, :exc:`cassandra.query.TraceUnavailable` will be raised. If the ResponseFuture is not done (async execution) and you try to retrieve the trace, :exc:`cassandra.query.TraceUnavailable` will be raised. `query_cl` is the consistency level used to poll the trace tables. """ if self._final_result is _NOT_SET and self._final_exception is None: raise TraceUnavailable( "Trace information was not available. The ResponseFuture is not done.") if self._query_traces: return self._get_query_trace(len(self._query_traces) - 1, max_wait, query_cl) def get_all_query_traces(self, max_wait_per=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ Fetches and returns the query traces for all query pages, if tracing was enabled. See note in :meth:`~.get_query_trace` regarding possible exceptions. """ if self._query_traces: return [self._get_query_trace(i, max_wait_per, query_cl) for i in range(len(self._query_traces))] return [] def _get_query_trace(self, i, max_wait, query_cl): trace = self._query_traces[i] if not trace.events: trace.populate(max_wait=max_wait, query_cl=query_cl) return trace def add_callback(self, fn, *args, **kwargs): """ Attaches a callback function to be called when the final results arrive. By default, `fn` will be called with the results as the first and only argument. If `*args` or `**kwargs` are supplied, they will be passed through as additional positional or keyword arguments to `fn`. If an error is hit while executing the operation, a callback attached here will not be called. Use :meth:`.add_errback()` or :meth:`add_callbacks()` if you wish to handle that case. If the final result has already been seen when this method is called, the callback will be called immediately (before this method returns). Note: in the case that the result is not available when the callback is added, the callback is executed by IO event thread. This means that the callback should not block or attempt further synchronous requests, because no further IO will be processed until the callback returns. **Important**: if the callback you attach results in an exception being raised, **the exception will be ignored**, so please ensure your callback handles all error cases that you care about. Usage example:: >>> session = cluster.connect("mykeyspace") >>> def handle_results(rows, start_time, should_log=False): ... if should_log: ... log.info("Total time: %f", time.time() - start_time) ... ... >>> future = session.execute_async("SELECT * FROM users") >>> future.add_callback(handle_results, time.time(), should_log=True) """ run_now = False with self._callback_lock: # Always add fn to self._callbacks, even when we're about to # execute it, to prevent races with functions like # start_fetching_next_page that reset _final_result self._callbacks.append((fn, args, kwargs)) if self._final_result is not _NOT_SET: run_now = True if run_now: fn(self._final_result, *args, **kwargs) return self def add_errback(self, fn, *args, **kwargs): """ Like :meth:`.add_callback()`, but handles error cases. An Exception instance will be passed as the first positional argument to `fn`. """ run_now = False with self._callback_lock: # Always add fn to self._errbacks, even when we're about to execute # it, to prevent races with functions like start_fetching_next_page # that reset _final_exception self._errbacks.append((fn, args, kwargs)) if self._final_exception: run_now = True if run_now: fn(self._final_exception, *args, **kwargs) return self def add_callbacks(self, callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_kwargs=None): """ A convenient combination of :meth:`.add_callback()` and :meth:`.add_errback()`. Example usage:: >>> session = cluster.connect() >>> query = "SELECT * FROM mycf" >>> future = session.execute_async(query) >>> def log_results(results, level='debug'): ... for row in results: ... log.log(level, "Result: %s", row) >>> def log_error(exc, query): ... log.error("Query '%s' failed: %s", query, exc) >>> future.add_callbacks( ... callback=log_results, callback_kwargs={'level': 'info'}, ... errback=log_error, errback_args=(query,)) """ self.add_callback(callback, *callback_args, **(callback_kwargs or {})) self.add_errback(errback, *errback_args, **(errback_kwargs or {})) def clear_callbacks(self): with self._callback_lock: self._callbacks = [] self._errbacks = [] def __str__(self): result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result return "" \ % (self.query, self._req_id, result, self._final_exception, self.coordinator_host) __repr__ = __str__ class QueryExhausted(Exception): """ Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and there are no more pages. You can check :attr:`.ResponseFuture.has_more_pages` before calling to avoid this. .. versionadded:: 2.0.0 """ pass class ResultSet(object): """ An iterator over the rows from a query result. Also supplies basic equality and indexing methods for backward-compatability. These methods materialize the entire result set (loading all pages), and should only be used if the total result size is understood. Warnings are emitted when paged results are materialized in this fashion. You can treat this as a normal iterator over rows:: >>> from cassandra.query import SimpleStatement >>> statement = SimpleStatement("SELECT * FROM users", fetch_size=10) >>> for user_row in session.execute(statement): ... process_user(user_row) Whenever there are no more rows in the current page, the next page will be fetched transparently. However, note that it *is* possible for an :class:`Exception` to be raised while fetching the next page, just like you might see on a normal call to ``session.execute()``. """ def __init__(self, response_future, initial_response): self.response_future = response_future self.column_names = response_future._col_names self.column_types = response_future._col_types self._set_current_rows(initial_response) self._page_iter = None self._list_mode = False @property def has_more_pages(self): """ True if the last response indicated more pages; False otherwise """ return self.response_future.has_more_pages @property def current_rows(self): """ The list of current page rows. May be empty if the result was empty, or this is the last page. """ return self._current_rows or [] def one(self): """ Return a single row of the results or None if empty. This is basically a shortcut to `result_set.current_rows[0]` and should only be used when you know a query returns a single row. Consider using an iterator if the ResultSet contains more than one row. """ row = None if self._current_rows: try: row = self._current_rows[0] except TypeError: # generator object is not subscriptable, PYTHON-1026 row = next(iter(self._current_rows)) return row def __iter__(self): if self._list_mode: return iter(self._current_rows) self._page_iter = iter(self._current_rows) return self def next(self): try: return next(self._page_iter) except StopIteration: if not self.response_future.has_more_pages: if not self._list_mode: self._current_rows = [] raise self.fetch_next_page() self._page_iter = iter(self._current_rows) return next(self._page_iter) __next__ = next def fetch_next_page(self): """ Manually, synchronously fetch the next page. Supplied for manually retrieving pages and inspecting :meth:`~.current_page`. It is not necessary to call this when iterating through results; paging happens implicitly in iteration. """ if self.response_future.has_more_pages: self.response_future.start_fetching_next_page() result = self.response_future.result() self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form else: self._current_rows = [] def _set_current_rows(self, result): if isinstance(result, Mapping): self._current_rows = [result] if result else [] return try: iter(result) # can't check directly for generator types because cython generators are different self._current_rows = result except TypeError: self._current_rows = [result] if result else [] def _fetch_all(self): self._current_rows = list(self) self._page_iter = None def _enter_list_mode(self, operator): if self._list_mode: return if self._page_iter: raise RuntimeError("Cannot use %s when results have been iterated." % operator) if self.response_future.has_more_pages: log.warning("Using %s on paged results causes entire result set to be materialized.", operator) self._fetch_all() # done regardless of paging status in case the row factory produces a generator self._list_mode = True def __eq__(self, other): self._enter_list_mode("equality operator") return self._current_rows == other def __getitem__(self, i): if i == 0: warn("ResultSet indexing support will be removed in 4.0. Consider using " "ResultSet.one() to get a single row.", DeprecationWarning) self._enter_list_mode("index operator") return self._current_rows[i] def __nonzero__(self): return bool(self._current_rows) __bool__ = __nonzero__ def get_query_trace(self, max_wait_sec=None): """ Gets the last query trace from the associated future. See :meth:`.ResponseFuture.get_query_trace` for details. """ return self.response_future.get_query_trace(max_wait_sec) def get_all_query_traces(self, max_wait_sec_per=None): """ Gets all query traces from the associated future. See :meth:`.ResponseFuture.get_all_query_traces` for details. """ return self.response_future.get_all_query_traces(max_wait_sec_per) @property def was_applied(self): """ For LWT results, returns whether the transaction was applied. Result is indeterminate if called on a result that was not an LWT request or on a :class:`.query.BatchStatement` containing LWT. In the latter case either all the batch succeeds or fails. Only valid when one of the of the internal row factories is in use. """ if self.response_future.row_factory not in (named_tuple_factory, dict_factory, tuple_factory): raise RuntimeError("Cannot determine LWT result with row factory %s" % (self.response_future.row_factory,)) is_batch_statement = isinstance(self.response_future.query, BatchStatement) if is_batch_statement and (not self.column_names or self.column_names[0] != "[applied]"): raise RuntimeError("No LWT were present in the BatchStatement") if not is_batch_statement and len(self.current_rows) != 1: raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows))) row = self.current_rows[0] if isinstance(row, tuple): return row[0] else: return row['[applied]'] @property def paging_state(self): """ Server paging state of the query. Can be `None` if the query was not paged. The driver treats paging state as opaque, but it may contain primary key data, so applications may want to avoid sending this to untrusted parties. """ return self.response_future._paging_state diff --git a/cassandra/compat.py b/cassandra/compat.py new file mode 100644 index 0000000..83c1b10 --- /dev/null +++ b/cassandra/compat.py @@ -0,0 +1,20 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six + +if six.PY2: + from collections import Mapping +elif six.PY3: + from collections.abc import Mapping diff --git a/cassandra/connection.py b/cassandra/connection.py index f017bf5..ba08ae2 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1,1137 +1,1418 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import # to enable import io from stdlib from collections import defaultdict, deque import errno -from functools import wraps, partial +from functools import wraps, partial, total_ordering from heapq import heappush, heappop import io import logging import six from six.moves import range import socket import struct import sys from threading import Thread, Event, RLock import time try: import ssl except ImportError: ssl = None # NOQA if 'gevent.monkey' in sys.modules: from gevent.queue import Queue, Empty else: from six.moves.queue import Queue, Empty # noqa from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut, ProtocolVersion from cassandra.marshal import int32_pack from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage, StartupMessage, ErrorMessage, CredentialsMessage, QueryMessage, ResultMessage, ProtocolHandler, InvalidRequestException, SupportedMessage, AuthResponseMessage, AuthChallengeMessage, AuthSuccessMessage, ProtocolException, RegisterMessage) from cassandra.util import OrderedDict log = logging.getLogger(__name__) # We use an ordered dictionary and specifically add lz4 before # snappy so that lz4 will be preferred. Changing the order of this # will change the compression preferences for the driver. locally_supported_compressions = OrderedDict() try: import lz4 except ImportError: pass else: # The compress and decompress functions we need were moved from the lz4 to # the lz4.block namespace, so we try both here. try: from lz4 import block as lz4_block except ImportError: lz4_block = lz4 + try: + lz4_block.compress + lz4_block.decompress + except AttributeError: + raise ImportError( + 'lz4 not imported correctly. Imported object should have ' + '.compress and and .decompress attributes but does not. ' + 'Please file a bug report on JIRA. (Imported object was ' + '{lz4_block})'.format(lz4_block=repr(lz4_block)) + ) + # Cassandra writes the uncompressed message length in big endian order, # but the lz4 lib requires little endian order, so we wrap these # functions to handle that def lz4_compress(byts): # write length in big-endian instead of little-endian return int32_pack(len(byts)) + lz4_block.compress(byts)[4:] def lz4_decompress(byts): # flip from big-endian to little-endian return lz4_block.decompress(byts[3::-1] + byts[4:]) locally_supported_compressions['lz4'] = (lz4_compress, lz4_decompress) try: import snappy except ImportError: pass else: # work around apparently buggy snappy decompress def decompress(byts): if byts == '\x00': return '' return snappy.decompress(byts) locally_supported_compressions['snappy'] = (snappy.compress, decompress) +DRIVER_NAME, DRIVER_VERSION = 'DataStax Python Driver', sys.modules['cassandra'].__version__ PROTOCOL_VERSION_MASK = 0x7f HEADER_DIRECTION_FROM_CLIENT = 0x00 HEADER_DIRECTION_TO_CLIENT = 0x80 HEADER_DIRECTION_MASK = 0x80 frame_header_v1_v2 = struct.Struct('>BbBi') frame_header_v3 = struct.Struct('>BhBi') +class EndPoint(object): + """ + Represents the information to connect to a cassandra node. + """ + + @property + def address(self): + """ + The IP address of the node. This is the RPC address the driver uses when connecting to the node + """ + raise NotImplementedError() + + @property + def port(self): + """ + The port of the node. + """ + raise NotImplementedError() + + @property + def ssl_options(self): + """ + SSL options specific to this endpoint. + """ + return None + + @property + def socket_family(self): + """ + The socket family of the endpoint. + """ + return socket.AF_UNSPEC + + def resolve(self): + """ + Resolve the endpoint to an address/port. This is called + only on socket connection. + """ + raise NotImplementedError() + + +class EndPointFactory(object): + + cluster = None + + def configure(self, cluster): + """ + This is called by the cluster during its initialization. + """ + self.cluster = cluster + return self + + def create(self, row): + """ + Create an EndPoint from a system.peers row. + """ + raise NotImplementedError() + + +@total_ordering +class DefaultEndPoint(EndPoint): + """ + Default EndPoint implementation, basically just an address and port. + """ + + def __init__(self, address, port=9042): + self._address = address + self._port = port + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + def resolve(self): + return self._address, self._port + + def __eq__(self, other): + return isinstance(other, DefaultEndPoint) and \ + self.address == other.address and self.port == other.port + + def __hash__(self): + return hash((self.address, self.port)) + + def __lt__(self, other): + return (self.address, self.port) < (other.address, other.port) + + def __str__(self): + return str("%s:%d" % (self.address, self.port)) + + def __repr__(self): + return "<%s: %s:%d>" % (self.__class__.__name__, self.address, self.port) + + +class DefaultEndPointFactory(EndPointFactory): + + port = None + """ + If set, force all endpoints to use this port. + """ + + def __init__(self, port=None): + self.port = port + + def create(self, row): + addr = None + if "rpc_address" in row: + addr = row.get("rpc_address") + if "native_transport_address" in row: + addr = row.get("native_transport_address") + if not addr or addr in ["0.0.0.0", "::"]: + addr = row.get("peer") + + # create the endpoint with the translated address + return DefaultEndPoint( + self.cluster.address_translator.translate(addr), + self.port if self.port is not None else 9042) + + +@total_ordering +class SniEndPoint(EndPoint): + """SNI Proxy EndPoint implementation.""" + + def __init__(self, proxy_address, server_name, port=9042): + self._proxy_address = proxy_address + self._index = 0 + self._resolved_address = None # resolved address + self._port = port + self._server_name = server_name + self._ssl_options = {'server_hostname': server_name} + + @property + def address(self): + return self._proxy_address + + @property + def port(self): + return self._port + + @property + def ssl_options(self): + return self._ssl_options + + def resolve(self): + try: + resolved_addresses = socket.getaddrinfo(self._proxy_address, self._port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + except socket.gaierror: + log.debug('Could not resolve sni proxy hostname "%s" ' + 'with port %d' % (self._proxy_address, self._port)) + raise + + # round-robin pick + self._resolved_address = sorted(addr[4][0] for addr in resolved_addresses)[self._index % len(resolved_addresses)] + self._index += 1 + + return self._resolved_address, self._port + + def __eq__(self, other): + return (isinstance(other, SniEndPoint) and + self.address == other.address and self.port == other.port and + self._server_name == other._server_name) + + def __hash__(self): + return hash((self.address, self.port, self._server_name)) + + def __lt__(self, other): + return ((self.address, self.port, self._server_name) < + (other.address, other.port, self._server_name)) + + def __str__(self): + return str("%s:%d:%s" % (self.address, self.port, self._server_name)) + + def __repr__(self): + return "<%s: %s:%d:%s>" % (self.__class__.__name__, + self.address, self.port, self._server_name) + + +class SniEndPointFactory(EndPointFactory): + + def __init__(self, proxy_address, port): + self._proxy_address = proxy_address + self._port = port + + def create(self, row): + host_id = row.get("host_id") + if host_id is None: + raise ValueError("No host_id to create the SniEndPoint") + + return SniEndPoint(self._proxy_address, str(host_id), self._port) + + def create_from_sni(self, sni): + return SniEndPoint(self._proxy_address, sni, self._port) + + +@total_ordering +class UnixSocketEndPoint(EndPoint): + """ + Unix Socket EndPoint implementation. + """ + + def __init__(self, unix_socket_path): + self._unix_socket_path = unix_socket_path + + @property + def address(self): + return self._unix_socket_path + + @property + def port(self): + return None + + @property + def socket_family(self): + return socket.AF_UNIX + + def resolve(self): + return self.address, None + + def __eq__(self, other): + return (isinstance(other, UnixSocketEndPoint) and + self._unix_socket_path == other._unix_socket_path) + + def __hash__(self): + return hash(self._unix_socket_path) + + def __lt__(self, other): + return self._unix_socket_path < other._unix_socket_path + + def __str__(self): + return str("%s" % (self._unix_socket_path,)) + + def __repr__(self): + return "<%s: %s>" % (self.__class__.__name__, self._unix_socket_path) + + class _Frame(object): def __init__(self, version, flags, stream, opcode, body_offset, end_pos): self.version = version self.flags = flags self.stream = stream self.opcode = opcode self.body_offset = body_offset self.end_pos = end_pos def __eq__(self, other): # facilitates testing if isinstance(other, _Frame): return (self.version == other.version and self.flags == other.flags and self.stream == other.stream and self.opcode == other.opcode and self.body_offset == other.body_offset and self.end_pos == other.end_pos) return NotImplemented def __str__(self): return "ver({0}); flags({1:04b}); stream({2}); op({3}); offset({4}); len({5})".format(self.version, self.flags, self.stream, self.opcode, self.body_offset, self.end_pos - self.body_offset) - NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK) class ConnectionException(Exception): """ An unrecoverable error was hit when attempting to use a connection, or the connection was already closed or defunct. """ - def __init__(self, message, host=None): + def __init__(self, message, endpoint=None): Exception.__init__(self, message) - self.host = host + self.endpoint = endpoint + + @property + def host(self): + return self.endpoint.address class ConnectionShutdown(ConnectionException): """ Raised when a connection has been marked as defunct or has been closed. """ pass class ProtocolVersionUnsupported(ConnectionException): """ Server rejected startup message due to unsupported protocol version """ - def __init__(self, host, startup_version): - msg = "Unsupported protocol version on %s: %d" % (host, startup_version) - super(ProtocolVersionUnsupported, self).__init__(msg, host) + def __init__(self, endpoint, startup_version): + msg = "Unsupported protocol version on %s: %d" % (endpoint, startup_version) + super(ProtocolVersionUnsupported, self).__init__(msg, endpoint) self.startup_version = startup_version class ConnectionBusy(Exception): """ An attempt was made to send a message through a :class:`.Connection` that was already at the max number of in-flight operations. """ pass class ProtocolError(Exception): """ Communication did not match the protocol that this driver expects. """ pass def defunct_on_error(f): @wraps(f) def wrapper(self, *args, **kwargs): try: return f(self, *args, **kwargs) except Exception as exc: self.defunct(exc) return wrapper DEFAULT_CQL_VERSION = '3.0.0' if six.PY3: def int_from_buf_item(i): return i else: int_from_buf_item = ord class Connection(object): CALLBACK_ERR_THREAD_THRESHOLD = 100 in_buffer_size = 4096 out_buffer_size = 4096 cql_version = None no_compact = False protocol_version = ProtocolVersion.MAX_SUPPORTED keyspace = None compression = True compressor = None decompressor = None + endpoint = None ssl_options = None + ssl_context = None last_error = None # The current number of operations that are in flight. More precisely, # the number of request IDs that are currently in use. in_flight = 0 # Max concurrent requests allowed per connection. This is set optimistically high, allowing # all request ids to be used in protocol version 3+. Normally concurrency would be controlled # at a higher level by the application or concurrent.execute_concurrent. This attribute # is for lower-level integrations that want some upper bound without reimplementing. max_in_flight = 2 ** 15 # A set of available request IDs. When using the v3 protocol or higher, # this will not initially include all request IDs in order to save memory, # but the set will grow if it is exhausted. request_ids = None # Tracks the highest used request ID in order to help with growing the # request_ids set highest_request_id = 0 is_defunct = False is_closed = False lock = None user_type_map = None msg_received = False is_unsupported_proto_version = False is_control_connection = False signaled_error = False # used for flagging at the pool level allow_beta_protocol_version = False _iobuf = None _current_frame = None _socket = None _socket_impl = socket _ssl_impl = ssl _check_hostname = False + _product_type = None def __init__(self, host='127.0.0.1', port=9042, authenticator=None, ssl_options=None, sockopts=None, compression=True, cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False, - user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False): - self.host = host - self.port = port + user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False, + ssl_context=None): + + # TODO next major rename host to endpoint and remove port kwarg. + self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port) + self.authenticator = authenticator self.ssl_options = ssl_options.copy() if ssl_options else None + self.ssl_context = ssl_context self.sockopts = sockopts self.compression = compression self.cql_version = cql_version self.protocol_version = protocol_version self.is_control_connection = is_control_connection self.user_type_map = user_type_map self.connect_timeout = connect_timeout self.allow_beta_protocol_version = allow_beta_protocol_version self.no_compact = no_compact self._push_watchers = defaultdict(set) self._requests = {} self._iobuf = io.BytesIO() if ssl_options: self._check_hostname = bool(self.ssl_options.pop('check_hostname', False)) if self._check_hostname: if not getattr(ssl, 'match_hostname', None): raise RuntimeError("ssl_options specify 'check_hostname', but ssl.match_hostname is not provided. " "Patch or upgrade Python to use this option.") + self.ssl_options.update(self.endpoint.ssl_options or {}) + elif self.endpoint.ssl_options: + self.ssl_options = self.endpoint.ssl_options + if protocol_version >= 3: self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1) # Don't fill the deque with 2**15 items right away. Start with some and add # more if needed. initial_size = min(300, self.max_in_flight) self.request_ids = deque(range(initial_size)) self.highest_request_id = initial_size - 1 else: self.max_request_id = min(self.max_in_flight, (2 ** 7) - 1) self.request_ids = deque(range(self.max_request_id + 1)) self.highest_request_id = self.max_request_id self.lock = RLock() self.connected_event = Event() + @property + def host(self): + return self.endpoint.address + + @property + def port(self): + return self.endpoint.port + @classmethod def initialize_reactor(cls): """ Called once by Cluster.connect(). This should be used by implementations to set up any resources that will be shared across connections. """ pass @classmethod def handle_fork(cls): """ Called after a forking. This should cleanup any remaining reactor state from the parent process. """ pass @classmethod def create_timer(cls, timeout, callback): raise NotImplementedError() @classmethod - def factory(cls, host, timeout, *args, **kwargs): + def factory(cls, endpoint, timeout, *args, **kwargs): """ A factory function which returns connections which have succeeded in connecting and are ready for service (or raises an exception otherwise). """ start = time.time() kwargs['connect_timeout'] = timeout - conn = cls(host, *args, **kwargs) + conn = cls(endpoint, *args, **kwargs) elapsed = time.time() - start conn.connected_event.wait(timeout - elapsed) if conn.last_error: if conn.is_unsupported_proto_version: - raise ProtocolVersionUnsupported(host, conn.protocol_version) + raise ProtocolVersionUnsupported(endpoint, conn.protocol_version) raise conn.last_error elif not conn.connected_event.is_set(): conn.close() raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout) else: return conn + def _get_socket_addresses(self): + address, port = self.endpoint.resolve() + + if hasattr(socket, 'AF_UNIX') and self.endpoint.socket_family == socket.AF_UNIX: + return [(socket.AF_UNIX, socket.SOCK_STREAM, 0, None, address)] + + addresses = socket.getaddrinfo(address, port, self.endpoint.socket_family, socket.SOCK_STREAM) + if not addresses: + raise ConnectionException("getaddrinfo returned empty list for %s" % (self.endpoint,)) + + return addresses + def _connect_socket(self): sockerr = None - addresses = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) - if not addresses: - raise ConnectionException("getaddrinfo returned empty list for %s" % (self.host,)) - for (af, socktype, proto, canonname, sockaddr) in addresses: + addresses = self._get_socket_addresses() + for (af, socktype, proto, _, sockaddr) in addresses: try: self._socket = self._socket_impl.socket(af, socktype, proto) - if self.ssl_options: + if self.ssl_context: + self._socket = self.ssl_context.wrap_socket(self._socket, + **(self.ssl_options or {})) + elif self.ssl_options: if not self._ssl_impl: raise RuntimeError("This version of Python was not compiled with SSL support") self._socket = self._ssl_impl.wrap_socket(self._socket, **self.ssl_options) self._socket.settimeout(self.connect_timeout) self._socket.connect(sockaddr) self._socket.settimeout(None) if self._check_hostname: - ssl.match_hostname(self._socket.getpeercert(), self.host) + ssl.match_hostname(self._socket.getpeercert(), self.endpoint.address) sockerr = None break except socket.error as err: if self._socket: self._socket.close() self._socket = None sockerr = err if sockerr: - raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % ([a[4] for a in addresses], sockerr.strerror or sockerr)) + raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % + ([a[4] for a in addresses], sockerr.strerror or sockerr)) if self.sockopts: for args in self.sockopts: self._socket.setsockopt(*args) def close(self): raise NotImplementedError() def defunct(self, exc): with self.lock: if self.is_defunct or self.is_closed: return self.is_defunct = True exc_info = sys.exc_info() # if we are not handling an exception, just use the passed exception, and don't try to format exc_info with the message if any(exc_info): log.debug("Defuncting connection (%s) to %s:", - id(self), self.host, exc_info=exc_info) + id(self), self.endpoint, exc_info=exc_info) else: log.debug("Defuncting connection (%s) to %s: %s", - id(self), self.host, exc) + id(self), self.endpoint, exc) self.last_error = exc self.close() self.error_all_requests(exc) self.connected_event.set() return exc def error_all_requests(self, exc): with self.lock: requests = self._requests self._requests = {} if not requests: return new_exc = ConnectionShutdown(str(exc)) def try_callback(cb): try: cb(new_exc) except Exception: log.warning("Ignoring unhandled exception while erroring requests for a " "failed connection (%s) to host %s:", - id(self), self.host, exc_info=True) + id(self), self.endpoint, exc_info=True) # run first callback from this thread to ensure pool state before leaving cb, _, _ = requests.popitem()[1] try_callback(cb) if not requests: return # additional requests are optionally errored from a separate thread # The default callback and retry logic is fairly expensive -- we don't # want to tie up the event thread when there are many requests def err_all_callbacks(): for cb, _, _ in requests.values(): try_callback(cb) if len(requests) < Connection.CALLBACK_ERR_THREAD_THRESHOLD: err_all_callbacks() else: # daemon thread here because we want to stay decoupled from the cluster TPE # TODO: would it make sense to just have a driver-global TPE? t = Thread(target=err_all_callbacks) t.daemon = True t.start() def get_request_id(self): """ This must be called while self.lock is held. """ try: return self.request_ids.popleft() except IndexError: new_request_id = self.highest_request_id + 1 # in_flight checks should guarantee this assert new_request_id <= self.max_request_id self.highest_request_id = new_request_id return self.highest_request_id def handle_pushed(self, response): log.debug("Message pushed from server: %r", response) for cb in self._push_watchers.get(response.event_type, []): try: cb(response.event_args) except Exception: log.exception("Pushed event handler errored, ignoring:") def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None): if self.is_defunct: - raise ConnectionShutdown("Connection to %s is defunct" % self.host) + raise ConnectionShutdown("Connection to %s is defunct" % self.endpoint) elif self.is_closed: - raise ConnectionShutdown("Connection to %s is closed" % self.host) + raise ConnectionShutdown("Connection to %s is closed" % self.endpoint) # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages self._requests[request_id] = (cb, decoder, result_metadata) msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, allow_beta_protocol_version=self.allow_beta_protocol_version) self.push(msg) return len(msg) - def wait_for_response(self, msg, timeout=None): - return self.wait_for_responses(msg, timeout=timeout)[0] + def wait_for_response(self, msg, timeout=None, **kwargs): + return self.wait_for_responses(msg, timeout=timeout, **kwargs)[0] def wait_for_responses(self, *msgs, **kwargs): """ Returns a list of (success, response) tuples. If success is False, response will be an Exception. Otherwise, response will be the normal query response. If fail_on_error was left as True and one of the requests failed, the corresponding Exception will be raised. """ if self.is_closed or self.is_defunct: raise ConnectionShutdown("Connection %s is already closed" % (self, )) timeout = kwargs.get('timeout') fail_on_error = kwargs.get('fail_on_error', True) waiter = ResponseWaiter(self, len(msgs), fail_on_error) # busy wait for sufficient space on the connection messages_sent = 0 while True: needed = len(msgs) - messages_sent with self.lock: available = min(needed, self.max_request_id - self.in_flight + 1) request_ids = [self.get_request_id() for _ in range(available)] self.in_flight += available for i, request_id in enumerate(request_ids): self.send_msg(msgs[messages_sent + i], request_id, partial(waiter.got_response, index=messages_sent + i)) messages_sent += available if messages_sent == len(msgs): break else: if timeout is not None: timeout -= 0.01 if timeout <= 0.0: raise OperationTimedOut() time.sleep(0.01) try: return waiter.deliver(timeout) except OperationTimedOut: raise except Exception as exc: self.defunct(exc) raise def register_watcher(self, event_type, callback, register_timeout=None): """ Register a callback for a given event type. """ self._push_watchers[event_type].add(callback) self.wait_for_response( RegisterMessage(event_list=[event_type]), timeout=register_timeout) def register_watchers(self, type_callback_dict, register_timeout=None): """ Register multiple callback/event type pairs, expressed as a dict. """ for event_type, callback in type_callback_dict.items(): self._push_watchers[event_type].add(callback) self.wait_for_response( RegisterMessage(event_list=type_callback_dict.keys()), timeout=register_timeout) def control_conn_disposed(self): self.is_control_connection = False self._push_watchers = {} @defunct_on_error def _read_frame_header(self): buf = self._iobuf.getvalue() pos = len(buf) if pos: version = int_from_buf_item(buf[0]) & PROTOCOL_VERSION_MASK if version > ProtocolVersion.MAX_SUPPORTED: raise ProtocolError("This version of the driver does not support protocol version %d" % version) frame_header = frame_header_v3 if version >= 3 else frame_header_v1_v2 # this frame header struct is everything after the version byte header_size = frame_header.size + 1 if pos >= header_size: flags, stream, op, body_len = frame_header.unpack_from(buf, 1) if body_len < 0: raise ProtocolError("Received negative body length: %r" % body_len) self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size) return pos def _reset_frame(self): self._iobuf = io.BytesIO(self._iobuf.read()) self._iobuf.seek(0, 2) # io.SEEK_END == 2 (constant not present in 2.6) self._current_frame = None def process_io_buffer(self): while True: if not self._current_frame: pos = self._read_frame_header() else: pos = self._iobuf.tell() if not self._current_frame or pos < self._current_frame.end_pos: # we don't have a complete header yet or we # already saw a header, but we don't have a # complete message yet return else: frame = self._current_frame self._iobuf.seek(frame.body_offset) msg = self._iobuf.read(frame.end_pos - frame.body_offset) self.process_msg(frame, msg) self._reset_frame() @defunct_on_error def process_msg(self, header, body): self.msg_received = True stream_id = header.stream if stream_id < 0: callback = None decoder = ProtocolHandler.decode_message result_metadata = None else: try: callback, decoder, result_metadata = self._requests.pop(stream_id) # This can only happen if the stream_id was # removed due to an OperationTimedOut except KeyError: return with self.lock: self.request_ids.append(stream_id) try: response = decoder(header.version, self.user_type_map, stream_id, header.flags, header.opcode, body, self.decompressor, result_metadata) except Exception as exc: log.exception("Error decoding response from Cassandra. " "%s; buffer: %r", header, self._iobuf.getvalue()) if callback is not None: callback(exc) self.defunct(exc) return try: if stream_id >= 0: if isinstance(response, ProtocolException): if 'unsupported protocol version' in response.message: self.is_unsupported_proto_version = True else: log.error("Closing connection %s due to protocol error: %s", self, response.summary_msg()) self.defunct(response) if callback is not None: callback(response) else: self.handle_pushed(response) except Exception: log.exception("Callback handler errored, ignoring:") @defunct_on_error def _send_options_message(self): - if self.cql_version is None and (not self.compression or not locally_supported_compressions): - log.debug("Not sending options message for new connection(%s) to %s " - "because compression is disabled and a cql version was not " - "specified", id(self), self.host) - self._compressor = None - self.cql_version = DEFAULT_CQL_VERSION - self._send_startup_message(no_compact=self.no_compact) - else: - log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host) - self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response) + log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.endpoint) + self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response) @defunct_on_error def _handle_options_response(self, options_response): if self.is_defunct: return if not isinstance(options_response, SupportedMessage): if isinstance(options_response, ConnectionException): raise options_response else: log.error("Did not get expected SupportedMessage response; " "instead, got: %s", options_response) raise ConnectionException("Did not get expected SupportedMessage " "response; instead, got: %s" % (options_response,)) log.debug("Received options response on new connection (%s) from %s", - id(self), self.host) + id(self), self.endpoint) supported_cql_versions = options_response.cql_versions remote_supported_compressions = options_response.options['COMPRESSION'] + self._product_type = options_response.options.get('PRODUCT_TYPE', [None])[0] if self.cql_version: if self.cql_version not in supported_cql_versions: raise ProtocolError( "cql_version %r is not supported by remote (w/ native " "protocol). Supported versions: %r" % (self.cql_version, supported_cql_versions)) else: self.cql_version = supported_cql_versions[0] self._compressor = None compression_type = None if self.compression: overlap = (set(locally_supported_compressions.keys()) & set(remote_supported_compressions)) if len(overlap) == 0: log.debug("No available compression types supported on both ends." " locally supported: %r. remotely supported: %r", locally_supported_compressions.keys(), remote_supported_compressions) else: compression_type = None if isinstance(self.compression, six.string_types): # the user picked a specific compression type ('snappy' or 'lz4') if self.compression not in remote_supported_compressions: raise ProtocolError( "The requested compression type (%s) is not supported by the Cassandra server at %s" - % (self.compression, self.host)) + % (self.compression, self.endpoint)) compression_type = self.compression else: # our locally supported compressions are ordered to prefer # lz4, if available for k in locally_supported_compressions.keys(): if k in overlap: compression_type = k break # set the decompressor here, but set the compressor only after # a successful Ready message self._compressor, self.decompressor = \ locally_supported_compressions[compression_type] self._send_startup_message(compression_type, no_compact=self.no_compact) @defunct_on_error def _send_startup_message(self, compression=None, no_compact=False): log.debug("Sending StartupMessage on %s", self) - opts = {} + opts = {'DRIVER_NAME': DRIVER_NAME, + 'DRIVER_VERSION': DRIVER_VERSION} if compression: opts['COMPRESSION'] = compression if no_compact: opts['NO_COMPACT'] = 'true' sm = StartupMessage(cqlversion=self.cql_version, options=opts) self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response) log.debug("Sent StartupMessage on %s", self) @defunct_on_error def _handle_startup_response(self, startup_response, did_authenticate=False): if self.is_defunct: return if isinstance(startup_response, ReadyMessage): if self.authenticator: log.warning("An authentication challenge was not sent, " "this is suspicious because the driver expects " "authentication (configured authenticator = %s)", self.authenticator.__class__.__name__) - log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.host) + log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.endpoint) if self._compressor: self.compressor = self._compressor self.connected_event.set() elif isinstance(startup_response, AuthenticateMessage): log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s", - id(self), self.host, startup_response.authenticator) + id(self), self.endpoint, startup_response.authenticator) if self.authenticator is None: raise AuthenticationFailed('Remote end requires authentication.') if isinstance(self.authenticator, dict): log.debug("Sending credentials-based auth response on %s", self) cm = CredentialsMessage(creds=self.authenticator) callback = partial(self._handle_startup_response, did_authenticate=True) self.send_msg(cm, self.get_request_id(), cb=callback) else: log.debug("Sending SASL-based auth response on %s", self) self.authenticator.server_authenticator_class = startup_response.authenticator initial_response = self.authenticator.initial_response() initial_response = "" if initial_response is None else initial_response self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), self._handle_auth_response) elif isinstance(startup_response, ErrorMessage): log.debug("Received ErrorMessage on new connection (%s) from %s: %s", - id(self), self.host, startup_response.summary_msg()) + id(self), self.endpoint, startup_response.summary_msg()) if did_authenticate: raise AuthenticationFailed( "Failed to authenticate to %s: %s" % - (self.host, startup_response.summary_msg())) + (self.endpoint, startup_response.summary_msg())) else: raise ConnectionException( "Failed to initialize new connection to %s: %s" - % (self.host, startup_response.summary_msg())) + % (self.endpoint, startup_response.summary_msg())) elif isinstance(startup_response, ConnectionShutdown): - log.debug("Connection to %s was closed during the startup handshake", (self.host)) + log.debug("Connection to %s was closed during the startup handshake", (self.endpoint)) raise startup_response else: msg = "Unexpected response during Connection setup: %r" log.error(msg, startup_response) raise ProtocolError(msg % (startup_response,)) @defunct_on_error def _handle_auth_response(self, auth_response): if self.is_defunct: return if isinstance(auth_response, AuthSuccessMessage): log.debug("Connection %s successfully authenticated", self) self.authenticator.on_authentication_success(auth_response.token) if self._compressor: self.compressor = self._compressor self.connected_event.set() elif isinstance(auth_response, AuthChallengeMessage): response = self.authenticator.evaluate_challenge(auth_response.challenge) msg = AuthResponseMessage("" if response is None else response) log.debug("Responding to auth challenge on %s", self) self.send_msg(msg, self.get_request_id(), self._handle_auth_response) elif isinstance(auth_response, ErrorMessage): log.debug("Received ErrorMessage on new connection (%s) from %s: %s", - id(self), self.host, auth_response.summary_msg()) + id(self), self.endpoint, auth_response.summary_msg()) raise AuthenticationFailed( "Failed to authenticate to %s: %s" % - (self.host, auth_response.summary_msg())) + (self.endpoint, auth_response.summary_msg())) elif isinstance(auth_response, ConnectionShutdown): - log.debug("Connection to %s was closed during the authentication process", self.host) + log.debug("Connection to %s was closed during the authentication process", self.endpoint) raise auth_response else: msg = "Unexpected response during Connection authentication to %s: %r" - log.error(msg, self.host, auth_response) - raise ProtocolError(msg % (self.host, auth_response)) + log.error(msg, self.endpoint, auth_response) + raise ProtocolError(msg % (self.endpoint, auth_response)) def set_keyspace_blocking(self, keyspace): if not keyspace or keyspace == self.keyspace: return query = QueryMessage(query='USE "%s"' % (keyspace,), consistency_level=ConsistencyLevel.ONE) try: result = self.wait_for_response(query) except InvalidRequestException as ire: # the keyspace probably doesn't exist raise ire.to_exception() except Exception as exc: conn_exc = ConnectionException( - "Problem while setting keyspace: %r" % (exc,), self.host) + "Problem while setting keyspace: %r" % (exc,), self.endpoint) self.defunct(conn_exc) raise conn_exc if isinstance(result, ResultMessage): self.keyspace = keyspace else: conn_exc = ConnectionException( - "Problem while setting keyspace: %r" % (result,), self.host) + "Problem while setting keyspace: %r" % (result,), self.endpoint) self.defunct(conn_exc) raise conn_exc def set_keyspace_async(self, keyspace, callback): """ Use this in order to avoid deadlocking the event loop thread. When the operation completes, `callback` will be called with two arguments: this connection and an Exception if an error occurred, otherwise :const:`None`. This method will always increment :attr:`.in_flight` attribute, even if it doesn't need to make a request, just to maintain an ":attr:`.in_flight` is incremented" invariant. """ # Here we increment in_flight unconditionally, whether we need to issue # a request or not. This is bad, but allows callers -- specifically # _set_keyspace_for_all_conns -- to assume that we increment # self.in_flight during this call. This allows the passed callback to # safely call HostConnection{Pool,}.return_connection on this # Connection. # # We use a busy wait on the lock here because: # - we'll only spin if the connection is at max capacity, which is very # unlikely for a set_keyspace call # - it allows us to avoid signaling a condition every time a request completes while True: with self.lock: if self.in_flight < self.max_request_id: self.in_flight += 1 break time.sleep(0.001) if not keyspace or keyspace == self.keyspace: callback(self, None) return query = QueryMessage(query='USE "%s"' % (keyspace,), consistency_level=ConsistencyLevel.ONE) def process_result(result): if isinstance(result, ResultMessage): self.keyspace = keyspace callback(self, None) elif isinstance(result, InvalidRequestException): callback(self, result.to_exception()) else: callback(self, self.defunct(ConnectionException( - "Problem while setting keyspace: %r" % (result,), self.host))) + "Problem while setting keyspace: %r" % (result,), self.endpoint))) # We've incremented self.in_flight above, so we "have permission" to # acquire a new request id request_id = self.get_request_id() self.send_msg(query, request_id, process_result) @property def is_idle(self): return not self.msg_received def reset_idle(self): self.msg_received = False def __str__(self): status = "" if self.is_defunct: status = " (defunct)" elif self.is_closed: status = " (closed)" - return "<%s(%r) %s:%d%s>" % (self.__class__.__name__, id(self), self.host, self.port, status) + return "<%s(%r) %s%s>" % (self.__class__.__name__, id(self), self.endpoint, status) __repr__ = __str__ class ResponseWaiter(object): def __init__(self, connection, num_responses, fail_on_error): self.connection = connection self.pending = num_responses self.fail_on_error = fail_on_error self.error = None self.responses = [None] * num_responses self.event = Event() def got_response(self, response, index): with self.connection.lock: self.connection.in_flight -= 1 if isinstance(response, Exception): if hasattr(response, 'to_exception'): response = response.to_exception() if self.fail_on_error: self.error = response self.event.set() else: self.responses[index] = (False, response) else: if not self.fail_on_error: self.responses[index] = (True, response) else: self.responses[index] = response self.pending -= 1 if not self.pending: self.event.set() def deliver(self, timeout=None): """ If fail_on_error was set to False, a list of (success, response) tuples will be returned. If success is False, response will be an Exception. Otherwise, response will be the normal query response. If fail_on_error was left as True and one of the requests failed, the corresponding Exception will be raised. Otherwise, the normal response will be returned. """ self.event.wait(timeout) if self.error: raise self.error elif not self.event.is_set(): raise OperationTimedOut() else: return self.responses class HeartbeatFuture(object): def __init__(self, connection, owner): self._exception = None self._event = Event() self.connection = connection self.owner = owner log.debug("Sending options message heartbeat on idle connection (%s) %s", - id(connection), connection.host) + id(connection), connection.endpoint) with connection.lock: if connection.in_flight <= connection.max_request_id: connection.in_flight += 1 connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) else: self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold") self._event.set() def wait(self, timeout): self._event.wait(timeout) if self._event.is_set(): if self._exception: raise self._exception else: - raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,), self.connection.host) + raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,), self.connection.endpoint) def _options_callback(self, response): if isinstance(response, SupportedMessage): log.debug("Received options response on connection (%s) from %s", - id(self.connection), self.connection.host) + id(self.connection), self.connection.endpoint) else: if isinstance(response, ConnectionException): self._exception = response else: self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s" % (response,)) self._event.set() class ConnectionHeartbeat(Thread): def __init__(self, interval_sec, get_connection_holders, timeout): Thread.__init__(self, name="Connection heartbeat") self._interval = interval_sec self._timeout = timeout self._get_connection_holders = get_connection_holders self._shutdown_event = Event() self.daemon = True self.start() class ShutdownException(Exception): pass def run(self): self._shutdown_event.wait(self._interval) while not self._shutdown_event.is_set(): start_time = time.time() futures = [] failed_connections = [] try: for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]: for connection in connections: self._raise_if_stopped() if not (connection.is_defunct or connection.is_closed): if connection.is_idle: try: futures.append(HeartbeatFuture(connection, owner)) except Exception as e: log.warning("Failed sending heartbeat message on connection (%s) to %s", - id(connection), connection.host) + id(connection), connection.endpoint) failed_connections.append((connection, owner, e)) else: connection.reset_idle() else: log.debug("Cannot send heartbeat message on connection (%s) to %s", - id(connection), connection.host) + id(connection), connection.endpoint) # make sure the owner sees this defunt/closed connection owner.return_connection(connection) self._raise_if_stopped() # Wait max `self._timeout` seconds for all HeartbeatFutures to complete timeout = self._timeout start_time = time.time() for f in futures: self._raise_if_stopped() connection = f.connection try: f.wait(timeout) # TODO: move this, along with connection locks in pool, down into Connection with connection.lock: connection.in_flight -= 1 connection.reset_idle() except Exception as e: log.warning("Heartbeat failed for connection (%s) to %s", - id(connection), connection.host) + id(connection), connection.endpoint) failed_connections.append((f.connection, f.owner, e)) timeout = self._timeout - (time.time() - start_time) for connection, owner, exc in failed_connections: self._raise_if_stopped() if not connection.is_control_connection: # Only HostConnection supports shutdown_on_error owner.shutdown_on_error = True connection.defunct(exc) owner.return_connection(connection) except self.ShutdownException: pass except Exception: log.error("Failed connection heartbeat", exc_info=True) elapsed = time.time() - start_time self._shutdown_event.wait(max(self._interval - elapsed, 0.01)) def stop(self): self._shutdown_event.set() self.join() def _raise_if_stopped(self): if self._shutdown_event.is_set(): raise self.ShutdownException() class Timer(object): canceled = False def __init__(self, timeout, callback): self.end = time.time() + timeout self.callback = callback def __lt__(self, other): return self.end < other.end def cancel(self): self.canceled = True def finish(self, time_now): if self.canceled: return True if time_now >= self.end: self.callback() return True return False class TimerManager(object): def __init__(self): self._queue = [] self._new_timers = [] def add_timer(self, timer): """ called from client thread with a Timer object """ self._new_timers.append((timer.end, timer)) def service_timeouts(self): """ run callbacks on all expired timers Called from the event thread :return: next end time, or None """ queue = self._queue if self._new_timers: new_timers = self._new_timers while new_timers: heappush(queue, new_timers.pop()) if queue: now = time.time() while queue: try: timer = queue[0][1] if timer.finish(now): heappop(queue) else: return timer.end except Exception: log.exception("Exception while servicing timeout callback: ") @property def next_timeout(self): try: return self._queue[0][0] except IndexError: pass diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py index 5808f60..4911612 100644 --- a/cassandra/cqlengine/columns.py +++ b/cassandra/cqlengine/columns.py @@ -1,1079 +1,1079 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy, copy from datetime import date, datetime, timedelta import logging import six from uuid import UUID as _UUID from cassandra import util from cassandra.cqltypes import SimpleDateType, _cqltypes, UserType from cassandra.cqlengine import ValidationError from cassandra.cqlengine.functions import get_total_seconds from cassandra.util import Duration as _Duration log = logging.getLogger(__name__) class BaseValueManager(object): def __init__(self, instance, column, value): self.instance = instance self.column = column self.value = value self.previous_value = None self.explicit = False @property def deleted(self): return self.column._val_is_null(self.value) and (self.explicit or not self.column._val_is_null(self.previous_value)) @property def changed(self): """ Indicates whether or not this value has changed. :rtype: boolean """ if self.explicit: return self.value != self.previous_value if isinstance(self.column, BaseContainerColumn): default_value = self.column.get_default() if self.column._val_is_null(default_value): return not self.column._val_is_null(self.value) and self.value != self.previous_value elif self.previous_value is None: return self.value != default_value return self.value != self.previous_value return False def reset_previous_value(self): self.previous_value = deepcopy(self.value) def getval(self): return self.value def setval(self, val): self.value = val self.explicit = True def delval(self): self.value = None def get_property(self): _get = lambda slf: self.getval() _set = lambda slf, val: self.setval(val) _del = lambda slf: self.delval() if self.column.can_delete: return property(_get, _set, _del) else: return property(_get, _set) class Column(object): # the cassandra type this column maps to db_type = None value_manager = BaseValueManager instance_counter = 0 _python_type_hashable = True primary_key = False """ bool flag, indicates this column is a primary key. The first primary key defined on a model is the partition key (unless partition keys are set), all others are cluster keys """ partition_key = False """ indicates that this column should be the partition key, defining more than one partition key column creates a compound partition key """ index = False """ bool flag, indicates an index should be created for this column """ custom_index = False """ bool flag, indicates an index is managed outside of cqlengine. This is useful if you want to do filter queries on fields that have custom indexes. """ db_field = None """ the fieldname this field will map to in the database """ default = None """ the default value, can be a value or a callable (no args) """ required = False """ boolean, is the field required? Model validation will raise and exception if required is set to True and there is a None value assigned """ clustering_order = None """ only applicable on clustering keys (primary keys that are not partition keys) determines the order that the clustering keys are sorted on disk """ discriminator_column = False """ boolean, if set to True, this column will be used for discriminating records of inherited models. Should only be set on a column of an abstract model being used for inheritance. There may only be one discriminator column per model. See :attr:`~.__discriminator_value__` for how to specify the value of this column on specialized models. """ static = False """ boolean, if set to True, this is a static column, with a single value per partition """ def __init__(self, primary_key=False, partition_key=False, index=False, db_field=None, default=None, required=False, clustering_order=None, discriminator_column=False, static=False, custom_index=False): self.partition_key = partition_key self.primary_key = partition_key or primary_key self.index = index self.custom_index = custom_index self.db_field = db_field self.default = default self.required = required self.clustering_order = clustering_order self.discriminator_column = discriminator_column # the column name in the model definition self.column_name = None self._partition_key_index = None self.static = static self.value = None # keep track of instantiation order self.position = Column.instance_counter Column.instance_counter += 1 def __ne__(self, other): if isinstance(other, Column): return self.position != other.position return NotImplemented def __eq__(self, other): if isinstance(other, Column): return self.position == other.position return NotImplemented def __lt__(self, other): if isinstance(other, Column): return self.position < other.position return NotImplemented def __le__(self, other): if isinstance(other, Column): return self.position <= other.position return NotImplemented def __gt__(self, other): if isinstance(other, Column): return self.position > other.position return NotImplemented def __ge__(self, other): if isinstance(other, Column): return self.position >= other.position return NotImplemented def __hash__(self): return id(self) def validate(self, value): """ Returns a cleaned and validated value. Raises a ValidationError if there's a problem """ if value is None: if self.required: raise ValidationError('{0} - None values are not allowed'.format(self.column_name or self.db_field)) return value def to_python(self, value): """ Converts data from the database into python values raises a ValidationError if the value can't be converted """ return value def to_database(self, value): """ Converts python value into database value """ return value @property def has_default(self): return self.default is not None @property def is_primary_key(self): return self.primary_key @property def can_delete(self): return not self.primary_key def get_default(self): if self.has_default: if callable(self.default): return self.default() else: return self.default def get_column_def(self): """ Returns a column definition for CQL table definition """ static = "static" if self.static else "" return '{0} {1} {2}'.format(self.cql, self.db_type, static) # TODO: make columns use cqltypes under the hood # until then, this bridges the gap in using types along with cassandra.metadata for CQL generation def cql_parameterized_type(self): return self.db_type def set_column_name(self, name): """ Sets the column name during document class construction This value will be ignored if db_field is set in __init__ """ self.column_name = name @property def db_field_name(self): """ Returns the name of the cql name of this column """ - return self.db_field or self.column_name + return self.db_field if self.db_field is not None else self.column_name @property def db_index_name(self): """ Returns the name of the cql index """ return 'index_{0}'.format(self.db_field_name) @property def has_index(self): return self.index or self.custom_index @property def cql(self): return self.get_cql() def get_cql(self): return '"{0}"'.format(self.db_field_name) def _val_is_null(self, val): """ determines if the given value equates to a null value for the given column type """ return val is None @property def sub_types(self): return [] @property def cql_type(self): return _cqltypes[self.db_type] class Blob(Column): """ Stores a raw binary value """ db_type = 'blob' def to_database(self, value): if not isinstance(value, (six.binary_type, bytearray)): raise Exception("expecting a binary, got a %s" % type(value)) val = super(Bytes, self).to_database(value) return bytearray(val) Bytes = Blob class Inet(Column): """ Stores an IP address in IPv4 or IPv6 format """ db_type = 'inet' class Text(Column): """ Stores a UTF-8 encoded string """ db_type = 'text' def __init__(self, min_length=None, max_length=None, **kwargs): """ :param int min_length: Sets the minimum length of this string, for validation purposes. Defaults to 1 if this is a ``required`` column. Otherwise, None. :param int max_length: Sets the maximum length of this string, for validation purposes. """ self.min_length = ( 1 if min_length is None and kwargs.get('required', False) else min_length) self.max_length = max_length if self.min_length is not None: if self.min_length < 0: raise ValueError( 'Minimum length is not allowed to be negative.') if self.max_length is not None: if self.max_length < 0: raise ValueError( 'Maximum length is not allowed to be negative.') if self.min_length is not None and self.max_length is not None: if self.max_length < self.min_length: raise ValueError( 'Maximum length must be greater or equal ' 'to minimum length.') super(Text, self).__init__(**kwargs) def validate(self, value): value = super(Text, self).validate(value) if not isinstance(value, (six.string_types, bytearray)) and value is not None: raise ValidationError('{0} {1} is not a string'.format(self.column_name, type(value))) if self.max_length is not None: if value and len(value) > self.max_length: raise ValidationError('{0} is longer than {1} characters'.format(self.column_name, self.max_length)) if self.min_length: if (self.min_length and not value) or len(value) < self.min_length: raise ValidationError('{0} is shorter than {1} characters'.format(self.column_name, self.min_length)) return value class Ascii(Text): """ Stores a US-ASCII character string """ db_type = 'ascii' def validate(self, value): """ Only allow ASCII and None values. Check against US-ASCII, a.k.a. 7-bit ASCII, a.k.a. ISO646-US, a.k.a. the Basic Latin block of the Unicode character set. Source: https://github.com/apache/cassandra/blob /3dcbe90e02440e6ee534f643c7603d50ca08482b/src/java/org/apache/cassandra /serializers/AsciiSerializer.java#L29 """ value = super(Ascii, self).validate(value) if value: charset = value if isinstance( value, (bytearray, )) else map(ord, value) if not set(range(128)).issuperset(charset): raise ValidationError( '{!r} is not an ASCII string.'.format(value)) return value class Integer(Column): """ Stores a 32-bit signed integer value """ db_type = 'int' def validate(self, value): val = super(Integer, self).validate(value) if val is None: return try: return int(val) except (TypeError, ValueError): raise ValidationError("{0} {1} can't be converted to integral value".format(self.column_name, value)) def to_python(self, value): return self.validate(value) def to_database(self, value): return self.validate(value) class TinyInt(Integer): """ Stores an 8-bit signed integer value .. versionadded:: 2.6.0 requires C* 2.2+ and protocol v4+ """ db_type = 'tinyint' class SmallInt(Integer): """ Stores a 16-bit signed integer value .. versionadded:: 2.6.0 requires C* 2.2+ and protocol v4+ """ db_type = 'smallint' class BigInt(Integer): """ Stores a 64-bit signed integer value """ db_type = 'bigint' class VarInt(Column): """ Stores an arbitrary-precision integer """ db_type = 'varint' def validate(self, value): val = super(VarInt, self).validate(value) if val is None: return try: return int(val) except (TypeError, ValueError): raise ValidationError( "{0} {1} can't be converted to integral value".format(self.column_name, value)) def to_python(self, value): return self.validate(value) def to_database(self, value): return self.validate(value) class CounterValueManager(BaseValueManager): def __init__(self, instance, column, value): super(CounterValueManager, self).__init__(instance, column, value) self.value = self.value or 0 self.previous_value = self.previous_value or 0 class Counter(Integer): """ Stores a counter that can be incremented and decremented """ db_type = 'counter' value_manager = CounterValueManager def __init__(self, index=False, db_field=None, required=False): super(Counter, self).__init__( primary_key=False, partition_key=False, index=index, db_field=db_field, default=0, required=required, ) class DateTime(Column): """ Stores a datetime value """ db_type = 'timestamp' truncate_microseconds = False """ Set this ``True`` to have model instances truncate the date, quantizing it in the same way it will be in the database. This allows equality comparison between assigned values and values read back from the database:: DateTime.truncate_microseconds = True assert Model.create(id=0, d=datetime.utcnow()) == Model.objects(id=0).first() Defaults to ``False`` to preserve legacy behavior. May change in the future. """ def to_python(self, value): if value is None: return if isinstance(value, datetime): if DateTime.truncate_microseconds: us = value.microsecond truncated_us = us // 1000 * 1000 return value - timedelta(microseconds=us - truncated_us) else: return value elif isinstance(value, date): return datetime(*(value.timetuple()[:6])) return datetime.utcfromtimestamp(value) def to_database(self, value): value = super(DateTime, self).to_database(value) if value is None: return if not isinstance(value, datetime): if isinstance(value, date): value = datetime(value.year, value.month, value.day) else: raise ValidationError("{0} '{1}' is not a datetime object".format(self.column_name, value)) epoch = datetime(1970, 1, 1, tzinfo=value.tzinfo) offset = get_total_seconds(epoch.tzinfo.utcoffset(epoch)) if epoch.tzinfo else 0 return int((get_total_seconds(value - epoch) - offset) * 1000) class Date(Column): """ Stores a simple date, with no time-of-day .. versionchanged:: 2.6.0 removed overload of Date and DateTime. DateTime is a drop-in replacement for legacy models requires C* 2.2+ and protocol v4+ """ db_type = 'date' def to_database(self, value): if value is None: return # need to translate to int version because some dates are not representable in # string form (datetime limitation) d = value if isinstance(value, util.Date) else util.Date(value) return d.days_from_epoch + SimpleDateType.EPOCH_OFFSET_DAYS def to_python(self, value): if value is None: return if isinstance(value, util.Date): return value if isinstance(value, datetime): value = value.date() return util.Date(value) class Time(Column): """ Stores a timezone-naive time-of-day, with nanosecond precision .. versionadded:: 2.6.0 requires C* 2.2+ and protocol v4+ """ db_type = 'time' def to_database(self, value): value = super(Time, self).to_database(value) if value is None: return # str(util.Time) yields desired CQL encoding return value if isinstance(value, util.Time) else util.Time(value) def to_python(self, value): value = super(Time, self).to_database(value) if value is None: return if isinstance(value, util.Time): return value return util.Time(value) class Duration(Column): """ Stores a duration (months, days, nanoseconds) .. versionadded:: 3.10.0 requires C* 3.10+ and protocol v4+ """ db_type = 'duration' def validate(self, value): val = super(Duration, self).validate(value) if val is None: return if not isinstance(val, _Duration): raise TypeError('{0} {1} is not a valid Duration.'.format(self.column_name, value)) return val class UUID(Column): """ Stores a type 1 or 4 UUID """ db_type = 'uuid' def validate(self, value): val = super(UUID, self).validate(value) if val is None: return if isinstance(val, _UUID): return val if isinstance(val, six.string_types): try: return _UUID(val) except ValueError: # fall-through to error pass raise ValidationError("{0} {1} is not a valid uuid".format( self.column_name, value)) def to_python(self, value): return self.validate(value) def to_database(self, value): return self.validate(value) class TimeUUID(UUID): """ UUID containing timestamp """ db_type = 'timeuuid' class Boolean(Column): """ Stores a boolean True or False value """ db_type = 'boolean' def validate(self, value): """ Always returns a Python boolean. """ value = super(Boolean, self).validate(value) if value is not None: value = bool(value) return value def to_python(self, value): return self.validate(value) class BaseFloat(Column): def validate(self, value): value = super(BaseFloat, self).validate(value) if value is None: return try: return float(value) except (TypeError, ValueError): raise ValidationError("{0} {1} is not a valid float".format(self.column_name, value)) def to_python(self, value): return self.validate(value) def to_database(self, value): return self.validate(value) class Float(BaseFloat): """ Stores a single-precision floating-point value """ db_type = 'float' class Double(BaseFloat): """ Stores a double-precision floating-point value """ db_type = 'double' class Decimal(Column): """ Stores a variable precision decimal value """ db_type = 'decimal' def validate(self, value): from decimal import Decimal as _Decimal from decimal import InvalidOperation val = super(Decimal, self).validate(value) if val is None: return try: return _Decimal(repr(val)) if isinstance(val, float) else _Decimal(val) except InvalidOperation: raise ValidationError("{0} '{1}' can't be coerced to decimal".format(self.column_name, val)) def to_python(self, value): return self.validate(value) def to_database(self, value): return self.validate(value) class BaseCollectionColumn(Column): """ Base Container type for collection-like columns. http://cassandra.apache.org/doc/cql3/CQL-3.0.html#collections """ def __init__(self, types, **kwargs): """ :param types: a sequence of sub types in this collection """ instances = [] for t in types: inheritance_comparator = issubclass if isinstance(t, type) else isinstance if not inheritance_comparator(t, Column): raise ValidationError("%s is not a column class" % (t,)) if t.db_type is None: raise ValidationError("%s is an abstract type" % (t,)) inst = t() if isinstance(t, type) else t if isinstance(t, BaseCollectionColumn): inst._freeze_db_type() instances.append(inst) self.types = instances super(BaseCollectionColumn, self).__init__(**kwargs) def validate(self, value): value = super(BaseCollectionColumn, self).validate(value) # It is dangerous to let collections have more than 65535. # See: https://issues.apache.org/jira/browse/CASSANDRA-5428 if value is not None and len(value) > 65535: raise ValidationError("{0} Collection can't have more than 65535 elements.".format(self.column_name)) return value def _val_is_null(self, val): return not val def _freeze_db_type(self): if not self.db_type.startswith('frozen'): self.db_type = "frozen<%s>" % (self.db_type,) @property def sub_types(self): return self.types @property def cql_type(self): return _cqltypes[self.__class__.__name__.lower()].apply_parameters([c.cql_type for c in self.types]) class Tuple(BaseCollectionColumn): """ Stores a fixed-length set of positional values http://docs.datastax.com/en/cql/3.1/cql/cql_reference/tupleType.html """ def __init__(self, *args, **kwargs): """ :param args: column types representing tuple composition """ if not args: raise ValueError("Tuple must specify at least one inner type") super(Tuple, self).__init__(args, **kwargs) self.db_type = 'tuple<{0}>'.format(', '.join(typ.db_type for typ in self.types)) def validate(self, value): val = super(Tuple, self).validate(value) if val is None: return if len(val) > len(self.types): raise ValidationError("Value %r has more fields than tuple definition (%s)" % (val, ', '.join(t for t in self.types))) return tuple(t.validate(v) for t, v in zip(self.types, val)) def to_python(self, value): if value is None: return tuple() return tuple(t.to_python(v) for t, v in zip(self.types, value)) def to_database(self, value): if value is None: return return tuple(t.to_database(v) for t, v in zip(self.types, value)) class BaseContainerColumn(BaseCollectionColumn): pass class Set(BaseContainerColumn): """ Stores a set of unordered, unique values http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_set_t.html """ _python_type_hashable = False def __init__(self, value_type, strict=True, default=set, **kwargs): """ :param value_type: a column class indicating the types of the value :param strict: sets whether non set values will be coerced to set type on validation, or raise a validation error, defaults to True """ self.strict = strict super(Set, self).__init__((value_type,), default=default, **kwargs) self.value_col = self.types[0] if not self.value_col._python_type_hashable: raise ValidationError("Cannot create a Set with unhashable value type (see PYTHON-494)") self.db_type = 'set<{0}>'.format(self.value_col.db_type) def validate(self, value): val = super(Set, self).validate(value) if val is None: return types = (set, util.SortedSet) if self.strict else (set, util.SortedSet, list, tuple) if not isinstance(val, types): if self.strict: raise ValidationError('{0} {1} is not a set object'.format(self.column_name, val)) else: raise ValidationError('{0} {1} cannot be coerced to a set object'.format(self.column_name, val)) if None in val: raise ValidationError("{0} None not allowed in a set".format(self.column_name)) # TODO: stop doing this conversion because it doesn't support non-hashable collections as keys (cassandra does) # will need to start using the cassandra.util types in the next major rev (PYTHON-494) return set(self.value_col.validate(v) for v in val) def to_python(self, value): if value is None: return set() return set(self.value_col.to_python(v) for v in value) def to_database(self, value): if value is None: return None return set(self.value_col.to_database(v) for v in value) class List(BaseContainerColumn): """ Stores a list of ordered values http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_list_t.html """ _python_type_hashable = False def __init__(self, value_type, default=list, **kwargs): """ :param value_type: a column class indicating the types of the value """ super(List, self).__init__((value_type,), default=default, **kwargs) self.value_col = self.types[0] self.db_type = 'list<{0}>'.format(self.value_col.db_type) def validate(self, value): val = super(List, self).validate(value) if val is None: return if not isinstance(val, (set, list, tuple)): raise ValidationError('{0} {1} is not a list object'.format(self.column_name, val)) if None in val: raise ValidationError("{0} None is not allowed in a list".format(self.column_name)) return [self.value_col.validate(v) for v in val] def to_python(self, value): if value is None: return [] return [self.value_col.to_python(v) for v in value] def to_database(self, value): if value is None: return None return [self.value_col.to_database(v) for v in value] class Map(BaseContainerColumn): """ Stores a key -> value map (dictionary) - http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_map_t.html + https://docs.datastax.com/en/dse/6.7/cql/cql/cql_using/useMap.html """ _python_type_hashable = False def __init__(self, key_type, value_type, default=dict, **kwargs): """ :param key_type: a column class indicating the types of the key :param value_type: a column class indicating the types of the value """ super(Map, self).__init__((key_type, value_type), default=default, **kwargs) self.key_col = self.types[0] self.value_col = self.types[1] if not self.key_col._python_type_hashable: raise ValidationError("Cannot create a Map with unhashable key type (see PYTHON-494)") self.db_type = 'map<{0}, {1}>'.format(self.key_col.db_type, self.value_col.db_type) def validate(self, value): val = super(Map, self).validate(value) if val is None: return if not isinstance(val, (dict, util.OrderedMap)): raise ValidationError('{0} {1} is not a dict object'.format(self.column_name, val)) if None in val: raise ValidationError("{0} None is not allowed in a map".format(self.column_name)) # TODO: stop doing this conversion because it doesn't support non-hashable collections as keys (cassandra does) # will need to start using the cassandra.util types in the next major rev (PYTHON-494) return dict((self.key_col.validate(k), self.value_col.validate(v)) for k, v in val.items()) def to_python(self, value): if value is None: return {} if value is not None: return dict((self.key_col.to_python(k), self.value_col.to_python(v)) for k, v in value.items()) def to_database(self, value): if value is None: return None return dict((self.key_col.to_database(k), self.value_col.to_database(v)) for k, v in value.items()) class UDTValueManager(BaseValueManager): @property def changed(self): if self.explicit: return self.value != self.previous_value default_value = self.column.get_default() if not self.column._val_is_null(default_value): return self.value != default_value elif self.previous_value is None: return not self.column._val_is_null(self.value) and self.value.has_changed_fields() return False def reset_previous_value(self): if self.value is not None: self.value.reset_changed_fields() self.previous_value = copy(self.value) class UserDefinedType(Column): """ User Defined Type column http://www.datastax.com/documentation/cql/3.1/cql/cql_using/cqlUseUDT.html These columns are represented by a specialization of :class:`cassandra.cqlengine.usertype.UserType`. Please see :ref:`user_types` for examples and discussion. """ value_manager = UDTValueManager def __init__(self, user_type, **kwargs): """ :param type user_type: specifies the :class:`~.cqlengine.usertype.UserType` model of the column """ self.user_type = user_type self.db_type = "frozen<%s>" % user_type.type_name() super(UserDefinedType, self).__init__(**kwargs) @property def sub_types(self): return list(self.user_type._fields.values()) @property def cql_type(self): return UserType.make_udt_class(keyspace='', udt_name=self.user_type.type_name(), field_names=[c.db_field_name for c in self.user_type._fields.values()], field_types=[c.cql_type for c in self.user_type._fields.values()]) def validate(self, value): val = super(UserDefinedType, self).validate(value) if val is None: return val.validate() return val def to_python(self, value): if value is None: return copied_value = deepcopy(value) for name, field in self.user_type._fields.items(): if copied_value[name] is not None or isinstance(field, BaseContainerColumn): copied_value[name] = field.to_python(copied_value[name]) return copied_value def to_database(self, value): if value is None: return copied_value = deepcopy(value) for name, field in self.user_type._fields.items(): if copied_value[name] is not None or isinstance(field, BaseContainerColumn): copied_value[name] = field.to_database(copied_value[name]) return copied_value def resolve_udts(col_def, out_list): for col in col_def.sub_types: resolve_udts(col, out_list) if isinstance(col_def, UserDefinedType): out_list.append(col_def.user_type) class _PartitionKeysToken(Column): """ virtual column representing token of partition columns. Used by filter(pk__token=Token(...)) filters """ def __init__(self, model): - self.partition_columns = model._partition_keys.values() + self.partition_columns = list(model._partition_keys.values()) super(_PartitionKeysToken, self).__init__(partition_key=True) @property def db_field_name(self): return 'token({0})'.format(', '.join(['"{0}"'.format(c.db_field_name) for c in self.partition_columns])) diff --git a/cassandra/cqlengine/connection.py b/cassandra/cqlengine/connection.py index 88371e9..884e04e 100644 --- a/cassandra/cqlengine/connection.py +++ b/cassandra/cqlengine/connection.py @@ -1,379 +1,384 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict import logging import six import threading -from cassandra.cluster import Cluster, _NOT_SET, NoHostAvailable, UserTypeDoesNotExist, ConsistencyLevel +from cassandra.cluster import Cluster, _ConfigMode, _NOT_SET, NoHostAvailable, UserTypeDoesNotExist, ConsistencyLevel from cassandra.query import SimpleStatement, dict_factory from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine.statements import BaseCQLStatement log = logging.getLogger(__name__) NOT_SET = _NOT_SET # required for passing timeout to Session.execute cluster = None session = None # connections registry DEFAULT_CONNECTION = object() _connections = {} # Because type models may be registered before a connection is present, # and because sessions may be replaced, we must register UDTs here, in order # to have them registered when a new session is established. udt_by_keyspace = defaultdict(dict) def format_log_context(msg, connection=None, keyspace=None): """Format log message to add keyspace and connection context""" connection_info = connection or 'DEFAULT_CONNECTION' if keyspace: msg = '[Connection: {0}, Keyspace: {1}] {2}'.format(connection_info, keyspace, msg) else: msg = '[Connection: {0}] {1}'.format(connection_info, msg) return msg class UndefinedKeyspaceException(CQLEngineException): pass class Connection(object): """CQLEngine Connection""" name = None hosts = None consistency = None retry_connect = False lazy_connect = False lazy_connect_lock = None cluster_options = None cluster = None session = None def __init__(self, name, hosts, consistency=None, lazy_connect=False, retry_connect=False, cluster_options=None): self.hosts = hosts self.name = name self.consistency = consistency self.lazy_connect = lazy_connect self.retry_connect = retry_connect self.cluster_options = cluster_options if cluster_options else {} self.lazy_connect_lock = threading.RLock() @classmethod def from_session(cls, name, session): instance = cls(name=name, hosts=session.hosts) instance.cluster, instance.session = session.cluster, session instance.setup_session() return instance def setup(self): """Setup the connection""" global cluster, session if 'username' in self.cluster_options or 'password' in self.cluster_options: raise CQLEngineException("Username & Password are now handled by using the native driver's auth_provider") if self.lazy_connect: return self.cluster = Cluster(self.hosts, **self.cluster_options) try: self.session = self.cluster.connect() log.debug(format_log_context("connection initialized with internally created session", connection=self.name)) except NoHostAvailable: if self.retry_connect: log.warning(format_log_context("connect failed, setting up for re-attempt on first use", connection=self.name)) self.lazy_connect = True raise - if self.consistency is not None: - self.session.default_consistency_level = self.consistency - if DEFAULT_CONNECTION in _connections and _connections[DEFAULT_CONNECTION] == self: cluster = _connections[DEFAULT_CONNECTION].cluster session = _connections[DEFAULT_CONNECTION].session self.setup_session() def setup_session(self): - self.session.row_factory = dict_factory + if self.cluster._config_mode == _ConfigMode.PROFILES: + self.cluster.profile_manager.default.row_factory = dict_factory + if self.consistency is not None: + self.cluster.profile_manager.default.consistency_level = self.consistency + else: + self.session.row_factory = dict_factory + if self.consistency is not None: + self.session.default_consistency_level = self.consistency enc = self.session.encoder enc.mapping[tuple] = enc.cql_encode_tuple _register_known_types(self.session.cluster) def handle_lazy_connect(self): # if lazy_connect is False, it means the cluster is setup and ready # No need to acquire the lock if not self.lazy_connect: return with self.lazy_connect_lock: # lazy_connect might have been set to False by another thread while waiting the lock # In this case, do nothing. if self.lazy_connect: log.debug(format_log_context("Lazy connect enabled", connection=self.name)) self.lazy_connect = False self.setup() def register_connection(name, hosts=None, consistency=None, lazy_connect=False, retry_connect=False, cluster_options=None, default=False, session=None): """ Add a connection to the connection registry. ``hosts`` and ``session`` are mutually exclusive, and ``consistency``, ``lazy_connect``, ``retry_connect``, and ``cluster_options`` only work with ``hosts``. Using ``hosts`` will create a new :class:`cassandra.cluster.Cluster` and :class:`cassandra.cluster.Session`. :param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`). :param int consistency: The default :class:`~.ConsistencyLevel` for the registered connection's new session. Default is the same as :attr:`.Session.default_consistency_level`. For use with ``hosts`` only; will fail when used with ``session``. :param bool lazy_connect: True if should not connect until first use. For use with ``hosts`` only; will fail when used with ``session``. :param bool retry_connect: True if we should retry to connect even if there was a connection failure initially. For use with ``hosts`` only; will fail when used with ``session``. :param dict cluster_options: A dict of options to be used as keyword arguments to :class:`cassandra.cluster.Cluster`. For use with ``hosts`` only; will fail when used with ``session``. :param bool default: If True, set the new connection as the cqlengine default :param Session session: A :class:`cassandra.cluster.Session` to be used in the created connection. """ if name in _connections: log.warning("Registering connection '{0}' when it already exists.".format(name)) if session is not None: invalid_config_args = (hosts is not None or consistency is not None or lazy_connect is not False or retry_connect is not False or cluster_options is not None) if invalid_config_args: raise CQLEngineException( "Session configuration arguments and 'session' argument are mutually exclusive" ) conn = Connection.from_session(name, session=session) - conn.setup_session() else: # use hosts argument - if consistency is None: - consistency = ConsistencyLevel.LOCAL_ONE conn = Connection( name, hosts=hosts, consistency=consistency, lazy_connect=lazy_connect, retry_connect=retry_connect, cluster_options=cluster_options ) conn.setup() _connections[name] = conn if default: set_default_connection(name) return conn def unregister_connection(name): global cluster, session if name not in _connections: return if DEFAULT_CONNECTION in _connections and _connections[name] == _connections[DEFAULT_CONNECTION]: del _connections[DEFAULT_CONNECTION] cluster = None session = None conn = _connections[name] if conn.cluster: conn.cluster.shutdown() del _connections[name] log.debug("Connection '{0}' has been removed from the registry.".format(name)) def set_default_connection(name): global cluster, session if name not in _connections: raise CQLEngineException("Connection '{0}' doesn't exist.".format(name)) log.debug("Connection '{0}' has been set as default.".format(name)) _connections[DEFAULT_CONNECTION] = _connections[name] cluster = _connections[name].cluster session = _connections[name].session def get_connection(name=None): if not name: name = DEFAULT_CONNECTION if name not in _connections: raise CQLEngineException("Connection name '{0}' doesn't exist in the registry.".format(name)) conn = _connections[name] conn.handle_lazy_connect() return conn def default(): """ Configures the default connection to localhost, using the driver defaults (except for row_factory) """ try: conn = get_connection() if conn.session: log.warning("configuring new default connection for cqlengine when one was already set") except: pass register_connection('default', hosts=None, default=True) log.debug("cqlengine connection initialized with default session to localhost") def set_session(s): """ Configures the default connection with a preexisting :class:`cassandra.cluster.Session` Note: the mapper presently requires a Session :attr:`~.row_factory` set to ``dict_factory``. This may be relaxed in the future """ try: conn = get_connection() except CQLEngineException: # no default connection set; initalize one register_connection('default', session=s, default=True) conn = get_connection() + else: + if conn.session: + log.warning("configuring new default session for cqlengine when one was already set") - if conn.session: - log.warning("configuring new default connection for cqlengine when one was already set") + if not any([ + s.cluster.profile_manager.default.row_factory is dict_factory and s.cluster._config_mode in [_ConfigMode.PROFILES, _ConfigMode.UNCOMMITTED], + s.row_factory is dict_factory and s.cluster._config_mode in [_ConfigMode.LEGACY, _ConfigMode.UNCOMMITTED], + ]): + raise CQLEngineException("Failed to initialize: row_factory must be 'dict_factory'") - if s.row_factory is not dict_factory: - raise CQLEngineException("Failed to initialize: 'Session.row_factory' must be 'dict_factory'.") conn.session = s conn.cluster = s.cluster # Set default keyspace from given session's keyspace if conn.session.keyspace: from cassandra.cqlengine import models models.DEFAULT_KEYSPACE = conn.session.keyspace conn.setup_session() log.debug("cqlengine default connection initialized with %s", s) def setup( hosts, default_keyspace, consistency=None, lazy_connect=False, retry_connect=False, **kwargs): """ Setup a the driver connection used by the mapper :param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`) :param str default_keyspace: The default keyspace to use :param int consistency: The global default :class:`~.ConsistencyLevel` - default is the same as :attr:`.Session.default_consistency_level` :param bool lazy_connect: True if should not connect until first use :param bool retry_connect: True if we should retry to connect even if there was a connection failure initially :param \*\*kwargs: Pass-through keyword arguments for :class:`cassandra.cluster.Cluster` """ from cassandra.cqlengine import models models.DEFAULT_KEYSPACE = default_keyspace register_connection('default', hosts=hosts, consistency=consistency, lazy_connect=lazy_connect, retry_connect=retry_connect, cluster_options=kwargs, default=True) def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connection=None): conn = get_connection(connection) if not conn.session: raise CQLEngineException("It is required to setup() cqlengine before executing queries") if isinstance(query, SimpleStatement): pass # elif isinstance(query, BaseCQLStatement): params = query.get_context() query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size) elif isinstance(query, six.string_types): query = SimpleStatement(query, consistency_level=consistency_level) - log.debug(format_log_context(query.query_string, connection=connection)) + log.debug(format_log_context('Query: {}, Params: {}'.format(query.query_string, params), connection=connection)) result = conn.session.execute(query, params, timeout=timeout) return result def get_session(connection=None): conn = get_connection(connection) return conn.session def get_cluster(connection=None): conn = get_connection(connection) if not conn.cluster: raise CQLEngineException("%s.cluster is not configured. Call one of the setup or default functions first." % __name__) return conn.cluster def register_udt(keyspace, type_name, klass, connection=None): udt_by_keyspace[keyspace][type_name] = klass try: cluster = get_cluster(connection) except CQLEngineException: cluster = None if cluster: try: cluster.register_user_type(keyspace, type_name, klass) except UserTypeDoesNotExist: pass # new types are covered in management sync functions def _register_known_types(cluster): from cassandra.cqlengine import models for ks_name, name_type_map in udt_by_keyspace.items(): for type_name, klass in name_type_map.items(): try: cluster.register_user_type(ks_name or models.DEFAULT_KEYSPACE, type_name, klass) except UserTypeDoesNotExist: pass # new types are covered in management sync functions diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index 57aac56..c6ceb16 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -1,905 +1,908 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from datetime import datetime, timedelta import time import six from six.moves import filter from cassandra.query import FETCH_SIZE_UNSET from cassandra.cqlengine import columns from cassandra.cqlengine import UnicodeMixin from cassandra.cqlengine.functions import QueryValue from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator, IsNotNullOperator class StatementException(Exception): pass class ValueQuoter(UnicodeMixin): def __init__(self, value): self.value = value def __unicode__(self): from cassandra.encoder import cql_quote if isinstance(self.value, (list, tuple)): return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' elif isinstance(self.value, dict): return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}' elif isinstance(self.value, set): return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}' return cql_quote(self.value) def __eq__(self, other): if isinstance(other, self.__class__): return self.value == other.value return False class InQuoter(ValueQuoter): def __unicode__(self): from cassandra.encoder import cql_quote return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')' class BaseClause(UnicodeMixin): def __init__(self, field, value): self.field = field self.value = value self.context_id = None def __unicode__(self): raise NotImplementedError def __hash__(self): return hash(self.field) ^ hash(self.value) def __eq__(self, other): if isinstance(other, self.__class__): return self.field == other.field and self.value == other.value return False def __ne__(self, other): return not self.__eq__(other) def get_context_size(self): """ returns the number of entries this clause will add to the query context """ return 1 def set_context_id(self, i): """ sets the value placeholder that will be used in the query """ self.context_id = i def update_context(self, ctx): """ updates the query context with this clauses values """ assert isinstance(ctx, dict) ctx[str(self.context_id)] = self.value class WhereClause(BaseClause): """ a single where statement used in queries """ def __init__(self, field, operator, value, quote_field=True): """ :param field: :param operator: :param value: :param quote_field: hack to get the token function rendering properly :return: """ if not isinstance(operator, BaseWhereOperator): raise StatementException( "operator must be of type {0}, got {1}".format(BaseWhereOperator, type(operator)) ) super(WhereClause, self).__init__(field, value) self.operator = operator self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value) self.quote_field = quote_field def __unicode__(self): field = ('"{0}"' if self.quote_field else '{0}').format(self.field) return u'{0} {1} {2}'.format(field, self.operator, six.text_type(self.query_value)) def __hash__(self): return super(WhereClause, self).__hash__() ^ hash(self.operator) def __eq__(self, other): if super(WhereClause, self).__eq__(other): return self.operator.__class__ == other.operator.__class__ return False def get_context_size(self): return self.query_value.get_context_size() def set_context_id(self, i): super(WhereClause, self).set_context_id(i) self.query_value.set_context_id(i) def update_context(self, ctx): if isinstance(self.operator, InOperator): ctx[str(self.context_id)] = InQuoter(self.value) else: self.query_value.update_context(ctx) class IsNotNullClause(WhereClause): def __init__(self, field): super(IsNotNullClause, self).__init__(field, IsNotNullOperator(), '') def __unicode__(self): field = ('"{0}"' if self.quote_field else '{0}').format(self.field) return u'{0} {1}'.format(field, self.operator) def update_context(self, ctx): pass def get_context_size(self): return 0 # alias for convenience IsNotNull = IsNotNullClause class AssignmentClause(BaseClause): """ a single variable st statement """ def __unicode__(self): return u'"{0}" = %({1})s'.format(self.field, self.context_id) def insert_tuple(self): return self.field, self.context_id class ConditionalClause(BaseClause): """ A single variable iff statement """ def __unicode__(self): return u'"{0}" = %({1})s'.format(self.field, self.context_id) def insert_tuple(self): return self.field, self.context_id class ContainerUpdateTypeMapMeta(type): def __init__(cls, name, bases, dct): if not hasattr(cls, 'type_map'): cls.type_map = {} else: cls.type_map[cls.col_type] = cls super(ContainerUpdateTypeMapMeta, cls).__init__(name, bases, dct) @six.add_metaclass(ContainerUpdateTypeMapMeta) class ContainerUpdateClause(AssignmentClause): def __init__(self, field, value, operation=None, previous=None): super(ContainerUpdateClause, self).__init__(field, value) self.previous = previous self._assignments = None self._operation = operation self._analyzed = False def _analyze(self): raise NotImplementedError def get_context_size(self): raise NotImplementedError def update_context(self, ctx): raise NotImplementedError class SetUpdateClause(ContainerUpdateClause): """ updates a set collection """ col_type = columns.Set _additions = None _removals = None def __unicode__(self): qs = [] ctx_id = self.context_id if (self.previous is None and self._assignments is None and self._additions is None and self._removals is None): qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] if self._assignments is not None: qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] ctx_id += 1 if self._additions is not None: qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] ctx_id += 1 if self._removals is not None: qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] return ', '.join(qs) def _analyze(self): """ works out the updates to be performed """ if self.value is None or self.value == self.previous: pass elif self._operation == "add": self._additions = self.value elif self._operation == "remove": self._removals = self.value elif self.previous is None: self._assignments = self.value else: # partial update time self._additions = (self.value - self.previous) or None self._removals = (self.previous - self.value) or None self._analyzed = True def get_context_size(self): if not self._analyzed: self._analyze() if (self.previous is None and not self._assignments and self._additions is None and self._removals is None): return 1 return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals)) def update_context(self, ctx): if not self._analyzed: self._analyze() ctx_id = self.context_id if (self.previous is None and self._assignments is None and self._additions is None and self._removals is None): ctx[str(ctx_id)] = set() if self._assignments is not None: ctx[str(ctx_id)] = self._assignments ctx_id += 1 if self._additions is not None: ctx[str(ctx_id)] = self._additions ctx_id += 1 if self._removals is not None: ctx[str(ctx_id)] = self._removals class ListUpdateClause(ContainerUpdateClause): """ updates a list collection """ col_type = columns.List _append = None _prepend = None def __unicode__(self): if not self._analyzed: self._analyze() qs = [] ctx_id = self.context_id if self._assignments is not None: qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] ctx_id += 1 if self._prepend is not None: qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)] ctx_id += 1 if self._append is not None: qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] return ', '.join(qs) def get_context_size(self): if not self._analyzed: self._analyze() return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend)) def update_context(self, ctx): if not self._analyzed: self._analyze() ctx_id = self.context_id if self._assignments is not None: ctx[str(ctx_id)] = self._assignments ctx_id += 1 if self._prepend is not None: ctx[str(ctx_id)] = self._prepend ctx_id += 1 if self._append is not None: ctx[str(ctx_id)] = self._append def _analyze(self): """ works out the updates to be performed """ if self.value is None or self.value == self.previous: pass elif self._operation == "append": self._append = self.value elif self._operation == "prepend": self._prepend = self.value elif self.previous is None: self._assignments = self.value elif len(self.value) < len(self.previous): # if elements have been removed, # rewrite the whole list self._assignments = self.value elif len(self.previous) == 0: # if we're updating from an empty # list, do a complete insert self._assignments = self.value else: # the max start idx we want to compare search_space = len(self.value) - max(0, len(self.previous) - 1) # the size of the sub lists we want to look at search_size = len(self.previous) for i in range(search_space): # slice boundary j = i + search_size sub = self.value[i:j] idx_cmp = lambda idx: self.previous[idx] == sub[idx] if idx_cmp(0) and idx_cmp(-1) and self.previous == sub: self._prepend = self.value[:i] or None self._append = self.value[j:] or None break # if both append and prepend are still None after looking # at both lists, an insert statement will be created if self._prepend is self._append is None: self._assignments = self.value self._analyzed = True class MapUpdateClause(ContainerUpdateClause): """ updates a map collection """ col_type = columns.Map _updates = None _removals = None def _analyze(self): if self._operation == "update": self._updates = self.value.keys() elif self._operation == "remove": self._removals = {v for v in self.value.keys()} else: if self.previous is None: self._updates = sorted([k for k, v in self.value.items()]) else: self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None self._analyzed = True def get_context_size(self): if self.is_assignment: return 1 return int((len(self._updates or []) * 2) + int(bool(self._removals))) def update_context(self, ctx): ctx_id = self.context_id if self.is_assignment: ctx[str(ctx_id)] = {} elif self._removals is not None: ctx[str(ctx_id)] = self._removals else: for key in self._updates or []: val = self.value.get(key) ctx[str(ctx_id)] = key ctx[str(ctx_id + 1)] = val ctx_id += 2 @property def is_assignment(self): if not self._analyzed: self._analyze() return self.previous is None and not self._updates and not self._removals def __unicode__(self): qs = [] ctx_id = self.context_id if self.is_assignment: qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] elif self._removals is not None: qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] ctx_id += 1 else: for _ in self._updates or []: qs += ['"{0}"[%({1})s] = %({2})s'.format(self.field, ctx_id, ctx_id + 1)] ctx_id += 2 return ', '.join(qs) class CounterUpdateClause(AssignmentClause): col_type = columns.Counter def __init__(self, field, value, previous=None): super(CounterUpdateClause, self).__init__(field, value) self.previous = previous or 0 def get_context_size(self): return 1 def update_context(self, ctx): ctx[str(self.context_id)] = abs(self.value - self.previous) def __unicode__(self): delta = self.value - self.previous sign = '-' if delta < 0 else '+' return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id) class BaseDeleteClause(BaseClause): pass class FieldDeleteClause(BaseDeleteClause): """ deletes a field from a row """ def __init__(self, field): super(FieldDeleteClause, self).__init__(field, None) def __unicode__(self): return '"{0}"'.format(self.field) def update_context(self, ctx): pass def get_context_size(self): return 0 class MapDeleteClause(BaseDeleteClause): """ removes keys from a map """ def __init__(self, field, value, previous=None): super(MapDeleteClause, self).__init__(field, value) self.value = self.value or {} self.previous = previous or {} self._analyzed = False self._removals = None def _analyze(self): self._removals = sorted([k for k in self.previous if k not in self.value]) self._analyzed = True def update_context(self, ctx): if not self._analyzed: self._analyze() for idx, key in enumerate(self._removals): ctx[str(self.context_id + idx)] = key def get_context_size(self): if not self._analyzed: self._analyze() return len(self._removals) def __unicode__(self): if not self._analyzed: self._analyze() return ', '.join(['"{0}"[%({1})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))]) class BaseCQLStatement(UnicodeMixin): """ The base cql statement class """ def __init__(self, table, timestamp=None, where=None, fetch_size=None, conditionals=None): super(BaseCQLStatement, self).__init__() self.table = table self.context_id = 0 self.context_counter = self.context_id self.timestamp = timestamp self.fetch_size = fetch_size if fetch_size else FETCH_SIZE_UNSET self.where_clauses = [] for clause in where or []: self._add_where_clause(clause) self.conditionals = [] for conditional in conditionals or []: self.add_conditional_clause(conditional) def _update_part_key_values(self, field_index_map, clauses, parts): for clause in filter(lambda c: c.field in field_index_map, clauses): parts[field_index_map[clause.field]] = clause.value def partition_key_values(self, field_index_map): parts = [None] * len(field_index_map) self._update_part_key_values(field_index_map, (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), parts) return parts def add_where(self, column, operator, value, quote_field=True): value = column.to_database(value) clause = WhereClause(column.db_field_name, operator, value, quote_field) self._add_where_clause(clause) def _add_where_clause(self, clause): clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() self.where_clauses.append(clause) def get_context(self): """ returns the context dict for this statement :rtype: dict """ ctx = {} for clause in self.where_clauses or []: clause.update_context(ctx) return ctx def add_conditional_clause(self, clause): """ Adds a iff clause to this statement :param clause: The clause that will be added to the iff statement :type clause: ConditionalClause """ clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() self.conditionals.append(clause) def _get_conditionals(self): return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.conditionals])) def get_context_size(self): return len(self.get_context()) def update_context_id(self, i): self.context_id = i self.context_counter = self.context_id for clause in self.where_clauses: clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() @property def timestamp_normalized(self): """ we're expecting self.timestamp to be either a long, int, a datetime, or a timedelta :return: """ if not self.timestamp: return None if isinstance(self.timestamp, six.integer_types): return self.timestamp if isinstance(self.timestamp, timedelta): tmp = datetime.now() + self.timestamp else: tmp = self.timestamp return int(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond) def __unicode__(self): raise NotImplementedError def __repr__(self): return self.__unicode__() @property def _where(self): return 'WHERE {0}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses])) class SelectStatement(BaseCQLStatement): """ a cql select statement """ def __init__(self, table, fields=None, count=False, where=None, order_by=None, limit=None, allow_filtering=False, distinct_fields=None, fetch_size=None): """ :param where :type where list of cqlengine.statements.WhereClause """ super(SelectStatement, self).__init__( table, where=where, fetch_size=fetch_size ) self.fields = [fields] if isinstance(fields, six.string_types) else (fields or []) self.distinct_fields = distinct_fields self.count = count self.order_by = [order_by] if isinstance(order_by, six.string_types) else order_by self.limit = limit self.allow_filtering = allow_filtering def __unicode__(self): qs = ['SELECT'] if self.distinct_fields: if self.count: qs += ['DISTINCT COUNT({0})'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))] else: qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))] elif self.count: qs += ['COUNT(*)'] else: qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*'] qs += ['FROM', self.table] if self.where_clauses: qs += [self._where] if self.order_by and not self.count: qs += ['ORDER BY {0}'.format(', '.join(six.text_type(o) for o in self.order_by))] if self.limit: qs += ['LIMIT {0}'.format(self.limit)] if self.allow_filtering: qs += ['ALLOW FILTERING'] return ' '.join(qs) class AssignmentStatement(BaseCQLStatement): """ value assignment statements """ def __init__(self, table, assignments=None, where=None, ttl=None, timestamp=None, conditionals=None): super(AssignmentStatement, self).__init__( table, where=where, conditionals=conditionals ) self.ttl = ttl self.timestamp = timestamp # add assignments self.assignments = [] for assignment in assignments or []: self._add_assignment_clause(assignment) def update_context_id(self, i): super(AssignmentStatement, self).update_context_id(i) for assignment in self.assignments: assignment.set_context_id(self.context_counter) self.context_counter += assignment.get_context_size() def partition_key_values(self, field_index_map): parts = super(AssignmentStatement, self).partition_key_values(field_index_map) self._update_part_key_values(field_index_map, self.assignments, parts) return parts def add_assignment(self, column, value): value = column.to_database(value) clause = AssignmentClause(column.db_field_name, value) self._add_assignment_clause(clause) def _add_assignment_clause(self, clause): clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() self.assignments.append(clause) @property def is_empty(self): return len(self.assignments) == 0 def get_context(self): ctx = super(AssignmentStatement, self).get_context() for clause in self.assignments: clause.update_context(ctx) return ctx class InsertStatement(AssignmentStatement): """ an cql insert statement """ def __init__(self, table, assignments=None, where=None, ttl=None, timestamp=None, if_not_exists=False): super(InsertStatement, self).__init__(table, assignments=assignments, where=where, ttl=ttl, timestamp=timestamp) self.if_not_exists = if_not_exists def __unicode__(self): qs = ['INSERT INTO {0}'.format(self.table)] # get column names and context placeholders fields = [a.insert_tuple() for a in self.assignments] columns, values = zip(*fields) qs += ["({0})".format(', '.join(['"{0}"'.format(c) for c in columns]))] qs += ['VALUES'] qs += ["({0})".format(', '.join(['%({0})s'.format(v) for v in values]))] if self.if_not_exists: qs += ["IF NOT EXISTS"] + using_options = [] if self.ttl: - qs += ["USING TTL {0}".format(self.ttl)] + using_options += ["TTL {}".format(self.ttl)] if self.timestamp: - qs += ["USING TIMESTAMP {0}".format(self.timestamp_normalized)] + using_options += ["TIMESTAMP {}".format(self.timestamp_normalized)] + if using_options: + qs += ["USING {}".format(" AND ".join(using_options))] return ' '.join(qs) class UpdateStatement(AssignmentStatement): """ an cql update select statement """ def __init__(self, table, assignments=None, where=None, ttl=None, timestamp=None, conditionals=None, if_exists=False): super(UpdateStatement, self). __init__(table, assignments=assignments, where=where, ttl=ttl, timestamp=timestamp, conditionals=conditionals) self.if_exists = if_exists def __unicode__(self): qs = ['UPDATE', self.table] using_options = [] if self.ttl: using_options += ["TTL {0}".format(self.ttl)] if self.timestamp: using_options += ["TIMESTAMP {0}".format(self.timestamp_normalized)] if using_options: qs += ["USING {0}".format(" AND ".join(using_options))] qs += ['SET'] qs += [', '.join([six.text_type(c) for c in self.assignments])] if self.where_clauses: qs += [self._where] if len(self.conditionals) > 0: qs += [self._get_conditionals()] if self.if_exists: qs += ["IF EXISTS"] return ' '.join(qs) def get_context(self): ctx = super(UpdateStatement, self).get_context() for clause in self.conditionals: clause.update_context(ctx) return ctx def update_context_id(self, i): super(UpdateStatement, self).update_context_id(i) for conditional in self.conditionals: conditional.set_context_id(self.context_counter) self.context_counter += conditional.get_context_size() def add_update(self, column, value, operation=None, previous=None): value = column.to_database(value) col_type = type(column) container_update_type = ContainerUpdateClause.type_map.get(col_type) if container_update_type: previous = column.to_database(previous) clause = container_update_type(column.db_field_name, value, operation, previous) elif col_type == columns.Counter: clause = CounterUpdateClause(column.db_field_name, value, previous) else: clause = AssignmentClause(column.db_field_name, value) if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates self._add_assignment_clause(clause) class DeleteStatement(BaseCQLStatement): """ a cql delete statement """ def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=None, if_exists=False): super(DeleteStatement, self).__init__( table, where=where, timestamp=timestamp, conditionals=conditionals ) self.fields = [] if isinstance(fields, six.string_types): fields = [fields] for field in fields or []: self.add_field(field) self.if_exists = if_exists def update_context_id(self, i): super(DeleteStatement, self).update_context_id(i) for field in self.fields: field.set_context_id(self.context_counter) self.context_counter += field.get_context_size() for t in self.conditionals: t.set_context_id(self.context_counter) self.context_counter += t.get_context_size() def get_context(self): ctx = super(DeleteStatement, self).get_context() for field in self.fields: field.update_context(ctx) for clause in self.conditionals: clause.update_context(ctx) return ctx def add_field(self, field): if isinstance(field, six.string_types): field = FieldDeleteClause(field) if not isinstance(field, BaseClause): raise StatementException("only instances of AssignmentClause can be added to statements") field.set_context_id(self.context_counter) self.context_counter += field.get_context_size() self.fields.append(field) def __unicode__(self): qs = ['DELETE'] if self.fields: qs += [', '.join(['{0}'.format(f) for f in self.fields])] qs += ['FROM', self.table] delete_option = [] if self.timestamp: delete_option += ["TIMESTAMP {0}".format(self.timestamp_normalized)] if delete_option: qs += [" USING {0} ".format(" AND ".join(delete_option))] if self.where_clauses: qs += [self._where] if self.conditionals: qs += [self._get_conditionals()] if self.if_exists: qs += ["IF EXISTS"] return ' '.join(qs) diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 22dec9a..55bb022 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -1,1079 +1,1147 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Representation of Cassandra data types. These classes should make it simple for the library (and caller software) to deal with Cassandra-style Java class type names and CQL type specifiers, and convert between them cleanly. Parameterized types are fully supported in both flavors. Once you have the right Type object for the type you want, you can use it to serialize, deserialize, or retrieve the corresponding CQL or Cassandra type strings. """ # NOTE: # If/when the need arises for interpret types from CQL string literals in # different ways (for https://issues.apache.org/jira/browse/CASSANDRA-3799, # for example), these classes would be a good place to tack on # .from_cql_literal() and .as_cql_literal() classmethods (or whatever). from __future__ import absolute_import # to enable import io from stdlib +import ast from binascii import unhexlify import calendar from collections import namedtuple from decimal import Decimal import io import logging import re import socket import time import six from six.moves import range import sys from uuid import UUID import warnings if six.PY3: import ipaddress from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack, uint16_pack, uint16_unpack, uint32_pack, uint32_unpack, int32_pack, int32_unpack, int64_pack, int64_unpack, float_pack, float_unpack, double_pack, double_unpack, varint_pack, varint_unpack, vints_pack, vints_unpack) from cassandra import util apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' cassandra_empty_type = 'org.apache.cassandra.db.marshal.EmptyType' cql_empty_type = 'empty' log = logging.getLogger(__name__) if six.PY3: _number_types = frozenset((int, float)) long = int def _name_from_hex_string(encoded_name): bin_str = unhexlify(encoded_name) return bin_str.decode('ascii') else: _number_types = frozenset((int, long, float)) _name_from_hex_string = unhexlify def trim_if_startswith(s, prefix): if s.startswith(prefix): return s[len(prefix):] return s _casstypes = {} _cqltypes = {} cql_type_scanner = re.Scanner(( ('frozen', None), (r'[a-zA-Z0-9_]+', lambda s, t: t), (r'[\s,<>]', None), )) def cql_types_from_string(cql_type): return cql_type_scanner.scan(cql_type)[0] class CassandraTypeType(type): """ The CassandraType objects in this module will normally be used directly, rather than through instances of those types. They can be instantiated, of course, but the type information is what this driver mainly needs. This metaclass registers CassandraType classes in the global by-cassandra-typename and by-cql-typename registries, unless their class name starts with an underscore. """ def __new__(metacls, name, bases, dct): dct.setdefault('cassname', name) cls = type.__new__(metacls, name, bases, dct) if not name.startswith('_'): _casstypes[name] = cls if not cls.typename.startswith(apache_cassandra_type_prefix): _cqltypes[cls.typename] = cls return cls casstype_scanner = re.Scanner(( (r'[()]', lambda s, t: t), (r'[a-zA-Z0-9_.:=>]+', lambda s, t: t), (r'[\s,]', None), )) +def cqltype_to_python(cql_string): + """ + Given a cql type string, creates a list that can be manipulated in python + Example: + int -> ['int'] + frozen> -> ['frozen', ['tuple', ['text', 'int']]] + """ + scanner = re.Scanner(( + (r'[a-zA-Z0-9_]+', lambda s, t: "'{}'".format(t)), + (r'<', lambda s, t: ', ['), + (r'>', lambda s, t: ']'), + (r'[, ]', lambda s, t: t), + (r'".*?"', lambda s, t: "'{}'".format(t)), + )) + + scanned_tokens = scanner.scan(cql_string)[0] + hierarchy = ast.literal_eval(''.join(scanned_tokens)) + return [hierarchy] if isinstance(hierarchy, str) else list(hierarchy) + + +def python_to_cqltype(types): + """ + Opposite of the `cql_to_python` function. Given a python list, creates a cql type string from the representation + Example: + ['int'] -> int + ['frozen', ['tuple', ['text', 'int']]] -> frozen> + """ + scanner = re.Scanner(( + (r"'[a-zA-Z0-9_]+'", lambda s, t: t[1:-1]), + (r'^\[', lambda s, t: None), + (r'\]$', lambda s, t: None), + (r',\s*\[', lambda s, t: '<'), + (r'\]', lambda s, t: '>'), + (r'[, ]', lambda s, t: t), + (r'\'".*?"\'', lambda s, t: t[1:-1]), + )) + + scanned_tokens = scanner.scan(repr(types))[0] + cql = ''.join(scanned_tokens).replace('\\\\', '\\') + return cql + + +def _strip_frozen_from_python(types): + """ + Given a python list representing a cql type, removes 'frozen' + Example: + ['frozen', ['tuple', ['text', 'int']]] -> ['tuple', ['text', 'int']] + """ + while 'frozen' in types: + index = types.index('frozen') + types = types[:index] + types[index + 1] + types[index + 2:] + new_types = [_strip_frozen_from_python(item) if isinstance(item, list) else item for item in types] + return new_types + + +def strip_frozen(cql): + """ + Given a cql type string, and removes frozen + Example: + frozen> -> tuple + """ + types = cqltype_to_python(cql) + types_without_frozen = _strip_frozen_from_python(types) + cql = python_to_cqltype(types_without_frozen) + return cql + + def lookup_casstype_simple(casstype): """ Given a Cassandra type name (either fully distinguished or not), hand back the CassandraType class responsible for it. If a name is not recognized, a custom _UnrecognizedType subclass will be created for it. This function does not handle complex types (so no type parameters-- nothing with parentheses). Use lookup_casstype() instead if you might need that. """ shortname = trim_if_startswith(casstype, apache_cassandra_type_prefix) try: typeclass = _casstypes[shortname] except KeyError: typeclass = mkUnrecognizedType(casstype) return typeclass def parse_casstype_args(typestring): tokens, remainder = casstype_scanner.scan(typestring) if remainder: raise ValueError("weird characters %r at end" % remainder) # use a stack of (types, names) lists args = [([], [])] for tok in tokens: if tok == '(': args.append(([], [])) elif tok == ')': types, names = args.pop() prev_types, prev_names = args[-1] prev_types[-1] = prev_types[-1].apply_parameters(types, names) else: types, names = args[-1] parts = re.split(':|=>', tok) tok = parts.pop() if parts: names.append(parts[0]) else: names.append(None) ctype = lookup_casstype_simple(tok) types.append(ctype) # return the first (outer) type, which will have all parameters applied return args[0][0][0] def lookup_casstype(casstype): """ Given a Cassandra type as a string (possibly including parameters), hand back the CassandraType class responsible for it. If a name is not recognized, a custom _UnrecognizedType subclass will be created for it. Example: >>> lookup_casstype('org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type)') """ if isinstance(casstype, (CassandraType, CassandraTypeType)): return casstype try: return parse_casstype_args(casstype) except (ValueError, AssertionError, IndexError) as e: raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e)) def is_reversed_casstype(data_type): return issubclass(data_type, ReversedType) class EmptyValue(object): """ See _CassandraType.support_empty_values """ def __str__(self): return "EMPTY" __repr__ = __str__ EMPTY = EmptyValue() @six.add_metaclass(CassandraTypeType) class _CassandraType(object): subtypes = () num_subtypes = 0 empty_binary_ok = False support_empty_values = False """ Back in the Thrift days, empty strings were used for "null" values of all types, including non-string types. For most users, an empty string value in an int column is the same as being null/not present, so the driver normally returns None in this case. (For string-like types, it *will* return an empty string by default instead of None.) To avoid this behavior, set this to :const:`True`. Instead of returning None for empty string values, the EMPTY singleton (an instance of EmptyValue) will be returned. """ def __repr__(self): return '<%s( %r )>' % (self.cql_parameterized_type(), self.val) @classmethod def from_binary(cls, byts, protocol_version): """ Deserialize a bytestring into a value. See the deserialize() method for more information. This method differs in that if None or the empty string is passed in, None may be returned. """ if byts is None: return None elif len(byts) == 0 and not cls.empty_binary_ok: return EMPTY if cls.support_empty_values else None return cls.deserialize(byts, protocol_version) @classmethod def to_binary(cls, val, protocol_version): """ Serialize a value into a bytestring. See the serialize() method for more information. This method differs in that if None is passed in, the result is the empty string. """ return b'' if val is None else cls.serialize(val, protocol_version) @staticmethod def deserialize(byts, protocol_version): """ Given a bytestring, deserialize into a value according to the protocol for this type. Note that this does not create a new instance of this class; it merely gives back a value that would be appropriate to go inside an instance of this class. """ return byts @staticmethod def serialize(val, protocol_version): """ Given a value appropriate for this class, serialize it according to the protocol for this type and return the corresponding bytestring. """ return val @classmethod def cass_parameterized_type_with(cls, subtypes, full=False): """ Return the name of this type as it would be expressed by Cassandra, optionally fully qualified. If subtypes is not None, it is expected to be a list of other CassandraType subclasses, and the output string includes the Cassandra names for those subclasses as well, as parameters to this one. Example: >>> LongType.cass_parameterized_type_with(()) 'LongType' >>> LongType.cass_parameterized_type_with((), full=True) 'org.apache.cassandra.db.marshal.LongType' >>> SetType.cass_parameterized_type_with([DecimalType], full=True) 'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)' """ cname = cls.cassname if full and '.' not in cname: cname = apache_cassandra_type_prefix + cname if not subtypes: return cname sublist = ', '.join(styp.cass_parameterized_type(full=full) for styp in subtypes) return '%s(%s)' % (cname, sublist) @classmethod def apply_parameters(cls, subtypes, names=None): """ Given a set of other CassandraTypes, create a new subtype of this type using them as parameters. This is how composite types are constructed. >>> MapType.apply_parameters([DateType, BooleanType]) `subtypes` will be a sequence of CassandraTypes. If provided, `names` will be an equally long sequence of column names or Nones. """ if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes: raise ValueError("%s types require %d subtypes (%d given)" % (cls.typename, cls.num_subtypes, len(subtypes))) newname = cls.cass_parameterized_type_with(subtypes) if six.PY2 and isinstance(newname, unicode): newname = newname.encode('utf-8') return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names}) @classmethod def cql_parameterized_type(cls): """ Return a CQL type specifier for this type. If this type has parameters, they are included in standard CQL <> notation. """ if not cls.subtypes: return cls.typename return '%s<%s>' % (cls.typename, ', '.join(styp.cql_parameterized_type() for styp in cls.subtypes)) @classmethod def cass_parameterized_type(cls, full=False): """ Return a Cassandra type specifier for this type. If this type has parameters, they are included in the standard () notation. """ return cls.cass_parameterized_type_with(cls.subtypes, full=full) # it's initially named with a _ to avoid registering it as a real type, but # client programs may want to use the name still for isinstance(), etc CassandraType = _CassandraType class _UnrecognizedType(_CassandraType): num_subtypes = 'UNKNOWN' if six.PY3: def mkUnrecognizedType(casstypename): return CassandraTypeType(casstypename, (_UnrecognizedType,), {'typename': "'%s'" % casstypename}) else: def mkUnrecognizedType(casstypename): # noqa return CassandraTypeType(casstypename.encode('utf8'), (_UnrecognizedType,), {'typename': "'%s'" % casstypename}) class BytesType(_CassandraType): typename = 'blob' empty_binary_ok = True @staticmethod def serialize(val, protocol_version): return six.binary_type(val) class DecimalType(_CassandraType): typename = 'decimal' @staticmethod def deserialize(byts, protocol_version): scale = int32_unpack(byts[:4]) unscaled = varint_unpack(byts[4:]) return Decimal('%de%d' % (unscaled, -scale)) @staticmethod def serialize(dec, protocol_version): try: sign, digits, exponent = dec.as_tuple() except AttributeError: try: sign, digits, exponent = Decimal(dec).as_tuple() except Exception: raise TypeError("Invalid type for Decimal value: %r", dec) unscaled = int(''.join([str(digit) for digit in digits])) if sign: unscaled *= -1 scale = int32_pack(-exponent) unscaled = varint_pack(unscaled) return scale + unscaled class UUIDType(_CassandraType): typename = 'uuid' @staticmethod def deserialize(byts, protocol_version): return UUID(bytes=byts) @staticmethod def serialize(uuid, protocol_version): try: return uuid.bytes except AttributeError: raise TypeError("Got a non-UUID object for a UUID value") class BooleanType(_CassandraType): typename = 'boolean' @staticmethod def deserialize(byts, protocol_version): return bool(int8_unpack(byts)) @staticmethod def serialize(truth, protocol_version): return int8_pack(truth) class ByteType(_CassandraType): typename = 'tinyint' @staticmethod def deserialize(byts, protocol_version): return int8_unpack(byts) @staticmethod def serialize(byts, protocol_version): return int8_pack(byts) if six.PY2: class AsciiType(_CassandraType): typename = 'ascii' empty_binary_ok = True else: class AsciiType(_CassandraType): typename = 'ascii' empty_binary_ok = True @staticmethod def deserialize(byts, protocol_version): return byts.decode('ascii') @staticmethod def serialize(var, protocol_version): try: return var.encode('ascii') except UnicodeDecodeError: return var class FloatType(_CassandraType): typename = 'float' @staticmethod def deserialize(byts, protocol_version): return float_unpack(byts) @staticmethod def serialize(byts, protocol_version): return float_pack(byts) class DoubleType(_CassandraType): typename = 'double' @staticmethod def deserialize(byts, protocol_version): return double_unpack(byts) @staticmethod def serialize(byts, protocol_version): return double_pack(byts) class LongType(_CassandraType): typename = 'bigint' @staticmethod def deserialize(byts, protocol_version): return int64_unpack(byts) @staticmethod def serialize(byts, protocol_version): return int64_pack(byts) class Int32Type(_CassandraType): typename = 'int' @staticmethod def deserialize(byts, protocol_version): return int32_unpack(byts) @staticmethod def serialize(byts, protocol_version): return int32_pack(byts) class IntegerType(_CassandraType): typename = 'varint' @staticmethod def deserialize(byts, protocol_version): return varint_unpack(byts) @staticmethod def serialize(byts, protocol_version): return varint_pack(byts) class InetAddressType(_CassandraType): typename = 'inet' @staticmethod def deserialize(byts, protocol_version): if len(byts) == 16: return util.inet_ntop(socket.AF_INET6, byts) else: # util.inet_pton could also handle, but this is faster # since we've already determined the AF return socket.inet_ntoa(byts) @staticmethod def serialize(addr, protocol_version): try: if ':' in addr: return util.inet_pton(socket.AF_INET6, addr) else: # util.inet_pton could also handle, but this is faster # since we've already determined the AF return socket.inet_aton(addr) except: if six.PY3 and isinstance(addr, (ipaddress.IPv4Address, ipaddress.IPv6Address)): return addr.packed raise ValueError("can't interpret %r as an inet address" % (addr,)) class CounterColumnType(LongType): typename = 'counter' cql_timestamp_formats = ( '%Y-%m-%d %H:%M', '%Y-%m-%d %H:%M:%S', '%Y-%m-%dT%H:%M', '%Y-%m-%dT%H:%M:%S', '%Y-%m-%d' ) _have_warned_about_timestamps = False class DateType(_CassandraType): typename = 'timestamp' @staticmethod def interpret_datestring(val): if val[-5] in ('+', '-'): offset = (int(val[-4:-2]) * 3600 + int(val[-2:]) * 60) * int(val[-5] + '1') val = val[:-5] else: offset = -time.timezone for tformat in cql_timestamp_formats: try: tval = time.strptime(val, tformat) except ValueError: continue # scale seconds to millis for the raw value return (calendar.timegm(tval) + offset) * 1e3 else: raise ValueError("can't interpret %r as a date" % (val,)) @staticmethod def deserialize(byts, protocol_version): timestamp = int64_unpack(byts) / 1000.0 return util.datetime_from_timestamp(timestamp) @staticmethod def serialize(v, protocol_version): try: # v is datetime timestamp_seconds = calendar.timegm(v.utctimetuple()) timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3 except AttributeError: try: timestamp = calendar.timegm(v.timetuple()) * 1e3 except AttributeError: # Ints and floats are valid timestamps too if type(v) not in _number_types: raise TypeError('DateType arguments must be a datetime, date, or timestamp') timestamp = v return int64_pack(long(timestamp)) class TimestampType(DateType): pass class TimeUUIDType(DateType): typename = 'timeuuid' def my_timestamp(self): return util.unix_time_from_uuid1(self.val) @staticmethod def deserialize(byts, protocol_version): return UUID(bytes=byts) @staticmethod def serialize(timeuuid, protocol_version): try: return timeuuid.bytes except AttributeError: raise TypeError("Got a non-UUID object for a UUID value") class SimpleDateType(_CassandraType): typename = 'date' date_format = "%Y-%m-%d" # Values of the 'date'` type are encoded as 32-bit unsigned integers # representing a number of days with epoch (January 1st, 1970) at the center of the # range (2^31). EPOCH_OFFSET_DAYS = 2 ** 31 @staticmethod def deserialize(byts, protocol_version): days = uint32_unpack(byts) - SimpleDateType.EPOCH_OFFSET_DAYS return util.Date(days) @staticmethod def serialize(val, protocol_version): try: days = val.days_from_epoch except AttributeError: if isinstance(val, six.integer_types): # the DB wants offset int values, but util.Date init takes days from epoch # here we assume int values are offset, as they would appear in CQL # short circuit to avoid subtracting just to add offset return uint32_pack(val) days = util.Date(val).days_from_epoch return uint32_pack(days + SimpleDateType.EPOCH_OFFSET_DAYS) class ShortType(_CassandraType): typename = 'smallint' @staticmethod def deserialize(byts, protocol_version): return int16_unpack(byts) @staticmethod def serialize(byts, protocol_version): return int16_pack(byts) class TimeType(_CassandraType): typename = 'time' @staticmethod def deserialize(byts, protocol_version): return util.Time(int64_unpack(byts)) @staticmethod def serialize(val, protocol_version): try: nano = val.nanosecond_time except AttributeError: nano = util.Time(val).nanosecond_time return int64_pack(nano) class DurationType(_CassandraType): typename = 'duration' @staticmethod def deserialize(byts, protocol_version): months, days, nanoseconds = vints_unpack(byts) return util.Duration(months, days, nanoseconds) @staticmethod def serialize(duration, protocol_version): try: m, d, n = duration.months, duration.days, duration.nanoseconds except AttributeError: raise TypeError('DurationType arguments must be a Duration.') return vints_pack([m, d, n]) class UTF8Type(_CassandraType): typename = 'text' empty_binary_ok = True @staticmethod def deserialize(byts, protocol_version): return byts.decode('utf8') @staticmethod def serialize(ustr, protocol_version): try: return ustr.encode('utf-8') except UnicodeDecodeError: # already utf-8 return ustr class VarcharType(UTF8Type): typename = 'varchar' class _ParameterizedType(_CassandraType): num_subtypes = 'UNKNOWN' @classmethod def deserialize(cls, byts, protocol_version): if not cls.subtypes: raise NotImplementedError("can't deserialize unparameterized %s" % cls.typename) return cls.deserialize_safe(byts, protocol_version) @classmethod def serialize(cls, val, protocol_version): if not cls.subtypes: raise NotImplementedError("can't serialize unparameterized %s" % cls.typename) return cls.serialize_safe(val, protocol_version) class _SimpleParameterizedType(_ParameterizedType): @classmethod def deserialize_safe(cls, byts, protocol_version): subtype, = cls.subtypes if protocol_version >= 3: unpack = int32_unpack length = 4 else: unpack = uint16_unpack length = 2 numelements = unpack(byts[:length]) p = length result = [] inner_proto = max(3, protocol_version) for _ in range(numelements): itemlen = unpack(byts[p:p + length]) p += length item = byts[p:p + itemlen] p += itemlen result.append(subtype.from_binary(item, inner_proto)) return cls.adapter(result) @classmethod def serialize_safe(cls, items, protocol_version): if isinstance(items, six.string_types): raise TypeError("Received a string for a type that expects a sequence") subtype, = cls.subtypes pack = int32_pack if protocol_version >= 3 else uint16_pack buf = io.BytesIO() buf.write(pack(len(items))) inner_proto = max(3, protocol_version) for item in items: itembytes = subtype.to_binary(item, inner_proto) buf.write(pack(len(itembytes))) buf.write(itembytes) return buf.getvalue() class ListType(_SimpleParameterizedType): typename = 'list' num_subtypes = 1 adapter = list class SetType(_SimpleParameterizedType): typename = 'set' num_subtypes = 1 adapter = util.sortedset class MapType(_ParameterizedType): typename = 'map' num_subtypes = 2 @classmethod def deserialize_safe(cls, byts, protocol_version): key_type, value_type = cls.subtypes if protocol_version >= 3: unpack = int32_unpack length = 4 else: unpack = uint16_unpack length = 2 numelements = unpack(byts[:length]) p = length themap = util.OrderedMapSerializedKey(key_type, protocol_version) inner_proto = max(3, protocol_version) for _ in range(numelements): key_len = unpack(byts[p:p + length]) p += length keybytes = byts[p:p + key_len] p += key_len val_len = unpack(byts[p:p + length]) p += length valbytes = byts[p:p + val_len] p += val_len key = key_type.from_binary(keybytes, inner_proto) val = value_type.from_binary(valbytes, inner_proto) themap._insert_unchecked(key, keybytes, val) return themap @classmethod def serialize_safe(cls, themap, protocol_version): key_type, value_type = cls.subtypes pack = int32_pack if protocol_version >= 3 else uint16_pack buf = io.BytesIO() buf.write(pack(len(themap))) try: items = six.iteritems(themap) except AttributeError: raise TypeError("Got a non-map object for a map value") inner_proto = max(3, protocol_version) for key, val in items: keybytes = key_type.to_binary(key, inner_proto) valbytes = value_type.to_binary(val, inner_proto) buf.write(pack(len(keybytes))) buf.write(keybytes) buf.write(pack(len(valbytes))) buf.write(valbytes) return buf.getvalue() class TupleType(_ParameterizedType): typename = 'tuple' @classmethod def deserialize_safe(cls, byts, protocol_version): proto_version = max(3, protocol_version) p = 0 values = [] for col_type in cls.subtypes: if p == len(byts): break itemlen = int32_unpack(byts[p:p + 4]) p += 4 if itemlen >= 0: item = byts[p:p + itemlen] p += itemlen else: item = None # collections inside UDTs are always encoded with at least the # version 3 format values.append(col_type.from_binary(item, proto_version)) if len(values) < len(cls.subtypes): nones = [None] * (len(cls.subtypes) - len(values)) values = values + nones return tuple(values) @classmethod def serialize_safe(cls, val, protocol_version): if len(val) > len(cls.subtypes): raise ValueError("Expected %d items in a tuple, but got %d: %s" % (len(cls.subtypes), len(val), val)) proto_version = max(3, protocol_version) buf = io.BytesIO() for item, subtype in zip(val, cls.subtypes): if item is not None: packed_item = subtype.to_binary(item, proto_version) buf.write(int32_pack(len(packed_item))) buf.write(packed_item) else: buf.write(int32_pack(-1)) return buf.getvalue() @classmethod def cql_parameterized_type(cls): subtypes_string = ', '.join(sub.cql_parameterized_type() for sub in cls.subtypes) return 'frozen>' % (subtypes_string,) class UserType(TupleType): typename = "org.apache.cassandra.db.marshal.UserType" _cache = {} _module = sys.modules[__name__] @classmethod def make_udt_class(cls, keyspace, udt_name, field_names, field_types): assert len(field_names) == len(field_types) if six.PY2 and isinstance(udt_name, unicode): udt_name = udt_name.encode('utf-8') instance = cls._cache.get((keyspace, udt_name)) if not instance or instance.fieldnames != field_names or instance.subtypes != field_types: instance = type(udt_name, (cls,), {'subtypes': field_types, 'cassname': cls.cassname, 'typename': udt_name, 'fieldnames': field_names, 'keyspace': keyspace, 'mapped_class': None, 'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)}) cls._cache[(keyspace, udt_name)] = instance return instance @classmethod def evict_udt_class(cls, keyspace, udt_name): if six.PY2 and isinstance(udt_name, unicode): udt_name = udt_name.encode('utf-8') try: del cls._cache[(keyspace, udt_name)] except KeyError: pass @classmethod def apply_parameters(cls, subtypes, names): keyspace = subtypes[0].cass_parameterized_type() # when parsed from cassandra type, the keyspace is created as an unrecognized cass type; This gets the name back udt_name = _name_from_hex_string(subtypes[1].cassname) field_names = tuple(_name_from_hex_string(encoded_name) for encoded_name in names[2:]) # using tuple here to match what comes into make_udt_class from other sources (for caching equality test) return cls.make_udt_class(keyspace, udt_name, field_names, tuple(subtypes[2:])) @classmethod def cql_parameterized_type(cls): return "frozen<%s>" % (cls.typename,) @classmethod def deserialize_safe(cls, byts, protocol_version): values = super(UserType, cls).deserialize_safe(byts, protocol_version) if cls.mapped_class: return cls.mapped_class(**dict(zip(cls.fieldnames, values))) elif cls.tuple_type: return cls.tuple_type(*values) else: return tuple(values) @classmethod def serialize_safe(cls, val, protocol_version): proto_version = max(3, protocol_version) buf = io.BytesIO() for i, (fieldname, subtype) in enumerate(zip(cls.fieldnames, cls.subtypes)): # first treat as a tuple, else by custom type try: item = val[i] except TypeError: item = getattr(val, fieldname) if item is not None: packed_item = subtype.to_binary(item, proto_version) buf.write(int32_pack(len(packed_item))) buf.write(packed_item) else: buf.write(int32_pack(-1)) return buf.getvalue() @classmethod def _make_registered_udt_namedtuple(cls, keyspace, name, field_names): # this is required to make the type resolvable via this module... # required when unregistered udts are pickled for use as keys in # util.OrderedMap t = cls._make_udt_tuple_type(name, field_names) if t: qualified_name = "%s_%s" % (keyspace, name) setattr(cls._module, qualified_name, t) return t @classmethod def _make_udt_tuple_type(cls, name, field_names): # fallback to positional named, then unnamed tuples # for CQL identifiers that aren't valid in Python, try: t = namedtuple(name, field_names) except ValueError: try: t = namedtuple(name, util._positional_rename_invalid_identifiers(field_names)) log.warning("could not create a namedtuple for '%s' because one or more " "field names are not valid Python identifiers (%s); " "returning positionally-named fields" % (name, field_names)) except ValueError: t = None log.warning("could not create a namedtuple for '%s' because the name is " "not a valid Python identifier; will return tuples in " "its place" % (name,)) return t class CompositeType(_ParameterizedType): typename = "org.apache.cassandra.db.marshal.CompositeType" @classmethod def cql_parameterized_type(cls): """ There is no CQL notation for Composites, so we override this. """ typestring = cls.cass_parameterized_type(full=True) return "'%s'" % (typestring,) @classmethod def deserialize_safe(cls, byts, protocol_version): result = [] for subtype in cls.subtypes: if not byts: # CompositeType can have missing elements at the end break element_length = uint16_unpack(byts[:2]) element = byts[2:2 + element_length] # skip element length, element, and the EOC (one byte) byts = byts[2 + element_length + 1:] result.append(subtype.from_binary(element, protocol_version)) return tuple(result) class DynamicCompositeType(_ParameterizedType): typename = "org.apache.cassandra.db.marshal.DynamicCompositeType" @classmethod def cql_parameterized_type(cls): sublist = ', '.join('%s=>%s' % (alias, typ.cass_parameterized_type(full=True)) for alias, typ in zip(cls.fieldnames, cls.subtypes)) return "'%s(%s)'" % (cls.typename, sublist) class ColumnToCollectionType(_ParameterizedType): """ This class only really exists so that we can cleanly evaluate types when Cassandra includes this. We don't actually need or want the extra information. """ typename = "org.apache.cassandra.db.marshal.ColumnToCollectionType" class ReversedType(_ParameterizedType): typename = "org.apache.cassandra.db.marshal.ReversedType" num_subtypes = 1 @classmethod def deserialize_safe(cls, byts, protocol_version): subtype, = cls.subtypes return subtype.from_binary(byts, protocol_version) @classmethod def serialize_safe(cls, val, protocol_version): subtype, = cls.subtypes return subtype.to_binary(val, protocol_version) class FrozenType(_ParameterizedType): typename = "frozen" num_subtypes = 1 @classmethod def deserialize_safe(cls, byts, protocol_version): subtype, = cls.subtypes return subtype.from_binary(byts, protocol_version) @classmethod def serialize_safe(cls, val, protocol_version): subtype, = cls.subtypes return subtype.to_binary(val, protocol_version) def is_counter_type(t): if isinstance(t, six.string_types): t = lookup_casstype(t) return issubclass(t, CounterColumnType) def cql_typename(casstypename): """ Translate a Cassandra-style type specifier (optionally-fully-distinguished Java class names for data types, along with optional parameters) into a CQL-style type specifier. >>> cql_typename('DateType') 'timestamp' >>> cql_typename('org.apache.cassandra.db.marshal.ListType(IntegerType)') 'list' """ return lookup_casstype(casstypename).cql_parameterized_type() diff --git a/cassandra/datastax/__init__.py b/cassandra/datastax/__init__.py new file mode 100644 index 0000000..2c9ca17 --- /dev/null +++ b/cassandra/datastax/__init__.py @@ -0,0 +1,13 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cassandra/datastax/cloud/__init__.py b/cassandra/datastax/cloud/__init__.py new file mode 100644 index 0000000..ed9435e --- /dev/null +++ b/cassandra/datastax/cloud/__init__.py @@ -0,0 +1,167 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import json +import tempfile +import shutil +from six.moves.urllib.request import urlopen + +_HAS_SSL = True +try: + from ssl import SSLContext, PROTOCOL_TLSv1, CERT_REQUIRED +except: + _HAS_SSL = False + +from zipfile import ZipFile + +# 2.7 vs 3.x +try: + from zipfile import BadZipFile +except: + from zipfile import BadZipfile as BadZipFile + +from cassandra import DriverException + +log = logging.getLogger(__name__) + +__all__ = ['get_cloud_config'] + +PRODUCT_APOLLO = "DATASTAX_APOLLO" + + +class CloudConfig(object): + + username = None + password = None + host = None + port = None + keyspace = None + local_dc = None + ssl_context = None + + sni_host = None + sni_port = None + host_ids = None + + @classmethod + def from_dict(cls, d): + c = cls() + + c.port = d.get('port', None) + try: + c.port = int(d['port']) + except: + pass + + c.username = d.get('username', None) + c.password = d.get('password', None) + c.host = d.get('host', None) + c.keyspace = d.get('keyspace', None) + c.local_dc = d.get('localDC', None) + + return c + + +def get_cloud_config(cloud_config): + if not _HAS_SSL: + raise DriverException("A Python installation with SSL is required to connect to a cloud cluster.") + + if 'secure_connect_bundle' not in cloud_config: + raise ValueError("The cloud config doesn't have a secure_connect_bundle specified.") + + try: + config = read_cloud_config_from_zip(cloud_config) + except BadZipFile: + raise ValueError("Unable to open the zip file for the cloud config. Check your secure connect bundle.") + + return read_metadata_info(config, cloud_config) + + +def read_cloud_config_from_zip(cloud_config): + secure_bundle = cloud_config['secure_connect_bundle'] + with ZipFile(secure_bundle) as zipfile: + base_dir = os.path.dirname(secure_bundle) + tmp_dir = tempfile.mkdtemp(dir=base_dir) + try: + zipfile.extractall(path=tmp_dir) + return parse_cloud_config(os.path.join(tmp_dir, 'config.json'), cloud_config) + finally: + shutil.rmtree(tmp_dir) + + +def parse_cloud_config(path, cloud_config): + with open(path, 'r') as stream: + data = json.load(stream) + + config = CloudConfig.from_dict(data) + config_dir = os.path.dirname(path) + + if 'ssl_context' in cloud_config: + config.ssl_context = cloud_config['ssl_context'] + else: + # Load the ssl_context before we delete the temporary directory + ca_cert_location = os.path.join(config_dir, 'ca.crt') + cert_location = os.path.join(config_dir, 'cert') + key_location = os.path.join(config_dir, 'key') + config.ssl_context = _ssl_context_from_cert(ca_cert_location, cert_location, key_location) + + return config + + +def read_metadata_info(config, cloud_config): + url = "https://{}:{}/metadata".format(config.host, config.port) + timeout = cloud_config['connect_timeout'] if 'connect_timeout' in cloud_config else 5 + try: + response = urlopen(url, context=config.ssl_context, timeout=timeout) + except Exception as e: + log.exception(e) + raise DriverException("Unable to connect to the metadata service at %s" % url) + + if response.code != 200: + raise DriverException(("Error while fetching the metadata at: %s. " + "The service returned error code %d." % (url, response.code))) + return parse_metadata_info(config, response.read().decode('utf-8')) + + +def parse_metadata_info(config, http_data): + try: + data = json.loads(http_data) + except: + msg = "Failed to load cluster metadata" + raise DriverException(msg) + + contact_info = data['contact_info'] + config.local_dc = contact_info['local_dc'] + + proxy_info = contact_info['sni_proxy_address'].split(':') + config.sni_host = proxy_info[0] + try: + config.sni_port = int(proxy_info[1]) + except: + config.sni_port = 9042 + + config.host_ids = [host_id for host_id in contact_info['contact_points']] + + return config + + +def _ssl_context_from_cert(ca_cert_location, cert_location, key_location): + ssl_context = SSLContext(PROTOCOL_TLSv1) + ssl_context.load_verify_locations(ca_cert_location) + ssl_context.verify_mode = CERT_REQUIRED + ssl_context.load_cert_chain(certfile=cert_location, keyfile=key_location) + + return ssl_context diff --git a/cassandra/io/asyncioreactor.py b/cassandra/io/asyncioreactor.py index f8835e8..b386388 100644 --- a/cassandra/io/asyncioreactor.py +++ b/cassandra/io/asyncioreactor.py @@ -1,215 +1,215 @@ from cassandra.connection import Connection, ConnectionShutdown import asyncio import logging import os import socket import ssl from threading import Lock, Thread, get_ident log = logging.getLogger(__name__) # This module uses ``yield from`` and ``@asyncio.coroutine`` over ``await`` and # ``async def`` for pre-Python-3.5 compatibility, so keep in mind that the # managed coroutines are generator-based, not native coroutines. See PEP 492: # https://www.python.org/dev/peps/pep-0492/#coroutine-objects try: asyncio.run_coroutine_threadsafe except AttributeError: raise ImportError( 'Cannot use asyncioreactor without access to ' 'asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)' ) class AsyncioTimer(object): """ An ``asyncioreactor``-specific Timer. Similar to :class:`.connection.Timer, but with a slightly different API due to limitations in the underlying ``call_later`` interface. Not meant to be used with a :class:`.connection.TimerManager`. """ @property def end(self): raise NotImplementedError('{} is not compatible with TimerManager and ' 'does not implement .end()') def __init__(self, timeout, callback, loop): delayed = self._call_delayed_coro(timeout=timeout, callback=callback, loop=loop) self._handle = asyncio.run_coroutine_threadsafe(delayed, loop=loop) @staticmethod @asyncio.coroutine def _call_delayed_coro(timeout, callback, loop): yield from asyncio.sleep(timeout, loop=loop) return callback() def __lt__(self, other): try: return self._handle < other._handle except AttributeError: raise NotImplemented def cancel(self): self._handle.cancel() def finish(self): # connection.Timer method not implemented here because we can't inspect # the Handle returned from call_later raise NotImplementedError('{} is not compatible with TimerManager and ' 'does not implement .finish()') class AsyncioConnection(Connection): """ An experimental implementation of :class:`.Connection` that uses the ``asyncio`` module in the Python standard library for its event loop. Note that it requires ``asyncio`` features that were only introduced in the 3.4 line in 3.4.6, and in the 3.5 line in 3.5.1. """ _loop = None _pid = os.getpid() _lock = Lock() _loop_thread = None _write_queue = None def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self._connect_socket() self._socket.setblocking(0) self._write_queue = asyncio.Queue(loop=self._loop) # see initialize_reactor -- loop is running in a separate thread, so we # have to use a threadsafe call self._read_watcher = asyncio.run_coroutine_threadsafe( self.handle_read(), loop=self._loop ) self._write_watcher = asyncio.run_coroutine_threadsafe( self.handle_write(), loop=self._loop ) self._send_options_message() @classmethod def initialize_reactor(cls): with cls._lock: if cls._pid != os.getpid(): cls._loop = None if cls._loop is None: cls._loop = asyncio.new_event_loop() asyncio.set_event_loop(cls._loop) if not cls._loop_thread: # daemonize so the loop will be shut down on interpreter # shutdown cls._loop_thread = Thread(target=cls._loop.run_forever, daemon=True, name="asyncio_thread") cls._loop_thread.start() @classmethod def create_timer(cls, timeout, callback): return AsyncioTimer(timeout, callback, loop=cls._loop) def close(self): with self.lock: if self.is_closed: return self.is_closed = True # close from the loop thread to avoid races when removing file # descriptors asyncio.run_coroutine_threadsafe( self._close(), loop=self._loop ) @asyncio.coroutine def _close(self): - log.debug("Closing connection (%s) to %s" % (id(self), self.host)) + log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint)) if self._write_watcher: self._write_watcher.cancel() if self._read_watcher: self._read_watcher.cancel() if self._socket: self._loop.remove_writer(self._socket.fileno()) self._loop.remove_reader(self._socket.fileno()) self._socket.close() - log.debug("Closed socket to %s" % (self.host,)) + log.debug("Closed socket to %s" % (self.endpoint,)) if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) # don't leave in-progress operations hanging self.connected_event.set() def push(self, data): buff_size = self.out_buffer_size if len(data) > buff_size: for i in range(0, len(data), buff_size): self._push_chunk(data[i:i + buff_size]) else: self._push_chunk(data) def _push_chunk(self, chunk): if self._loop_thread.ident != get_ident(): asyncio.run_coroutine_threadsafe( self._write_queue.put(chunk), loop=self._loop ) else: # avoid races/hangs by just scheduling this, not using threadsafe self._loop.create_task(self._write_queue.put(chunk)) @asyncio.coroutine def handle_write(self): while True: try: next_msg = yield from self._write_queue.get() if next_msg: yield from self._loop.sock_sendall(self._socket, next_msg) except socket.error as err: log.debug("Exception in send for %s: %s", self, err) self.defunct(err) return except asyncio.CancelledError: return @asyncio.coroutine def handle_read(self): while True: try: buf = yield from self._loop.sock_recv(self._socket, self.in_buffer_size) self._iobuf.write(buf) # sock_recv expects EWOULDBLOCK if socket provides no data, but # nonblocking ssl sockets raise these instead, so we handle them # ourselves by yielding to the event loop, where the socket will # get the reading/writing it "wants" before retrying except (ssl.SSLWantWriteError, ssl.SSLWantReadError): yield continue except socket.error as err: log.debug("Exception during socket recv for %s: %s", self, err) self.defunct(err) return # leave the read loop except asyncio.CancelledError: return if buf and self._iobuf.tell(): self.process_io_buffer() else: log.debug("Connection %s closed by server", self) self.close() return diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index 91431aa..d3dd0cf 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -1,461 +1,464 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import atexit from collections import deque from functools import partial import logging import os import socket import sys from threading import Lock, Thread, Event import time import weakref import sys from six.moves import range try: from weakref import WeakSet except ImportError: from cassandra.util import WeakSet # noqa import asyncore try: import ssl except ImportError: ssl = None # NOQA from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager log = logging.getLogger(__name__) _dispatcher_map = {} def _cleanup(loop): if loop: loop._cleanup() class WaitableTimer(Timer): def __init__(self, timeout, callback): Timer.__init__(self, timeout, callback) self.callback = callback self.event = Event() self.final_exception = None def finish(self, time_now): try: finished = Timer.finish(self, time_now) if finished: self.event.set() return True return False except Exception as e: self.final_exception = e self.event.set() return True def wait(self, timeout=None): self.event.wait(timeout) if self.final_exception: raise self.final_exception class _PipeWrapper(object): def __init__(self, fd): self.fd = fd def fileno(self): return self.fd def close(self): os.close(self.fd) def getsockopt(self, level, optname, buflen=None): # act like an unerrored socket for the asyncore error handling if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: return 0 raise NotImplementedError() class _AsyncoreDispatcher(asyncore.dispatcher): def __init__(self, socket): asyncore.dispatcher.__init__(self, map=_dispatcher_map) # inject after to avoid base class validation self.set_socket(socket) self._notified = False def writable(self): return False def validate(self): assert not self._notified self.notify_loop() assert self._notified self.loop(0.1) assert not self._notified def loop(self, timeout): asyncore.loop(timeout=timeout, use_poll=True, map=_dispatcher_map, count=1) class _AsyncorePipeDispatcher(_AsyncoreDispatcher): def __init__(self): self.read_fd, self.write_fd = os.pipe() _AsyncoreDispatcher.__init__(self, _PipeWrapper(self.read_fd)) def writable(self): return False def handle_read(self): while len(os.read(self.read_fd, 4096)) == 4096: pass self._notified = False def notify_loop(self): if not self._notified: self._notified = True os.write(self.write_fd, b'x') class _AsyncoreUDPDispatcher(_AsyncoreDispatcher): """ Experimental alternate dispatcher for avoiding busy wait in the asyncore loop. It is not used by default because it relies on local port binding. Port scanning is not implemented, so multiple clients on one host will collide. This address would need to be set per instance, or this could be specialized to scan until an address is found. To use:: from cassandra.io.asyncorereactor import _AsyncoreUDPDispatcher, AsyncoreLoop AsyncoreLoop._loop_dispatch_class = _AsyncoreUDPDispatcher """ bind_address = ('localhost', 10000) def __init__(self): self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._socket.bind(self.bind_address) self._socket.setblocking(0) _AsyncoreDispatcher.__init__(self, self._socket) def handle_read(self): try: d = self._socket.recvfrom(1) while d and d[1]: d = self._socket.recvfrom(1) except socket.error as e: pass self._notified = False def notify_loop(self): if not self._notified: self._notified = True self._socket.sendto(b'', self.bind_address) def loop(self, timeout): asyncore.loop(timeout=timeout, use_poll=False, map=_dispatcher_map, count=1) class _BusyWaitDispatcher(object): max_write_latency = 0.001 """ Timeout pushed down to asyncore select/poll. Dictates the amount of time it will sleep before coming back to check if anything is writable. """ def notify_loop(self): pass def loop(self, timeout): if not _dispatcher_map: time.sleep(0.005) count = timeout // self.max_write_latency asyncore.loop(timeout=self.max_write_latency, use_poll=True, map=_dispatcher_map, count=count) def validate(self): pass def close(self): pass class AsyncoreLoop(object): timer_resolution = 0.1 # used as the max interval to be in the io loop before returning to service timeouts _loop_dispatch_class = _AsyncorePipeDispatcher if os.name != 'nt' else _BusyWaitDispatcher def __init__(self): self._pid = os.getpid() self._loop_lock = Lock() self._started = False self._shutdown = False self._thread = None self._timers = TimerManager() try: dispatcher = self._loop_dispatch_class() dispatcher.validate() log.debug("Validated loop dispatch with %s", self._loop_dispatch_class) except Exception: log.exception("Failed validating loop dispatch with %s. Using busy wait execution instead.", self._loop_dispatch_class) dispatcher.close() dispatcher = _BusyWaitDispatcher() self._loop_dispatcher = dispatcher def maybe_start(self): should_start = False did_acquire = False try: did_acquire = self._loop_lock.acquire(False) if did_acquire and not self._started: self._started = True should_start = True finally: if did_acquire: self._loop_lock.release() if should_start: self._thread = Thread(target=self._run_loop, name="asyncore_cassandra_driver_event_loop") self._thread.daemon = True self._thread.start() def wake_loop(self): self._loop_dispatcher.notify_loop() def _run_loop(self): log.debug("Starting asyncore event loop") with self._loop_lock: while not self._shutdown: try: self._loop_dispatcher.loop(self.timer_resolution) self._timers.service_timeouts() except Exception: log.debug("Asyncore event loop stopped unexepectedly", exc_info=True) break self._started = False log.debug("Asyncore event loop ended") def add_timer(self, timer): self._timers.add_timer(timer) # This function is called from a different thread than the event loop # thread, so for this call to be thread safe, we must wake up the loop # in case it's stuck at a select self.wake_loop() def _cleanup(self): global _dispatcher_map self._shutdown = True if not self._thread: return log.debug("Waiting for event loop thread to join...") self._thread.join(timeout=1.0) if self._thread.is_alive(): log.warning( "Event loop thread could not be joined, so shutdown may not be clean. " "Please call Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") # Ensure all connections are closed and in-flight requests cancelled for conn in tuple(_dispatcher_map.values()): if conn is not self._loop_dispatcher: conn.close() self._timers.service_timeouts() # Once all the connections are closed, close the dispatcher self._loop_dispatcher.close() log.debug("Dispatchers were closed") _global_loop = None atexit.register(partial(_cleanup, _global_loop)) class AsyncoreConnection(Connection, asyncore.dispatcher): """ An implementation of :class:`.Connection` that uses the ``asyncore`` module in the Python standard library for its event loop. """ _writable = False _readable = False @classmethod def initialize_reactor(cls): global _global_loop if not _global_loop: _global_loop = AsyncoreLoop() else: current_pid = os.getpid() if _global_loop._pid != current_pid: log.debug("Detected fork, clearing and reinitializing reactor state") cls.handle_fork() _global_loop = AsyncoreLoop() @classmethod def handle_fork(cls): global _dispatcher_map, _global_loop _dispatcher_map = {} if _global_loop: _global_loop._cleanup() _global_loop = None @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) _global_loop.add_timer(timer) return timer def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self.deque = deque() self.deque_lock = Lock() self._connect_socket() # start the event loop if needed _global_loop.maybe_start() init_handler = WaitableTimer( timeout=0, callback=partial(asyncore.dispatcher.__init__, self, self._socket, _dispatcher_map) ) _global_loop.add_timer(init_handler) init_handler.wait(kwargs["connect_timeout"]) self._writable = True self._readable = True self._send_options_message() def close(self): with self.lock: if self.is_closed: return self.is_closed = True - log.debug("Closing connection (%s) to %s", id(self), self.host) + log.debug("Closing connection (%s) to %s", id(self), self.endpoint) self._writable = False self._readable = False # We don't have to wait for this to be closed, we can just schedule it self.create_timer(0, partial(asyncore.dispatcher.close, self)) - log.debug("Closed socket to %s", self.host) + log.debug("Closed socket to %s", self.endpoint) if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) #This happens when the connection is shutdown while waiting for the ReadyMessage if not self.connected_event.is_set(): - self.last_error = ConnectionShutdown("Connection to %s was closed" % self.host) + self.last_error = ConnectionShutdown("Connection to %s was closed" % self.endpoint) # don't leave in-progress operations hanging self.connected_event.set() def handle_error(self): self.defunct(sys.exc_info()[1]) def handle_close(self): log.debug("Connection %s closed by server", self) self.close() def handle_write(self): while True: with self.deque_lock: try: next_msg = self.deque.popleft() except IndexError: self._writable = False return try: sent = self.send(next_msg) self._readable = True except socket.error as err: - if (err.args[0] in NONBLOCKING): + if (err.args[0] in NONBLOCKING or + err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE)): with self.deque_lock: self.deque.appendleft(next_msg) else: self.defunct(err) return else: if sent < len(next_msg): with self.deque_lock: self.deque.appendleft(next_msg[sent:]) if sent == 0: return def handle_read(self): try: while True: buf = self.recv(self.in_buffer_size) self._iobuf.write(buf) if len(buf) < self.in_buffer_size: break except socket.error as err: if ssl and isinstance(err, ssl.SSLError): if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): - return + if not self._iobuf.tell(): + return else: self.defunct(err) return elif err.args[0] in NONBLOCKING: - return + if not self._iobuf.tell(): + return else: self.defunct(err) return if self._iobuf.tell(): self.process_io_buffer() if not self._requests and not self.is_control_connection: self._readable = False def push(self, data): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] for i in range(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] with self.deque_lock: self.deque.extend(chunks) self._writable = True _global_loop.wake_loop() def writable(self): return self._writable def readable(self): return self._readable or (self.is_control_connection and not (self.is_defunct or self.is_closed)) diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py index bc01f75..2b16ef6 100644 --- a/cassandra/io/eventletreactor.py +++ b/cassandra/io/eventletreactor.py @@ -1,155 +1,155 @@ # Copyright 2014 Symantec Corporation # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Originally derived from MagnetoDB source: # https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py import eventlet from eventlet.green import socket from eventlet.queue import Queue from greenlet import GreenletExit import logging from threading import Event import time from six.moves import xrange from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) class EventletConnection(Connection): """ An implementation of :class:`.Connection` that utilizes ``eventlet``. This implementation assumes all eventlet monkey patching is active. It is not tested with partial patching. """ _read_watcher = None _write_watcher = None _socket_impl = eventlet.green.socket _ssl_impl = eventlet.green.ssl _timers = None _timeout_watcher = None _new_timer = None @classmethod def initialize_reactor(cls): eventlet.monkey_patch() if not cls._timers: cls._timers = TimerManager() cls._timeout_watcher = eventlet.spawn(cls.service_timeouts) cls._new_timer = Event() @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) cls._timers.add_timer(timer) cls._new_timer.set() return timer @classmethod def service_timeouts(cls): """ cls._timeout_watcher runs in this loop forever. It is usually waiting for the next timeout on the cls._new_timer Event. When new timers are added, that event is set so that the watcher can wake up and possibly set an earlier timeout. """ timer_manager = cls._timers while True: next_end = timer_manager.service_timeouts() sleep_time = max(next_end - time.time(), 0) if next_end else 10000 cls._new_timer.wait(sleep_time) cls._new_timer.clear() def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self._write_queue = Queue() self._connect_socket() self._read_watcher = eventlet.spawn(lambda: self.handle_read()) self._write_watcher = eventlet.spawn(lambda: self.handle_write()) self._send_options_message() def close(self): with self.lock: if self.is_closed: return self.is_closed = True - log.debug("Closing connection (%s) to %s" % (id(self), self.host)) + log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint)) cur_gthread = eventlet.getcurrent() if self._read_watcher and self._read_watcher != cur_gthread: self._read_watcher.kill() if self._write_watcher and self._write_watcher != cur_gthread: self._write_watcher.kill() if self._socket: self._socket.close() - log.debug("Closed socket to %s" % (self.host,)) + log.debug("Closed socket to %s" % (self.endpoint,)) if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) # don't leave in-progress operations hanging self.connected_event.set() def handle_close(self): log.debug("connection closed by server") self.close() def handle_write(self): while True: try: next_msg = self._write_queue.get() self._socket.sendall(next_msg) except socket.error as err: log.debug("Exception during socket send for %s: %s", self, err) self.defunct(err) return # Leave the write loop except GreenletExit: # graceful greenthread exit return def handle_read(self): while True: try: buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) except socket.error as err: log.debug("Exception during socket recv for %s: %s", self, err) self.defunct(err) return # leave the read loop except GreenletExit: # graceful greenthread exit return if buf and self._iobuf.tell(): self.process_io_buffer() else: log.debug("Connection %s closed by server", self) self.close() return def push(self, data): chunk_size = self.out_buffer_size for i in xrange(0, len(data), chunk_size): self._write_queue.put(data[i:i + chunk_size]) diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py index bbf9e83..ebc664d 100644 --- a/cassandra/io/geventreactor.py +++ b/cassandra/io/geventreactor.py @@ -1,138 +1,138 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gevent import gevent.event from gevent.queue import Queue from gevent import socket import gevent.ssl import logging import time from six.moves import range from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) class GeventConnection(Connection): """ An implementation of :class:`.Connection` that utilizes ``gevent``. This implementation assumes all gevent monkey patching is active. It is not tested with partial patching. """ _read_watcher = None _write_watcher = None _socket_impl = gevent.socket _ssl_impl = gevent.ssl _timers = None _timeout_watcher = None _new_timer = None @classmethod def initialize_reactor(cls): if not cls._timers: cls._timers = TimerManager() cls._timeout_watcher = gevent.spawn(cls.service_timeouts) cls._new_timer = gevent.event.Event() @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) cls._timers.add_timer(timer) cls._new_timer.set() return timer @classmethod def service_timeouts(cls): timer_manager = cls._timers timer_event = cls._new_timer while True: next_end = timer_manager.service_timeouts() sleep_time = max(next_end - time.time(), 0) if next_end else 10000 timer_event.wait(sleep_time) timer_event.clear() def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self._write_queue = Queue() self._connect_socket() self._read_watcher = gevent.spawn(self.handle_read) self._write_watcher = gevent.spawn(self.handle_write) self._send_options_message() def close(self): with self.lock: if self.is_closed: return self.is_closed = True - log.debug("Closing connection (%s) to %s" % (id(self), self.host)) + log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint)) if self._read_watcher: self._read_watcher.kill(block=False) if self._write_watcher: self._write_watcher.kill(block=False) if self._socket: self._socket.close() - log.debug("Closed socket to %s" % (self.host,)) + log.debug("Closed socket to %s" % (self.endpoint,)) if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) # don't leave in-progress operations hanging self.connected_event.set() def handle_close(self): log.debug("connection closed by server") self.close() def handle_write(self): while True: try: next_msg = self._write_queue.get() self._socket.sendall(next_msg) except socket.error as err: log.debug("Exception in send for %s: %s", self, err) self.defunct(err) return def handle_read(self): while True: try: buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) except socket.error as err: log.debug("Exception in read for %s: %s", self, err) self.defunct(err) return # leave the read loop if buf and self._iobuf.tell(): self.process_io_buffer() else: log.debug("Connection %s closed by server", self) self.close() return def push(self, data): chunk_size = self.out_buffer_size for i in range(0, len(data), chunk_size): self._write_queue.put(data[i:i + chunk_size]) diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index 21111b0..7d4bf8e 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -1,375 +1,377 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import atexit from collections import deque from functools import partial import logging import os import socket import ssl from threading import Lock, Thread import time -import weakref from six.moves import range from cassandra.connection import (Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager) try: import cassandra.io.libevwrapper as libev except ImportError: raise ImportError( "The C extension needed to use libev was not found. This " "probably means that you didn't have the required build dependencies " "when installing the driver. See " "http://datastax.github.io/python-driver/installation.html#c-extensions " "for instructions on installing build dependencies and building " "the C extension.") log = logging.getLogger(__name__) def _cleanup(loop): if loop: loop._cleanup() class LibevLoop(object): def __init__(self): self._pid = os.getpid() self._loop = libev.Loop() self._notifier = libev.Async(self._loop) self._notifier.start() # prevent _notifier from keeping the loop from returning self._loop.unref() self._started = False self._shutdown = False self._lock = Lock() self._lock_thread = Lock() self._thread = None # set of all connections; only replaced with a new copy # while holding _conn_set_lock, never modified in place self._live_conns = set() # newly created connections that need their write/read watcher started self._new_conns = set() # recently closed connections that need their write/read watcher stopped self._closed_conns = set() self._conn_set_lock = Lock() self._preparer = libev.Prepare(self._loop, self._loop_will_run) # prevent _preparer from keeping the loop from returning self._loop.unref() self._preparer.start() self._timers = TimerManager() self._loop_timer = libev.Timer(self._loop, self._on_loop_timer) def maybe_start(self): should_start = False with self._lock: if not self._started: log.debug("Starting libev event loop") self._started = True should_start = True if should_start: with self._lock_thread: if not self._shutdown: self._thread = Thread(target=self._run_loop, name="event_loop") self._thread.daemon = True self._thread.start() self._notifier.send() def _run_loop(self): while True: self._loop.start() # there are still active watchers, no deadlock with self._lock: if not self._shutdown and self._live_conns: log.debug("Restarting event loop") continue else: # all Connections have been closed, no active watchers log.debug("All Connections currently closed, event loop ended") self._started = False break def _cleanup(self): self._shutdown = True if not self._thread: return for conn in self._live_conns | self._new_conns | self._closed_conns: conn.close() for watcher in (conn._write_watcher, conn._read_watcher): if watcher: watcher.stop() self.notify() # wake the timer watcher # PYTHON-752 Thread might have just been created and not started with self._lock_thread: self._thread.join(timeout=1.0) if self._thread.is_alive(): log.warning( "Event loop thread could not be joined, so shutdown may not be clean. " "Please call Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") def add_timer(self, timer): self._timers.add_timer(timer) self._notifier.send() # wake up in case this timer is earlier def _update_timer(self): if not self._shutdown: next_end = self._timers.service_timeouts() if next_end: self._loop_timer.start(next_end - time.time()) # timer handles negative values else: self._loop_timer.stop() def _on_loop_timer(self): self._timers.service_timeouts() def notify(self): self._notifier.send() def connection_created(self, conn): with self._conn_set_lock: new_live_conns = self._live_conns.copy() new_live_conns.add(conn) self._live_conns = new_live_conns new_new_conns = self._new_conns.copy() new_new_conns.add(conn) self._new_conns = new_new_conns def connection_destroyed(self, conn): with self._conn_set_lock: new_live_conns = self._live_conns.copy() new_live_conns.discard(conn) self._live_conns = new_live_conns new_closed_conns = self._closed_conns.copy() new_closed_conns.add(conn) self._closed_conns = new_closed_conns self._notifier.send() def _loop_will_run(self, prepare): changed = False for conn in self._live_conns: if not conn.deque and conn._write_watcher_is_active: if conn._write_watcher: conn._write_watcher.stop() conn._write_watcher_is_active = False changed = True elif conn.deque and not conn._write_watcher_is_active: conn._write_watcher.start() conn._write_watcher_is_active = True changed = True if self._new_conns: with self._conn_set_lock: to_start = self._new_conns self._new_conns = set() for conn in to_start: conn._read_watcher.start() changed = True if self._closed_conns: with self._conn_set_lock: to_stop = self._closed_conns self._closed_conns = set() for conn in to_stop: if conn._write_watcher: conn._write_watcher.stop() # clear reference cycles from IO callback del conn._write_watcher if conn._read_watcher: conn._read_watcher.stop() # clear reference cycles from IO callback del conn._read_watcher changed = True # TODO: update to do connection management, timer updates through dedicated async 'notifier' callbacks self._update_timer() if changed: self._notifier.send() _global_loop = None atexit.register(partial(_cleanup, _global_loop)) class LibevConnection(Connection): """ An implementation of :class:`.Connection` that uses libev for its event loop. """ _write_watcher_is_active = False _read_watcher = None _write_watcher = None _socket = None @classmethod def initialize_reactor(cls): global _global_loop if not _global_loop: _global_loop = LibevLoop() else: if _global_loop._pid != os.getpid(): log.debug("Detected fork, clearing and reinitializing reactor state") cls.handle_fork() _global_loop = LibevLoop() @classmethod def handle_fork(cls): global _global_loop if _global_loop: _global_loop._cleanup() _global_loop = None @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) _global_loop.add_timer(timer) return timer def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self.deque = deque() self._deque_lock = Lock() self._connect_socket() self._socket.setblocking(0) with _global_loop._lock: self._read_watcher = libev.IO(self._socket.fileno(), libev.EV_READ, _global_loop._loop, self.handle_read) self._write_watcher = libev.IO(self._socket.fileno(), libev.EV_WRITE, _global_loop._loop, self.handle_write) self._send_options_message() _global_loop.connection_created(self) # start the global event loop if needed _global_loop.maybe_start() def close(self): with self.lock: if self.is_closed: return self.is_closed = True - log.debug("Closing connection (%s) to %s", id(self), self.host) + log.debug("Closing connection (%s) to %s", id(self), self.endpoint) _global_loop.connection_destroyed(self) self._socket.close() - log.debug("Closed socket to %s", self.host) + log.debug("Closed socket to %s", self.endpoint) # don't leave in-progress operations hanging if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) def handle_write(self, watcher, revents, errno=None): if revents & libev.EV_ERROR: if errno: exc = IOError(errno, os.strerror(errno)) else: exc = Exception("libev reported an error") self.defunct(exc) return while True: try: with self._deque_lock: next_msg = self.deque.popleft() except IndexError: return try: sent = self._socket.send(next_msg) except socket.error as err: - if (err.args[0] in NONBLOCKING): + if (err.args[0] in NONBLOCKING or + err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE)): with self._deque_lock: self.deque.appendleft(next_msg) else: self.defunct(err) return else: if sent < len(next_msg): with self._deque_lock: self.deque.appendleft(next_msg[sent:]) def handle_read(self, watcher, revents, errno=None): if revents & libev.EV_ERROR: if errno: exc = IOError(errno, os.strerror(errno)) else: exc = Exception("libev reported an error") self.defunct(exc) return try: while True: buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) if len(buf) < self.in_buffer_size: break except socket.error as err: if ssl and isinstance(err, ssl.SSLError): if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): - return + if not self._iobuf.tell(): + return else: self.defunct(err) return elif err.args[0] in NONBLOCKING: - return + if not self._iobuf.tell(): + return else: self.defunct(err) return if self._iobuf.tell(): self.process_io_buffer() else: log.debug("Connection %s closed by server", self) self.close() def push(self, data): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] for i in range(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] with self._deque_lock: self.deque.extend(chunks) _global_loop.notify() diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py index 3611cdf..1dbe9d8 100644 --- a/cassandra/io/twistedreactor.py +++ b/cassandra/io/twistedreactor.py @@ -1,317 +1,317 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Module that implements an event loop based on twisted ( https://twistedmatrix.com ). """ import atexit from functools import partial import logging from threading import Thread, Lock import time from twisted.internet import reactor, protocol import weakref from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) def _cleanup(cleanup_weakref): try: cleanup_weakref()._cleanup() except ReferenceError: return class TwistedConnectionProtocol(protocol.Protocol): """ Twisted Protocol class for handling data received and connection made events. """ def __init__(self): self.connection = None def dataReceived(self, data): """ Callback function that is called when data has been received on the connection. Reaches back to the Connection object and queues the data for processing. """ self.connection._iobuf.write(data) self.connection.handle_read() def connectionMade(self): """ Callback function that is called when a connection has succeeded. Reaches back to the Connection object and confirms that the connection is ready. """ try: # Non SSL connection self.connection = self.transport.connector.factory.conn except AttributeError: # SSL connection self.connection = self.transport.connector.factory.wrappedFactory.conn self.connection.client_connection_made(self.transport) def connectionLost(self, reason): # reason is a Failure instance self.connection.defunct(reason.value) class TwistedConnectionClientFactory(protocol.ClientFactory): def __init__(self, connection): # ClientFactory does not define __init__() in parent classes # and does not inherit from object. self.conn = connection def buildProtocol(self, addr): """ Twisted function that defines which kind of protocol to use in the ClientFactory. """ return TwistedConnectionProtocol() def clientConnectionFailed(self, connector, reason): """ Overridden twisted callback which is called when the connection attempt fails. """ log.debug("Connect failed: %s", reason) self.conn.defunct(reason.value) def clientConnectionLost(self, connector, reason): """ Overridden twisted callback which is called when the connection goes away (cleanly or otherwise). It should be safe to call defunct() here instead of just close, because we can assume that if the connection was closed cleanly, there are no requests to error out. If this assumption turns out to be false, we can call close() instead of defunct() when "reason" is an appropriate type. """ log.debug("Connect lost: %s", reason) self.conn.defunct(reason.value) class TwistedLoop(object): _lock = None _thread = None _timeout_task = None _timeout = None def __init__(self): self._lock = Lock() self._timers = TimerManager() def maybe_start(self): with self._lock: if not reactor.running: self._thread = Thread(target=reactor.run, name="cassandra_driver_twisted_event_loop", kwargs={'installSignalHandlers': False}) self._thread.daemon = True self._thread.start() atexit.register(partial(_cleanup, weakref.ref(self))) def _cleanup(self): if self._thread: reactor.callFromThread(reactor.stop) self._thread.join(timeout=1.0) if self._thread.is_alive(): log.warning("Event loop thread could not be joined, so " "shutdown may not be clean. Please call " "Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") def add_timer(self, timer): self._timers.add_timer(timer) # callFromThread to schedule from the loop thread, where # the timeout task can safely be modified reactor.callFromThread(self._schedule_timeout, timer.end) def _schedule_timeout(self, next_timeout): if next_timeout: delay = max(next_timeout - time.time(), 0) if self._timeout_task and self._timeout_task.active(): if next_timeout < self._timeout: self._timeout_task.reset(delay) self._timeout = next_timeout else: self._timeout_task = reactor.callLater(delay, self._on_loop_timer) self._timeout = next_timeout def _on_loop_timer(self): self._timers.service_timeouts() self._schedule_timeout(self._timers.next_timeout) try: from twisted.internet import ssl import OpenSSL.crypto from OpenSSL.crypto import load_certificate, FILETYPE_PEM class _SSLContextFactory(ssl.ClientContextFactory): def __init__(self, ssl_options, check_hostname, host): self.ssl_options = ssl_options self.check_hostname = check_hostname self.host = host def getContext(self): # This version has to be OpenSSL.SSL.DESIRED_VERSION # instead of ssl.DESIRED_VERSION as in other loops self.method = self.ssl_options["ssl_version"] context = ssl.ClientContextFactory.getContext(self) if "certfile" in self.ssl_options: context.use_certificate_file(self.ssl_options["certfile"]) if "keyfile" in self.ssl_options: context.use_privatekey_file(self.ssl_options["keyfile"]) if "ca_certs" in self.ssl_options: x509 = load_certificate(FILETYPE_PEM, open(self.ssl_options["ca_certs"]).read()) store = context.get_cert_store() store.add_cert(x509) if "cert_reqs" in self.ssl_options: # This expects OpenSSL.SSL.VERIFY_NONE/OpenSSL.SSL.VERIFY_PEER # or OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT context.set_verify(self.ssl_options["cert_reqs"], callback=self.verify_callback) return context def verify_callback(self, connection, x509, errnum, errdepth, ok): if ok: if self.check_hostname and self.host != x509.get_subject().commonName: return False return ok _HAS_SSL = True except ImportError as e: _HAS_SSL = False class TwistedConnection(Connection): """ An implementation of :class:`.Connection` that utilizes the Twisted event loop. """ _loop = None @classmethod def initialize_reactor(cls): if not cls._loop: cls._loop = TwistedLoop() @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) cls._loop.add_timer(timer) return timer def __init__(self, *args, **kwargs): """ Initialization method. Note that we can't call reactor methods directly here because it's not thread-safe, so we schedule the reactor/connection stuff to be run from the event loop thread when it gets the chance. """ Connection.__init__(self, *args, **kwargs) self.is_closed = True self.connector = None self.transport = None reactor.callFromThread(self.add_connection) self._loop.maybe_start() def add_connection(self): """ Convenience function to connect and store the resulting connector. """ if self.ssl_options: if not _HAS_SSL: raise ImportError( str(e) + ', pyOpenSSL must be installed to enable SSL support with the Twisted event loop' ) self.connector = reactor.connectSSL( - host=self.host, port=self.port, + host=self.endpoint.address, port=self.port, factory=TwistedConnectionClientFactory(self), - contextFactory=_SSLContextFactory(self.ssl_options, self._check_hostname, self.host), + contextFactory=_SSLContextFactory(self.ssl_options, self._check_hostname, self.endpoint.address), timeout=self.connect_timeout) else: self.connector = reactor.connectTCP( - host=self.host, port=self.port, + host=self.endpoint.address, port=self.port, factory=TwistedConnectionClientFactory(self), timeout=self.connect_timeout) def client_connection_made(self, transport): """ Called by twisted protocol when a connection attempt has succeeded. """ with self.lock: self.is_closed = False self.transport = transport self._send_options_message() def close(self): """ Disconnect and error-out all requests. """ with self.lock: if self.is_closed: return self.is_closed = True - log.debug("Closing connection (%s) to %s", id(self), self.host) + log.debug("Closing connection (%s) to %s", id(self), self.endpoint) reactor.callFromThread(self.connector.disconnect) - log.debug("Closed socket to %s", self.host) + log.debug("Closed socket to %s", self.endpoint) if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) # don't leave in-progress operations hanging self.connected_event.set() def handle_read(self): """ Process the incoming data buffer. """ self.process_io_buffer() def push(self, data): """ This function is called when outgoing data should be queued for sending. Note that we can't call transport.write() directly because it is not thread-safe, so we schedule it to run from within the event loop when it gets the chance. """ reactor.callFromThread(self.transport.write, data) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 377ea4d..1824b3f 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -1,2821 +1,2845 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from binascii import unhexlify from bisect import bisect_left -from collections import defaultdict, Mapping +from collections import defaultdict from functools import total_ordering from hashlib import md5 from itertools import islice, cycle import json import logging import re import six from six.moves import zip import sys from threading import RLock import struct import random murmur3 = None try: from cassandra.murmur3 import murmur3 except ImportError as e: pass from cassandra import SignatureDescriptor, ConsistencyLevel, InvalidRequest, Unauthorized import cassandra.cqltypes as types from cassandra.encoder import Encoder from cassandra.marshal import varint_unpack from cassandra.protocol import QueryMessage from cassandra.query import dict_factory, bind_params -from cassandra.util import OrderedDict +from cassandra.util import OrderedDict, Version from cassandra.pool import HostDistance - +from cassandra.connection import EndPoint +from cassandra.compat import Mapping log = logging.getLogger(__name__) cql_keywords = set(( 'add', 'aggregate', 'all', 'allow', 'alter', 'and', 'apply', 'as', 'asc', 'ascii', 'authorize', 'batch', 'begin', 'bigint', 'blob', 'boolean', 'by', 'called', 'clustering', 'columnfamily', 'compact', 'contains', 'count', 'counter', 'create', 'custom', 'date', 'decimal', 'delete', 'desc', 'describe', 'distinct', 'double', 'drop', 'entries', 'execute', 'exists', 'filtering', 'finalfunc', 'float', 'from', 'frozen', 'full', 'function', 'functions', 'grant', 'if', 'in', 'index', 'inet', 'infinity', 'initcond', 'input', 'insert', 'int', 'into', 'is', 'json', 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'modify', 'nan', 'nologin', 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission', 'permissions', 'primary', 'rename', 'replace', 'returns', 'revoke', 'role', 'roles', 'schema', 'select', 'set', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'table', 'text', 'time', 'timestamp', 'timeuuid', 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', 'update', 'use', 'user', 'users', 'using', 'uuid', 'values', 'varchar', 'varint', 'view', 'where', 'with', 'writetime' )) """ Set of keywords in CQL. Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g """ cql_keywords_unreserved = set(( 'aggregate', 'all', 'as', 'ascii', 'bigint', 'blob', 'boolean', 'called', 'clustering', 'compact', 'contains', 'count', 'counter', 'custom', 'date', 'decimal', 'distinct', 'double', 'exists', 'filtering', 'finalfunc', 'float', 'frozen', 'function', 'functions', 'inet', 'initcond', 'input', 'int', 'json', 'key', 'keys', 'keyspaces', 'language', 'list', 'login', 'map', 'nologin', 'nosuperuser', 'options', 'password', 'permission', 'permissions', 'returns', 'role', 'roles', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'text', 'time', 'timestamp', 'timeuuid', 'tinyint', 'trigger', 'ttl', 'tuple', 'type', 'user', 'users', 'uuid', 'values', 'varchar', 'varint', 'writetime' )) """ Set of unreserved keywords in CQL. Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g """ cql_keywords_reserved = cql_keywords - cql_keywords_unreserved """ Set of reserved keywords in CQL. """ _encoder = Encoder() class Metadata(object): """ Holds a representation of the cluster schema and topology. """ cluster_name = None """ The string name of the cluster. """ keyspaces = None """ A map from keyspace names to matching :class:`~.KeyspaceMetadata` instances. """ partitioner = None """ The string name of the partitioner for the cluster. """ token_map = None """ A :class:`~.TokenMap` instance describing the ring topology. """ + dbaas = False + """ A boolean indicating if connected to a DBaaS cluster """ + def __init__(self): self.keyspaces = {} + self.dbaas = False self._hosts = {} self._hosts_lock = RLock() def export_schema_as_string(self): """ Returns a string that can be executed as a query in order to recreate the entire schema. The string is formatted to be human readable. """ return "\n\n".join(ks.export_as_string() for ks in self.keyspaces.values()) def refresh(self, connection, timeout, target_type=None, change_type=None, **kwargs): - server_version = self.get_host(connection.host).release_version + server_version = self.get_host(connection.endpoint).release_version parser = get_schema_parser(connection, server_version, timeout) if not target_type: self._rebuild_all(parser) return tt_lower = target_type.lower() try: parse_method = getattr(parser, 'get_' + tt_lower) meta = parse_method(self.keyspaces, **kwargs) if meta: update_method = getattr(self, '_update_' + tt_lower) if tt_lower == 'keyspace' and connection.protocol_version < 3: # we didn't have 'type' target in legacy protocol versions, so we need to query those too user_types = parser.get_types_map(self.keyspaces, **kwargs) self._update_keyspace(meta, user_types) else: update_method(meta) else: drop_method = getattr(self, '_drop_' + tt_lower) drop_method(**kwargs) except AttributeError: raise ValueError("Unknown schema target_type: '%s'" % target_type) def _rebuild_all(self, parser): current_keyspaces = set() for keyspace_meta in parser.get_all_keyspaces(): current_keyspaces.add(keyspace_meta.name) old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None) self.keyspaces[keyspace_meta.name] = keyspace_meta if old_keyspace_meta: self._keyspace_updated(keyspace_meta.name) else: self._keyspace_added(keyspace_meta.name) # remove not-just-added keyspaces removed_keyspaces = [name for name in self.keyspaces.keys() if name not in current_keyspaces] self.keyspaces = dict((name, meta) for name, meta in self.keyspaces.items() if name in current_keyspaces) for ksname in removed_keyspaces: self._keyspace_removed(ksname) def _update_keyspace(self, keyspace_meta, new_user_types=None): ks_name = keyspace_meta.name old_keyspace_meta = self.keyspaces.get(ks_name, None) self.keyspaces[ks_name] = keyspace_meta if old_keyspace_meta: keyspace_meta.tables = old_keyspace_meta.tables keyspace_meta.user_types = new_user_types if new_user_types is not None else old_keyspace_meta.user_types keyspace_meta.indexes = old_keyspace_meta.indexes keyspace_meta.functions = old_keyspace_meta.functions keyspace_meta.aggregates = old_keyspace_meta.aggregates keyspace_meta.views = old_keyspace_meta.views if (keyspace_meta.replication_strategy != old_keyspace_meta.replication_strategy): self._keyspace_updated(ks_name) else: self._keyspace_added(ks_name) def _drop_keyspace(self, keyspace): if self.keyspaces.pop(keyspace, None): self._keyspace_removed(keyspace) def _update_table(self, meta): try: keyspace_meta = self.keyspaces[meta.keyspace_name] # this is unfortunate, but protocol v4 does not differentiate # between events for tables and views. .get_table will # return one or the other based on the query results. # Here we deal with that. if isinstance(meta, TableMetadata): keyspace_meta._add_table_metadata(meta) else: keyspace_meta._add_view_metadata(meta) except KeyError: # can happen if keyspace disappears while processing async event pass def _drop_table(self, keyspace, table): try: keyspace_meta = self.keyspaces[keyspace] keyspace_meta._drop_table_metadata(table) # handles either table or view except KeyError: # can happen if keyspace disappears while processing async event pass def _update_type(self, type_meta): try: self.keyspaces[type_meta.keyspace].user_types[type_meta.name] = type_meta except KeyError: # can happen if keyspace disappears while processing async event pass def _drop_type(self, keyspace, type): try: self.keyspaces[keyspace].user_types.pop(type, None) except KeyError: # can happen if keyspace disappears while processing async event pass def _update_function(self, function_meta): try: self.keyspaces[function_meta.keyspace].functions[function_meta.signature] = function_meta except KeyError: # can happen if keyspace disappears while processing async event pass def _drop_function(self, keyspace, function): try: self.keyspaces[keyspace].functions.pop(function.signature, None) except KeyError: pass def _update_aggregate(self, aggregate_meta): try: self.keyspaces[aggregate_meta.keyspace].aggregates[aggregate_meta.signature] = aggregate_meta except KeyError: pass def _drop_aggregate(self, keyspace, aggregate): try: self.keyspaces[keyspace].aggregates.pop(aggregate.signature, None) except KeyError: pass def _keyspace_added(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) def _keyspace_updated(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) def _keyspace_removed(self, ksname): if self.token_map: self.token_map.remove_keyspace(ksname) def rebuild_token_map(self, partitioner, token_map): """ Rebuild our view of the topology from fresh rows from the system topology tables. For internal use only. """ self.partitioner = partitioner if partitioner.endswith('RandomPartitioner'): token_class = MD5Token elif partitioner.endswith('Murmur3Partitioner'): token_class = Murmur3Token elif partitioner.endswith('ByteOrderedPartitioner'): token_class = BytesToken else: self.token_map = None return token_to_host_owner = {} ring = [] for host, token_strings in six.iteritems(token_map): for token_string in token_strings: token = token_class.from_string(token_string) ring.append(token) token_to_host_owner[token] = host all_tokens = sorted(ring) self.token_map = TokenMap( token_class, token_to_host_owner, all_tokens, self) def get_replicas(self, keyspace, key): """ Returns a list of :class:`.Host` instances that are replicas for a given partition key. """ t = self.token_map if not t: return [] try: return t.get_replicas(keyspace, t.token_class.from_key(key)) except NoMurmur3: return [] def can_support_partitioner(self): if self.partitioner.endswith('Murmur3Partitioner') and murmur3 is None: return False else: return True def add_or_return_host(self, host): """ Returns a tuple (host, new), where ``host`` is a Host instance, and ``new`` is a bool indicating whether the host was newly added. """ with self._hosts_lock: try: - return self._hosts[host.address], False + return self._hosts[host.endpoint], False except KeyError: - self._hosts[host.address] = host + self._hosts[host.endpoint] = host return host, True def remove_host(self, host): with self._hosts_lock: - return bool(self._hosts.pop(host.address, False)) + return bool(self._hosts.pop(host.endpoint, False)) + + def get_host(self, endpoint_or_address): + """ + Find a host in the metadata for a specific endpoint. If a string inet address is passed, + iterate all hosts to match the :attr:`~.pool.Host.broadcast_rpc_address` attribute. + """ + if not isinstance(endpoint_or_address, EndPoint): + return self._get_host_by_address(endpoint_or_address) - def get_host(self, address): - return self._hosts.get(address) + return self._hosts.get(endpoint_or_address) + + def _get_host_by_address(self, address): + for host in six.itervalues(self._hosts): + if host.broadcast_rpc_address == address: + return host + return None def all_hosts(self): """ Returns a list of all known :class:`.Host` instances in the cluster. """ with self._hosts_lock: return list(self._hosts.values()) REPLICATION_STRATEGY_CLASS_PREFIX = "org.apache.cassandra.locator." def trim_if_startswith(s, prefix): if s.startswith(prefix): return s[len(prefix):] return s _replication_strategies = {} class ReplicationStrategyTypeType(type): def __new__(metacls, name, bases, dct): dct.setdefault('name', name) cls = type.__new__(metacls, name, bases, dct) if not name.startswith('_'): _replication_strategies[name] = cls return cls @six.add_metaclass(ReplicationStrategyTypeType) class _ReplicationStrategy(object): options_map = None @classmethod def create(cls, strategy_class, options_map): if not strategy_class: return None strategy_name = trim_if_startswith(strategy_class, REPLICATION_STRATEGY_CLASS_PREFIX) rs_class = _replication_strategies.get(strategy_name, None) if rs_class is None: rs_class = _UnknownStrategyBuilder(strategy_name) _replication_strategies[strategy_name] = rs_class try: rs_instance = rs_class(options_map) except Exception as exc: log.warning("Failed creating %s with options %s: %s", strategy_name, options_map, exc) return None return rs_instance def make_token_replica_map(self, token_to_host_owner, ring): raise NotImplementedError() def export_for_schema(self): raise NotImplementedError() ReplicationStrategy = _ReplicationStrategy class _UnknownStrategyBuilder(object): def __init__(self, name): self.name = name def __call__(self, options_map): strategy_instance = _UnknownStrategy(self.name, options_map) return strategy_instance class _UnknownStrategy(ReplicationStrategy): def __init__(self, name, options_map): self.name = name self.options_map = options_map.copy() if options_map is not None else dict() self.options_map['class'] = self.name def __eq__(self, other): return (isinstance(other, _UnknownStrategy) and self.name == other.name and self.options_map == other.options_map) def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ if self.options_map: return dict((str(key), str(value)) for key, value in self.options_map.items()) return "{'class': '%s'}" % (self.name, ) def make_token_replica_map(self, token_to_host_owner, ring): return {} class SimpleStrategy(ReplicationStrategy): replication_factor = None """ The replication factor for this keyspace. """ def __init__(self, options_map): try: self.replication_factor = int(options_map['replication_factor']) except Exception: raise ValueError("SimpleStrategy requires an integer 'replication_factor' option") def make_token_replica_map(self, token_to_host_owner, ring): replica_map = {} for i in range(len(ring)): j, hosts = 0, list() while len(hosts) < self.replication_factor and j < len(ring): token = ring[(i + j) % len(ring)] host = token_to_host_owner[token] if host not in hosts: hosts.append(host) j += 1 replica_map[ring[i]] = hosts return replica_map def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ return "{'class': 'SimpleStrategy', 'replication_factor': '%d'}" \ % (self.replication_factor,) def __eq__(self, other): if not isinstance(other, SimpleStrategy): return False return self.replication_factor == other.replication_factor class NetworkTopologyStrategy(ReplicationStrategy): dc_replication_factors = None """ A map of datacenter names to the replication factor for that DC. """ def __init__(self, dc_replication_factors): self.dc_replication_factors = dict( (str(k), int(v)) for k, v in dc_replication_factors.items()) def make_token_replica_map(self, token_to_host_owner, ring): dc_rf_map = dict((dc, int(rf)) for dc, rf in self.dc_replication_factors.items() if rf > 0) # build a map of DCs to lists of indexes into `ring` for tokens that # belong to that DC dc_to_token_offset = defaultdict(list) dc_racks = defaultdict(set) hosts_per_dc = defaultdict(set) for i, token in enumerate(ring): host = token_to_host_owner[token] dc_to_token_offset[host.datacenter].append(i) if host.datacenter and host.rack: dc_racks[host.datacenter].add(host.rack) hosts_per_dc[host.datacenter].add(host) # A map of DCs to an index into the dc_to_token_offset value for that dc. # This is how we keep track of advancing around the ring for each DC. dc_to_current_index = defaultdict(int) replica_map = defaultdict(list) for i in range(len(ring)): replicas = replica_map[ring[i]] # go through each DC and find the replicas in that DC for dc in dc_to_token_offset.keys(): if dc not in dc_rf_map: continue # advance our per-DC index until we're up to at least the # current token in the ring token_offsets = dc_to_token_offset[dc] index = dc_to_current_index[dc] num_tokens = len(token_offsets) while index < num_tokens and token_offsets[index] < i: index += 1 dc_to_current_index[dc] = index replicas_remaining = dc_rf_map[dc] replicas_this_dc = 0 skipped_hosts = [] racks_placed = set() racks_this_dc = dc_racks[dc] hosts_this_dc = len(hosts_per_dc[dc]) - for token_offset in islice(cycle(token_offsets), index, index + num_tokens): + + for token_offset_index in six.moves.range(index, index+num_tokens): + if token_offset_index >= len(token_offsets): + token_offset_index = token_offset_index - len(token_offsets) + + token_offset = token_offsets[token_offset_index] host = token_to_host_owner[ring[token_offset]] if replicas_remaining == 0 or replicas_this_dc == hosts_this_dc: break if host in replicas: continue if host.rack in racks_placed and len(racks_placed) < len(racks_this_dc): skipped_hosts.append(host) continue replicas.append(host) replicas_this_dc += 1 replicas_remaining -= 1 racks_placed.add(host.rack) if len(racks_placed) == len(racks_this_dc): for host in skipped_hosts: if replicas_remaining == 0: break replicas.append(host) replicas_remaining -= 1 del skipped_hosts[:] return replica_map def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ ret = "{'class': 'NetworkTopologyStrategy'" for dc, repl_factor in sorted(self.dc_replication_factors.items()): ret += ", '%s': '%d'" % (dc, repl_factor) return ret + "}" def __eq__(self, other): if not isinstance(other, NetworkTopologyStrategy): return False return self.dc_replication_factors == other.dc_replication_factors class LocalStrategy(ReplicationStrategy): def __init__(self, options_map): pass def make_token_replica_map(self, token_to_host_owner, ring): return {} def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ return "{'class': 'LocalStrategy'}" def __eq__(self, other): return isinstance(other, LocalStrategy) class KeyspaceMetadata(object): """ A representation of the schema for a single keyspace. """ name = None """ The string name of the keyspace. """ durable_writes = True """ A boolean indicating whether durable writes are enabled for this keyspace or not. """ replication_strategy = None """ A :class:`.ReplicationStrategy` subclass object. """ tables = None """ A map from table names to instances of :class:`~.TableMetadata`. """ indexes = None """ A dict mapping index names to :class:`.IndexMetadata` instances. """ user_types = None """ A map from user-defined type names to instances of :class:`~cassandra.metadata.UserType`. .. versionadded:: 2.1.0 """ functions = None """ A map from user-defined function signatures to instances of :class:`~cassandra.metadata.Function`. .. versionadded:: 2.6.0 """ aggregates = None """ A map from user-defined aggregate signatures to instances of :class:`~cassandra.metadata.Aggregate`. .. versionadded:: 2.6.0 """ views = None """ A dict mapping view names to :class:`.MaterializedViewMetadata` instances. """ virtual = False """ A boolean indicating if this is a virtual keyspace or not. Always ``False`` for clusters running pre-4.0 versions of Cassandra. .. versionadded:: 3.15 """ _exc_info = None """ set if metadata parsing failed """ def __init__(self, name, durable_writes, strategy_class, strategy_options): self.name = name self.durable_writes = durable_writes self.replication_strategy = ReplicationStrategy.create(strategy_class, strategy_options) self.tables = {} self.indexes = {} self.user_types = {} self.functions = {} self.aggregates = {} self.views = {} def export_as_string(self): """ Returns a CQL query string that can be used to recreate the entire keyspace, including user-defined types and tables. """ cql = "\n\n".join([self.as_cql_query() + ';'] + self.user_type_strings() + [f.export_as_string() for f in self.functions.values()] + [a.export_as_string() for a in self.aggregates.values()] + [t.export_as_string() for t in self.tables.values()]) if self._exc_info: import traceback ret = "/*\nWarning: Keyspace %s is incomplete because of an error processing metadata.\n" % \ (self.name) for line in traceback.format_exception(*self._exc_info): ret += line ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % cql return ret if self.virtual: return ("/*\nWarning: Keyspace {ks} is a virtual keyspace and cannot be recreated with CQL.\n" "Structure, for reference:*/\n" "{cql}\n" "").format(ks=self.name, cql=cql) return cql def as_cql_query(self): """ Returns a CQL query string that can be used to recreate just this keyspace, not including user-defined types and tables. """ if self.virtual: return "// VIRTUAL KEYSPACE {}".format(protect_name(self.name)) ret = "CREATE KEYSPACE %s WITH replication = %s " % ( protect_name(self.name), self.replication_strategy.export_for_schema()) return ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false")) def user_type_strings(self): user_type_strings = [] user_types = self.user_types.copy() keys = sorted(user_types.keys()) for k in keys: if k in user_types: self.resolve_user_types(k, user_types, user_type_strings) return user_type_strings def resolve_user_types(self, key, user_types, user_type_strings): user_type = user_types.pop(key) for type_name in user_type.field_types: for sub_type in types.cql_types_from_string(type_name): if sub_type in user_types: self.resolve_user_types(sub_type, user_types, user_type_strings) user_type_strings.append(user_type.export_as_string()) def _add_table_metadata(self, table_metadata): old_indexes = {} old_meta = self.tables.get(table_metadata.name, None) if old_meta: # views are not queried with table, so they must be transferred to new table_metadata.views = old_meta.views # indexes will be updated with what is on the new metadata old_indexes = old_meta.indexes # note the intentional order of add before remove # this makes sure the maps are never absent something that existed before this update for index_name, index_metadata in six.iteritems(table_metadata.indexes): self.indexes[index_name] = index_metadata for index_name in (n for n in old_indexes if n not in table_metadata.indexes): self.indexes.pop(index_name, None) self.tables[table_metadata.name] = table_metadata def _drop_table_metadata(self, table_name): table_meta = self.tables.pop(table_name, None) if table_meta: for index_name in table_meta.indexes: self.indexes.pop(index_name, None) for view_name in table_meta.views: self.views.pop(view_name, None) return # we can't tell table drops from views, so drop both # (name is unique among them, within a keyspace) view_meta = self.views.pop(table_name, None) if view_meta: try: self.tables[view_meta.base_table_name].views.pop(table_name, None) except KeyError: pass def _add_view_metadata(self, view_metadata): try: self.tables[view_metadata.base_table_name].views[view_metadata.name] = view_metadata self.views[view_metadata.name] = view_metadata except KeyError: pass class UserType(object): """ A user defined type, as created by ``CREATE TYPE`` statements. User-defined types were introduced in Cassandra 2.1. .. versionadded:: 2.1.0 """ keyspace = None """ The string name of the keyspace in which this type is defined. """ name = None """ The name of this type. """ field_names = None """ An ordered list of the names for each field in this user-defined type. """ field_types = None """ An ordered list of the types for each field in this user-defined type. """ def __init__(self, keyspace, name, field_names, field_types): self.keyspace = keyspace self.name = name # non-frozen collections can return None self.field_names = field_names or [] self.field_types = field_types or [] def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this type. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ ret = "CREATE TYPE %s.%s (%s" % ( protect_name(self.keyspace), protect_name(self.name), "\n" if formatted else "") if formatted: field_join = ",\n" padding = " " else: field_join = ", " padding = "" fields = [] for field_name, field_type in zip(self.field_names, self.field_types): fields.append("%s %s" % (protect_name(field_name), field_type)) ret += field_join.join("%s%s" % (padding, field) for field in fields) ret += "\n)" if formatted else ")" return ret def export_as_string(self): return self.as_cql_query(formatted=True) + ';' class Aggregate(object): """ A user defined aggregate function, as created by ``CREATE AGGREGATE`` statements. Aggregate functions were introduced in Cassandra 2.2 .. versionadded:: 2.6.0 """ keyspace = None """ The string name of the keyspace in which this aggregate is defined """ name = None """ The name of this aggregate """ argument_types = None """ An ordered list of the types for each argument to the aggregate """ final_func = None """ Name of a final function """ initial_condition = None """ Initial condition of the aggregate """ return_type = None """ Return type of the aggregate """ state_func = None """ Name of a state function """ state_type = None """ Type of the aggregate state """ def __init__(self, keyspace, name, argument_types, state_func, state_type, final_func, initial_condition, return_type): self.keyspace = keyspace self.name = name self.argument_types = argument_types self.state_func = state_func self.state_type = state_type self.final_func = final_func self.initial_condition = initial_condition self.return_type = return_type def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this aggregate. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace) name = protect_name(self.name) - type_list = ', '.join(self.argument_types) + type_list = ', '.join([types.strip_frozen(arg_type) for arg_type in self.argument_types]) state_func = protect_name(self.state_func) - state_type = self.state_type + state_type = types.strip_frozen(self.state_type) ret = "CREATE AGGREGATE %(keyspace)s.%(name)s(%(type_list)s)%(sep)s" \ "SFUNC %(state_func)s%(sep)s" \ "STYPE %(state_type)s" % locals() ret += ''.join((sep, 'FINALFUNC ', protect_name(self.final_func))) if self.final_func else '' ret += ''.join((sep, 'INITCOND ', self.initial_condition)) if self.initial_condition is not None else '' return ret def export_as_string(self): return self.as_cql_query(formatted=True) + ';' @property def signature(self): return SignatureDescriptor.format_signature(self.name, self.argument_types) class Function(object): """ A user defined function, as created by ``CREATE FUNCTION`` statements. User-defined functions were introduced in Cassandra 2.2 .. versionadded:: 2.6.0 """ keyspace = None """ The string name of the keyspace in which this function is defined """ name = None """ The name of this function """ argument_types = None """ An ordered list of the types for each argument to the function """ argument_names = None """ An ordered list of the names of each argument to the function """ return_type = None """ Return type of the function """ language = None """ Language of the function body """ body = None """ Function body string """ called_on_null_input = None """ Flag indicating whether this function should be called for rows with null values (convenience function to avoid handling nulls explicitly if the result will just be null) """ def __init__(self, keyspace, name, argument_types, argument_names, return_type, language, body, called_on_null_input): self.keyspace = keyspace self.name = name self.argument_types = argument_types # argument_types (frozen>) will always be a list # argument_name is not frozen in C* < 3.0 and may return None self.argument_names = argument_names or [] self.return_type = return_type self.language = language self.body = body self.called_on_null_input = called_on_null_input def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this function. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace) name = protect_name(self.name) - arg_list = ', '.join(["%s %s" % (protect_name(n), t) + arg_list = ', '.join(["%s %s" % (protect_name(n), types.strip_frozen(t)) for n, t in zip(self.argument_names, self.argument_types)]) typ = self.return_type lang = self.language body = self.body on_null = "CALLED" if self.called_on_null_input else "RETURNS NULL" return "CREATE FUNCTION %(keyspace)s.%(name)s(%(arg_list)s)%(sep)s" \ "%(on_null)s ON NULL INPUT%(sep)s" \ "RETURNS %(typ)s%(sep)s" \ "LANGUAGE %(lang)s%(sep)s" \ "AS $$%(body)s$$" % locals() def export_as_string(self): return self.as_cql_query(formatted=True) + ';' @property def signature(self): return SignatureDescriptor.format_signature(self.name, self.argument_types) class TableMetadata(object): """ A representation of the schema for a single table. """ keyspace_name = None """ String name of this Table's keyspace """ name = None """ The string name of the table. """ partition_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the partition key for this table. This will always hold at least one column. """ clustering_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the clustering key for this table. These are all of the :attr:`.primary_key` columns that are not in the :attr:`.partition_key`. Note that a table may have no clustering keys, in which case this will be an empty list. """ @property def primary_key(self): """ A list of :class:`.ColumnMetadata` representing the components of the primary key for this table. """ return self.partition_key + self.clustering_key columns = None """ A dict mapping column names to :class:`.ColumnMetadata` instances. """ indexes = None """ A dict mapping index names to :class:`.IndexMetadata` instances. """ is_compact_storage = False options = None """ A dict mapping table option names to their specific settings for this table. """ compaction_options = { "min_compaction_threshold": "min_threshold", "max_compaction_threshold": "max_threshold", "compaction_strategy_class": "class"} triggers = None """ A dict mapping trigger names to :class:`.TriggerMetadata` instances. """ views = None """ A dict mapping view names to :class:`.MaterializedViewMetadata` instances. """ _exc_info = None """ set if metadata parsing failed """ virtual = False """ A boolean indicating if this is a virtual table or not. Always ``False`` for clusters running pre-4.0 versions of Cassandra. .. versionadded:: 3.15 """ @property def is_cql_compatible(self): """ A boolean indicating if this table can be represented as CQL in export """ if self.virtual: return False comparator = getattr(self, 'comparator', None) if comparator: # no compact storage with more than one column beyond PK if there # are clustering columns incompatible = (self.is_compact_storage and len(self.columns) > len(self.primary_key) + 1 and len(self.clustering_key) >= 1) return not incompatible return True extensions = None """ Metadata describing configuration for table extensions """ def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None, virtual=False): self.keyspace_name = keyspace_name self.name = name self.partition_key = [] if partition_key is None else partition_key self.clustering_key = [] if clustering_key is None else clustering_key self.columns = OrderedDict() if columns is None else columns self.indexes = {} self.options = {} if options is None else options self.comparator = None self.triggers = OrderedDict() if triggers is None else triggers self.views = {} self.virtual = virtual def export_as_string(self): """ Returns a string of CQL queries that can be used to recreate this table along with all indexes on it. The returned string is formatted to be human readable. """ if self._exc_info: import traceback ret = "/*\nWarning: Table %s.%s is incomplete because of an error processing metadata.\n" % \ (self.keyspace_name, self.name) for line in traceback.format_exception(*self._exc_info): ret += line ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() elif not self.is_cql_compatible: # If we can't produce this table with CQL, comment inline ret = "/*\nWarning: Table %s.%s omitted because it has constructs not compatible with CQL (was created via legacy API).\n" % \ (self.keyspace_name, self.name) ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() elif self.virtual: ret = ('/*\nWarning: Table {ks}.{tab} is a virtual table and cannot be recreated with CQL.\n' 'Structure, for reference:\n' '{cql}\n*/').format(ks=self.keyspace_name, tab=self.name, cql=self._all_as_cql()) else: ret = self._all_as_cql() return ret def _all_as_cql(self): ret = self.as_cql_query(formatted=True) ret += ";" for index in self.indexes.values(): ret += "\n%s;" % index.as_cql_query() for trigger_meta in self.triggers.values(): ret += "\n%s;" % (trigger_meta.as_cql_query(),) for view_meta in self.views.values(): ret += "\n\n%s;" % (view_meta.as_cql_query(formatted=True),) if self.extensions: registry = _RegisteredExtensionType._extension_registry for k in six.viewkeys(registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey ext = registry[k] cql = ext.after_table_cql(self, k, self.extensions[k]) if cql: ret += "\n\n%s" % (cql,) return ret def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this table (index creations are not included). If `formatted` is set to :const:`True`, extra whitespace will be added to make the query human readable. """ ret = "%s TABLE %s.%s (%s" % ( ('VIRTUAL' if self.virtual else 'CREATE'), protect_name(self.keyspace_name), protect_name(self.name), "\n" if formatted else "") if formatted: column_join = ",\n" padding = " " else: column_join = ", " padding = "" columns = [] for col in self.columns.values(): columns.append("%s %s%s" % (protect_name(col.name), col.cql_type, ' static' if col.is_static else '')) if len(self.partition_key) == 1 and not self.clustering_key: columns[0] += " PRIMARY KEY" ret += column_join.join("%s%s" % (padding, col) for col in columns) # primary key if len(self.partition_key) > 1 or self.clustering_key: ret += "%s%sPRIMARY KEY (" % (column_join, padding) if len(self.partition_key) > 1: ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key) else: ret += protect_name(self.partition_key[0].name) if self.clustering_key: ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key) ret += ")" # properties ret += "%s) WITH " % ("\n" if formatted else "") ret += self._property_string(formatted, self.clustering_key, self.options, self.is_compact_storage) return ret @classmethod def _property_string(cls, formatted, clustering_key, options_map, is_compact_storage=False): properties = [] if is_compact_storage: properties.append("COMPACT STORAGE") if clustering_key: cluster_str = "CLUSTERING ORDER BY " inner = [] for col in clustering_key: ordering = "DESC" if col.is_reversed else "ASC" inner.append("%s %s" % (protect_name(col.name), ordering)) cluster_str += "(%s)" % ", ".join(inner) properties.append(cluster_str) properties.extend(cls._make_option_strings(options_map)) join_str = "\n AND " if formatted else " AND " return join_str.join(properties) @classmethod def _make_option_strings(cls, options_map): ret = [] options_copy = dict(options_map.items()) actual_options = json.loads(options_copy.pop('compaction_strategy_options', '{}')) value = options_copy.pop("compaction_strategy_class", None) actual_options.setdefault("class", value) compaction_option_strings = ["'%s': '%s'" % (k, v) for k, v in actual_options.items()] ret.append('compaction = {%s}' % ', '.join(compaction_option_strings)) for system_table_name in cls.compaction_options.keys(): options_copy.pop(system_table_name, None) # delete if present options_copy.pop('compaction_strategy_option', None) if not options_copy.get('compression'): params = json.loads(options_copy.pop('compression_parameters', '{}')) param_strings = ["'%s': '%s'" % (k, v) for k, v in params.items()] ret.append('compression = {%s}' % ', '.join(param_strings)) for name, value in options_copy.items(): if value is not None: if name == "comment": value = value or "" ret.append("%s = %s" % (name, protect_value(value))) return list(sorted(ret)) class TableExtensionInterface(object): """ Defines CQL/DDL for Cassandra table extensions. """ # limited API for now. Could be expanded as new extension types materialize -- "extend_option_strings", for example @classmethod def after_table_cql(cls, ext_key, ext_blob): """ Called to produce CQL/DDL to follow the table definition. Should contain requisite terminating semicolon(s). """ pass class _RegisteredExtensionType(type): _extension_registry = {} def __new__(mcs, name, bases, dct): cls = super(_RegisteredExtensionType, mcs).__new__(mcs, name, bases, dct) if name != 'RegisteredTableExtension': mcs._extension_registry[cls.name] = cls return cls @six.add_metaclass(_RegisteredExtensionType) class RegisteredTableExtension(TableExtensionInterface): """ Extending this class registers it by name (associated by key in the `system_schema.tables.extensions` map). """ name = None """ Name of the extension (key in the map) """ def protect_name(name): return maybe_escape_name(name) def protect_names(names): return [protect_name(n) for n in names] def protect_value(value): if value is None: return 'NULL' if isinstance(value, (int, float, bool)): return str(value).lower() return "'%s'" % value.replace("'", "''") valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$') def is_valid_name(name): if name is None: return False if name.lower() in cql_keywords_reserved: return False return valid_cql3_word_re.match(name) is not None def maybe_escape_name(name): if is_valid_name(name): return name return escape_name(name) def escape_name(name): return '"%s"' % (name.replace('"', '""'),) class ColumnMetadata(object): """ A representation of a single column in a table. """ table = None """ The :class:`.TableMetadata` this column belongs to. """ name = None """ The string name of this column. """ cql_type = None """ The CQL type for the column. """ is_static = False """ If this column is static (available in Cassandra 2.1+), this will be :const:`True`, otherwise :const:`False`. """ is_reversed = False """ If this column is reversed (DESC) as in clustering order """ _cass_type = None def __init__(self, table_metadata, column_name, cql_type, is_static=False, is_reversed=False): self.table = table_metadata self.name = column_name self.cql_type = cql_type self.is_static = is_static self.is_reversed = is_reversed def __str__(self): return "%s %s" % (self.name, self.cql_type) class IndexMetadata(object): """ A representation of a secondary index on a column. """ keyspace_name = None """ A string name of the keyspace. """ table_name = None """ A string name of the table this index is on. """ name = None """ A string name for the index. """ kind = None """ A string representing the kind of index (COMPOSITE, CUSTOM,...). """ index_options = {} """ A dict of index options. """ def __init__(self, keyspace_name, table_name, index_name, kind, index_options): self.keyspace_name = keyspace_name self.table_name = table_name self.name = index_name self.kind = kind self.index_options = index_options def as_cql_query(self): """ Returns a CQL query that can be used to recreate this index. """ options = dict(self.index_options) index_target = options.pop("target") if self.kind != "CUSTOM": return "CREATE INDEX %s ON %s.%s (%s)" % ( protect_name(self.name), protect_name(self.keyspace_name), protect_name(self.table_name), index_target) else: class_name = options.pop("class_name") ret = "CREATE CUSTOM INDEX %s ON %s.%s (%s) USING '%s'" % ( protect_name(self.name), protect_name(self.keyspace_name), protect_name(self.table_name), index_target, class_name) if options: # PYTHON-1008: `ret` will always be a unicode opts_cql_encoded = _encoder.cql_encode_all_types(options, as_text_type=True) ret += " WITH OPTIONS = %s" % opts_cql_encoded return ret def export_as_string(self): """ Returns a CQL query string that can be used to recreate this index. """ return self.as_cql_query() + ';' class TokenMap(object): """ Information about the layout of the ring. """ token_class = None """ A subclass of :class:`.Token`, depending on what partitioner the cluster uses. """ token_to_host_owner = None """ A map of :class:`.Token` objects to the :class:`.Host` that owns that token. """ tokens_to_hosts_by_ks = None """ A map of keyspace names to a nested map of :class:`.Token` objects to sets of :class:`.Host` objects. """ ring = None """ An ordered list of :class:`.Token` instances in the ring. """ _metadata = None def __init__(self, token_class, token_to_host_owner, all_tokens, metadata): self.token_class = token_class self.ring = all_tokens self.token_to_host_owner = token_to_host_owner self.tokens_to_hosts_by_ks = {} self._metadata = metadata self._rebuild_lock = RLock() def rebuild_keyspace(self, keyspace, build_if_absent=False): with self._rebuild_lock: try: current = self.tokens_to_hosts_by_ks.get(keyspace, None) if (build_if_absent and current is None) or (not build_if_absent and current is not None): ks_meta = self._metadata.keyspaces.get(keyspace) if ks_meta: replica_map = self.replica_map_for_keyspace(self._metadata.keyspaces[keyspace]) self.tokens_to_hosts_by_ks[keyspace] = replica_map except Exception: # should not happen normally, but we don't want to blow up queries because of unexpected meta state # bypass until new map is generated self.tokens_to_hosts_by_ks[keyspace] = {} log.exception("Failed creating a token map for keyspace '%s' with %s. PLEASE REPORT THIS: https://datastax-oss.atlassian.net/projects/PYTHON", keyspace, self.token_to_host_owner) def replica_map_for_keyspace(self, ks_metadata): strategy = ks_metadata.replication_strategy if strategy: return strategy.make_token_replica_map(self.token_to_host_owner, self.ring) else: return None def remove_keyspace(self, keyspace): self.tokens_to_hosts_by_ks.pop(keyspace, None) def get_replicas(self, keyspace, token): """ Get a set of :class:`.Host` instances representing all of the replica nodes for a given :class:`.Token`. """ tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None) if tokens_to_hosts is None: self.rebuild_keyspace(keyspace, build_if_absent=True) tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None) if tokens_to_hosts: # The values in self.ring correspond to the end of the # token range up to and including the value listed. point = bisect_left(self.ring, token) if point == len(self.ring): return tokens_to_hosts[self.ring[0]] else: return tokens_to_hosts[self.ring[point]] return [] @total_ordering class Token(object): """ Abstract class representing a token. """ def __init__(self, token): self.value = token @classmethod def hash_fn(cls, key): return key @classmethod def from_key(cls, key): return cls(cls.hash_fn(key)) @classmethod def from_string(cls, token_string): raise NotImplementedError() def __eq__(self, other): return self.value == other.value def __lt__(self, other): return self.value < other.value def __hash__(self): return hash(self.value) def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self.value) __str__ = __repr__ MIN_LONG = -(2 ** 63) MAX_LONG = (2 ** 63) - 1 class NoMurmur3(Exception): pass class HashToken(Token): @classmethod def from_string(cls, token_string): """ `token_string` should be the string representation from the server. """ # The hash partitioners just store the deciman value return cls(int(token_string)) class Murmur3Token(HashToken): """ A token for ``Murmur3Partitioner``. """ @classmethod def hash_fn(cls, key): if murmur3 is not None: h = int(murmur3(key)) return h if h != MIN_LONG else MAX_LONG else: raise NoMurmur3() def __init__(self, token): """ `token` is an int or string representing the token. """ self.value = int(token) class MD5Token(HashToken): """ A token for ``RandomPartitioner``. """ @classmethod def hash_fn(cls, key): if isinstance(key, six.text_type): key = key.encode('UTF-8') return abs(varint_unpack(md5(key).digest())) class BytesToken(Token): """ A token for ``ByteOrderedPartitioner``. """ @classmethod def from_string(cls, token_string): """ `token_string` should be the string representation from the server. """ # unhexlify works fine with unicode input in everythin but pypy3, where it Raises "TypeError: 'str' does not support the buffer interface" if isinstance(token_string, six.text_type): token_string = token_string.encode('ascii') # The BOP stores a hex string return cls(unhexlify(token_string)) class TriggerMetadata(object): """ A representation of a trigger for a table. """ table = None """ The :class:`.TableMetadata` this trigger belongs to. """ name = None """ The string name of this trigger. """ options = None """ A dict mapping trigger option names to their specific settings for this table. """ def __init__(self, table_metadata, trigger_name, options=None): self.table = table_metadata self.name = trigger_name self.options = options def as_cql_query(self): ret = "CREATE TRIGGER %s ON %s.%s USING %s" % ( protect_name(self.name), protect_name(self.table.keyspace_name), protect_name(self.table.name), protect_value(self.options['class']) ) return ret def export_as_string(self): return self.as_cql_query() + ';' class _SchemaParser(object): def __init__(self, connection, timeout): self.connection = connection self.timeout = timeout def _handle_results(self, success, result, expected_failures=tuple()): """ Given a bool and a ResultSet (the form returned per result from Connection.wait_for_responses), return a dictionary containing the results. Used to process results from asynchronous queries to system tables. ``expected_failures`` will usually be used to allow callers to ignore ``InvalidRequest`` errors caused by a missing system keyspace. For example, some DSE versions report a 4.X server version, but do not have virtual tables. Thus, running against 4.X servers, SchemaParserV4 uses expected_failures to make a best-effort attempt to read those keyspaces, but treat them as empty if they're not found. :param success: A boolean representing whether or not the query succeeded :param result: The resultset in question. :expected_failures: An Exception class or an iterable thereof. If the query failed, but raised an instance of an expected failure class, this will ignore the failure and return an empty list. """ if not success and isinstance(result, expected_failures): return [] elif success: return dict_factory(*result.results) if result else [] else: raise result def _query_build_row(self, query_string, build_func): result = self._query_build_rows(query_string, build_func) return result[0] if result else None def _query_build_rows(self, query_string, build_func): query = QueryMessage(query=query_string, consistency_level=ConsistencyLevel.ONE) responses = self.connection.wait_for_responses((query), timeout=self.timeout, fail_on_error=False) (success, response) = responses[0] if success: result = dict_factory(*response.results) return [build_func(row) for row in result] elif isinstance(response, InvalidRequest): log.debug("user types table not found") return [] else: raise response class SchemaParserV22(_SchemaParser): _SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces" _SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies" _SELECT_COLUMNS = "SELECT * FROM system.schema_columns" _SELECT_TRIGGERS = "SELECT * FROM system.schema_triggers" _SELECT_TYPES = "SELECT * FROM system.schema_usertypes" _SELECT_FUNCTIONS = "SELECT * FROM system.schema_functions" _SELECT_AGGREGATES = "SELECT * FROM system.schema_aggregates" _table_name_col = 'columnfamily_name' _function_agg_arument_type_col = 'signature' recognized_table_options = ( "comment", "read_repair_chance", "dclocal_read_repair_chance", # kept to be safe, but see _build_table_options() "local_read_repair_chance", "replicate_on_write", "gc_grace_seconds", "bloom_filter_fp_chance", "caching", "compaction_strategy_class", "compaction_strategy_options", "min_compaction_threshold", "max_compaction_threshold", "compression_parameters", "min_index_interval", "max_index_interval", "index_interval", "speculative_retry", "rows_per_partition_to_cache", "memtable_flush_period_in_ms", "populate_io_cache_on_flush", "compression", "default_time_to_live") def __init__(self, connection, timeout): super(SchemaParserV22, self).__init__(connection, timeout) self.keyspaces_result = [] self.tables_result = [] self.columns_result = [] self.triggers_result = [] self.types_result = [] self.functions_result = [] self.aggregates_result = [] self.keyspace_table_rows = defaultdict(list) self.keyspace_table_col_rows = defaultdict(lambda: defaultdict(list)) self.keyspace_type_rows = defaultdict(list) self.keyspace_func_rows = defaultdict(list) self.keyspace_agg_rows = defaultdict(list) self.keyspace_table_trigger_rows = defaultdict(lambda: defaultdict(list)) def get_all_keyspaces(self): self._query_all() for row in self.keyspaces_result: keyspace_meta = self._build_keyspace_metadata(row) try: for table_row in self.keyspace_table_rows.get(keyspace_meta.name, []): table_meta = self._build_table_metadata(table_row) keyspace_meta._add_table_metadata(table_meta) for usertype_row in self.keyspace_type_rows.get(keyspace_meta.name, []): usertype = self._build_user_type(usertype_row) keyspace_meta.user_types[usertype.name] = usertype for fn_row in self.keyspace_func_rows.get(keyspace_meta.name, []): fn = self._build_function(fn_row) keyspace_meta.functions[fn.signature] = fn for agg_row in self.keyspace_agg_rows.get(keyspace_meta.name, []): agg = self._build_aggregate(agg_row) keyspace_meta.aggregates[agg.signature] = agg except Exception: log.exception("Error while parsing metadata for keyspace %s. Metadata model will be incomplete.", keyspace_meta.name) keyspace_meta._exc_info = sys.exc_info() yield keyspace_meta def get_table(self, keyspaces, keyspace, table): cl = ConsistencyLevel.ONE where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col,), (keyspace, table), _encoder) cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl) col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl) (cf_success, cf_result), (col_success, col_result), (triggers_success, triggers_result) \ = self.connection.wait_for_responses(cf_query, col_query, triggers_query, timeout=self.timeout, fail_on_error=False) table_result = self._handle_results(cf_success, cf_result) col_result = self._handle_results(col_success, col_result) # the triggers table doesn't exist in C* 1.2 triggers_result = self._handle_results(triggers_success, triggers_result, expected_failures=InvalidRequest) if table_result: return self._build_table_metadata(table_result[0], col_result, triggers_result) def get_type(self, keyspaces, keyspace, type): where_clause = bind_params(" WHERE keyspace_name = %s AND type_name = %s", (keyspace, type), _encoder) return self._query_build_row(self._SELECT_TYPES + where_clause, self._build_user_type) def get_types_map(self, keyspaces, keyspace): where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) types = self._query_build_rows(self._SELECT_TYPES + where_clause, self._build_user_type) return dict((t.name, t) for t in types) def get_function(self, keyspaces, keyspace, function): where_clause = bind_params(" WHERE keyspace_name = %%s AND function_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), (keyspace, function.name, function.argument_types), _encoder) return self._query_build_row(self._SELECT_FUNCTIONS + where_clause, self._build_function) def get_aggregate(self, keyspaces, keyspace, aggregate): where_clause = bind_params(" WHERE keyspace_name = %%s AND aggregate_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), (keyspace, aggregate.name, aggregate.argument_types), _encoder) return self._query_build_row(self._SELECT_AGGREGATES + where_clause, self._build_aggregate) def get_keyspace(self, keyspaces, keyspace): where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) return self._query_build_row(self._SELECT_KEYSPACES + where_clause, self._build_keyspace_metadata) @classmethod def _build_keyspace_metadata(cls, row): try: ksm = cls._build_keyspace_metadata_internal(row) except Exception: name = row["keyspace_name"] ksm = KeyspaceMetadata(name, False, 'UNKNOWN', {}) ksm._exc_info = sys.exc_info() # capture exc_info before log because nose (test) logging clears it in certain circumstances log.exception("Error while parsing metadata for keyspace %s row(%s)", name, row) return ksm @staticmethod def _build_keyspace_metadata_internal(row): name = row["keyspace_name"] durable_writes = row["durable_writes"] strategy_class = row["strategy_class"] strategy_options = json.loads(row["strategy_options"]) return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) @classmethod def _build_user_type(cls, usertype_row): field_types = list(map(cls._schema_type_to_cql, usertype_row['field_types'])) return UserType(usertype_row['keyspace_name'], usertype_row['type_name'], usertype_row['field_names'], field_types) @classmethod def _build_function(cls, function_row): return_type = cls._schema_type_to_cql(function_row['return_type']) return Function(function_row['keyspace_name'], function_row['function_name'], function_row[cls._function_agg_arument_type_col], function_row['argument_names'], return_type, function_row['language'], function_row['body'], function_row['called_on_null_input']) @classmethod def _build_aggregate(cls, aggregate_row): cass_state_type = types.lookup_casstype(aggregate_row['state_type']) initial_condition = aggregate_row['initcond'] if initial_condition is not None: initial_condition = _encoder.cql_encode_all_types(cass_state_type.deserialize(initial_condition, 3)) state_type = _cql_from_cass_type(cass_state_type) return_type = cls._schema_type_to_cql(aggregate_row['return_type']) return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], aggregate_row['signature'], aggregate_row['state_func'], state_type, aggregate_row['final_func'], initial_condition, return_type) def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): keyspace_name = row["keyspace_name"] cfname = row[self._table_name_col] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][cfname] trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][cfname] if not col_rows: # CASSANDRA-8487 log.warning("Building table metadata with no column meta for %s.%s", keyspace_name, cfname) table_meta = TableMetadata(keyspace_name, cfname) try: comparator = types.lookup_casstype(row["comparator"]) table_meta.comparator = comparator is_dct_comparator = issubclass(comparator, types.DynamicCompositeType) is_composite_comparator = issubclass(comparator, types.CompositeType) column_name_types = comparator.subtypes if is_composite_comparator else (comparator,) num_column_name_components = len(column_name_types) last_col = column_name_types[-1] column_aliases = row.get("column_aliases", None) clustering_rows = [r for r in col_rows if r.get('type', None) == "clustering_key"] if len(clustering_rows) > 1: clustering_rows = sorted(clustering_rows, key=lambda row: row.get('component_index')) if column_aliases is not None: column_aliases = json.loads(column_aliases) if not column_aliases: # json load failed or column_aliases empty PYTHON-562 column_aliases = [r.get('column_name') for r in clustering_rows] if is_composite_comparator: if issubclass(last_col, types.ColumnToCollectionType): # collections is_compact = False has_value = False clustering_size = num_column_name_components - 2 elif (len(column_aliases) == num_column_name_components - 1 and issubclass(last_col, types.UTF8Type)): # aliases? is_compact = False has_value = False clustering_size = num_column_name_components - 1 else: # compact table is_compact = True has_value = column_aliases or not col_rows clustering_size = num_column_name_components # Some thrift tables define names in composite types (see PYTHON-192) if not column_aliases and hasattr(comparator, 'fieldnames'): column_aliases = filter(None, comparator.fieldnames) else: is_compact = True if column_aliases or not col_rows or is_dct_comparator: has_value = True clustering_size = num_column_name_components else: has_value = False clustering_size = 0 # partition key partition_rows = [r for r in col_rows if r.get('type', None) == "partition_key"] if len(partition_rows) > 1: partition_rows = sorted(partition_rows, key=lambda row: row.get('component_index')) key_aliases = row.get("key_aliases") if key_aliases is not None: key_aliases = json.loads(key_aliases) if key_aliases else [] else: # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. key_aliases = [r.get('column_name') for r in partition_rows] key_validator = row.get("key_validator") if key_validator is not None: key_type = types.lookup_casstype(key_validator) key_types = key_type.subtypes if issubclass(key_type, types.CompositeType) else [key_type] else: key_types = [types.lookup_casstype(r.get('validator')) for r in partition_rows] for i, col_type in enumerate(key_types): if len(key_aliases) > i: column_name = key_aliases[i] elif i == 0: column_name = "key" else: column_name = "key%d" % i col = ColumnMetadata(table_meta, column_name, col_type.cql_parameterized_type()) table_meta.columns[column_name] = col table_meta.partition_key.append(col) # clustering key for i in range(clustering_size): if len(column_aliases) > i: column_name = column_aliases[i] else: column_name = "column%d" % (i + 1) data_type = column_name_types[i] cql_type = _cql_from_cass_type(data_type) is_reversed = types.is_reversed_casstype(data_type) col = ColumnMetadata(table_meta, column_name, cql_type, is_reversed=is_reversed) table_meta.columns[column_name] = col table_meta.clustering_key.append(col) # value alias (if present) if has_value: value_alias_rows = [r for r in col_rows if r.get('type', None) == "compact_value"] if not key_aliases: # TODO are we checking the right thing here? value_alias = "value" else: value_alias = row.get("value_alias", None) if value_alias is None and value_alias_rows: # CASSANDRA-8487 # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. value_alias = value_alias_rows[0].get('column_name') default_validator = row.get("default_validator") if default_validator: validator = types.lookup_casstype(default_validator) else: if value_alias_rows: # CASSANDRA-8487 validator = types.lookup_casstype(value_alias_rows[0].get('validator')) cql_type = _cql_from_cass_type(validator) col = ColumnMetadata(table_meta, value_alias, cql_type) if value_alias: # CASSANDRA-8487 table_meta.columns[value_alias] = col # other normal columns for col_row in col_rows: column_meta = self._build_column_metadata(table_meta, col_row) - if column_meta.name: + if column_meta.name is not None: table_meta.columns[column_meta.name] = column_meta index_meta = self._build_index_metadata(column_meta, col_row) if index_meta: table_meta.indexes[index_meta.name] = index_meta for trigger_row in trigger_rows: trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) table_meta.triggers[trigger_meta.name] = trigger_meta table_meta.options = self._build_table_options(row) table_meta.is_compact_storage = is_compact except Exception: table_meta._exc_info = sys.exc_info() log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, cfname, row, col_rows) return table_meta def _build_table_options(self, row): """ Setup the mostly-non-schema table options, like caching settings """ options = dict((o, row.get(o)) for o in self.recognized_table_options if o in row) # the option name when creating tables is "dclocal_read_repair_chance", # but the column name in system.schema_columnfamilies is # "local_read_repair_chance". We'll store this as dclocal_read_repair_chance, # since that's probably what users are expecting (and we need it for the # CREATE TABLE statement anyway). if "local_read_repair_chance" in options: val = options.pop("local_read_repair_chance") options["dclocal_read_repair_chance"] = val return options @classmethod def _build_column_metadata(cls, table_metadata, row): name = row["column_name"] type_string = row["validator"] data_type = types.lookup_casstype(type_string) cql_type = _cql_from_cass_type(data_type) is_static = row.get("type", None) == "static" is_reversed = types.is_reversed_casstype(data_type) column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed) column_meta._cass_type = data_type return column_meta @staticmethod def _build_index_metadata(column_metadata, row): index_name = row.get("index_name") kind = row.get("index_type") if index_name or kind: options = row.get("index_options") options = json.loads(options) if options else {} options = options or {} # if the json parsed to None, init empty dict # generate a CQL index identity string target = protect_name(column_metadata.name) if kind != "CUSTOM": if "index_keys" in options: target = 'keys(%s)' % (target,) elif "index_values" in options: # don't use any "function" for collection values pass else: # it might be a "full" index on a frozen collection, but # we need to check the data type to verify that, because # there is no special index option for full-collection # indexes. data_type = column_metadata._cass_type collection_types = ('map', 'set', 'list') if data_type.typename == "frozen" and data_type.subtypes[0].typename in collection_types: # no index option for full-collection index target = 'full(%s)' % (target,) options['target'] = target return IndexMetadata(column_metadata.table.keyspace_name, column_metadata.table.name, index_name, kind, options) @staticmethod def _build_trigger_metadata(table_metadata, row): name = row["trigger_name"] options = row["trigger_options"] trigger_meta = TriggerMetadata(table_metadata, name, options) return trigger_meta def _query_all(self): cl = ConsistencyLevel.ONE queries = [ QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), QueryMessage(query=self._SELECT_COLUMN_FAMILIES, consistency_level=cl), QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl) ] ((ks_success, ks_result), (table_success, table_result), (col_success, col_result), (types_success, types_result), (functions_success, functions_result), (aggregates_success, aggregates_result), (triggers_success, triggers_result)) = ( self.connection.wait_for_responses(*queries, timeout=self.timeout, fail_on_error=False) ) self.keyspaces_result = self._handle_results(ks_success, ks_result) self.tables_result = self._handle_results(table_success, table_result) self.columns_result = self._handle_results(col_success, col_result) # if we're connected to Cassandra < 2.0, the triggers table will not exist if triggers_success: self.triggers_result = dict_factory(*triggers_result.results) else: if isinstance(triggers_result, InvalidRequest): log.debug("triggers table not found") elif isinstance(triggers_result, Unauthorized): log.warning("this version of Cassandra does not allow access to schema_triggers metadata with authorization enabled (CASSANDRA-7967); " "The driver will operate normally, but will not reflect triggers in the local metadata model, or schema strings.") else: raise triggers_result # if we're connected to Cassandra < 2.1, the usertypes table will not exist if types_success: self.types_result = dict_factory(*types_result.results) else: if isinstance(types_result, InvalidRequest): log.debug("user types table not found") self.types_result = {} else: raise types_result # functions were introduced in Cassandra 2.2 if functions_success: self.functions_result = dict_factory(*functions_result.results) else: if isinstance(functions_result, InvalidRequest): log.debug("user functions table not found") else: raise functions_result # aggregates were introduced in Cassandra 2.2 if aggregates_success: self.aggregates_result = dict_factory(*aggregates_result.results) else: if isinstance(aggregates_result, InvalidRequest): log.debug("user aggregates table not found") else: raise aggregates_result self._aggregate_results() def _aggregate_results(self): m = self.keyspace_table_rows for row in self.tables_result: m[row["keyspace_name"]].append(row) m = self.keyspace_table_col_rows for row in self.columns_result: ksname = row["keyspace_name"] cfname = row[self._table_name_col] m[ksname][cfname].append(row) m = self.keyspace_type_rows for row in self.types_result: m[row["keyspace_name"]].append(row) m = self.keyspace_func_rows for row in self.functions_result: m[row["keyspace_name"]].append(row) m = self.keyspace_agg_rows for row in self.aggregates_result: m[row["keyspace_name"]].append(row) m = self.keyspace_table_trigger_rows for row in self.triggers_result: ksname = row["keyspace_name"] cfname = row[self._table_name_col] m[ksname][cfname].append(row) @staticmethod def _schema_type_to_cql(type_string): cass_type = types.lookup_casstype(type_string) return _cql_from_cass_type(cass_type) class SchemaParserV3(SchemaParserV22): _SELECT_KEYSPACES = "SELECT * FROM system_schema.keyspaces" _SELECT_TABLES = "SELECT * FROM system_schema.tables" _SELECT_COLUMNS = "SELECT * FROM system_schema.columns" _SELECT_INDEXES = "SELECT * FROM system_schema.indexes" _SELECT_TRIGGERS = "SELECT * FROM system_schema.triggers" _SELECT_TYPES = "SELECT * FROM system_schema.types" _SELECT_FUNCTIONS = "SELECT * FROM system_schema.functions" _SELECT_AGGREGATES = "SELECT * FROM system_schema.aggregates" _SELECT_VIEWS = "SELECT * FROM system_schema.views" _table_name_col = 'table_name' _function_agg_arument_type_col = 'argument_types' recognized_table_options = ( 'bloom_filter_fp_chance', 'caching', 'cdc', 'comment', 'compaction', 'compression', 'crc_check_chance', 'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds', 'max_index_interval', 'memtable_flush_period_in_ms', 'min_index_interval', 'read_repair_chance', 'speculative_retry') def __init__(self, connection, timeout): super(SchemaParserV3, self).__init__(connection, timeout) self.indexes_result = [] self.keyspace_table_index_rows = defaultdict(lambda: defaultdict(list)) self.keyspace_view_rows = defaultdict(list) def get_all_keyspaces(self): for keyspace_meta in super(SchemaParserV3, self).get_all_keyspaces(): for row in self.keyspace_view_rows[keyspace_meta.name]: view_meta = self._build_view_metadata(row) keyspace_meta._add_view_metadata(view_meta) yield keyspace_meta def get_table(self, keyspaces, keyspace, table): cl = ConsistencyLevel.ONE where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder) cf_query = QueryMessage(query=self._SELECT_TABLES + where_clause, consistency_level=cl) col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) indexes_query = QueryMessage(query=self._SELECT_INDEXES + where_clause, consistency_level=cl) triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl) # in protocol v4 we don't know if this event is a view or a table, so we look for both where_clause = bind_params(" WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder) view_query = QueryMessage(query=self._SELECT_VIEWS + where_clause, consistency_level=cl) ((cf_success, cf_result), (col_success, col_result), (indexes_sucess, indexes_result), (triggers_success, triggers_result), (view_success, view_result)) = ( self.connection.wait_for_responses( cf_query, col_query, indexes_query, triggers_query, view_query, timeout=self.timeout, fail_on_error=False) ) table_result = self._handle_results(cf_success, cf_result) col_result = self._handle_results(col_success, col_result) if table_result: indexes_result = self._handle_results(indexes_sucess, indexes_result) triggers_result = self._handle_results(triggers_success, triggers_result) return self._build_table_metadata(table_result[0], col_result, triggers_result, indexes_result) view_result = self._handle_results(view_success, view_result) if view_result: return self._build_view_metadata(view_result[0], col_result) @staticmethod def _build_keyspace_metadata_internal(row): name = row["keyspace_name"] durable_writes = row["durable_writes"] strategy_options = dict(row["replication"]) strategy_class = strategy_options.pop("class") return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) @staticmethod def _build_aggregate(aggregate_row): return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], aggregate_row['argument_types'], aggregate_row['state_func'], aggregate_row['state_type'], aggregate_row['final_func'], aggregate_row['initcond'], aggregate_row['return_type']) def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_rows=None, virtual=False): keyspace_name = row["keyspace_name"] table_name = row[self._table_name_col] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][table_name] trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][table_name] index_rows = index_rows or self.keyspace_table_index_rows[keyspace_name][table_name] table_meta = TableMetadataV3(keyspace_name, table_name, virtual=virtual) try: table_meta.options = self._build_table_options(row) flags = row.get('flags', set()) if flags: - compact_static = False - table_meta.is_compact_storage = 'dense' in flags or 'super' in flags or 'compound' not in flags is_dense = 'dense' in flags + compact_static = not is_dense and 'super' not in flags and 'compound' not in flags + table_meta.is_compact_storage = is_dense or 'super' in flags or 'compound' not in flags elif virtual: compact_static = False table_meta.is_compact_storage = False is_dense = False else: compact_static = True table_meta.is_compact_storage = True is_dense = False self._build_table_columns(table_meta, col_rows, compact_static, is_dense, virtual) for trigger_row in trigger_rows: trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) table_meta.triggers[trigger_meta.name] = trigger_meta for index_row in index_rows: index_meta = self._build_index_metadata(table_meta, index_row) if index_meta: table_meta.indexes[index_meta.name] = index_meta table_meta.extensions = row.get('extensions', {}) except Exception: table_meta._exc_info = sys.exc_info() log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, table_name, row, col_rows) return table_meta def _build_table_options(self, row): """ Setup the mostly-non-schema table options, like caching settings """ return dict((o, row.get(o)) for o in self.recognized_table_options if o in row) def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False, virtual=False): # partition key partition_rows = [r for r in col_rows if r.get('kind', None) == "partition_key"] if len(partition_rows) > 1: partition_rows = sorted(partition_rows, key=lambda row: row.get('position')) for r in partition_rows: # we have to add meta here (and not in the later loop) because TableMetadata.columns is an # OrderedDict, and it assumes keys are inserted first, in order, when exporting CQL column_meta = self._build_column_metadata(meta, r) meta.columns[column_meta.name] = column_meta meta.partition_key.append(meta.columns[r.get('column_name')]) # clustering key if not compact_static: clustering_rows = [r for r in col_rows if r.get('kind', None) == "clustering"] if len(clustering_rows) > 1: clustering_rows = sorted(clustering_rows, key=lambda row: row.get('position')) for r in clustering_rows: column_meta = self._build_column_metadata(meta, r) meta.columns[column_meta.name] = column_meta meta.clustering_key.append(meta.columns[r.get('column_name')]) for col_row in (r for r in col_rows if r.get('kind', None) not in ('partition_key', 'clustering_key')): column_meta = self._build_column_metadata(meta, col_row) if is_dense and column_meta.cql_type == types.cql_empty_type: continue if compact_static and not column_meta.is_static: # for compact static tables, we omit the clustering key and value, and only add the logical columns. # They are marked not static so that it generates appropriate CQL continue if compact_static: column_meta.is_static = False meta.columns[column_meta.name] = column_meta def _build_view_metadata(self, row, col_rows=None): keyspace_name = row["keyspace_name"] view_name = row["view_name"] base_table_name = row["base_table_name"] include_all_columns = row["include_all_columns"] where_clause = row["where_clause"] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][view_name] view_meta = MaterializedViewMetadata(keyspace_name, view_name, base_table_name, include_all_columns, where_clause, self._build_table_options(row)) self._build_table_columns(view_meta, col_rows) view_meta.extensions = row.get('extensions', {}) return view_meta @staticmethod def _build_column_metadata(table_metadata, row): name = row["column_name"] cql_type = row["type"] is_static = row.get("kind", None) == "static" is_reversed = row["clustering_order"].upper() == "DESC" column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed) return column_meta @staticmethod def _build_index_metadata(table_metadata, row): index_name = row.get("index_name") kind = row.get("kind") if index_name or kind: index_options = row.get("options") return IndexMetadata(table_metadata.keyspace_name, table_metadata.name, index_name, kind, index_options) else: return None @staticmethod def _build_trigger_metadata(table_metadata, row): name = row["trigger_name"] options = row["options"] trigger_meta = TriggerMetadata(table_metadata, name, options) return trigger_meta def _query_all(self): cl = ConsistencyLevel.ONE queries = [ QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), QueryMessage(query=self._SELECT_TABLES, consistency_level=cl), QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl), QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl), QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl) ] ((ks_success, ks_result), (table_success, table_result), (col_success, col_result), (types_success, types_result), (functions_success, functions_result), (aggregates_success, aggregates_result), (triggers_success, triggers_result), (indexes_success, indexes_result), (views_success, views_result)) = self.connection.wait_for_responses( *queries, timeout=self.timeout, fail_on_error=False ) self.keyspaces_result = self._handle_results(ks_success, ks_result) self.tables_result = self._handle_results(table_success, table_result) self.columns_result = self._handle_results(col_success, col_result) self.triggers_result = self._handle_results(triggers_success, triggers_result) self.types_result = self._handle_results(types_success, types_result) self.functions_result = self._handle_results(functions_success, functions_result) self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) self.indexes_result = self._handle_results(indexes_success, indexes_result) self.views_result = self._handle_results(views_success, views_result) self._aggregate_results() def _aggregate_results(self): super(SchemaParserV3, self)._aggregate_results() m = self.keyspace_table_index_rows for row in self.indexes_result: ksname = row["keyspace_name"] cfname = row[self._table_name_col] m[ksname][cfname].append(row) m = self.keyspace_view_rows for row in self.views_result: m[row["keyspace_name"]].append(row) @staticmethod def _schema_type_to_cql(type_string): return type_string class SchemaParserV4(SchemaParserV3): recognized_table_options = tuple( opt for opt in SchemaParserV3.recognized_table_options if opt not in ( # removed in V4: CASSANDRA-13910 'dclocal_read_repair_chance', 'read_repair_chance' ) ) _SELECT_VIRTUAL_KEYSPACES = 'SELECT * from system_virtual_schema.keyspaces' _SELECT_VIRTUAL_TABLES = 'SELECT * from system_virtual_schema.tables' _SELECT_VIRTUAL_COLUMNS = 'SELECT * from system_virtual_schema.columns' def __init__(self, connection, timeout): super(SchemaParserV4, self).__init__(connection, timeout) self.virtual_keyspaces_rows = defaultdict(list) self.virtual_tables_rows = defaultdict(list) self.virtual_columns_rows = defaultdict(lambda: defaultdict(list)) def _query_all(self): cl = ConsistencyLevel.ONE # todo: this duplicates V3; we should find a way for _query_all methods # to extend each other. queries = [ # copied from V3 QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), QueryMessage(query=self._SELECT_TABLES, consistency_level=cl), QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl), QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl), QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl), # V4-only queries QueryMessage(query=self._SELECT_VIRTUAL_KEYSPACES, consistency_level=cl), QueryMessage(query=self._SELECT_VIRTUAL_TABLES, consistency_level=cl), QueryMessage(query=self._SELECT_VIRTUAL_COLUMNS, consistency_level=cl) ] responses = self.connection.wait_for_responses( *queries, timeout=self.timeout, fail_on_error=False) ( # copied from V3 (ks_success, ks_result), (table_success, table_result), (col_success, col_result), (types_success, types_result), (functions_success, functions_result), (aggregates_success, aggregates_result), (triggers_success, triggers_result), (indexes_success, indexes_result), (views_success, views_result), # V4-only responses (virtual_ks_success, virtual_ks_result), (virtual_table_success, virtual_table_result), (virtual_column_success, virtual_column_result) ) = responses # copied from V3 self.keyspaces_result = self._handle_results(ks_success, ks_result) self.tables_result = self._handle_results(table_success, table_result) self.columns_result = self._handle_results(col_success, col_result) self.triggers_result = self._handle_results(triggers_success, triggers_result) self.types_result = self._handle_results(types_success, types_result) self.functions_result = self._handle_results(functions_success, functions_result) self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) self.indexes_result = self._handle_results(indexes_success, indexes_result) self.views_result = self._handle_results(views_success, views_result) # V4-only results # These tables don't exist in some DSE versions reporting 4.X so we can # ignore them if we got an error self.virtual_keyspaces_result = self._handle_results( virtual_ks_success, virtual_ks_result, expected_failures=InvalidRequest ) self.virtual_tables_result = self._handle_results( virtual_table_success, virtual_table_result, expected_failures=InvalidRequest ) self.virtual_columns_result = self._handle_results( virtual_column_success, virtual_column_result, expected_failures=InvalidRequest ) self._aggregate_results() def _aggregate_results(self): super(SchemaParserV4, self)._aggregate_results() m = self.virtual_tables_rows for row in self.virtual_tables_result: m[row["keyspace_name"]].append(row) m = self.virtual_columns_rows for row in self.virtual_columns_result: ks_name = row['keyspace_name'] tab_name = row[self._table_name_col] m[ks_name][tab_name].append(row) def get_all_keyspaces(self): for x in super(SchemaParserV4, self).get_all_keyspaces(): yield x for row in self.virtual_keyspaces_result: ks_name = row['keyspace_name'] keyspace_meta = self._build_keyspace_metadata(row) keyspace_meta.virtual = True for table_row in self.virtual_tables_rows.get(ks_name, []): table_name = table_row[self._table_name_col] col_rows = self.virtual_columns_rows[ks_name][table_name] keyspace_meta._add_table_metadata( self._build_table_metadata(table_row, col_rows=col_rows, virtual=True) ) yield keyspace_meta @staticmethod def _build_keyspace_metadata_internal(row): # necessary fields that aren't int virtual ks row["durable_writes"] = row.get("durable_writes", None) row["replication"] = row.get("replication", {}) row["replication"]["class"] = row["replication"].get("class", None) return super(SchemaParserV4, SchemaParserV4)._build_keyspace_metadata_internal(row) class TableMetadataV3(TableMetadata): compaction_options = {} option_maps = ['compaction', 'compression', 'caching'] @property def is_cql_compatible(self): return True @classmethod def _make_option_strings(cls, options_map): ret = [] options_copy = dict(options_map.items()) for option in cls.option_maps: value = options_copy.get(option) if isinstance(value, Mapping): del options_copy[option] params = ("'%s': '%s'" % (k, v) for k, v in value.items()) ret.append("%s = {%s}" % (option, ', '.join(params))) for name, value in options_copy.items(): if value is not None: if name == "comment": value = value or "" ret.append("%s = %s" % (name, protect_value(value))) return list(sorted(ret)) class MaterializedViewMetadata(object): """ A representation of a materialized view on a table """ keyspace_name = None """ A string name of the view.""" name = None """ A string name of the view.""" base_table_name = None """ A string name of the base table for this view.""" partition_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the partition key for this view. This will always hold at least one column. """ clustering_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the clustering key for this view. Note that a table may have no clustering keys, in which case this will be an empty list. """ columns = None """ A dict mapping column names to :class:`.ColumnMetadata` instances. """ include_all_columns = None """ A flag indicating whether the view was created AS SELECT * """ where_clause = None """ String WHERE clause for the view select statement. From server metadata """ options = None """ A dict mapping table option names to their specific settings for this view. """ extensions = None """ Metadata describing configuration for table extensions """ def __init__(self, keyspace_name, view_name, base_table_name, include_all_columns, where_clause, options): self.keyspace_name = keyspace_name self.name = view_name self.base_table_name = base_table_name self.partition_key = [] self.clustering_key = [] self.columns = OrderedDict() self.include_all_columns = include_all_columns self.where_clause = where_clause self.options = options or {} def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this function. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace_name) name = protect_name(self.name) selected_cols = '*' if self.include_all_columns else ', '.join(protect_name(col.name) for col in self.columns.values()) base_table = protect_name(self.base_table_name) where_clause = self.where_clause part_key = ', '.join(protect_name(col.name) for col in self.partition_key) if len(self.partition_key) > 1: pk = "((%s)" % part_key else: pk = "(%s" % part_key if self.clustering_key: pk += ", %s" % ', '.join(protect_name(col.name) for col in self.clustering_key) pk += ")" properties = TableMetadataV3._property_string(formatted, self.clustering_key, self.options) ret = ("CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" "SELECT %(selected_cols)s%(sep)s" "FROM %(keyspace)s.%(base_table)s%(sep)s" "WHERE %(where_clause)s%(sep)s" "PRIMARY KEY %(pk)s%(sep)s" "WITH %(properties)s") % locals() if self.extensions: registry = _RegisteredExtensionType._extension_registry for k in six.viewkeys(registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey ext = registry[k] cql = ext.after_table_cql(self, k, self.extensions[k]) if cql: ret += "\n\n%s" % (cql,) return ret def export_as_string(self): return self.as_cql_query(formatted=True) + ";" def get_schema_parser(connection, server_version, timeout): - server_major_version = int(server_version.split('.')[0]) - if server_major_version >= 4: + version = Version(server_version) + if version >= Version('4-a'): return SchemaParserV4(connection, timeout) - if server_major_version >= 3: + if version >= Version('3.0.0'): return SchemaParserV3(connection, timeout) else: # we could further specialize by version. Right now just refactoring the # multi-version parser we have as of C* 2.2.0rc1. return SchemaParserV22(connection, timeout) def _cql_from_cass_type(cass_type): """ A string representation of the type for this column, such as "varchar" or "map". """ if issubclass(cass_type, types.ReversedType): return cass_type.subtypes[0].cql_parameterized_type() else: return cass_type.cql_parameterized_type() NO_VALID_REPLICA = object() def group_keys_by_replica(session, keyspace, table, keys): """ Returns a :class:`dict` with the keys grouped per host. This can be used to more accurately group by IN clause or to batch the keys per host. If a valid replica is not found for a particular key it will be grouped under :class:`~.NO_VALID_REPLICA` Example usage:: result = group_keys_by_replica( session, "system", "peers", (("127.0.0.1", ), ("127.0.0.2", )) ) """ cluster = session.cluster partition_keys = cluster.metadata.keyspaces[keyspace].tables[table].partition_key serializers = list(types._cqltypes[partition_key.cql_type] for partition_key in partition_keys) keys_per_host = defaultdict(list) distance = cluster._default_load_balancing_policy.distance for key in keys: serialized_key = [serializer.serialize(pk, cluster.protocol_version) for serializer, pk in zip(serializers, key)] if len(serialized_key) == 1: routing_key = serialized_key[0] else: routing_key = b"".join(struct.pack(">H%dsB" % len(p), len(p), p, 0) for p in serialized_key) all_replicas = cluster.metadata.get_replicas(keyspace, routing_key) # First check if there are local replicas valid_replicas = [host for host in all_replicas if host.is_up and distance(host) == HostDistance.LOCAL] if not valid_replicas: valid_replicas = [host for host in all_replicas if host.is_up] if valid_replicas: keys_per_host[random.choice(valid_replicas)].append(key) else: # We will group under this statement all the keys for which # we haven't found a valid replica keys_per_host[NO_VALID_REPLICA].append(key) return dict(keys_per_host) + diff --git a/cassandra/policies.py b/cassandra/policies.py index fdd96d5..d610666 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -1,1049 +1,1104 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from itertools import islice, cycle, groupby, repeat import logging from random import randint, shuffle from threading import Lock import socket +import warnings from cassandra import WriteType as WT # This is done this way because WriteType was originally # defined here and in order not to break the API. # It may removed in the next mayor. WriteType = WT from cassandra import ConsistencyLevel, OperationTimedOut log = logging.getLogger(__name__) class HostDistance(object): """ A measure of how "distant" a node is from the client, which may influence how the load balancer distributes requests and how many connections are opened to the node. """ IGNORED = -1 """ A node with this distance should never be queried or have connections opened to it. """ LOCAL = 0 """ Nodes with ``LOCAL`` distance will be preferred for operations under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`) and will have a greater number of connections opened against them by default. This distance is typically used for nodes within the same datacenter as the client. """ REMOTE = 1 """ Nodes with ``REMOTE`` distance will be treated as a last resort by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`) and will have a smaller number of connections opened against them by default. This distance is typically used for nodes outside of the datacenter that the client is running in. """ class HostStateListener(object): def on_up(self, host): """ Called when a node is marked up. """ raise NotImplementedError() def on_down(self, host): """ Called when a node is marked down. """ raise NotImplementedError() def on_add(self, host): """ Called when a node is added to the cluster. The newly added node should be considered up. """ raise NotImplementedError() def on_remove(self, host): """ Called when a node is removed from the cluster. """ raise NotImplementedError() class LoadBalancingPolicy(HostStateListener): """ Load balancing policies are used to decide how to distribute requests among all possible coordinator nodes in the cluster. In particular, they may focus on querying "near" nodes (those in a local datacenter) or on querying nodes who happen to be replicas for the requested data. You may also use subclasses of :class:`.LoadBalancingPolicy` for custom behavior. """ _hosts_lock = None def __init__(self): self._hosts_lock = Lock() def distance(self, host): """ Returns a measure of how remote a :class:`~.pool.Host` is in terms of the :class:`.HostDistance` enums. """ raise NotImplementedError() def populate(self, cluster, hosts): """ This method is called to initialize the load balancing policy with a set of :class:`.Host` instances before its first use. The `cluster` parameter is an instance of :class:`.Cluster`. """ raise NotImplementedError() def make_query_plan(self, working_keyspace=None, query=None): """ Given a :class:`~.query.Statement` instance, return a iterable of :class:`.Host` instances which should be queried in that order. A generator may work well for custom implementations of this method. Note that the `query` argument may be :const:`None` when preparing statements. `working_keyspace` should be the string name of the current keyspace, as set through :meth:`.Session.set_keyspace()` or with a ``USE`` statement. """ raise NotImplementedError() def check_supported(self): """ This will be called after the cluster Metadata has been initialized. If the load balancing policy implementation cannot be supported for some reason (such as a missing C extension), this is the point at which it should raise an exception. """ pass class RoundRobinPolicy(LoadBalancingPolicy): """ A subclass of :class:`.LoadBalancingPolicy` which evenly distributes queries across all nodes in the cluster, regardless of what datacenter the nodes may be in. - - This load balancing policy is used by default. """ _live_hosts = frozenset(()) _position = 0 def populate(self, cluster, hosts): self._live_hosts = frozenset(hosts) if len(hosts) > 1: self._position = randint(0, len(hosts) - 1) def distance(self, host): return HostDistance.LOCAL def make_query_plan(self, working_keyspace=None, query=None): # not thread-safe, but we don't care much about lost increments # for the purposes of load balancing pos = self._position self._position += 1 hosts = self._live_hosts length = len(hosts) if length: pos %= length return islice(cycle(hosts), pos, pos + length) else: return [] def on_up(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.union((host, )) def on_down(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.difference((host, )) def on_add(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.union((host, )) def on_remove(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.difference((host, )) class DCAwareRoundRobinPolicy(LoadBalancingPolicy): """ Similar to :class:`.RoundRobinPolicy`, but prefers hosts in the local datacenter and only uses nodes in remote datacenters as a last resort. """ local_dc = None used_hosts_per_remote_dc = 0 def __init__(self, local_dc='', used_hosts_per_remote_dc=0): """ The `local_dc` parameter should be the name of the datacenter (such as is reported by ``nodetool ring``) that should be considered local. If not specified, the driver will choose a local_dc based on the first host among :attr:`.Cluster.contact_points` having a valid DC. If relying on this mechanism, all specified contact points should be nodes in a single, local DC. `used_hosts_per_remote_dc` controls how many nodes in each remote datacenter will have connections opened against them. In other words, `used_hosts_per_remote_dc` hosts will be considered :attr:`~.HostDistance.REMOTE` and the rest will be considered :attr:`~.HostDistance.IGNORED`. By default, all remote hosts are ignored. """ self.local_dc = local_dc self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} self._position = 0 - self._contact_points = [] + self._endpoints = [] LoadBalancingPolicy.__init__(self) def _dc(self, host): return host.datacenter or self.local_dc def populate(self, cluster, hosts): for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): self._dc_live_hosts[dc] = tuple(set(dc_hosts)) if not self.local_dc: - self._contact_points = cluster.contact_points_resolved + self._endpoints = [ + endpoint + for endpoint in cluster.endpoints_resolved] self._position = randint(0, len(hosts) - 1) if hosts else 0 def distance(self, host): dc = self._dc(host) if dc == self.local_dc: return HostDistance.LOCAL if not self.used_hosts_per_remote_dc: return HostDistance.IGNORED else: dc_hosts = self._dc_live_hosts.get(dc) if not dc_hosts: return HostDistance.IGNORED if host in list(dc_hosts)[:self.used_hosts_per_remote_dc]: return HostDistance.REMOTE else: return HostDistance.IGNORED def make_query_plan(self, working_keyspace=None, query=None): # not thread-safe, but we don't care much about lost increments # for the purposes of load balancing pos = self._position self._position += 1 local_live = self._dc_live_hosts.get(self.local_dc, ()) pos = (pos % len(local_live)) if local_live else 0 for host in islice(cycle(local_live), pos, pos + len(local_live)): yield host # the dict can change, so get candidate DCs iterating over keys of a copy other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc] for dc in other_dcs: remote_live = self._dc_live_hosts.get(dc, ()) for host in remote_live[:self.used_hosts_per_remote_dc]: yield host def on_up(self, host): # not worrying about threads because this will happen during # control connection startup/refresh if not self.local_dc and host.datacenter: - if host.address in self._contact_points: + if host.endpoint in self._endpoints: self.local_dc = host.datacenter log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " "if incorrect, please specify a local_dc to the constructor, " "or limit contact points to local cluster nodes" % - (self.local_dc, host.address)) - del self._contact_points + (self.local_dc, host.endpoint)) + del self._endpoints dc = self._dc(host) with self._hosts_lock: current_hosts = self._dc_live_hosts.get(dc, ()) if host not in current_hosts: self._dc_live_hosts[dc] = current_hosts + (host, ) def on_down(self, host): dc = self._dc(host) with self._hosts_lock: current_hosts = self._dc_live_hosts.get(dc, ()) if host in current_hosts: hosts = tuple(h for h in current_hosts if h != host) if hosts: self._dc_live_hosts[dc] = hosts else: del self._dc_live_hosts[dc] def on_add(self, host): self.on_up(host) def on_remove(self, host): self.on_down(host) class TokenAwarePolicy(LoadBalancingPolicy): """ A :class:`.LoadBalancingPolicy` wrapper that adds token awareness to a child policy. This alters the child policy's behavior so that it first attempts to send queries to :attr:`~.HostDistance.LOCAL` replicas (as determined by the child policy) based on the :class:`.Statement`'s :attr:`~.Statement.routing_key`. If :attr:`.shuffle_replicas` is truthy, these replicas will be yielded in a random order. Once those hosts are exhausted, the remaining hosts in the child policy's query plan will be used in the order provided by the child policy. If no :attr:`~.Statement.routing_key` is set on the query, the child policy's query plan will be used as is. """ _child_policy = None _cluster_metadata = None shuffle_replicas = False """ Yield local replicas in a random order. """ def __init__(self, child_policy, shuffle_replicas=False): self._child_policy = child_policy self.shuffle_replicas = shuffle_replicas def populate(self, cluster, hosts): self._cluster_metadata = cluster.metadata self._child_policy.populate(cluster, hosts) def check_supported(self): if not self._cluster_metadata.can_support_partitioner(): raise RuntimeError( '%s cannot be used with the cluster partitioner (%s) because ' 'the relevant C extension for this driver was not compiled. ' 'See the installation instructions for details on building ' 'and installing the C extensions.' % (self.__class__.__name__, self._cluster_metadata.partitioner)) def distance(self, *args, **kwargs): return self._child_policy.distance(*args, **kwargs) def make_query_plan(self, working_keyspace=None, query=None): if query and query.keyspace: keyspace = query.keyspace else: keyspace = working_keyspace child = self._child_policy if query is None: for host in child.make_query_plan(keyspace, query): yield host else: routing_key = query.routing_key if routing_key is None or keyspace is None: for host in child.make_query_plan(keyspace, query): yield host else: replicas = self._cluster_metadata.get_replicas(keyspace, routing_key) if self.shuffle_replicas: shuffle(replicas) for replica in replicas: if replica.is_up and \ child.distance(replica) == HostDistance.LOCAL: yield replica for host in child.make_query_plan(keyspace, query): # skip if we've already listed this host if host not in replicas or \ child.distance(host) == HostDistance.REMOTE: yield host def on_up(self, *args, **kwargs): return self._child_policy.on_up(*args, **kwargs) def on_down(self, *args, **kwargs): return self._child_policy.on_down(*args, **kwargs) def on_add(self, *args, **kwargs): return self._child_policy.on_add(*args, **kwargs) def on_remove(self, *args, **kwargs): return self._child_policy.on_remove(*args, **kwargs) class WhiteListRoundRobinPolicy(RoundRobinPolicy): """ A subclass of :class:`.RoundRobinPolicy` which evenly distributes queries across all nodes in the cluster, regardless of what datacenter the nodes may be in, but only if that node exists in the list of allowed nodes This policy is addresses the issue described in https://datastax-oss.atlassian.net/browse/JAVA-145 Where connection errors occur when connection attempts are made to private IP addresses remotely """ def __init__(self, hosts): """ The `hosts` parameter should be a sequence of hosts to permit connections to. """ self._allowed_hosts = hosts self._allowed_hosts_resolved = [endpoint[4][0] for a in self._allowed_hosts for endpoint in socket.getaddrinfo(a, None, socket.AF_UNSPEC, socket.SOCK_STREAM)] RoundRobinPolicy.__init__(self) def populate(self, cluster, hosts): self._live_hosts = frozenset(h for h in hosts if h.address in self._allowed_hosts_resolved) if len(hosts) <= 1: self._position = 0 else: self._position = randint(0, len(hosts) - 1) def distance(self, host): if host.address in self._allowed_hosts_resolved: return HostDistance.LOCAL else: return HostDistance.IGNORED def on_up(self, host): if host.address in self._allowed_hosts_resolved: RoundRobinPolicy.on_up(self, host) def on_add(self, host): if host.address in self._allowed_hosts_resolved: RoundRobinPolicy.on_add(self, host) class HostFilterPolicy(LoadBalancingPolicy): """ A :class:`.LoadBalancingPolicy` subclass configured with a child policy, and a single-argument predicate. This policy defers to the child policy for hosts where ``predicate(host)`` is truthy. Hosts for which ``predicate(host)`` is falsey will be considered :attr:`.IGNORED`, and will not be used in a query plan. This can be used in the cases where you need a whitelist or blacklist policy, e.g. to prepare for decommissioning nodes or for testing: .. code-block:: python def address_is_ignored(host): return host.address in [ignored_address0, ignored_address1] blacklist_filter_policy = HostFilterPolicy( child_policy=RoundRobinPolicy(), predicate=address_is_ignored ) cluster = Cluster( primary_host, load_balancing_policy=blacklist_filter_policy, ) See the note in the :meth:`.make_query_plan` documentation for a caveat on how wrapping ordering polices (e.g. :class:`.RoundRobinPolicy`) may break desirable properties of the wrapped policy. Please note that whitelist and blacklist policies are not recommended for general, day-to-day use. You probably want something like :class:`.DCAwareRoundRobinPolicy`, which prefers a local DC but has fallbacks, over a brute-force method like whitelisting or blacklisting. """ def __init__(self, child_policy, predicate): """ :param child_policy: an instantiated :class:`.LoadBalancingPolicy` that this one will defer to. :param predicate: a one-parameter function that takes a :class:`.Host`. If it returns a falsey value, the :class:`.Host` will be :attr:`.IGNORED` and not returned in query plans. """ super(HostFilterPolicy, self).__init__() self._child_policy = child_policy self._predicate = predicate def on_up(self, host, *args, **kwargs): return self._child_policy.on_up(host, *args, **kwargs) def on_down(self, host, *args, **kwargs): return self._child_policy.on_down(host, *args, **kwargs) def on_add(self, host, *args, **kwargs): return self._child_policy.on_add(host, *args, **kwargs) def on_remove(self, host, *args, **kwargs): return self._child_policy.on_remove(host, *args, **kwargs) @property def predicate(self): """ A predicate, set on object initialization, that takes a :class:`.Host` and returns a value. If the value is falsy, the :class:`.Host` is :class:`~HostDistance.IGNORED`. If the value is truthy, :class:`.HostFilterPolicy` defers to the child policy to determine the host's distance. This is a read-only value set in ``__init__``, implemented as a ``property``. """ return self._predicate def distance(self, host): """ Checks if ``predicate(host)``, then returns :attr:`~HostDistance.IGNORED` if falsey, and defers to the child policy otherwise. """ if self.predicate(host): return self._child_policy.distance(host) else: return HostDistance.IGNORED def populate(self, cluster, hosts): self._child_policy.populate(cluster=cluster, hosts=hosts) def make_query_plan(self, working_keyspace=None, query=None): """ Defers to the child policy's :meth:`.LoadBalancingPolicy.make_query_plan` and filters the results. Note that this filtering may break desirable properties of the wrapped policy in some cases. For instance, imagine if you configure this policy to filter out ``host2``, and to wrap a round-robin policy that rotates through three hosts in the order ``host1, host2, host3``, ``host2, host3, host1``, ``host3, host1, host2``, repeating. This policy will yield ``host1, host3``, ``host3, host1``, ``host3, host1``, disproportionately favoring ``host3``. """ child_qp = self._child_policy.make_query_plan( working_keyspace=working_keyspace, query=query ) for host in child_qp: if self.predicate(host): yield host def check_supported(self): return self._child_policy.check_supported() class ConvictionPolicy(object): """ A policy which decides when hosts should be considered down based on the types of failures and the number of failures. If custom behavior is needed, this class may be subclassed. """ def __init__(self, host): """ `host` is an instance of :class:`.Host`. """ self.host = host def add_failure(self, connection_exc): """ Implementations should return :const:`True` if the host should be convicted, :const:`False` otherwise. """ raise NotImplementedError() def reset(self): """ Implementations should clear out any convictions or state regarding the host. """ raise NotImplementedError() class SimpleConvictionPolicy(ConvictionPolicy): """ The default implementation of :class:`ConvictionPolicy`, which simply marks a host as down after the first failure of any kind. """ def add_failure(self, connection_exc): return not isinstance(connection_exc, OperationTimedOut) def reset(self): pass class ReconnectionPolicy(object): """ This class and its subclasses govern how frequently an attempt is made to reconnect to nodes that are marked as dead. If custom behavior is needed, this class may be subclassed. """ def new_schedule(self): """ This should return a finite or infinite iterable of delays (each as a floating point number of seconds) inbetween each failed reconnection attempt. Note that if the iterable is finite, reconnection attempts will cease once the iterable is exhausted. """ raise NotImplementedError() class ConstantReconnectionPolicy(ReconnectionPolicy): """ A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay inbetween each reconnection attempt. """ def __init__(self, delay, max_attempts=64): """ `delay` should be a floating point number of seconds to wait inbetween each attempt. `max_attempts` should be a total number of attempts to be made before giving up, or :const:`None` to continue reconnection attempts forever. The default is 64. """ if delay < 0: raise ValueError("delay must not be negative") if max_attempts is not None and max_attempts < 0: raise ValueError("max_attempts must not be negative") self.delay = delay self.max_attempts = max_attempts def new_schedule(self): if self.max_attempts: return repeat(self.delay, self.max_attempts) return repeat(self.delay) class ExponentialReconnectionPolicy(ReconnectionPolicy): """ A :class:`.ReconnectionPolicy` subclass which exponentially increases the length of the delay inbetween each reconnection attempt up to a set maximum delay. + + A random amount of jitter (+/- 15%) will be added to the pure exponential + delay value to avoid the situations where many reconnection handlers are + trying to reconnect at exactly the same time. """ # TODO: max_attempts is 64 to preserve legacy default behavior # consider changing to None in major release to prevent the policy # giving up forever def __init__(self, base_delay, max_delay, max_attempts=64): """ `base_delay` and `max_delay` should be in floating point units of seconds. `max_attempts` should be a total number of attempts to be made before giving up, or :const:`None` to continue reconnection attempts forever. The default is 64. """ if base_delay < 0 or max_delay < 0: raise ValueError("Delays may not be negative") if max_delay < base_delay: raise ValueError("Max delay must be greater than base delay") if max_attempts is not None and max_attempts < 0: raise ValueError("max_attempts must not be negative") self.base_delay = base_delay self.max_delay = max_delay self.max_attempts = max_attempts def new_schedule(self): i, overflowed = 0, False while self.max_attempts is None or i < self.max_attempts: if overflowed: yield self.max_delay else: try: - yield min(self.base_delay * (2 ** i), self.max_delay) + yield self._add_jitter(min(self.base_delay * (2 ** i), self.max_delay)) except OverflowError: overflowed = True yield self.max_delay i += 1 + # Adds -+ 15% to the delay provided + def _add_jitter(self, value): + jitter = randint(85, 115) + delay = (jitter * value) / 100 + return min(max(self.base_delay, delay), self.max_delay) + class RetryPolicy(object): """ A policy that describes whether to retry, rethrow, or ignore coordinator timeout and unavailable failures. These are failures reported from the server side. Timeouts are configured by `settings in cassandra.yaml `_. Unavailable failures occur when the coordinator cannot acheive the consistency level for a request. For further information see the method descriptions below. To specify a default retry policy, set the :attr:`.Cluster.default_retry_policy` attribute to an instance of this class or one of its subclasses. To specify a retry policy per query, set the :attr:`.Statement.retry_policy` attribute to an instance of this class or one of its subclasses. If custom behavior is needed for retrying certain operations, this class may be subclassed. """ RETRY = 0 """ This should be returned from the below methods if the operation should be retried on the same connection. """ RETHROW = 1 """ This should be returned from the below methods if the failure should be propagated and no more retries attempted. """ IGNORE = 2 """ This should be returned from the below methods if the failure should be ignored but no more retries should be attempted. """ RETRY_NEXT_HOST = 3 """ This should be returned from the below methods if the operation should be retried on another connection. """ def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): """ This is called when a read operation times out from the coordinator's perspective (i.e. a replica did not respond to the coordinator in time). It should return a tuple with two items: one of the class enums (such as :attr:`.RETRY`) and a :class:`.ConsistencyLevel` to retry the operation at or :const:`None` to keep the same consistency level. `query` is the :class:`.Statement` that timed out. `consistency` is the :class:`.ConsistencyLevel` that the operation was attempted at. The `required_responses` and `received_responses` parameters describe how many replicas needed to respond to meet the requested consistency level and how many actually did respond before the coordinator timed out the request. `data_retrieved` is a boolean indicating whether any of those responses contained data (as opposed to just a digest). `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. By default, operations will be retried at most once, and only if a sufficient number of replicas responded (with data digests). """ if retry_num != 0: return self.RETHROW, None elif received_responses >= required_responses and not data_retrieved: return self.RETRY, consistency else: return self.RETHROW, None def on_write_timeout(self, query, consistency, write_type, required_responses, received_responses, retry_num): """ This is called when a write operation times out from the coordinator's perspective (i.e. a replica did not respond to the coordinator in time). `query` is the :class:`.Statement` that timed out. `consistency` is the :class:`.ConsistencyLevel` that the operation was attempted at. `write_type` is one of the :class:`.WriteType` enums describing the type of write operation. The `required_responses` and `received_responses` parameters describe how many replicas needed to acknowledge the write to meet the requested consistency level and how many replicas actually did acknowledge the write before the coordinator timed out the request. `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. By default, failed write operations will retried at most once, and they will only be retried if the `write_type` was :attr:`~.WriteType.BATCH_LOG`. """ if retry_num != 0: return self.RETHROW, None elif write_type == WriteType.BATCH_LOG: return self.RETRY, consistency else: return self.RETHROW, None def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): """ This is called when the coordinator node determines that a read or write operation cannot be successful because the number of live replicas are too low to meet the requested :class:`.ConsistencyLevel`. - This means that the read or write operation was never forwared to + This means that the read or write operation was never forwarded to any replicas. `query` is the :class:`.Statement` that failed. `consistency` is the :class:`.ConsistencyLevel` that the operation was attempted at. `required_replicas` is the number of replicas that would have needed to acknowledge the operation to meet the requested consistency level. `alive_replicas` is the number of replicas that the coordinator considered alive at the time of the request. `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. - By default, no retries will be attempted and the error will be re-raised. + By default, if this is the first retry, it triggers a retry on the next + host in the query plan with the same consistency level. If this is not the + first retry, no retries will be attempted and the error will be re-raised. + """ + return (self.RETRY_NEXT_HOST, None) if retry_num == 0 else (self.RETHROW, None) + + def on_request_error(self, query, consistency, error, retry_num): + """ + This is called when an unexpected error happens. This can be in the + following situations: + + * On a connection error + * On server errors: overloaded, isBootstrapping, serverError, etc. + + `query` is the :class:`.Statement` that timed out. + + `consistency` is the :class:`.ConsistencyLevel` that the operation was + attempted at. + + `error` the instance of the exception. + + `retry_num` counts how many times the operation has been retried, so + the first time this method is called, `retry_num` will be 0. + + The default, it triggers a retry on the next host in the query plan + with the same consistency level. """ - return (self.RETRY_NEXT_HOST, consistency) if retry_num == 0 else (self.RETHROW, None) + # TODO revisit this for the next major + # To preserve the same behavior than before, we don't take retry_num into account + return self.RETRY_NEXT_HOST, None class FallthroughRetryPolicy(RetryPolicy): """ A retry policy that never retries and always propagates failures to the application. """ def on_read_timeout(self, *args, **kwargs): return self.RETHROW, None def on_write_timeout(self, *args, **kwargs): return self.RETHROW, None def on_unavailable(self, *args, **kwargs): return self.RETHROW, None + def on_request_error(self, *args, **kwargs): + return self.RETHROW, None + class DowngradingConsistencyRetryPolicy(RetryPolicy): """ + *Deprecated:* This retry policy will be removed in the next major release. + A retry policy that sometimes retries with a lower consistency level than the one initially requested. **BEWARE**: This policy may retry queries using a lower consistency level than the one initially requested. By doing so, it may break consistency guarantees. In other words, if you use this retry policy, there are cases (documented below) where a read at :attr:`~.QUORUM` *may not* see a preceding write at :attr:`~.QUORUM`. Do not use this policy unless you have understood the cases where this can happen and are ok with that. It is also recommended to subclass this class so that queries that required a consistency level downgrade can be recorded (so that repairs can be made later, etc). This policy implements the same retries as :class:`.RetryPolicy`, but on top of that, it also retries in the following cases: * On a read timeout: if the number of replicas that responded is greater than one but lower than is required by the requested consistency level, the operation is retried at a lower consistency level. * On a write timeout: if the operation is an :attr:`~.UNLOGGED_BATCH` and at least one replica acknowledged the write, the operation is retried at a lower consistency level. Furthermore, for other write types, if at least one replica acknowledged the write, the timeout is ignored. * On an unavailable exception: if at least one replica is alive, the operation is retried at a lower consistency level. The reasoning behind this retry policy is as follows: if, based on the information the Cassandra coordinator node returns, retrying the operation with the initially requested consistency has a chance to succeed, do it. Otherwise, if based on that information we know the initially requested consistency level cannot be achieved currently, then: * For writes, ignore the exception (thus silently failing the consistency requirement) if we know the write has been persisted on at least one replica. * For reads, try reading at a lower consistency level (thus silently failing the consistency requirement). In other words, this policy implements the idea that if the requested consistency level cannot be achieved, the next best thing for writes is to make sure the data is persisted, and that reading something is better than reading nothing, even if there is a risk of reading stale data. """ + def __init__(self, *args, **kwargs): + super(DowngradingConsistencyRetryPolicy, self).__init__(*args, **kwargs) + warnings.warn('DowngradingConsistencyRetryPolicy is deprecated ' + 'and will be removed in the next major release.', + DeprecationWarning) + def _pick_consistency(self, num_responses): if num_responses >= 3: return self.RETRY, ConsistencyLevel.THREE elif num_responses >= 2: return self.RETRY, ConsistencyLevel.TWO elif num_responses >= 1: return self.RETRY, ConsistencyLevel.ONE else: return self.RETHROW, None def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): if retry_num != 0: return self.RETHROW, None + elif ConsistencyLevel.is_serial(consistency): + # Downgrading does not make sense for a CAS read query + return self.RETHROW, None elif received_responses < required_responses: return self._pick_consistency(received_responses) elif not data_retrieved: return self.RETRY, consistency else: return self.RETHROW, None def on_write_timeout(self, query, consistency, write_type, required_responses, received_responses, retry_num): if retry_num != 0: return self.RETHROW, None if write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): if received_responses > 0: # persisted on at least one replica return self.IGNORE, None else: return self.RETHROW, None elif write_type == WriteType.UNLOGGED_BATCH: return self._pick_consistency(received_responses) elif write_type == WriteType.BATCH_LOG: return self.RETRY, consistency return self.RETHROW, None def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): if retry_num != 0: return self.RETHROW, None + elif ConsistencyLevel.is_serial(consistency): + # failed at the paxos phase of a LWT, retry on the next host + return self.RETRY_NEXT_HOST, None else: return self._pick_consistency(alive_replicas) class AddressTranslator(object): """ Interface for translating cluster-defined endpoints. The driver discovers nodes using server metadata and topology change events. Normally, the endpoint defined by the server is the right way to connect to a node. In some environments, these addresses may not be reachable, or not preferred (public vs. private IPs in cloud environments, suboptimal routing, etc). This interface allows for translating from server defined endpoints to preferred addresses for driver connections. *Note:* :attr:`~Cluster.contact_points` provided while creating the :class:`~.Cluster` instance are not translated using this mechanism -- only addresses received from Cassandra nodes are. """ def translate(self, addr): """ Accepts the node ip address, and returns a translated address to be used connecting to this node. """ raise NotImplementedError() class IdentityTranslator(AddressTranslator): """ Returns the endpoint with no translation """ def translate(self, addr): return addr class EC2MultiRegionTranslator(AddressTranslator): """ Resolves private ips of the hosts in the same datacenter as the client, and public ips of hosts in other datacenters. """ def translate(self, addr): """ Reverse DNS the public broadcast_address, then lookup that hostname to get the AWS-resolved IP, which will point to the private IP address within the same datacenter. """ # get family of this address so we translate to the same family = socket.getaddrinfo(addr, 0, socket.AF_UNSPEC, socket.SOCK_STREAM)[0][0] host = socket.getfqdn(addr) for a in socket.getaddrinfo(host, 0, family, socket.SOCK_STREAM): try: return a[4][0] except Exception: pass return addr class SpeculativeExecutionPolicy(object): """ Interface for specifying speculative execution plans """ def new_plan(self, keyspace, statement): """ Returns :param keyspace: :param statement: :return: """ raise NotImplementedError() class SpeculativeExecutionPlan(object): def next_execution(self, host): raise NotImplementedError() class NoSpeculativeExecutionPlan(SpeculativeExecutionPlan): def next_execution(self, host): return -1 class NoSpeculativeExecutionPolicy(SpeculativeExecutionPolicy): def new_plan(self, keyspace, statement): return NoSpeculativeExecutionPlan() class ConstantSpeculativeExecutionPolicy(SpeculativeExecutionPolicy): """ A speculative execution policy that sends a new query every X seconds (**delay**) for a maximum of Y attempts (**max_attempts**). """ def __init__(self, delay, max_attempts): self.delay = delay self.max_attempts = max_attempts class ConstantSpeculativeExecutionPlan(SpeculativeExecutionPlan): def __init__(self, delay, max_attempts): self.delay = delay self.remaining = max_attempts def next_execution(self, host): if self.remaining > 0: self.remaining -= 1 return self.delay else: return -1 def new_plan(self, keyspace, statement): return self.ConstantSpeculativeExecutionPlan(self.delay, self.max_attempts) diff --git a/cassandra/pool.py b/cassandra/pool.py index 1d6bcf4..cd814ef 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -1,796 +1,818 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Connection pooling and host management. """ from functools import total_ordering import logging import socket import time from threading import Lock, RLock, Condition import weakref try: from weakref import WeakSet except ImportError: from cassandra.util import WeakSet # NOQA from cassandra import AuthenticationFailed -from cassandra.connection import ConnectionException +from cassandra.connection import ConnectionException, EndPoint, DefaultEndPoint from cassandra.policies import HostDistance log = logging.getLogger(__name__) class NoConnectionsAvailable(Exception): """ All existing connections to a given host are busy, or there are no open connections. """ pass @total_ordering class Host(object): """ Represents a single Cassandra node. """ - address = None + endpoint = None """ - The IP address of the node. This is the RPC address the driver uses when connecting to the node + The :class:`~.connection.EndPoint` to connect to the node. """ broadcast_address = None """ broadcast address configured for the node, *if available* ('peer' in system.peers table). This is not present in the ``system.local`` table for older versions of Cassandra. It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. """ + broadcast_rpc_address = None + """ + The broadcast rpc address of the node (`native_address` or `rpc_address`). + """ + listen_address = None """ listen address configured for the node, *if available*. This is only available in the ``system.local`` table for newer versions of Cassandra. It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. Usually the same as ``broadcast_address`` unless configured differently in cassandra.yaml. """ conviction_policy = None """ A :class:`~.ConvictionPolicy` instance for determining when this node should be marked up or down. """ is_up = None """ :const:`True` if the node is considered up, :const:`False` if it is considered down, and :const:`None` if it is not known if the node is up or down. """ release_version = None """ release_version as queried from the control connection system tables """ + host_id = None + """ + The unique identifier of the cassandra node + """ + dse_version = None """ dse_version as queried from the control connection system tables. Only populated when connecting to DSE with this property available. Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. """ dse_workload = None """ DSE workload queried from the control connection system tables. Only populated when connecting to DSE with this property available. Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. """ _datacenter = None _rack = None _reconnection_handler = None lock = None _currently_handling_node_up = False - def __init__(self, inet_address, conviction_policy_factory, datacenter=None, rack=None): - if inet_address is None: - raise ValueError("inet_address may not be None") + def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None): + if endpoint is None: + raise ValueError("endpoint may not be None") if conviction_policy_factory is None: raise ValueError("conviction_policy_factory may not be None") - self.address = inet_address + self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint) self.conviction_policy = conviction_policy_factory(self) + self.host_id = host_id self.set_location_info(datacenter, rack) self.lock = RLock() + @property + def address(self): + """ + The IP address of the endpoint. This is the RPC address the driver uses when connecting to the node. + """ + # backward compatibility + return self.endpoint.address + @property def datacenter(self): """ The datacenter the node is in. """ return self._datacenter @property def rack(self): """ The rack the node is in. """ return self._rack def set_location_info(self, datacenter, rack): """ Sets the datacenter and rack for this node. Intended for internal use (by the control connection, which periodically checks the ring topology) only. """ self._datacenter = datacenter self._rack = rack def set_up(self): if not self.is_up: - log.debug("Host %s is now marked up", self.address) + log.debug("Host %s is now marked up", self.endpoint) self.conviction_policy.reset() self.is_up = True def set_down(self): self.is_up = False def signal_connection_failure(self, connection_exc): return self.conviction_policy.add_failure(connection_exc) def is_currently_reconnecting(self): return self._reconnection_handler is not None def get_and_set_reconnection_handler(self, new_handler): """ Atomically replaces the reconnection handler for this host. Intended for internal use only. """ with self.lock: old = self._reconnection_handler self._reconnection_handler = new_handler return old def __eq__(self, other): - return self.address == other.address + if isinstance(other, Host): + return self.endpoint == other.endpoint + else: # TODO Backward compatibility, remove next major + return self.endpoint.address == other def __hash__(self): - return hash(self.address) + return hash(self.endpoint) def __lt__(self, other): - return self.address < other.address + return self.endpoint < other.endpoint def __str__(self): - return str(self.address) + return str(self.endpoint) def __repr__(self): dc = (" %s" % (self._datacenter,)) if self._datacenter else "" - return "<%s: %s%s>" % (self.__class__.__name__, self.address, dc) + return "<%s: %s%s>" % (self.__class__.__name__, self.endpoint, dc) class _ReconnectionHandler(object): """ Abstract class for attempting reconnections with a given schedule and scheduler. """ _cancelled = False def __init__(self, scheduler, schedule, callback, *callback_args, **callback_kwargs): self.scheduler = scheduler self.schedule = schedule self.callback = callback self.callback_args = callback_args self.callback_kwargs = callback_kwargs def start(self): if self._cancelled: log.debug("Reconnection handler was cancelled before starting") return first_delay = next(self.schedule) self.scheduler.schedule(first_delay, self.run) def run(self): if self._cancelled: return conn = None try: conn = self.try_reconnect() except Exception as exc: try: next_delay = next(self.schedule) except StopIteration: # the schedule has been exhausted next_delay = None # call on_exception for logging purposes even if next_delay is None if self.on_exception(exc, next_delay): if next_delay is None: log.warning( "Will not continue to retry reconnection attempts " "due to an exhausted retry schedule") else: self.scheduler.schedule(next_delay, self.run) else: if not self._cancelled: self.on_reconnection(conn) self.callback(*(self.callback_args), **(self.callback_kwargs)) finally: if conn: conn.close() def cancel(self): self._cancelled = True def try_reconnect(self): """ Subclasses must implement this method. It should attempt to open a new Connection and return it; if a failure occurs, an Exception should be raised. """ raise NotImplementedError() def on_reconnection(self, connection): """ Called when a new Connection is successfully opened. Nothing is done by default. """ pass def on_exception(self, exc, next_delay): """ Called when an Exception is raised when trying to connect. `exc` is the Exception that was raised and `next_delay` is the number of seconds (as a float) that the handler will wait before attempting to connect again. Subclasses should return :const:`False` if no more attempts to connection should be made, :const:`True` otherwise. The default behavior is to always retry unless the error is an :exc:`.AuthenticationFailed` instance. """ if isinstance(exc, AuthenticationFailed): return False else: return True class _HostReconnectionHandler(_ReconnectionHandler): def __init__(self, host, connection_factory, is_host_addition, on_add, on_up, *args, **kwargs): _ReconnectionHandler.__init__(self, *args, **kwargs) self.is_host_addition = is_host_addition self.on_add = on_add self.on_up = on_up self.host = host self.connection_factory = connection_factory def try_reconnect(self): return self.connection_factory() def on_reconnection(self, connection): log.info("Successful reconnection to %s, marking node up if it isn't already", self.host) if self.is_host_addition: self.on_add(self.host) else: self.on_up(self.host) def on_exception(self, exc, next_delay): if isinstance(exc, AuthenticationFailed): return False else: log.warning("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s", self.host, next_delay, exc) log.debug("Reconnection error details", exc_info=True) return True class HostConnection(object): """ When using v3 of the native protocol, this is used instead of a connection pool per host (HostConnectionPool) due to the increased in-flight capacity of individual connections. """ host = None host_distance = None is_shutdown = False shutdown_on_error = False _session = None _connection = None _lock = None _keyspace = None def __init__(self, host, host_distance, session): self.host = host self.host_distance = host_distance self._session = weakref.proxy(session) self._lock = Lock() # this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool. self._stream_available_condition = Condition(self._lock) self._is_replacing = False if host_distance == HostDistance.IGNORED: log.debug("Not opening connection to ignored host %s", self.host) return elif host_distance == HostDistance.REMOTE and not session.cluster.connect_to_remote_hosts: log.debug("Not opening connection to remote host %s", self.host) return log.debug("Initializing connection for host %s", self.host) - self._connection = session.cluster.connection_factory(host.address) + self._connection = session.cluster.connection_factory(host.endpoint) self._keyspace = session.keyspace if self._keyspace: self._connection.set_keyspace_blocking(self._keyspace) log.debug("Finished initializing connection for host %s", self.host) def borrow_connection(self, timeout): if self.is_shutdown: raise ConnectionException( "Pool for %s is shutdown" % (self.host,), self.host) conn = self._connection if not conn: raise NoConnectionsAvailable() start = time.time() remaining = timeout while True: with conn.lock: if conn.in_flight <= conn.max_request_id: conn.in_flight += 1 return conn, conn.get_request_id() if timeout is not None: remaining = timeout - time.time() + start if remaining < 0: break with self._stream_available_condition: self._stream_available_condition.wait(remaining) raise NoConnectionsAvailable("All request IDs are currently in use") def return_connection(self, connection): with connection.lock: connection.in_flight -= 1 with self._stream_available_condition: self._stream_available_condition.notify() if connection.is_defunct or connection.is_closed: if connection.signaled_error and not self.shutdown_on_error: return is_down = False if not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) is_down = self._session.cluster.signal_connection_failure( self.host, connection.last_error, is_host_addition=False) connection.signaled_error = True if self.shutdown_on_error and not is_down: is_down = True self._session.cluster.on_down(self.host, is_host_addition=False) if is_down: self.shutdown() else: self._connection = None with self._lock: if self._is_replacing: return self._is_replacing = True self._session.submit(self._replace, connection) def _replace(self, connection): with self._lock: if self.is_shutdown: return log.debug("Replacing connection (%s) to %s", id(connection), self.host) try: - conn = self._session.cluster.connection_factory(self.host.address) + conn = self._session.cluster.connection_factory(self.host) if self._keyspace: conn.set_keyspace_blocking(self._keyspace) self._connection = conn except Exception: - log.warning("Failed reconnecting %s. Retrying." % (self.host.address,)) + log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,)) self._session.submit(self._replace, connection) else: with self._lock: self._is_replacing = False self._stream_available_condition.notify() def shutdown(self): with self._lock: if self.is_shutdown: return else: self.is_shutdown = True self._stream_available_condition.notify_all() if self._connection: self._connection.close() self._connection = None def _set_keyspace_for_all_conns(self, keyspace, callback): if self.is_shutdown or not self._connection: return def connection_finished_setting_keyspace(conn, error): self.return_connection(conn) errors = [] if not error else [error] callback(self, errors) self._keyspace = keyspace self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace) def get_connections(self): c = self._connection return [c] if c else [] def get_state(self): connection = self._connection open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0 in_flights = [connection.in_flight] if connection else [] return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights} @property def open_count(self): connection = self._connection return 1 if connection and not (connection.is_closed or connection.is_defunct) else 0 _MAX_SIMULTANEOUS_CREATION = 1 _MIN_TRASH_INTERVAL = 10 class HostConnectionPool(object): """ Used to pool connections to a host for v1 and v2 native protocol. """ host = None host_distance = None is_shutdown = False open_count = 0 _scheduled_for_creation = 0 _next_trash_allowed_at = 0 _keyspace = None def __init__(self, host, host_distance, session): self.host = host self.host_distance = host_distance self._session = weakref.proxy(session) self._lock = RLock() self._conn_available_condition = Condition() log.debug("Initializing new connection pool for host %s", self.host) core_conns = session.cluster.get_core_connections_per_host(host_distance) - self._connections = [session.cluster.connection_factory(host.address) + self._connections = [session.cluster.connection_factory(host.endpoint) for i in range(core_conns)] self._keyspace = session.keyspace if self._keyspace: for conn in self._connections: conn.set_keyspace_blocking(self._keyspace) self._trash = set() self._next_trash_allowed_at = time.time() self.open_count = core_conns log.debug("Finished initializing new connection pool for host %s", self.host) def borrow_connection(self, timeout): if self.is_shutdown: raise ConnectionException( "Pool for %s is shutdown" % (self.host,), self.host) conns = self._connections if not conns: # handled specially just for simpler code log.debug("Detected empty pool, opening core conns to %s", self.host) core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) with self._lock: # we check the length of self._connections again # along with self._scheduled_for_creation while holding the lock # in case multiple threads hit this condition at the same time to_create = core_conns - (len(self._connections) + self._scheduled_for_creation) for i in range(to_create): self._scheduled_for_creation += 1 self._session.submit(self._create_new_connection) # in_flight is incremented by wait_for_conn conn = self._wait_for_conn(timeout) return conn else: # note: it would be nice to push changes to these config settings # to pools instead of doing a new lookup on every # borrow_connection() call max_reqs = self._session.cluster.get_max_requests_per_connection(self.host_distance) max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) least_busy = min(conns, key=lambda c: c.in_flight) request_id = None # to avoid another thread closing this connection while # trashing it (through the return_connection process), hold # the connection lock from this point until we've incremented # its in_flight count need_to_wait = False with least_busy.lock: if least_busy.in_flight < least_busy.max_request_id: least_busy.in_flight += 1 request_id = least_busy.get_request_id() else: # once we release the lock, wait for another connection need_to_wait = True if need_to_wait: # wait_for_conn will increment in_flight on the conn least_busy, request_id = self._wait_for_conn(timeout) # if we have too many requests on this connection but we still # have space to open a new connection against this host, go ahead # and schedule the creation of a new connection if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns: self._maybe_spawn_new_connection() return least_busy, request_id def _maybe_spawn_new_connection(self): with self._lock: if self._scheduled_for_creation >= _MAX_SIMULTANEOUS_CREATION: return if self.open_count >= self._session.cluster.get_max_connections_per_host(self.host_distance): return self._scheduled_for_creation += 1 log.debug("Submitting task for creation of new Connection to %s", self.host) self._session.submit(self._create_new_connection) def _create_new_connection(self): try: self._add_conn_if_under_max() except (ConnectionException, socket.error) as exc: log.warning("Failed to create new connection to %s: %s", self.host, exc) except Exception: log.exception("Unexpectedly failed to create new connection") finally: with self._lock: self._scheduled_for_creation -= 1 def _add_conn_if_under_max(self): max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) with self._lock: if self.is_shutdown: return True if self.open_count >= max_conns: return True self.open_count += 1 log.debug("Going to open new connection to host %s", self.host) try: - conn = self._session.cluster.connection_factory(self.host.address) + conn = self._session.cluster.connection_factory(self.host.endpoint) if self._keyspace: conn.set_keyspace_blocking(self._session.keyspace) self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL with self._lock: new_connections = self._connections[:] + [conn] self._connections = new_connections log.debug("Added new connection (%s) to pool for host %s, signaling availablility", id(conn), self.host) self._signal_available_conn() return True except (ConnectionException, socket.error) as exc: log.warning("Failed to add new connection to pool for host %s: %s", self.host, exc) with self._lock: self.open_count -= 1 if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False): self.shutdown() return False except AuthenticationFailed: with self._lock: self.open_count -= 1 return False def _await_available_conn(self, timeout): with self._conn_available_condition: self._conn_available_condition.wait(timeout) def _signal_available_conn(self): with self._conn_available_condition: self._conn_available_condition.notify() def _signal_all_available_conn(self): with self._conn_available_condition: self._conn_available_condition.notify_all() def _wait_for_conn(self, timeout): start = time.time() remaining = timeout while remaining > 0: # wait on our condition for the possibility that a connection # is useable self._await_available_conn(remaining) # self.shutdown() may trigger the above Condition if self.is_shutdown: raise ConnectionException("Pool is shutdown") conns = self._connections if conns: least_busy = min(conns, key=lambda c: c.in_flight) with least_busy.lock: if least_busy.in_flight < least_busy.max_request_id: least_busy.in_flight += 1 return least_busy, least_busy.get_request_id() remaining = timeout - (time.time() - start) raise NoConnectionsAvailable() def return_connection(self, connection): with connection.lock: connection.in_flight -= 1 in_flight = connection.in_flight if connection.is_defunct or connection.is_closed: if not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) is_down = self._session.cluster.signal_connection_failure( self.host, connection.last_error, is_host_addition=False) connection.signaled_error = True if is_down: self.shutdown() else: self._replace(connection) else: if connection in self._trash: with connection.lock: if connection.in_flight == 0: with self._lock: if connection in self._trash: self._trash.remove(connection) log.debug("Closing trashed connection (%s) to %s", id(connection), self.host) connection.close() return core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) min_reqs = self._session.cluster.get_min_requests_per_connection(self.host_distance) # we can use in_flight here without holding the connection lock # because the fact that in_flight dipped below the min at some # point is enough to start the trashing procedure if len(self._connections) > core_conns and in_flight <= min_reqs and \ time.time() >= self._next_trash_allowed_at: self._maybe_trash_connection(connection) else: self._signal_available_conn() def _maybe_trash_connection(self, connection): core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) did_trash = False with self._lock: if connection not in self._connections: return if self.open_count > core_conns: did_trash = True self.open_count -= 1 new_connections = self._connections[:] new_connections.remove(connection) self._connections = new_connections with connection.lock: if connection.in_flight == 0: log.debug("Skipping trash and closing unused connection (%s) to %s", id(connection), self.host) connection.close() # skip adding it to the trash if we're already closing it return self._trash.add(connection) if did_trash: self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL log.debug("Trashed connection (%s) to %s", id(connection), self.host) def _replace(self, connection): should_replace = False with self._lock: if connection in self._connections: new_connections = self._connections[:] new_connections.remove(connection) self._connections = new_connections self.open_count -= 1 should_replace = True if should_replace: log.debug("Replacing connection (%s) to %s", id(connection), self.host) connection.close() self._session.submit(self._retrying_replace) else: log.debug("Closing connection (%s) to %s", id(connection), self.host) connection.close() def _retrying_replace(self): replaced = False try: replaced = self._add_conn_if_under_max() except Exception: log.exception("Failed replacing connection to %s", self.host) if not replaced: log.debug("Failed replacing connection to %s. Retrying.", self.host) self._session.submit(self._retrying_replace) def shutdown(self): with self._lock: if self.is_shutdown: return else: self.is_shutdown = True self._signal_all_available_conn() for conn in self._connections: conn.close() self.open_count -= 1 for conn in self._trash: conn.close() def ensure_core_connections(self): if self.is_shutdown: return core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) with self._lock: to_create = core_conns - (len(self._connections) + self._scheduled_for_creation) for i in range(to_create): self._scheduled_for_creation += 1 self._session.submit(self._create_new_connection) def _set_keyspace_for_all_conns(self, keyspace, callback): """ Asynchronously sets the keyspace for all connections. When all connections have been set, `callback` will be called with two arguments: this pool, and a list of any errors that occurred. """ remaining_callbacks = set(self._connections) errors = [] if not remaining_callbacks: callback(self, errors) return def connection_finished_setting_keyspace(conn, error): self.return_connection(conn) remaining_callbacks.remove(conn) if error: errors.append(error) if not remaining_callbacks: callback(self, errors) self._keyspace = keyspace for conn in self._connections: conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace) def get_connections(self): return self._connections def get_state(self): in_flights = [c.in_flight for c in self._connections] return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights} diff --git a/cassandra/query.py b/cassandra/query.py index b2193d6..74a9896 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -1,1089 +1,1089 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module holds classes for working with prepared statements and specifying consistency levels and retry policies for individual queries. """ from collections import namedtuple from datetime import datetime, timedelta import re import struct import time import six from six.moves import range, zip import warnings from cassandra import ConsistencyLevel, OperationTimedOut from cassandra.util import unix_time_from_uuid1 from cassandra.encoder import Encoder import cassandra.encoder from cassandra.protocol import _UNSET_VALUE from cassandra.util import OrderedDict, _sanitize_identifiers import logging log = logging.getLogger(__name__) UNSET_VALUE = _UNSET_VALUE """ Specifies an unset value when binding a prepared statement. Unset values are ignored, allowing prepared statements to be used without specify See https://issues.apache.org/jira/browse/CASSANDRA-7304 for further details on semantics. .. versionadded:: 2.6.0 Only valid when using native protocol v4+ """ NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') _clean_name_cache = {} def _clean_column_name(name): try: return _clean_name_cache[name] except KeyError: clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))) _clean_name_cache[name] = clean return clean def tuple_factory(colnames, rows): """ Returns each row as a tuple Example:: >>> from cassandra.query import tuple_factory >>> session = cluster.connect('mykeyspace') >>> session.row_factory = tuple_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") >>> print rows[0] ('Bob', 42) .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ return rows class PseudoNamedTupleRow(object): """ Helper class for pseudo_named_tuple_factory. These objects provide an __iter__ interface, as well as index- and attribute-based access to values, but otherwise do not attempt to implement the full namedtuple or iterable interface. """ def __init__(self, ordered_dict): self._dict = ordered_dict self._tuple = tuple(ordered_dict.values()) def __getattr__(self, name): return self._dict[name] def __getitem__(self, idx): return self._tuple[idx] def __iter__(self): return iter(self._tuple) def __repr__(self): return '{t}({od})'.format(t=self.__class__.__name__, od=self._dict) def pseudo_namedtuple_factory(colnames, rows): """ Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback factory for cases where :meth:`.named_tuple_factory` fails to create rows. """ return [PseudoNamedTupleRow(od) for od in ordered_dict_factory(colnames, rows)] def named_tuple_factory(colnames, rows): """ Returns each row as a `namedtuple `_. This is the default row factory. Example:: >>> from cassandra.query import named_tuple_factory >>> session = cluster.connect('mykeyspace') >>> session.row_factory = named_tuple_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") >>> user = rows[0] >>> # you can access field by their name: >>> print "name: %s, age: %d" % (user.name, user.age) name: Bob, age: 42 >>> # or you can access fields by their position (like a tuple) >>> name, age = user >>> print "name: %s, age: %d" % (name, age) name: Bob, age: 42 >>> name = user[0] >>> age = user[1] >>> print "name: %s, age: %d" % (name, age) name: Bob, age: 42 .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ clean_column_names = map(_clean_column_name, colnames) try: Row = namedtuple('Row', clean_column_names) except SyntaxError: warnings.warn( "Failed creating namedtuple for a result because there were too " "many columns. This is due to a Python limitation that affects " "namedtuple in Python 3.0-3.6 (see issue18896). The row will be " "created with {substitute_factory_name}, which lacks some namedtuple " "features and is slower. To avoid slower performance accessing " "values on row objects, Upgrade to Python 3.7, or use a different " "row factory. (column names: {colnames})".format( substitute_factory_name=pseudo_namedtuple_factory.__name__, colnames=colnames ) ) return pseudo_namedtuple_factory(colnames, rows) except Exception: clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " "(see Python 'namedtuple' documentation for details on name rules). " "Results will be returned with positional names. " "Avoid this by choosing different names, using SELECT \"\" AS aliases, " "or specifying a different row_factory on your Session" % (colnames, clean_column_names)) Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) return [Row(*row) for row in rows] def dict_factory(colnames, rows): """ Returns each row as a dict. Example:: >>> from cassandra.query import dict_factory >>> session = cluster.connect('mykeyspace') >>> session.row_factory = dict_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") >>> print rows[0] {u'age': 42, u'name': u'Bob'} .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ return [dict(zip(colnames, row)) for row in rows] def ordered_dict_factory(colnames, rows): """ Like :meth:`~cassandra.query.dict_factory`, but returns each row as an OrderedDict, so the order of the columns is preserved. .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ return [OrderedDict(zip(colnames, row)) for row in rows] FETCH_SIZE_UNSET = object() class Statement(object): """ An abstract class representing a single query. There are three subclasses: :class:`.SimpleStatement`, :class:`.BoundStatement`, and :class:`.BatchStatement`. These can be passed to :meth:`.Session.execute()`. """ retry_policy = None """ An instance of a :class:`cassandra.policies.RetryPolicy` or one of its subclasses. This controls when a query will be retried and how it will be retried. """ consistency_level = None """ The :class:`.ConsistencyLevel` to be used for this operation. Defaults to :const:`None`, which means that the default consistency level for the Session this is executed in will be used. """ fetch_size = FETCH_SIZE_UNSET """ How many rows will be fetched at a time. This overrides the default of :attr:`.Session.default_fetch_size` This only takes effect when protocol version 2 or higher is used. See :attr:`.Cluster.protocol_version` for details. .. versionadded:: 2.0.0 """ keyspace = None """ The string name of the keyspace this query acts on. This is used when :class:`~.TokenAwarePolicy` is configured for :attr:`.Cluster.load_balancing_policy` It is set implicitly on :class:`.BoundStatement`, and :class:`.BatchStatement`, but must be set explicitly on :class:`.SimpleStatement`. .. versionadded:: 2.1.3 """ custom_payload = None """ :ref:`custom_payload` to be passed to the server. These are only allowed when using protocol version 4 or higher. .. versionadded:: 2.6.0 """ is_idempotent = False """ Flag indicating whether this statement is safe to run multiple times in speculative execution. """ _serial_consistency_level = None _routing_key = None def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, is_idempotent=False): if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') if retry_policy is not None: self.retry_policy = retry_policy if consistency_level is not None: self.consistency_level = consistency_level self._routing_key = routing_key if serial_consistency_level is not None: self.serial_consistency_level = serial_consistency_level if fetch_size is not FETCH_SIZE_UNSET: self.fetch_size = fetch_size if keyspace is not None: self.keyspace = keyspace if custom_payload is not None: self.custom_payload = custom_payload self.is_idempotent = is_idempotent def _key_parts_packed(self, parts): for p in parts: l = len(p) yield struct.pack(">H%dsB" % l, l, p, 0) def _get_routing_key(self): return self._routing_key def _set_routing_key(self, key): if isinstance(key, (list, tuple)): if len(key) == 1: self._routing_key = key[0] else: self._routing_key = b"".join(self._key_parts_packed(key)) else: self._routing_key = key def _del_routing_key(self): self._routing_key = None routing_key = property( _get_routing_key, _set_routing_key, _del_routing_key, """ The :attr:`~.TableMetadata.partition_key` portion of the primary key, which can be used to determine which nodes are replicas for the query. If the partition key is a composite, a list or tuple must be passed in. Each key component should be in its packed (binary) format, so all components should be strings. """) def _get_serial_consistency_level(self): return self._serial_consistency_level def _set_serial_consistency_level(self, serial_consistency_level): - acceptable = (None, ConsistencyLevel.SERIAL, ConsistencyLevel.LOCAL_SERIAL) - if serial_consistency_level not in acceptable: + if (serial_consistency_level is not None and + not ConsistencyLevel.is_serial(serial_consistency_level)): raise ValueError( "serial_consistency_level must be either ConsistencyLevel.SERIAL " "or ConsistencyLevel.LOCAL_SERIAL") self._serial_consistency_level = serial_consistency_level def _del_serial_consistency_level(self): self._serial_consistency_level = None serial_consistency_level = property( _get_serial_consistency_level, _set_serial_consistency_level, _del_serial_consistency_level, """ The serial consistency level is only used by conditional updates (``INSERT``, ``UPDATE`` and ``DELETE`` with an ``IF`` condition). For those, the ``serial_consistency_level`` defines the consistency level of the serial phase (or "paxos" phase) while the normal :attr:`~.consistency_level` defines the consistency for the "learn" phase, i.e. what type of reads will be guaranteed to see the update right away. For example, if a conditional write has a :attr:`~.consistency_level` of :attr:`~.ConsistencyLevel.QUORUM` (and is successful), then a :attr:`~.ConsistencyLevel.QUORUM` read is guaranteed to see that write. But if the regular :attr:`~.consistency_level` of that write is :attr:`~.ConsistencyLevel.ANY`, then only a read with a :attr:`~.consistency_level` of :attr:`~.ConsistencyLevel.SERIAL` is guaranteed to see it (even a read with consistency :attr:`~.ConsistencyLevel.ALL` is not guaranteed to be enough). The serial consistency can only be one of :attr:`~.ConsistencyLevel.SERIAL` or :attr:`~.ConsistencyLevel.LOCAL_SERIAL`. While ``SERIAL`` guarantees full linearizability (with other ``SERIAL`` updates), ``LOCAL_SERIAL`` only guarantees it in the local data center. The serial consistency level is ignored for any query that is not a conditional update. Serial reads should use the regular :attr:`consistency_level`. Serial consistency levels may only be used against Cassandra 2.0+ and the :attr:`~.Cluster.protocol_version` must be set to 2 or higher. See :doc:`/lwt` for a discussion on how to work with results returned from conditional statements. .. versionadded:: 2.0.0 """) class SimpleStatement(Statement): """ A simple, un-prepared query. """ def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, is_idempotent=False): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the `parameters` argument of :meth:`.Session.execute()`. See :class:`Statement` attributes for a description of the other parameters. """ Statement.__init__(self, retry_policy, consistency_level, routing_key, serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) self._query_string = query_string @property def query_string(self): return self._query_string def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.query_string, consistency)) __repr__ = __str__ class PreparedStatement(object): """ A statement that has been prepared against at least one Cassandra node. Instances of this class should not be created directly, but through :meth:`.Session.prepare()`. A :class:`.PreparedStatement` should be prepared only once. Re-preparing a statement may affect performance (as the operation requires a network roundtrip). |prepared_stmt_head|: Do not use ``*`` in prepared statements if you might change the schema of the table being queried. The driver and server each maintain a map between metadata for a schema and statements that were prepared against that schema. When a user changes a schema, e.g. by adding or removing a column, the server invalidates its mappings involving that schema. However, there is currently no way to propagate that invalidation to drivers. Thus, after a schema change, the driver will incorrectly interpret the results of ``SELECT *`` queries prepared before the schema change. This is currently being addressed in `CASSANDRA-10786 `_. .. |prepared_stmt_head| raw:: html A note about * in prepared statements """ column_metadata = None #TODO: make this bind_metadata in next major retry_policy = None consistency_level = None custom_payload = None fetch_size = FETCH_SIZE_UNSET keyspace = None # change to prepared_keyspace in major release protocol_version = None query_id = None query_string = None result_metadata = None result_metadata_id = None routing_key_indexes = None _routing_key_index_set = None - serial_consistency_level = None + serial_consistency_level = None # TODO never used? def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace, protocol_version, result_metadata, result_metadata_id): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes self.query_string = query self.keyspace = keyspace self.protocol_version = protocol_version self.result_metadata = result_metadata self.result_metadata_id = result_metadata_id self.is_idempotent = False @classmethod def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version, result_metadata, result_metadata_id): if not column_metadata: return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version, result_metadata, result_metadata_id) if pk_indexes: routing_key_indexes = pk_indexes else: routing_key_indexes = None first_col = column_metadata[0] ks_meta = cluster_metadata.keyspaces.get(first_col.keyspace_name) if ks_meta: table_meta = ks_meta.tables.get(first_col.table_name) if table_meta: partition_key_columns = table_meta.partition_key # make a map of {column_name: index} for each column in the statement statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata)) # a list of which indexes in the statement correspond to partition key items try: routing_key_indexes = [statement_indexes[c.name] for c in partition_key_columns] except KeyError: # we're missing a partition key component in the prepared pass # statement; just leave routing_key_indexes as None return PreparedStatement(column_metadata, query_id, routing_key_indexes, query, prepared_keyspace, protocol_version, result_metadata, result_metadata_id) def bind(self, values): """ Creates and returns a :class:`BoundStatement` instance using `values`. See :meth:`BoundStatement.bind` for rules on input ``values``. """ return BoundStatement(self).bind(values) def is_routing_key_index(self, i): if self._routing_key_index_set is None: self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set() return i in self._routing_key_index_set def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.query_string, consistency)) __repr__ = __str__ class BoundStatement(Statement): """ A prepared statement that has been bound to a particular set of values. These may be created directly or through :meth:`.PreparedStatement.bind()`. """ prepared_statement = None """ The :class:`PreparedStatement` instance that this was created from. """ values = None """ The sequence of values that were bound to the prepared statement. """ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. See :class:`Statement` attributes for a description of the other parameters. """ self.prepared_statement = prepared_statement self.retry_policy = prepared_statement.retry_policy self.consistency_level = prepared_statement.consistency_level self.serial_consistency_level = prepared_statement.serial_consistency_level self.fetch_size = prepared_statement.fetch_size self.custom_payload = prepared_statement.custom_payload self.is_idempotent = prepared_statement.is_idempotent self.values = [] meta = prepared_statement.column_metadata if meta: self.keyspace = meta[0].keyspace_name Statement.__init__(self, retry_policy, consistency_level, routing_key, serial_consistency_level, fetch_size, keyspace, custom_payload, prepared_statement.is_idempotent) def bind(self, values): """ Binds a sequence of values for the prepared statement parameters and returns this instance. Note that `values` *must* be: * a sequence, even if you are only binding one value, or * a dict that relates 1-to-1 between dict keys and columns .. versionchanged:: 2.6.0 :data:`~.UNSET_VALUE` was introduced. These can be bound as positional parameters in a sequence, or by name in a dict. Additionally, when using protocol v4+: * short sequences will be extended to match bind parameters with UNSET_VALUE * names may be omitted from a dict with UNSET_VALUE implied. .. versionchanged:: 3.0.0 method will not throw if extra keys are present in bound dict (PYTHON-178) """ if values is None: values = () proto_version = self.prepared_statement.protocol_version col_meta = self.prepared_statement.column_metadata # special case for binding dicts if isinstance(values, dict): values_dict = values values = [] # sort values accordingly for col in col_meta: try: values.append(values_dict[col.name]) except KeyError: if proto_version >= 4: values.append(UNSET_VALUE) else: raise KeyError( 'Column name `%s` not found in bound dict.' % (col.name)) value_len = len(values) col_meta_len = len(col_meta) if value_len > col_meta_len: raise ValueError( "Too many arguments provided to bind() (got %d, expected %d)" % (len(values), len(col_meta))) # this is fail-fast for clarity pre-v4. When v4 can be assumed, # the error will be better reported when UNSET_VALUE is implicitly added. if proto_version < 4 and self.prepared_statement.routing_key_indexes and \ value_len < len(self.prepared_statement.routing_key_indexes): raise ValueError( "Too few arguments provided to bind() (got %d, required %d for routing key)" % (value_len, len(self.prepared_statement.routing_key_indexes))) self.raw_values = values self.values = [] for value, col_spec in zip(values, col_meta): if value is None: self.values.append(None) elif value is UNSET_VALUE: if proto_version >= 4: self._append_unset_value() else: raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) else: try: self.values.append(col_spec.type.serialize(value, proto_version)) except (TypeError, struct.error) as exc: actual_type = type(value) message = ('Received an argument of invalid type for column "%s". ' 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) raise TypeError(message) if proto_version >= 4: diff = col_meta_len - len(self.values) if diff: for _ in range(diff): self._append_unset_value() return self def _append_unset_value(self): next_index = len(self.values) if self.prepared_statement.is_routing_key_index(next_index): col_meta = self.prepared_statement.column_metadata[next_index] raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name) self.values.append(UNSET_VALUE) @property def routing_key(self): if not self.prepared_statement.routing_key_indexes: return None if self._routing_key is not None: return self._routing_key routing_indexes = self.prepared_statement.routing_key_indexes if len(routing_indexes) == 1: self._routing_key = self.values[routing_indexes[0]] else: self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes)) return self._routing_key def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.prepared_statement.query_string, self.raw_values, consistency)) __repr__ = __str__ class BatchType(object): """ A BatchType is used with :class:`.BatchStatement` instances to control the atomicity of the batch operation. .. versionadded:: 2.0.0 """ LOGGED = None """ Atomic batch operation. """ UNLOGGED = None """ Non-atomic batch operation. """ COUNTER = None """ Batches of counter operations. """ def __init__(self, name, value): self.name = name self.value = value def __str__(self): return self.name def __repr__(self): return "BatchType.%s" % (self.name, ) BatchType.LOGGED = BatchType("LOGGED", 0) BatchType.UNLOGGED = BatchType("UNLOGGED", 1) BatchType.COUNTER = BatchType("COUNTER", 2) class BatchStatement(Statement): """ A protocol-level batch of operations which are applied atomically by default. .. versionadded:: 2.0.0 """ batch_type = None """ The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. """ serial_consistency_level = None """ The same as :attr:`.Statement.serial_consistency_level`, but is only supported when using protocol version 3 or higher. """ _statements_and_parameters = None _session = None def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None, serial_consistency_level=None, session=None, custom_payload=None): """ `batch_type` specifies The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. `retry_policy` should be a :class:`~.RetryPolicy` instance for controlling retries on the operation. `consistency_level` should be a :class:`~.ConsistencyLevel` value to be used for all operations in the batch. `custom_payload` is a :ref:`custom_payload` passed to the server. Note: as Statement objects are added to the batch, this map is updated with any values found in their custom payloads. These are only allowed when using protocol version 4 or higher. Example usage: .. code-block:: python insert_user = session.prepare("INSERT INTO users (name, age) VALUES (?, ?)") batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM) for (name, age) in users_to_insert: batch.add(insert_user, (name, age)) session.execute(batch) You can also mix different types of operations within a batch: .. code-block:: python batch = BatchStatement() batch.add(SimpleStatement("INSERT INTO users (name, age) VALUES (%s, %s)"), (name, age)) batch.add(SimpleStatement("DELETE FROM pending_users WHERE name=%s"), (name,)) session.execute(batch) .. versionadded:: 2.0.0 .. versionchanged:: 2.1.0 Added `serial_consistency_level` as a parameter .. versionchanged:: 2.6.0 Added `custom_payload` as a parameter """ self.batch_type = batch_type self._statements_and_parameters = [] self._session = session Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) def clear(self): """ This is a convenience method to clear a batch statement for reuse. *Note:* it should not be used concurrently with uncompleted execution futures executing the same ``BatchStatement``. """ del self._statements_and_parameters[:] self.keyspace = None self.routing_key = None if self.custom_payload: self.custom_payload.clear() def add(self, statement, parameters=None): """ Adds a :class:`.Statement` and optional sequence of parameters to be used with the statement to the batch. Like with other statements, parameters must be a sequence, even if there is only one item. """ if isinstance(statement, six.string_types): if parameters: encoder = Encoder() if self._session is None else self._session.encoder statement = bind_params(statement, parameters, encoder) self._add_statement_and_params(False, statement, ()) elif isinstance(statement, PreparedStatement): query_id = statement.query_id bound_statement = statement.bind(() if parameters is None else parameters) self._update_state(bound_statement) self._add_statement_and_params(True, query_id, bound_statement.values) elif isinstance(statement, BoundStatement): if parameters: raise ValueError( "Parameters cannot be passed with a BoundStatement " "to BatchStatement.add()") self._update_state(statement) self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) else: # it must be a SimpleStatement query_string = statement.query_string if parameters: encoder = Encoder() if self._session is None else self._session.encoder query_string = bind_params(query_string, parameters, encoder) self._update_state(statement) self._add_statement_and_params(False, query_string, ()) return self def add_all(self, statements, parameters): """ Adds a sequence of :class:`.Statement` objects and a matching sequence of parameters to the batch. Statement and parameter sequences must be of equal length or one will be truncated. :const:`None` can be used in the parameters position where are needed. """ for statement, value in zip(statements, parameters): self.add(statement, value) def _add_statement_and_params(self, is_prepared, statement, parameters): if len(self._statements_and_parameters) >= 0xFFFF: raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) self._statements_and_parameters.append((is_prepared, statement, parameters)) def _maybe_set_routing_attributes(self, statement): if self.routing_key is None: if statement.keyspace and statement.routing_key: self.routing_key = statement.routing_key self.keyspace = statement.keyspace def _update_custom_payload(self, statement): if statement.custom_payload: if self.custom_payload is None: self.custom_payload = {} self.custom_payload.update(statement.custom_payload) def _update_state(self, statement): self._maybe_set_routing_attributes(statement) self._update_custom_payload(statement) def __len__(self): return len(self._statements_and_parameters) def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.batch_type, len(self), consistency)) __repr__ = __str__ ValueSequence = cassandra.encoder.ValueSequence """ A wrapper class that is used to specify that a sequence of values should be treated as a CQL list of values instead of a single column collection when used as part of the `parameters` argument for :meth:`.Session.execute()`. This is typically needed when supplying a list of keys to select. For example:: >>> my_user_ids = ('alice', 'bob', 'charles') >>> query = "SELECT * FROM users WHERE user_id IN %s" >>> session.execute(query, parameters=[ValueSequence(my_user_ids)]) """ def bind_params(query, params, encoder): if six.PY2 and isinstance(query, six.text_type): query = query.encode('utf-8') if isinstance(params, dict): return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params)) else: return query % tuple(encoder.cql_encode_all_types(v) for v in params) class TraceUnavailable(Exception): """ Raised when complete trace details cannot be fetched from Cassandra. """ pass class QueryTrace(object): """ A trace of the duration and events that occurred when executing an operation. """ trace_id = None """ :class:`uuid.UUID` unique identifier for this tracing session. Matches the ``session_id`` column in ``system_traces.sessions`` and ``system_traces.events``. """ request_type = None """ A string that very generally describes the traced operation. """ duration = None """ A :class:`datetime.timedelta` measure of the duration of the query. """ client = None """ The IP address of the client that issued this request This is only available when using Cassandra 2.2+ """ coordinator = None """ The IP address of the host that acted as coordinator for this request. """ parameters = None """ A :class:`dict` of parameters for the traced operation, such as the specific query string. """ started_at = None """ A UTC :class:`datetime.datetime` object describing when the operation was started. """ events = None """ A chronologically sorted list of :class:`.TraceEvent` instances representing the steps the traced operation went through. This corresponds to the rows in ``system_traces.events`` for this tracing session. """ _session = None _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s" _SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s" _BASE_RETRY_SLEEP = 0.003 def __init__(self, trace_id, session): self.trace_id = trace_id self._session = session def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): """ Retrieves the actual tracing details from Cassandra and populates the attributes of this instance. Because tracing details are stored asynchronously by Cassandra, this may need to retry the session detail fetch. If the trace is still not available after `max_wait` seconds, :exc:`.TraceUnavailable` will be raised; if `max_wait` is :const:`None`, this will retry forever. `wait_for_complete=False` bypasses the wait for duration to be populated. This can be used to query events from partial sessions. `query_cl` specifies a consistency level to use for polling the trace tables, if it should be different than the session default. """ attempt = 0 start = time.time() while True: time_spent = time.time() - start if max_wait is not None and time_spent >= max_wait: raise TraceUnavailable( "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,)) log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) session_results = self._execute( SimpleStatement(self._SELECT_SESSIONS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) # PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries is_complete = session_results and session_results[0].duration is not None and session_results[0].started_at is not None if not session_results or (wait_for_complete and not is_complete): time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) attempt += 1 continue if is_complete: log.debug("Fetched trace info for trace ID: %s", self.trace_id) else: log.debug("Fetching parital trace info for trace ID: %s", self.trace_id) session_row = session_results[0] self.request_type = session_row.request self.duration = timedelta(microseconds=session_row.duration) if is_complete else None self.started_at = session_row.started_at self.coordinator = session_row.coordinator self.parameters = session_row.parameters # since C* 2.2 self.client = getattr(session_row, 'client', None) log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) time_spent = time.time() - start event_results = self._execute( SimpleStatement(self._SELECT_EVENTS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) log.debug("Fetched trace events for trace ID: %s", self.trace_id) self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) for r in event_results) break def _execute(self, query, parameters, time_spent, max_wait): timeout = (max_wait - time_spent) if max_wait is not None else None future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) # in case the user switched the row factory, set it to namedtuple for this query future.row_factory = named_tuple_factory future.send_request() try: return future.result() except OperationTimedOut: raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) def __str__(self): return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ % (self.request_type, self.trace_id, self.coordinator, self.started_at, self.duration, self.parameters) class TraceEvent(object): """ Representation of a single event within a query trace. """ description = None """ A brief description of the event. """ datetime = None """ A UTC :class:`datetime.datetime` marking when the event occurred. """ source = None """ The IP address of the node this event occurred on. """ source_elapsed = None """ A :class:`datetime.timedelta` measuring the amount of time until this event occurred starting from when :attr:`.source` first received the query. """ thread_name = None """ The name of the thread that this event occurred on. """ def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.description = description self.datetime = datetime.utcfromtimestamp(unix_time_from_uuid1(timeuuid)) self.source = source if source_elapsed is not None: self.source_elapsed = timedelta(microseconds=source_elapsed) else: self.source_elapsed = None self.thread_name = thread_name def __str__(self): return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) diff --git a/cassandra/util.py b/cassandra/util.py index 28d2c72..efb3a95 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -1,1234 +1,1328 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import with_statement import calendar import datetime from functools import total_ordering import random import six import uuid import sys DATETIME_EPOC = datetime.datetime(1970, 1, 1) assert sys.byteorder in ('little', 'big') is_little_endian = sys.byteorder == 'little' def datetime_from_timestamp(timestamp): """ Creates a timezone-agnostic datetime from timestamp (in seconds) in a consistent manner. Works around a Windows issue with large negative timestamps (PYTHON-119), and rounding differences in Python 3.4 (PYTHON-340). :param timestamp: a unix timestamp, in seconds """ dt = DATETIME_EPOC + datetime.timedelta(seconds=timestamp) return dt def unix_time_from_uuid1(uuid_arg): """ Converts a version 1 :class:`uuid.UUID` to a timestamp with the same precision as :meth:`time.time()` returns. This is useful for examining the results of queries returning a v1 :class:`~uuid.UUID`. :param uuid_arg: a version 1 :class:`~uuid.UUID` """ return (uuid_arg.time - 0x01B21DD213814000) / 1e7 def datetime_from_uuid1(uuid_arg): """ Creates a timezone-agnostic datetime from the timestamp in the specified type-1 UUID. :param uuid_arg: a version 1 :class:`~uuid.UUID` """ return datetime_from_timestamp(unix_time_from_uuid1(uuid_arg)) def min_uuid_from_time(timestamp): """ Generates the minimum TimeUUID (type 1) for a given timestamp, as compared by Cassandra. See :func:`uuid_from_time` for argument and return types. """ return uuid_from_time(timestamp, 0x808080808080, 0x80) # Cassandra does byte-wise comparison; fill with min signed bytes (0x80 = -128) def max_uuid_from_time(timestamp): """ Generates the maximum TimeUUID (type 1) for a given timestamp, as compared by Cassandra. See :func:`uuid_from_time` for argument and return types. """ return uuid_from_time(timestamp, 0x7f7f7f7f7f7f, 0x3f7f) # Max signed bytes (0x7f = 127) def uuid_from_time(time_arg, node=None, clock_seq=None): """ Converts a datetime or timestamp to a type 1 :class:`uuid.UUID`. :param time_arg: The time to use for the timestamp portion of the UUID. This can either be a :class:`datetime` object or a timestamp in seconds (as returned from :meth:`time.time()`). :type datetime: :class:`datetime` or timestamp :param node: None integer for the UUID (up to 48 bits). If not specified, this field is randomized. :type node: long :param clock_seq: Clock sequence field for the UUID (up to 14 bits). If not specified, a random sequence is generated. :type clock_seq: int :rtype: :class:`uuid.UUID` """ if hasattr(time_arg, 'utctimetuple'): seconds = int(calendar.timegm(time_arg.utctimetuple())) microseconds = (seconds * 1e6) + time_arg.time().microsecond else: microseconds = int(time_arg * 1e6) # 0x01b21dd213814000 is the number of 100-ns intervals between the # UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00. intervals = int(microseconds * 10) + 0x01b21dd213814000 time_low = intervals & 0xffffffff time_mid = (intervals >> 32) & 0xffff time_hi_version = (intervals >> 48) & 0x0fff if clock_seq is None: clock_seq = random.getrandbits(14) else: if clock_seq > 0x3fff: raise ValueError('clock_seq is out of range (need a 14-bit value)') clock_seq_low = clock_seq & 0xff clock_seq_hi_variant = 0x80 | ((clock_seq >> 8) & 0x3f) if node is None: node = random.getrandbits(48) return uuid.UUID(fields=(time_low, time_mid, time_hi_version, clock_seq_hi_variant, clock_seq_low, node), version=1) LOWEST_TIME_UUID = uuid.UUID('00000000-0000-1000-8080-808080808080') """ The lowest possible TimeUUID, as sorted by Cassandra. """ HIGHEST_TIME_UUID = uuid.UUID('ffffffff-ffff-1fff-bf7f-7f7f7f7f7f7f') """ The highest possible TimeUUID, as sorted by Cassandra. """ try: from collections import OrderedDict except ImportError: # OrderedDict from Python 2.7+ # Copyright (c) 2009 Raymond Hettinger # # Permission is hereby granted, free of charge, to any person # obtaining a copy of this software and associated documentation files # (the "Software"), to deal in the Software without restriction, # including without limitation the rights to use, copy, modify, merge, # publish, distribute, sublicense, and/or sell copies of the Software, # and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR # OTHER DEALINGS IN THE SOFTWARE. from UserDict import DictMixin class OrderedDict(dict, DictMixin): # noqa """ A dictionary which maintains the insertion order of keys. """ def __init__(self, *args, **kwds): """ A dictionary which maintains the insertion order of keys. """ if len(args) > 1: raise TypeError('expected at most 1 arguments, got %d' % len(args)) try: self.__end except AttributeError: self.clear() self.update(*args, **kwds) def clear(self): self.__end = end = [] end += [None, end, end] # sentinel node for doubly linked list self.__map = {} # key --> [key, prev, next] dict.clear(self) def __setitem__(self, key, value): if key not in self: end = self.__end curr = end[1] curr[2] = end[1] = self.__map[key] = [key, curr, end] dict.__setitem__(self, key, value) def __delitem__(self, key): dict.__delitem__(self, key) key, prev, next = self.__map.pop(key) prev[2] = next next[1] = prev def __iter__(self): end = self.__end curr = end[2] while curr is not end: yield curr[0] curr = curr[2] def __reversed__(self): end = self.__end curr = end[1] while curr is not end: yield curr[0] curr = curr[1] def popitem(self, last=True): if not self: raise KeyError('dictionary is empty') if last: key = next(reversed(self)) else: key = next(iter(self)) value = self.pop(key) return key, value def __reduce__(self): items = [[k, self[k]] for k in self] tmp = self.__map, self.__end del self.__map, self.__end inst_dict = vars(self).copy() self.__map, self.__end = tmp if inst_dict: return (self.__class__, (items,), inst_dict) return self.__class__, (items,) def keys(self): return list(self) setdefault = DictMixin.setdefault update = DictMixin.update pop = DictMixin.pop values = DictMixin.values items = DictMixin.items iterkeys = DictMixin.iterkeys itervalues = DictMixin.itervalues iteritems = DictMixin.iteritems def __repr__(self): if not self: return '%s()' % (self.__class__.__name__,) return '%s(%r)' % (self.__class__.__name__, self.items()) def copy(self): return self.__class__(self) @classmethod def fromkeys(cls, iterable, value=None): d = cls() for key in iterable: d[key] = value return d def __eq__(self, other): if isinstance(other, OrderedDict): if len(self) != len(other): return False for p, q in zip(self.items(), other.items()): if p != q: return False return True return dict.__eq__(self, other) def __ne__(self, other): return not self == other # WeakSet from Python 2.7+ (https://code.google.com/p/weakrefset) from _weakref import ref class _IterationGuard(object): # This context manager registers itself in the current iterators of the # weak container, such as to delay all removals until the context manager # exits. # This technique should be relatively thread-safe (since sets are). def __init__(self, weakcontainer): # Don't create cycles self.weakcontainer = ref(weakcontainer) def __enter__(self): w = self.weakcontainer() if w is not None: w._iterating.add(self) return self def __exit__(self, e, t, b): w = self.weakcontainer() if w is not None: s = w._iterating s.remove(self) if not s: w._commit_removals() class WeakSet(object): def __init__(self, data=None): self.data = set() def _remove(item, selfref=ref(self)): self = selfref() if self is not None: if self._iterating: self._pending_removals.append(item) else: self.data.discard(item) self._remove = _remove # A list of keys to be removed self._pending_removals = [] self._iterating = set() if data is not None: self.update(data) def _commit_removals(self): l = self._pending_removals discard = self.data.discard while l: discard(l.pop()) def __iter__(self): with _IterationGuard(self): for itemref in self.data: item = itemref() if item is not None: yield item def __len__(self): return sum(x() is not None for x in self.data) def __contains__(self, item): return ref(item) in self.data def __reduce__(self): return (self.__class__, (list(self),), getattr(self, '__dict__', None)) __hash__ = None def add(self, item): if self._pending_removals: self._commit_removals() self.data.add(ref(item, self._remove)) def clear(self): if self._pending_removals: self._commit_removals() self.data.clear() def copy(self): return self.__class__(self) def pop(self): if self._pending_removals: self._commit_removals() while True: try: itemref = self.data.pop() except KeyError: raise KeyError('pop from empty WeakSet') item = itemref() if item is not None: return item def remove(self, item): if self._pending_removals: self._commit_removals() self.data.remove(ref(item)) def discard(self, item): if self._pending_removals: self._commit_removals() self.data.discard(ref(item)) def update(self, other): if self._pending_removals: self._commit_removals() if isinstance(other, self.__class__): self.data.update(other.data) else: for element in other: self.add(element) def __ior__(self, other): self.update(other) return self # Helper functions for simple delegating methods. def _apply(self, other, method): if not isinstance(other, self.__class__): other = self.__class__(other) newdata = method(other.data) newset = self.__class__() newset.data = newdata return newset def difference(self, other): return self._apply(other, self.data.difference) __sub__ = difference def difference_update(self, other): if self._pending_removals: self._commit_removals() if self is other: self.data.clear() else: self.data.difference_update(ref(item) for item in other) def __isub__(self, other): if self._pending_removals: self._commit_removals() if self is other: self.data.clear() else: self.data.difference_update(ref(item) for item in other) return self def intersection(self, other): return self._apply(other, self.data.intersection) __and__ = intersection def intersection_update(self, other): if self._pending_removals: self._commit_removals() self.data.intersection_update(ref(item) for item in other) def __iand__(self, other): if self._pending_removals: self._commit_removals() self.data.intersection_update(ref(item) for item in other) return self def issubset(self, other): return self.data.issubset(ref(item) for item in other) __lt__ = issubset def __le__(self, other): return self.data <= set(ref(item) for item in other) def issuperset(self, other): return self.data.issuperset(ref(item) for item in other) __gt__ = issuperset def __ge__(self, other): return self.data >= set(ref(item) for item in other) def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented return self.data == set(ref(item) for item in other) def symmetric_difference(self, other): return self._apply(other, self.data.symmetric_difference) __xor__ = symmetric_difference def symmetric_difference_update(self, other): if self._pending_removals: self._commit_removals() if self is other: self.data.clear() else: self.data.symmetric_difference_update(ref(item) for item in other) def __ixor__(self, other): if self._pending_removals: self._commit_removals() if self is other: self.data.clear() else: self.data.symmetric_difference_update(ref(item) for item in other) return self def union(self, other): return self._apply(other, self.data.union) __or__ = union def isdisjoint(self, other): return len(self.intersection(other)) == 0 -from bisect import bisect_left - - class SortedSet(object): ''' A sorted set based on sorted list A sorted set implementation is used in this case because it does not require its elements to be immutable/hashable. #Not implemented: update functions, inplace operators ''' def __init__(self, iterable=()): self._items = [] self.update(iterable) def __len__(self): return len(self._items) def __getitem__(self, i): return self._items[i] def __iter__(self): return iter(self._items) def __reversed__(self): return reversed(self._items) def __repr__(self): return '%s(%r)' % ( self.__class__.__name__, self._items) def __reduce__(self): return self.__class__, (self._items,) def __eq__(self, other): if isinstance(other, self.__class__): return self._items == other._items else: try: return len(other) == len(self._items) and all(item in self for item in other) except TypeError: return NotImplemented def __ne__(self, other): if isinstance(other, self.__class__): return self._items != other._items else: try: return len(other) != len(self._items) or any(item not in self for item in other) except TypeError: return NotImplemented def __le__(self, other): return self.issubset(other) def __lt__(self, other): return len(other) > len(self._items) and self.issubset(other) def __ge__(self, other): return self.issuperset(other) def __gt__(self, other): return len(self._items) > len(other) and self.issuperset(other) def __and__(self, other): return self._intersect(other) __rand__ = __and__ def __iand__(self, other): isect = self._intersect(other) self._items = isect._items return self def __or__(self, other): return self.union(other) __ror__ = __or__ def __ior__(self, other): union = self.union(other) self._items = union._items return self def __sub__(self, other): return self._diff(other) def __rsub__(self, other): return sortedset(other) - self def __isub__(self, other): diff = self._diff(other) self._items = diff._items return self def __xor__(self, other): return self.symmetric_difference(other) __rxor__ = __xor__ def __ixor__(self, other): sym_diff = self.symmetric_difference(other) self._items = sym_diff._items return self def __contains__(self, item): - i = bisect_left(self._items, item) + i = self._find_insertion(item) return i < len(self._items) and self._items[i] == item def __delitem__(self, i): del self._items[i] def __delslice__(self, i, j): del self._items[i:j] def add(self, item): - i = bisect_left(self._items, item) + i = self._find_insertion(item) if i < len(self._items): if self._items[i] != item: self._items.insert(i, item) else: self._items.append(item) def update(self, iterable): for i in iterable: self.add(i) def clear(self): del self._items[:] def copy(self): new = sortedset() new._items = list(self._items) return new def isdisjoint(self, other): return len(self._intersect(other)) == 0 def issubset(self, other): return len(self._intersect(other)) == len(self._items) def issuperset(self, other): return len(self._intersect(other)) == len(other) def pop(self): if not self._items: raise KeyError("pop from empty set") return self._items.pop() def remove(self, item): - i = bisect_left(self._items, item) + i = self._find_insertion(item) if i < len(self._items): if self._items[i] == item: self._items.pop(i) return raise KeyError('%r' % item) def union(self, *others): union = sortedset() union._items = list(self._items) for other in others: - if isinstance(other, self.__class__): - i = 0 - for item in other._items: - i = bisect_left(union._items, item, i) - if i < len(union._items): - if item != union._items[i]: - union._items.insert(i, item) - else: - union._items.append(item) - else: - for item in other: - union.add(item) + for item in other: + union.add(item) return union def intersection(self, *others): isect = self.copy() for other in others: isect = isect._intersect(other) if not isect: break return isect def difference(self, *others): diff = self.copy() for other in others: diff = diff._diff(other) if not diff: break return diff def symmetric_difference(self, other): diff_self_other = self._diff(other) diff_other_self = other.difference(self) return diff_self_other.union(diff_other_self) def _diff(self, other): diff = sortedset() - if isinstance(other, self.__class__): - i = 0 - for item in self._items: - i = bisect_left(other._items, item, i) - if i < len(other._items): - if item != other._items[i]: - diff._items.append(item) - else: - diff._items.append(item) - else: - for item in self._items: - if item not in other: - diff.add(item) + for item in self._items: + if item not in other: + diff.add(item) return diff def _intersect(self, other): isect = sortedset() - if isinstance(other, self.__class__): - i = 0 - for item in self._items: - i = bisect_left(other._items, item, i) - if i < len(other._items): - if item == other._items[i]: - isect._items.append(item) - else: - break - else: - for item in self._items: - if item in other: - isect.add(item) + for item in self._items: + if item in other: + isect.add(item) return isect + def _find_insertion(self, x): + # this uses bisect_left algorithm unless it has elements it can't compare, + # in which case it defaults to grouping non-comparable items at the beginning or end, + # and scanning sequentially to find an insertion point + a = self._items + lo = 0 + hi = len(a) + try: + while lo < hi: + mid = (lo + hi) // 2 + if a[mid] < x: lo = mid + 1 + else: hi = mid + except TypeError: + # could not compare a[mid] with x + # start scanning to find insertion point while swallowing type errors + lo = 0 + compared_one = False # flag is used to determine whether uncomparables are grouped at the front or back + while lo < hi: + try: + if a[lo] == x or a[lo] >= x: break + compared_one = True + except TypeError: + if compared_one: break + lo += 1 + return lo + sortedset = SortedSet # backwards-compatibility -from collections import Mapping +from cassandra.compat import Mapping from six.moves import cPickle class OrderedMap(Mapping): ''' An ordered map that accepts non-hashable types for keys. It also maintains the insertion order of items, behaving as OrderedDict in that regard. These maps are constructed and read just as normal mapping types, exept that they may contain arbitrary collections and other non-hashable items as keys:: >>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'), ... ({'three': 3, 'four': 4}, 'value2')]) >>> list(od.keys()) [{'two': 2, 'one': 1}, {'three': 3, 'four': 4}] >>> list(od.values()) ['value', 'value2'] These constructs are needed to support nested collections in Cassandra 2.1.3+, - where frozen collections can be specified as parameters to others\*:: + where frozen collections can be specified as parameters to others:: CREATE TABLE example ( ... value map>, double> ... ) This class derives from the (immutable) Mapping API. Objects in these maps are not intended be modified. - \* Note: Because of the way Cassandra encodes nested types, when using the + Note: Because of the way Cassandra encodes nested types, when using the driver with nested collections, :attr:`~.Cluster.protocol_version` must be 3 or higher. ''' def __init__(self, *args, **kwargs): if len(args) > 1: raise TypeError('expected at most 1 arguments, got %d' % len(args)) self._items = [] self._index = {} if args: e = args[0] if callable(getattr(e, 'keys', None)): for k in e.keys(): self._insert(k, e[k]) else: for k, v in e: self._insert(k, v) for k, v in six.iteritems(kwargs): self._insert(k, v) def _insert(self, key, value): flat_key = self._serialize_key(key) i = self._index.get(flat_key, -1) if i >= 0: self._items[i] = (key, value) else: self._items.append((key, value)) self._index[flat_key] = len(self._items) - 1 __setitem__ = _insert def __getitem__(self, key): try: index = self._index[self._serialize_key(key)] return self._items[index][1] except KeyError: raise KeyError(str(key)) def __delitem__(self, key): # not efficient -- for convenience only try: index = self._index.pop(self._serialize_key(key)) self._index = dict((k, i if i < index else i - 1) for k, i in self._index.items()) self._items.pop(index) except KeyError: raise KeyError(str(key)) def __iter__(self): for i in self._items: yield i[0] def __len__(self): return len(self._items) def __eq__(self, other): if isinstance(other, OrderedMap): return self._items == other._items try: d = dict(other) return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items) except KeyError: return False except TypeError: pass return NotImplemented def __repr__(self): return '%s([%s])' % ( self.__class__.__name__, ', '.join("(%r, %r)" % (k, v) for k, v in self._items)) def __str__(self): return '{%s}' % ', '.join("%r: %r" % (k, v) for k, v in self._items) def popitem(self): try: kv = self._items.pop() del self._index[self._serialize_key(kv[0])] return kv except IndexError: raise KeyError() def _serialize_key(self, key): return cPickle.dumps(key) class OrderedMapSerializedKey(OrderedMap): def __init__(self, cass_type, protocol_version): super(OrderedMapSerializedKey, self).__init__() self.cass_key_type = cass_type self.protocol_version = protocol_version def _insert_unchecked(self, key, flat_key, value): self._items.append((key, value)) self._index[flat_key] = len(self._items) - 1 def _serialize_key(self, key): return self.cass_key_type.serialize(key, self.protocol_version) import datetime import time if six.PY3: long = int @total_ordering class Time(object): ''' Idealized time, independent of day. Up to nanosecond resolution ''' MICRO = 1000 MILLI = 1000 * MICRO SECOND = 1000 * MILLI MINUTE = 60 * SECOND HOUR = 60 * MINUTE DAY = 24 * HOUR nanosecond_time = 0 def __init__(self, value): """ Initializer value can be: - integer_type: absolute nanoseconds in the day - datetime.time: built-in time - string_type: a string time of the form "HH:MM:SS[.mmmuuunnn]" """ if isinstance(value, six.integer_types): self._from_timestamp(value) elif isinstance(value, datetime.time): self._from_time(value) elif isinstance(value, six.string_types): self._from_timestring(value) else: raise TypeError('Time arguments must be a whole number, datetime.time, or string') @property def hour(self): """ The hour component of this time (0-23) """ return self.nanosecond_time // Time.HOUR @property def minute(self): """ The minute component of this time (0-59) """ minutes = self.nanosecond_time // Time.MINUTE return minutes % 60 @property def second(self): """ The second component of this time (0-59) """ seconds = self.nanosecond_time // Time.SECOND return seconds % 60 @property def nanosecond(self): """ The fractional seconds component of the time, in nanoseconds """ return self.nanosecond_time % Time.SECOND def time(self): """ Return a built-in datetime.time (nanosecond precision truncated to micros). """ return datetime.time(hour=self.hour, minute=self.minute, second=self.second, microsecond=self.nanosecond // Time.MICRO) def _from_timestamp(self, t): if t >= Time.DAY: raise ValueError("value must be less than number of nanoseconds in a day (%d)" % Time.DAY) self.nanosecond_time = t def _from_timestring(self, s): try: parts = s.split('.') base_time = time.strptime(parts[0], "%H:%M:%S") self.nanosecond_time = (base_time.tm_hour * Time.HOUR + base_time.tm_min * Time.MINUTE + base_time.tm_sec * Time.SECOND) if len(parts) > 1: # right pad to 9 digits nano_time_str = parts[1] + "0" * (9 - len(parts[1])) self.nanosecond_time += int(nano_time_str) except ValueError: raise ValueError("can't interpret %r as a time" % (s,)) def _from_time(self, t): self.nanosecond_time = (t.hour * Time.HOUR + t.minute * Time.MINUTE + t.second * Time.SECOND + t.microsecond * Time.MICRO) def __hash__(self): return self.nanosecond_time def __eq__(self, other): if isinstance(other, Time): return self.nanosecond_time == other.nanosecond_time if isinstance(other, six.integer_types): return self.nanosecond_time == other return self.nanosecond_time % Time.MICRO == 0 and \ datetime.time(hour=self.hour, minute=self.minute, second=self.second, microsecond=self.nanosecond // Time.MICRO) == other def __ne__(self, other): return not self.__eq__(other) def __lt__(self, other): if not isinstance(other, Time): return NotImplemented return self.nanosecond_time < other.nanosecond_time def __repr__(self): return "Time(%s)" % self.nanosecond_time def __str__(self): return "%02d:%02d:%02d.%09d" % (self.hour, self.minute, self.second, self.nanosecond) @total_ordering class Date(object): ''' Idealized date: year, month, day Offers wider year range than datetime.date. For Dates that cannot be represented as a datetime.date (because datetime.MINYEAR, datetime.MAXYEAR), this type falls back to printing days_from_epoch offset. ''' MINUTE = 60 HOUR = 60 * MINUTE DAY = 24 * HOUR date_format = "%Y-%m-%d" days_from_epoch = 0 def __init__(self, value): """ Initializer value can be: - integer_type: absolute days from epoch (1970, 1, 1). Can be negative. - datetime.date: built-in date - string_type: a string time of the form "yyyy-mm-dd" """ if isinstance(value, six.integer_types): self.days_from_epoch = value elif isinstance(value, (datetime.date, datetime.datetime)): self._from_timetuple(value.timetuple()) elif isinstance(value, six.string_types): self._from_datestring(value) else: raise TypeError('Date arguments must be a whole number, datetime.date, or string') @property def seconds(self): """ Absolute seconds from epoch (can be negative) """ return self.days_from_epoch * Date.DAY def date(self): """ Return a built-in datetime.date for Dates falling in the years [datetime.MINYEAR, datetime.MAXYEAR] ValueError is raised for Dates outside this range. """ try: dt = datetime_from_timestamp(self.seconds) return datetime.date(dt.year, dt.month, dt.day) except Exception: raise ValueError("%r exceeds ranges for built-in datetime.date" % self) def _from_timetuple(self, t): self.days_from_epoch = calendar.timegm(t) // Date.DAY def _from_datestring(self, s): if s[0] == '+': s = s[1:] dt = datetime.datetime.strptime(s, self.date_format) self._from_timetuple(dt.timetuple()) def __hash__(self): return self.days_from_epoch def __eq__(self, other): if isinstance(other, Date): return self.days_from_epoch == other.days_from_epoch if isinstance(other, six.integer_types): return self.days_from_epoch == other try: return self.date() == other except Exception: return False def __ne__(self, other): return not self.__eq__(other) def __lt__(self, other): if not isinstance(other, Date): return NotImplemented return self.days_from_epoch < other.days_from_epoch def __repr__(self): return "Date(%s)" % self.days_from_epoch def __str__(self): try: dt = datetime_from_timestamp(self.seconds) return "%04d-%02d-%02d" % (dt.year, dt.month, dt.day) except: # If we overflow datetime.[MIN|MAX] return str(self.days_from_epoch) import socket if hasattr(socket, 'inet_pton'): inet_pton = socket.inet_pton inet_ntop = socket.inet_ntop else: """ Windows doesn't have socket.inet_pton and socket.inet_ntop until Python 3.4 This is an alternative impl using ctypes, based on this win_inet_pton project: https://github.com/hickeroar/win_inet_pton """ import ctypes class sockaddr(ctypes.Structure): """ Shared struct for ipv4 and ipv6. https://msdn.microsoft.com/en-us/library/windows/desktop/ms740496(v=vs.85).aspx ``__pad1`` always covers the port. When being used for ``sockaddr_in6``, ``ipv4_addr`` actually covers ``sin6_flowinfo``, resulting in proper alignment for ``ipv6_addr``. """ _fields_ = [("sa_family", ctypes.c_short), ("__pad1", ctypes.c_ushort), ("ipv4_addr", ctypes.c_byte * 4), ("ipv6_addr", ctypes.c_byte * 16), ("__pad2", ctypes.c_ulong)] if hasattr(ctypes, 'windll'): WSAStringToAddressA = ctypes.windll.ws2_32.WSAStringToAddressA WSAAddressToStringA = ctypes.windll.ws2_32.WSAAddressToStringA else: def not_windows(*args): raise OSError("IPv6 addresses cannot be handled on Windows. " "Missing ctypes.windll") WSAStringToAddressA = not_windows WSAAddressToStringA = not_windows def inet_pton(address_family, ip_string): if address_family == socket.AF_INET: return socket.inet_aton(ip_string) addr = sockaddr() addr.sa_family = address_family addr_size = ctypes.c_int(ctypes.sizeof(addr)) if WSAStringToAddressA( ip_string, address_family, None, ctypes.byref(addr), ctypes.byref(addr_size) ) != 0: raise socket.error(ctypes.FormatError()) if address_family == socket.AF_INET6: return ctypes.string_at(addr.ipv6_addr, 16) raise socket.error('unknown address family') def inet_ntop(address_family, packed_ip): if address_family == socket.AF_INET: return socket.inet_ntoa(packed_ip) addr = sockaddr() addr.sa_family = address_family addr_size = ctypes.c_int(ctypes.sizeof(addr)) ip_string = ctypes.create_string_buffer(128) ip_string_size = ctypes.c_int(ctypes.sizeof(ip_string)) if address_family == socket.AF_INET6: if len(packed_ip) != ctypes.sizeof(addr.ipv6_addr): raise socket.error('packed IP wrong length for inet_ntoa') ctypes.memmove(addr.ipv6_addr, packed_ip, 16) else: raise socket.error('unknown address family') if WSAAddressToStringA( ctypes.byref(addr), addr_size, None, ip_string, ctypes.byref(ip_string_size) ) != 0: raise socket.error(ctypes.FormatError()) return ip_string[:ip_string_size.value - 1] import keyword # similar to collections.namedtuple, reproduced here because Python 2.6 did not have the rename logic def _positional_rename_invalid_identifiers(field_names): names_out = list(field_names) for index, name in enumerate(field_names): if (not all(c.isalnum() or c == '_' for c in name) or keyword.iskeyword(name) or not name or name[0].isdigit() or name.startswith('_')): names_out[index] = 'field_%d_' % index return names_out def _sanitize_identifiers(field_names): names_out = _positional_rename_invalid_identifiers(field_names) if len(names_out) != len(set(names_out)): observed_names = set() for index, name in enumerate(names_out): while names_out[index] in observed_names: names_out[index] = "%s_" % (names_out[index],) observed_names.add(names_out[index]) return names_out class Duration(object): """ Cassandra Duration Type """ months = 0 days = 0 nanoseconds = 0 def __init__(self, months=0, days=0, nanoseconds=0): self.months = months self.days = days self.nanoseconds = nanoseconds def __eq__(self, other): return isinstance(other, self.__class__) and self.months == other.months and self.days == other.days and self.nanoseconds == other.nanoseconds def __repr__(self): return "Duration({0}, {1}, {2})".format(self.months, self.days, self.nanoseconds) def __str__(self): has_negative_values = self.months < 0 or self.days < 0 or self.nanoseconds < 0 return '%s%dmo%dd%dns' % ( '-' if has_negative_values else '', abs(self.months), abs(self.days), abs(self.nanoseconds) ) + + +@total_ordering +class Version(object): + """ + Internal minimalist class to compare versions. + A valid version is: .... + + TODO: when python2 support is removed, use packaging.version. + """ + + _version = None + major = None + minor = 0 + patch = 0 + build = 0 + prerelease = 0 + + def __init__(self, version): + self._version = version + if '-' in version: + version_without_prerelease, self.prerelease = version.split('-', 1) + else: + version_without_prerelease = version + parts = list(reversed(version_without_prerelease.split('.'))) + if len(parts) > 4: + raise ValueError("Invalid version: {}. Only 4 " + "components plus prerelease are supported".format(version)) + + self.major = int(parts.pop()) + self.minor = int(parts.pop()) if parts else 0 + self.patch = int(parts.pop()) if parts else 0 + + if parts: # we have a build version + build = parts.pop() + try: + self.build = int(build) + except ValueError: + self.build = build + + def __hash__(self): + return self._version + + def __repr__(self): + version_string = "Version({0}, {1}, {2}".format(self.major, self.minor, self.patch) + if self.build: + version_string += ", {}".format(self.build) + if self.prerelease: + version_string += ", {}".format(self.prerelease) + version_string += ")" + + return version_string + + def __str__(self): + return self._version + + @staticmethod + def _compare_version_part(version, other_version, cmp): + if not (isinstance(version, six.integer_types) and + isinstance(other_version, six.integer_types)): + version = str(version) + other_version = str(other_version) + + return cmp(version, other_version) + + def __eq__(self, other): + if not isinstance(other, Version): + return NotImplemented + + return (self.major == other.major and + self.minor == other.minor and + self.patch == other.patch and + self._compare_version_part(self.build, other.build, lambda s, o: s == o) and + self._compare_version_part(self.prerelease, other.prerelease, lambda s, o: s == o) + ) + + def __gt__(self, other): + if not isinstance(other, Version): + return NotImplemented + + is_major_ge = self.major >= other.major + is_minor_ge = self.minor >= other.minor + is_patch_ge = self.patch >= other.patch + is_build_gt = self._compare_version_part(self.build, other.build, lambda s, o: s > o) + is_build_ge = self._compare_version_part(self.build, other.build, lambda s, o: s >= o) + + # By definition, a prerelease comes BEFORE the actual release, so if a version + # doesn't have a prerelease, it's automatically greater than anything that does + if self.prerelease and not other.prerelease: + is_prerelease_gt = False + elif other.prerelease and not self.prerelease: + is_prerelease_gt = True + else: + is_prerelease_gt = self._compare_version_part(self.prerelease, other.prerelease, lambda s, o: s > o) \ + + return (self.major > other.major or + (is_major_ge and self.minor > other.minor) or + (is_major_ge and is_minor_ge and self.patch > other.patch) or + (is_major_ge and is_minor_ge and is_patch_ge and is_build_gt) or + (is_major_ge and is_minor_ge and is_patch_ge and is_build_ge and is_prerelease_gt) + ) diff --git a/cassandra_driver.egg-info/PKG-INFO b/cassandra_driver.egg-info/PKG-INFO index 11e02d9..79d9098 100644 --- a/cassandra_driver.egg-info/PKG-INFO +++ b/cassandra_driver.egg-info/PKG-INFO @@ -1,112 +1,114 @@ Metadata-Version: 1.1 Name: cassandra-driver -Version: 3.16.0 +Version: 3.20.2 Summary: Python driver for Cassandra Home-page: http://github.com/datastax/python-driver -Author:: Tyler Hobbs -Author-email:: tyler@datastax.com +Author: Tyler Hobbs +Author-email: tyler@datastax.com License: UNKNOWN Description: DataStax Python Driver for Apache Cassandra =========================================== .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master :target: https://travis-ci.org/datastax/python-driver A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. - The driver supports Python 2.7, 3.4, 3.5, and 3.6. + The driver supports Python 2.7, 3.4, 3.5, 3.6 and 3.7. If you require compatibility with DataStax Enterprise, use the `DataStax Enterprise Python Driver `_. **Note:** DataStax products do not support big-endian systems. Feedback Requested ------------------ **Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short). Features -------- * `Synchronous `_ and `Asynchronous `_ APIs * `Simple, Prepared, and Batch statements `_ * Asynchronous IO, parallel execution, request pipelining * `Connection pooling `_ * Automatic node discovery * `Automatic reconnection `_ * Configurable `load balancing `_ and `retry policies `_ * `Concurrent execution utilities `_ * `Object mapper `_ + * `Connecting to DataStax Apollo database (cloud) `_ Installation ------------ Installation through pip is recommended:: $ pip install cassandra-driver For more complete installation instructions, see the `installation guide `_. Documentation ------------- The documentation can be found online `here `_. A couple of links for getting up to speed: * `Installation `_ * `Getting started guide `_ * `API docs `_ * `Performance tips `_ Object Mapper ------------- cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the community) is now maintained as an integral part of this package. Refer to `documentation here `_. Contributing ------------ See `CONTRIBUTING.md `_. Reporting Problems ------------------ Please report any bugs and make any feature requests on the `JIRA `_ issue tracker. If you would like to contribute, please feel free to open a pull request. Getting Help ------------ Your best options for getting help with the driver are the `mailing list `_ and the ``#datastax-drivers`` channel in the `DataStax Academy Slack `_. License ------- - Copyright 2013-2017 DataStax + Copyright DataStax, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Keywords: cassandra,cql,orm Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Natural Language :: English Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2.7 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Software Development :: Libraries :: Python Modules diff --git a/cassandra_driver.egg-info/SOURCES.txt b/cassandra_driver.egg-info/SOURCES.txt index 8680288..cb77c19 100644 --- a/cassandra_driver.egg-info/SOURCES.txt +++ b/cassandra_driver.egg-info/SOURCES.txt @@ -1,66 +1,200 @@ LICENSE MANIFEST.in README.rst ez_setup.py setup.py cassandra/__init__.py cassandra/auth.py cassandra/buffer.pxd cassandra/bytesio.pxd cassandra/bytesio.pyx cassandra/cluster.py cassandra/cmurmur3.c +cassandra/compat.py cassandra/concurrent.py cassandra/connection.py cassandra/cqltypes.py cassandra/cython_deps.py cassandra/cython_marshal.pyx cassandra/cython_utils.pxd cassandra/cython_utils.pyx cassandra/deserializers.pxd cassandra/deserializers.pyx cassandra/encoder.py cassandra/ioutils.pyx cassandra/marshal.py cassandra/metadata.py cassandra/metrics.py cassandra/murmur3.py cassandra/numpyFlags.h cassandra/numpy_parser.pyx cassandra/obj_parser.pyx cassandra/parsing.pxd cassandra/parsing.pyx cassandra/policies.py cassandra/pool.py cassandra/protocol.py cassandra/query.py cassandra/row_parser.pyx cassandra/timestamps.py cassandra/tuple.pxd cassandra/type_codes.pxd cassandra/type_codes.py cassandra/util.py cassandra/cqlengine/__init__.py cassandra/cqlengine/columns.py cassandra/cqlengine/connection.py cassandra/cqlengine/functions.py cassandra/cqlengine/management.py cassandra/cqlengine/models.py cassandra/cqlengine/named.py cassandra/cqlengine/operators.py cassandra/cqlengine/query.py cassandra/cqlengine/statements.py cassandra/cqlengine/usertype.py +cassandra/datastax/__init__.py +cassandra/datastax/cloud/__init__.py cassandra/io/__init__.py cassandra/io/asyncioreactor.py cassandra/io/asyncorereactor.py cassandra/io/eventletreactor.py cassandra/io/geventreactor.py cassandra/io/libevreactor.py cassandra/io/libevwrapper.c cassandra/io/twistedreactor.py cassandra_driver.egg-info/PKG-INFO cassandra_driver.egg-info/SOURCES.txt cassandra_driver.egg-info/dependency_links.txt cassandra_driver.egg-info/requires.txt -cassandra_driver.egg-info/top_level.txt \ No newline at end of file +cassandra_driver.egg-info/top_level.txt +tests/__init__.py +tests/integration/__init__.py +tests/integration/datatype_utils.py +tests/integration/util.py +tests/integration/cqlengine/__init__.py +tests/integration/cqlengine/base.py +tests/integration/cqlengine/test_batch_query.py +tests/integration/cqlengine/test_connections.py +tests/integration/cqlengine/test_consistency.py +tests/integration/cqlengine/test_context_query.py +tests/integration/cqlengine/test_ifexists.py +tests/integration/cqlengine/test_ifnotexists.py +tests/integration/cqlengine/test_lwt_conditional.py +tests/integration/cqlengine/test_timestamp.py +tests/integration/cqlengine/test_ttl.py +tests/integration/cqlengine/columns/__init__.py +tests/integration/cqlengine/columns/test_container_columns.py +tests/integration/cqlengine/columns/test_counter_column.py +tests/integration/cqlengine/columns/test_static_column.py +tests/integration/cqlengine/columns/test_validation.py +tests/integration/cqlengine/columns/test_value_io.py +tests/integration/cqlengine/connections/__init__.py +tests/integration/cqlengine/connections/test_connection.py +tests/integration/cqlengine/management/__init__.py +tests/integration/cqlengine/management/test_compaction_settings.py +tests/integration/cqlengine/management/test_management.py +tests/integration/cqlengine/model/__init__.py +tests/integration/cqlengine/model/test_class_construction.py +tests/integration/cqlengine/model/test_equality_operations.py +tests/integration/cqlengine/model/test_model.py +tests/integration/cqlengine/model/test_model_io.py +tests/integration/cqlengine/model/test_polymorphism.py +tests/integration/cqlengine/model/test_udts.py +tests/integration/cqlengine/model/test_updates.py +tests/integration/cqlengine/model/test_value_lists.py +tests/integration/cqlengine/operators/__init__.py +tests/integration/cqlengine/operators/test_where_operators.py +tests/integration/cqlengine/query/__init__.py +tests/integration/cqlengine/query/test_batch_query.py +tests/integration/cqlengine/query/test_datetime_queries.py +tests/integration/cqlengine/query/test_named.py +tests/integration/cqlengine/query/test_queryoperators.py +tests/integration/cqlengine/query/test_queryset.py +tests/integration/cqlengine/query/test_updates.py +tests/integration/cqlengine/statements/__init__.py +tests/integration/cqlengine/statements/test_assignment_clauses.py +tests/integration/cqlengine/statements/test_base_clause.py +tests/integration/cqlengine/statements/test_base_statement.py +tests/integration/cqlengine/statements/test_delete_statement.py +tests/integration/cqlengine/statements/test_insert_statement.py +tests/integration/cqlengine/statements/test_select_statement.py +tests/integration/cqlengine/statements/test_update_statement.py +tests/integration/cqlengine/statements/test_where_clause.py +tests/integration/long/__init__.py +tests/integration/long/test_consistency.py +tests/integration/long/test_failure_types.py +tests/integration/long/test_ipv6.py +tests/integration/long/test_large_data.py +tests/integration/long/test_loadbalancingpolicies.py +tests/integration/long/test_schema.py +tests/integration/long/test_ssl.py +tests/integration/long/utils.py +tests/integration/simulacron/__init__.py +tests/integration/simulacron/test_cluster.py +tests/integration/simulacron/test_connection.py +tests/integration/simulacron/test_policies.py +tests/integration/simulacron/utils.py +tests/integration/standard/__init__.py +tests/integration/standard/test_authentication.py +tests/integration/standard/test_client_warnings.py +tests/integration/standard/test_cluster.py +tests/integration/standard/test_concurrent.py +tests/integration/standard/test_connection.py +tests/integration/standard/test_control_connection.py +tests/integration/standard/test_custom_payload.py +tests/integration/standard/test_custom_protocol_handler.py +tests/integration/standard/test_cython_protocol_handlers.py +tests/integration/standard/test_dse.py +tests/integration/standard/test_metadata.py +tests/integration/standard/test_metrics.py +tests/integration/standard/test_policies.py +tests/integration/standard/test_prepared_statements.py +tests/integration/standard/test_query.py +tests/integration/standard/test_query_paging.py +tests/integration/standard/test_routing.py +tests/integration/standard/test_row_factories.py +tests/integration/standard/test_types.py +tests/integration/standard/test_udts.py +tests/integration/standard/utils.py +tests/integration/upgrade/__init__.py +tests/integration/upgrade/test_upgrade.py +tests/unit/__init__.py +tests/unit/test_cluster.py +tests/unit/test_concurrent.py +tests/unit/test_connection.py +tests/unit/test_control_connection.py +tests/unit/test_exception.py +tests/unit/test_marshalling.py +tests/unit/test_metadata.py +tests/unit/test_orderedmap.py +tests/unit/test_parameter_binding.py +tests/unit/test_policies.py +tests/unit/test_protocol.py +tests/unit/test_query.py +tests/unit/test_response_future.py +tests/unit/test_resultset.py +tests/unit/test_sortedset.py +tests/unit/test_time_util.py +tests/unit/test_timestamps.py +tests/unit/test_types.py +tests/unit/test_util_types.py +tests/unit/utils.py +tests/unit/cqlengine/__init__.py +tests/unit/cqlengine/test_columns.py +tests/unit/cqlengine/test_connection.py +tests/unit/cqlengine/test_udt.py +tests/unit/cython/__init__.py +tests/unit/cython/test_bytesio.py +tests/unit/cython/test_types.py +tests/unit/cython/test_utils.py +tests/unit/cython/utils.py +tests/unit/io/__init__.py +tests/unit/io/eventlet_utils.py +tests/unit/io/gevent_utils.py +tests/unit/io/test_asyncioreactor.py +tests/unit/io/test_asyncorereactor.py +tests/unit/io/test_eventletreactor.py +tests/unit/io/test_geventreactor.py +tests/unit/io/test_libevreactor.py +tests/unit/io/test_twistedreactor.py +tests/unit/io/utils.py \ No newline at end of file diff --git a/cassandra_driver.egg-info/requires.txt b/cassandra_driver.egg-info/requires.txt index 0b5cc57..e323a45 100644 --- a/cassandra_driver.egg-info/requires.txt +++ b/cassandra_driver.egg-info/requires.txt @@ -1,2 +1 @@ six>=1.9 -futures diff --git a/setup.py b/setup.py index 1b0ebf6..1259092 100644 --- a/setup.py +++ b/setup.py @@ -1,446 +1,448 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import os import sys import warnings if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests": print("Running gevent tests") from gevent.monkey import patch_all patch_all() if __name__ == '__main__' and sys.argv[1] == "eventlet_nosetests": print("Running eventlet tests") from eventlet import monkey_patch monkey_patch() import ez_setup ez_setup.use_setuptools() from setuptools import setup from distutils.command.build_ext import build_ext from distutils.core import Extension from distutils.errors import (CCompilerError, DistutilsPlatformError, DistutilsExecError) from distutils.cmd import Command PY3 = sys.version_info[0] == 3 try: import subprocess has_subprocess = True except ImportError: has_subprocess = False from cassandra import __version__ long_description = "" with open("README.rst") as f: long_description = f.read() try: from nose.commands import nosetests except ImportError: gevent_nosetests = None eventlet_nosetests = None else: class gevent_nosetests(nosetests): description = "run nosetests with gevent monkey patching" class eventlet_nosetests(nosetests): description = "run nosetests with eventlet monkey patching" has_cqlengine = False if __name__ == '__main__' and sys.argv[1] == "install": try: import cqlengine has_cqlengine = True except ImportError: pass PROFILING = False class DocCommand(Command): description = "generate or test documentation" user_options = [("test", "t", "run doctests instead of generating documentation")] boolean_options = ["test"] def initialize_options(self): self.test = False def finalize_options(self): pass def run(self): if self.test: path = "docs/_build/doctest" mode = "doctest" else: path = "docs/_build/%s" % __version__ mode = "html" try: os.makedirs(path) except: pass if has_subprocess: # Prevent run with in-place extensions because cython-generated objects do not carry docstrings # http://docs.cython.org/src/userguide/special_methods.html#docstrings import glob for f in glob.glob("cassandra/*.so"): print("Removing '%s' to allow docs to run on pure python modules." %(f,)) os.unlink(f) # Build io extension to make import and docstrings work try: output = subprocess.check_output( ["python", "setup.py", "build_ext", "--inplace", "--force", "--no-murmur3", "--no-cython"], stderr=subprocess.STDOUT) except subprocess.CalledProcessError as exc: raise RuntimeError("Documentation step '%s' failed: %s: %s" % ("build_ext", exc, exc.output)) else: print(output) try: output = subprocess.check_output( ["sphinx-build", "-b", mode, "docs", path], stderr=subprocess.STDOUT) except subprocess.CalledProcessError as exc: raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output)) else: print(output) print("") print("Documentation step '%s' performed, results here:" % mode) print(" file://%s/%s/index.html" % (os.path.dirname(os.path.realpath(__file__)), path)) class BuildFailed(Exception): def __init__(self, ext): self.ext = ext murmur3_ext = Extension('cassandra.cmurmur3', sources=['cassandra/cmurmur3.c']) libev_ext = Extension('cassandra.io.libevwrapper', sources=['cassandra/io/libevwrapper.c'], include_dirs=['/usr/include/libev', '/usr/local/include', '/opt/local/include'], libraries=['ev'], library_dirs=['/usr/local/lib', '/opt/local/lib']) platform_unsupported_msg = \ """ =============================================================================== The optional C extensions are not supported on this platform. =============================================================================== """ arch_unsupported_msg = \ """ =============================================================================== The optional C extensions are not supported on big-endian systems. =============================================================================== """ pypy_unsupported_msg = \ """ ================================================================================= Some optional C extensions are not supported in PyPy. Only murmur3 will be built. ================================================================================= """ is_windows = os.name == 'nt' is_pypy = "PyPy" in sys.version if is_pypy: sys.stderr.write(pypy_unsupported_msg) is_supported_platform = sys.platform != "cli" and not sys.platform.startswith("java") is_supported_arch = sys.byteorder != "big" if not is_supported_platform: sys.stderr.write(platform_unsupported_msg) elif not is_supported_arch: sys.stderr.write(arch_unsupported_msg) try_extensions = "--no-extensions" not in sys.argv and is_supported_platform and is_supported_arch and not os.environ.get('CASS_DRIVER_NO_EXTENSIONS') try_murmur3 = try_extensions and "--no-murmur3" not in sys.argv try_libev = try_extensions and "--no-libev" not in sys.argv and not is_pypy and not is_windows try_cython = try_extensions and "--no-cython" not in sys.argv and not is_pypy and not os.environ.get('CASS_DRIVER_NO_CYTHON') try_cython &= 'egg_info' not in sys.argv # bypass setup_requires for pip egg_info calls, which will never have --install-option"--no-cython" coming fomr pip sys.argv = [a for a in sys.argv if a not in ("--no-murmur3", "--no-libev", "--no-cython", "--no-extensions")] build_concurrency = int(os.environ.get('CASS_DRIVER_BUILD_CONCURRENCY', '0')) class NoPatchExtension(Extension): # Older versions of setuptools.extension has a static flag which is set False before our # setup_requires lands Cython. It causes our *.pyx sources to be renamed to *.c in # the initializer. # The other workaround would be to manually generate sources, but that bypasses a lot # of the niceness cythonize embodies (setup build dir, conditional build, etc). # Newer setuptools does not have this problem because it checks for cython dynamically. # https://bitbucket.org/pypa/setuptools/commits/714c3144e08fd01a9f61d1c88411e76d2538b2e4 def __init__(self, *args, **kwargs): # bypass the patched init if possible if Extension.__bases__: base, = Extension.__bases__ base.__init__(self, *args, **kwargs) else: Extension.__init__(self, *args, **kwargs) class build_extensions(build_ext): error_message = """ =============================================================================== WARNING: could not compile %s. The C extensions are not required for the driver to run, but they add support for token-aware routing with the Murmur3Partitioner. On Windows, make sure Visual Studio or an SDK is installed, and your environment is configured to build for the appropriate architecture (matching your Python runtime). This is often a matter of using vcvarsall.bat from your install directory, or running from a command prompt in the Visual Studio Tools Start Menu. =============================================================================== """ if is_windows else """ =============================================================================== WARNING: could not compile %s. The C extensions are not required for the driver to run, but they add support for libev and token-aware routing with the Murmur3Partitioner. Linux users should ensure that GCC and the Python headers are available. On Ubuntu and Debian, this can be accomplished by running: $ sudo apt-get install build-essential python-dev On RedHat and RedHat-based systems like CentOS and Fedora: $ sudo yum install gcc python-devel On OSX, homebrew installations of Python should provide the necessary headers. libev Support ------------- For libev support, you will also need to install libev and its headers. On Debian/Ubuntu: $ sudo apt-get install libev4 libev-dev On RHEL/CentOS/Fedora: $ sudo yum install libev libev-devel On OSX, via homebrew: $ brew install libev =============================================================================== """ def run(self): try: self._setup_extensions() build_ext.run(self) except DistutilsPlatformError as exc: sys.stderr.write('%s\n' % str(exc)) warnings.warn(self.error_message % "C extensions.") def build_extensions(self): if build_concurrency > 1: self.check_extensions_list(self.extensions) import multiprocessing.pool multiprocessing.pool.ThreadPool(processes=build_concurrency).map(self.build_extension, self.extensions) else: build_ext.build_extensions(self) def build_extension(self, ext): try: build_ext.build_extension(self, ext) except (CCompilerError, DistutilsExecError, DistutilsPlatformError, IOError) as exc: sys.stderr.write('%s\n' % str(exc)) name = "The %s extension" % (ext.name,) warnings.warn(self.error_message % (name,)) def _setup_extensions(self): # We defer extension setup until this command to leveraage 'setup_requires' pulling in Cython before we # attempt to import anything self.extensions = [] if try_murmur3: self.extensions.append(murmur3_ext) if try_libev: self.extensions.append(libev_ext) if try_cython: try: from Cython.Build import cythonize cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'metadata', 'pool', 'protocol', 'query', 'util'] compile_args = [] if is_windows else ['-Wno-unused-function'] self.extensions.extend(cythonize( [Extension('cassandra.%s' % m, ['cassandra/%s.py' % m], extra_compile_args=compile_args) for m in cython_candidates], nthreads=build_concurrency, exclude_failures=True)) self.extensions.extend(cythonize(NoPatchExtension("*", ["cassandra/*.pyx"], extra_compile_args=compile_args), nthreads=build_concurrency)) except Exception: sys.stderr.write("Failed to cythonize one or more modules. These will not be compiled as extensions (optional).\n") def pre_build_check(): """ Try to verify build tools """ if os.environ.get('CASS_DRIVER_NO_PRE_BUILD_CHECK'): return True try: from distutils.ccompiler import new_compiler from distutils.sysconfig import customize_compiler from distutils.dist import Distribution # base build_ext just to emulate compiler option setup be = build_ext(Distribution()) be.initialize_options() be.finalize_options() # First, make sure we have a Python include directory have_python_include = any(os.path.isfile(os.path.join(p, 'Python.h')) for p in be.include_dirs) if not have_python_include: sys.stderr.write("Did not find 'Python.h' in %s.\n" % (be.include_dirs,)) return False compiler = new_compiler(compiler=be.compiler) customize_compiler(compiler) try: # We must be able to initialize the compiler if it has that method if hasattr(compiler, "initialize"): compiler.initialize() except: return False executables = [] if compiler.compiler_type in ('unix', 'cygwin'): executables = [compiler.executables[exe][0] for exe in ('compiler_so', 'linker_so')] elif compiler.compiler_type == 'nt': executables = [getattr(compiler, exe) for exe in ('cc', 'linker')] if executables: from distutils.spawn import find_executable for exe in executables: if not find_executable(exe): sys.stderr.write("Failed to find %s for compiler type %s.\n" % (exe, compiler.compiler_type)) return False except Exception as exc: sys.stderr.write('%s\n' % str(exc)) sys.stderr.write("Failed pre-build check. Attempting anyway.\n") # if we are unable to positively id the compiler type, or one of these assumptions fails, # just proceed as we would have without the check return True def run_setup(extensions): kw = {'cmdclass': {'doc': DocCommand}} if gevent_nosetests is not None: kw['cmdclass']['gevent_nosetests'] = gevent_nosetests if eventlet_nosetests is not None: kw['cmdclass']['eventlet_nosetests'] = eventlet_nosetests kw['cmdclass']['build_ext'] = build_extensions kw['ext_modules'] = [Extension('DUMMY', [])] # dummy extension makes sure build_ext is called for install if try_cython: # precheck compiler before adding to setup_requires # we don't actually negate try_cython because: # 1.) build_ext eats errors at compile time, letting the install complete while producing useful feedback # 2.) there could be a case where the python environment has cython installed but the system doesn't have build tools if pre_build_check(): - cython_dep = 'Cython>=0.20,!=0.25,<0.29' + cython_dep = 'Cython>=0.20,!=0.25,<0.30' user_specified_cython_version = os.environ.get('CASS_DRIVER_ALLOWED_CYTHON_VERSION') if user_specified_cython_version is not None: cython_dep = 'Cython==%s' % (user_specified_cython_version,) kw['setup_requires'] = [cython_dep] else: sys.stderr.write("Bypassing Cython setup requirement\n") dependencies = ['six >=1.9'] if not PY3: dependencies.append('futures') setup( name='cassandra-driver', version=__version__, description='Python driver for Cassandra', long_description=long_description, url='http://github.com/datastax/python-driver', author='Tyler Hobbs', author_email='tyler@datastax.com', - packages=['cassandra', 'cassandra.io', 'cassandra.cqlengine'], + packages=['cassandra', 'cassandra.io', 'cassandra.cqlengine', 'cassandra.datastax', + 'cassandra.datastax.cloud'], keywords='cassandra,cql,orm', include_package_data=True, install_requires=dependencies, tests_require=['nose', 'mock>=2.0.0', 'PyYAML', 'pytz', 'sure'], classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'License :: OSI Approved :: Apache Software License', 'Natural Language :: English', 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', 'Topic :: Software Development :: Libraries :: Python Modules' ], **kw) run_setup(None) if has_cqlengine: warnings.warn("\n#######\n'cqlengine' package is present on path: %s\n" "cqlengine is now an integrated sub-package of this driver.\n" "It is recommended to remove this package to reduce the chance for conflicting usage" % cqlengine.__file__) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..6260583 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,116 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa +import logging +import sys +import socket +import platform +import os +from concurrent.futures import ThreadPoolExecutor + +log = logging.getLogger() +log.setLevel('DEBUG') +# if nose didn't already attach a log handler, add one here +if not log.handlers: + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s [%(module)s:%(lineno)s]: %(message)s')) + log.addHandler(handler) + + +def is_eventlet_monkey_patched(): + if 'eventlet.patcher' not in sys.modules: + return False + import eventlet.patcher + return eventlet.patcher.is_monkey_patched('socket') + + +def is_gevent_monkey_patched(): + if 'gevent.monkey' not in sys.modules: + return False + import gevent.socket + return socket.socket is gevent.socket.socket + + +def is_gevent_time_monkey_patched(): + import gevent.monkey + return "time" in gevent.monkey.saved + + +def is_eventlet_time_monkey_patched(): + import eventlet + return eventlet.patcher.is_monkey_patched('time') + + +def is_monkey_patched(): + return is_gevent_monkey_patched() or is_eventlet_monkey_patched() + +thread_pool_executor_class = ThreadPoolExecutor + +EVENT_LOOP_MANAGER = os.getenv('EVENT_LOOP_MANAGER', "libev") +if "gevent" in EVENT_LOOP_MANAGER: + import gevent.monkey + gevent.monkey.patch_all() + from cassandra.io.geventreactor import GeventConnection + connection_class = GeventConnection +elif "eventlet" in EVENT_LOOP_MANAGER: + from eventlet import monkey_patch + monkey_patch() + + from cassandra.io.eventletreactor import EventletConnection + connection_class = EventletConnection + + try: + from futurist import GreenThreadPoolExecutor + thread_pool_executor_class = GreenThreadPoolExecutor + except: + # futurist is installed only with python >=3.7 + pass +elif "asyncore" in EVENT_LOOP_MANAGER: + from cassandra.io.asyncorereactor import AsyncoreConnection + connection_class = AsyncoreConnection +elif "twisted" in EVENT_LOOP_MANAGER: + from cassandra.io.twistedreactor import TwistedConnection + connection_class = TwistedConnection +elif "asyncio" in EVENT_LOOP_MANAGER: + from cassandra.io.asyncioreactor import AsyncioConnection + connection_class = AsyncioConnection + +else: + try: + from cassandra.io.libevreactor import LibevConnection + connection_class = LibevConnection + except ImportError: + connection_class = None + + +# If set to to true this will force the Cython tests to run regardless of whether they are installed +cython_env = os.getenv('VERIFY_CYTHON', "False") + + +VERIFY_CYTHON = False + +if(cython_env == 'True'): + VERIFY_CYTHON = True + + +def is_windows(): + return "Windows" in platform.system() + + +notwindows = unittest.skipUnless(not is_windows(), "This test is not adequate for windows") +notpypy = unittest.skipUnless(not platform.python_implementation() == 'PyPy', "This tests is not suitable for pypy") diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..c087200 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,835 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from cassandra.cluster import Cluster + +from tests import connection_class, EVENT_LOOP_MANAGER +Cluster.connection_class = connection_class + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa +from packaging.version import Version +import logging +import socket +import sys +import time +import traceback +import platform +from threading import Event +from subprocess import call +from itertools import groupby +import six +import shutil + +from cassandra import OperationTimedOut, ReadTimeout, ReadFailure, WriteTimeout, WriteFailure, AlreadyExists, \ + InvalidRequest +from cassandra.cluster import NoHostAvailable + +from cassandra.protocol import ConfigurationException + +try: + from ccmlib.dse_cluster import DseCluster + from ccmlib.cluster import Cluster as CCMCluster + from ccmlib.cluster_factory import ClusterFactory as CCMClusterFactory + from ccmlib import common +except ImportError as e: + CCMClusterFactory = None + +log = logging.getLogger(__name__) + +CLUSTER_NAME = 'test_cluster' +SINGLE_NODE_CLUSTER_NAME = 'single_node' +MULTIDC_CLUSTER_NAME = 'multidc_test_cluster' + +CCM_CLUSTER = None + +path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'ccm') +if not os.path.exists(path): + os.mkdir(path) + +cass_version = None +cql_version = None + + +def get_server_versions(): + """ + Probe system.local table to determine Cassandra and CQL version. + Returns a tuple of (cassandra_version, cql_version). + """ + global cass_version, cql_version + + if cass_version is not None: + return (cass_version, cql_version) + + c = Cluster() + s = c.connect() + row = s.execute('SELECT cql_version, release_version FROM system.local')[0] + + cass_version = _tuple_version(row.release_version) + cql_version = _tuple_version(row.cql_version) + + c.shutdown() + + return (cass_version, cql_version) + + +def _tuple_version(version_string): + if '-' in version_string: + version_string = version_string[:version_string.index('-')] + + return tuple([int(p) for p in version_string.split('.')]) + + +def cmd_line_args_to_dict(env_var): + cmd_args_env = os.environ.get(env_var, None) + args = {} + if cmd_args_env: + cmd_args = cmd_args_env.strip().split(' ') + while cmd_args: + cmd_arg = cmd_args.pop(0) + cmd_arg_value = True if cmd_arg.startswith('--') else cmd_args.pop(0) + args[cmd_arg.lstrip('-')] = cmd_arg_value + return args + + +def _get_dse_version_from_cass(cass_version): + if cass_version.startswith('2.1'): + dse_ver = "4.8.15" + elif cass_version.startswith('3.0'): + dse_ver = "5.0.12" + elif cass_version.startswith('3.10') or cass_version.startswith('3.11'): + dse_ver = "5.1.7" + elif cass_version.startswith('4.0'): + dse_ver = "6.0" + else: + log.error("Unknown cassandra version found {0}, defaulting to 2.1".format(cass_version)) + dse_ver = "2.1" + return dse_ver + +USE_CASS_EXTERNAL = bool(os.getenv('USE_CASS_EXTERNAL', False)) +KEEP_TEST_CLUSTER = bool(os.getenv('KEEP_TEST_CLUSTER', False)) +SIMULACRON_JAR = os.getenv('SIMULACRON_JAR', None) +CLOUD_PROXY_PATH = os.getenv('CLOUD_PROXY_PATH', None) + +CASSANDRA_IP = os.getenv('CASSANDRA_IP', '127.0.0.1') +CASSANDRA_DIR = os.getenv('CASSANDRA_DIR', None) + +default_cassandra_version = '3.11.4' +cv_string = os.getenv('CASSANDRA_VERSION', default_cassandra_version) +mcv_string = os.getenv('MAPPED_CASSANDRA_VERSION', None) +try: + cassandra_version = Version(cv_string) # env var is set to test-dse +except: + # fallback to MAPPED_CASSANDRA_VERSION + cassandra_version = Version(mcv_string) +CASSANDRA_VERSION = Version(mcv_string) if mcv_string else cassandra_version +CCM_VERSION = cassandra_version if mcv_string else CASSANDRA_VERSION + +default_dse_version = _get_dse_version_from_cass(CASSANDRA_VERSION.base_version) + +DSE_VERSION = Version(os.getenv('DSE_VERSION', default_dse_version)) + +CCM_KWARGS = {} +if CASSANDRA_DIR: + log.info("Using Cassandra dir: %s", CASSANDRA_DIR) + CCM_KWARGS['install_dir'] = CASSANDRA_DIR + +else: + log.info('Using Cassandra version: %s', CASSANDRA_VERSION) + log.info('Using CCM version: %s', CCM_VERSION) + CCM_KWARGS['version'] = CCM_VERSION + +#This changes the default contact_point parameter in Cluster +def set_default_cass_ip(): + if CASSANDRA_IP.startswith("127.0.0."): + return + defaults = list(Cluster.__init__.__defaults__) + defaults = [[CASSANDRA_IP]] + defaults[1:] + try: + Cluster.__init__.__defaults__ = tuple(defaults) + except: + Cluster.__init__.__func__.__defaults__ = tuple(defaults) + + +def set_default_beta_flag_true(): + defaults = list(Cluster.__init__.__defaults__) + defaults = (defaults[:28] + [True] + defaults[29:]) + try: + Cluster.__init__.__defaults__ = tuple(defaults) + except: + Cluster.__init__.__func__.__defaults__ = tuple(defaults) + + +def get_default_protocol(): + if CASSANDRA_VERSION >= Version('4.0'): + set_default_beta_flag_true() + return 5 + elif CASSANDRA_VERSION >= Version('2.2'): + return 4 + elif CASSANDRA_VERSION >= Version('2.1'): + return 3 + elif CASSANDRA_VERSION >= Version('2.0'): + return 2 + else: + return 1 + + +def get_supported_protocol_versions(): + """ + 1.2 -> 1 + 2.0 -> 2, 1 + 2.1 -> 3, 2, 1 + 2.2 -> 4, 3, 2, 1 + 3.X -> 4, 3 + 3.10 -> 5(beta),4,3 +` """ + if CASSANDRA_VERSION >= Version('4.0'): + return (3, 4, 5) + elif CASSANDRA_VERSION >= Version('3.10'): + return (3, 4) + elif CASSANDRA_VERSION >= Version('3.0'): + return (3, 4) + elif CASSANDRA_VERSION >= Version('2.2'): + return (1, 2, 3, 4) + elif CASSANDRA_VERSION >= Version('2.1'): + return (1, 2, 3) + elif CASSANDRA_VERSION >= Version('2.0'): + return (1, 2) + else: + return (1, ) + + +def get_unsupported_lower_protocol(): + """ + This is used to determine the lowest protocol version that is NOT + supported by the version of C* running + """ + + if CASSANDRA_VERSION >= Version('3.0'): + return 2 + else: + return None + + +def get_unsupported_upper_protocol(): + """ + This is used to determine the highest protocol version that is NOT + supported by the version of C* running + """ + + if CASSANDRA_VERSION >= Version('2.2'): + return None + if CASSANDRA_VERSION >= Version('2.1'): + return 4 + elif CASSANDRA_VERSION >= Version('2.0'): + return 3 + else: + return None + + +default_protocol_version = get_default_protocol() + + +PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', default_protocol_version)) + + +def local_decorator_creator(): + if USE_CASS_EXTERNAL or not CASSANDRA_IP.startswith("127.0.0."): + return unittest.skip('Tests only runs against local C*') + + def _id_and_mark(f): + f.local = True + return f + + return _id_and_mark + +local = local_decorator_creator() +notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, 'Protocol v1 not supported') +lessthenprotocolv4 = unittest.skipUnless(PROTOCOL_VERSION < 4, 'Protocol versions 4 or greater not supported') +greaterthanprotocolv3 = unittest.skipUnless(PROTOCOL_VERSION >= 4, 'Protocol versions less than 4 are not supported') +protocolv5 = unittest.skipUnless(5 in get_supported_protocol_versions(), 'Protocol versions less than 5 are not supported') + +greaterthancass20 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.1'), 'Cassandra version 2.1 or greater required') +greaterthancass21 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.2'), 'Cassandra version 2.2 or greater required') +greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.0'), 'Cassandra version 3.0 or greater required') +greaterthanorequalcass36 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.6'), 'Cassandra version 3.6 or greater required') +greaterthanorequalcass3_10 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.10'), 'Cassandra version 3.10 or greater required') +greaterthanorequalcass3_11 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.11'), 'Cassandra version 3.10 or greater required') +greaterthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION >= Version('4.0'), 'Cassandra version 4.0 or greater required') +lessthanorequalcass40 = unittest.skipIf(CASSANDRA_VERSION >= Version('4.0'), 'Cassandra version 4.0 or greater required') +lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < Version('3.0'), 'Cassandra version less then 3.0 required') +pypy = unittest.skipUnless(platform.python_implementation() == "PyPy", "Test is skipped unless it's on PyPy") +notpy3 = unittest.skipIf(sys.version_info >= (3, 0), "Test not applicable for Python 3.x runtime") +requiresmallclockgranularity = unittest.skipIf("Windows" in platform.system() or "asyncore" in EVENT_LOOP_MANAGER, + "This test is not suitible for environments with large clock granularity") +requiressimulacron = unittest.skipIf(SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"), "Simulacron jar hasn't been specified or C* version is 2.0") +requirescloudproxy = unittest.skipIf(CLOUD_PROXY_PATH is None, "Cloud Proxy path hasn't been specified") + + +def wait_for_node_socket(node, timeout): + binary_itf = node.network_interfaces['binary'] + if not common.check_socket_listening(binary_itf, timeout=timeout): + log.warning("Unable to connect to binary socket for node " + node.name) + else: + log.debug("Node %s is up and listening " % (node.name,)) + + +def check_socket_listening(itf, timeout=60): + end = time.time() + timeout + while time.time() <= end: + try: + sock = socket.socket() + sock.connect(itf) + sock.close() + return True + except socket.error: + # Try again in another 200ms + time.sleep(.2) + continue + return False + + +def get_cluster(): + return CCM_CLUSTER + + +def get_node(node_id): + return CCM_CLUSTER.nodes['node%s' % node_id] + + +def use_multidc(dc_list, workloads=[]): + use_cluster(MULTIDC_CLUSTER_NAME, dc_list, start=True, workloads=workloads) + + +def use_singledc(start=True, workloads=[]): + use_cluster(CLUSTER_NAME, [3], start=start, workloads=workloads) + + +def use_single_node(start=True, workloads=[]): + use_cluster(SINGLE_NODE_CLUSTER_NAME, [1], start=start, workloads=workloads) + + +def remove_cluster(): + if USE_CASS_EXTERNAL or KEEP_TEST_CLUSTER: + return + + global CCM_CLUSTER + if CCM_CLUSTER: + log.debug("Removing cluster {0}".format(CCM_CLUSTER.name)) + tries = 0 + while tries < 100: + try: + CCM_CLUSTER.remove() + CCM_CLUSTER = None + return + except OSError: + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + time.sleep(1) + + raise RuntimeError("Failed to remove cluster after 100 attempts") + + +def is_current_cluster(cluster_name, node_counts): + global CCM_CLUSTER + if CCM_CLUSTER and CCM_CLUSTER.name == cluster_name: + if [len(list(nodes)) for dc, nodes in + groupby(CCM_CLUSTER.nodelist(), lambda n: n.data_center)] == node_counts: + return True + return False + + +def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, set_keyspace=True, ccm_options=None, + configuration_options={}, dse_cluster=False, dse_options={}, + dse_version=None): + if not workloads: + workloads = [] + if (dse_version and not dse_cluster): + raise ValueError('specified dse_version {} but not dse_cluster'.format(dse_version)) + set_default_cass_ip() + + if ccm_options is None and dse_cluster: + ccm_options = {"version": dse_version or DSE_VERSION} + elif ccm_options is None: + ccm_options = CCM_KWARGS.copy() + + if 'version' in ccm_options and not isinstance(ccm_options['version'], Version): + ccm_options['version'] = Version(ccm_options['version']) + + cassandra_version = ccm_options.get('version', CCM_VERSION) + dse_version = ccm_options.get('version', DSE_VERSION) + + if 'version' in ccm_options: + ccm_options['version'] = ccm_options['version'].base_version + + global CCM_CLUSTER + if USE_CASS_EXTERNAL: + if CCM_CLUSTER: + log.debug("Using external CCM cluster {0}".format(CCM_CLUSTER.name)) + else: + log.debug("Using unnamed external cluster") + if set_keyspace and start: + setup_keyspace(ipformat=ipformat, wait=False) + return + + if is_current_cluster(cluster_name, nodes): + log.debug("Using existing cluster, matching topology: {0}".format(cluster_name)) + else: + if CCM_CLUSTER: + log.debug("Stopping existing cluster, topology mismatch: {0}".format(CCM_CLUSTER.name)) + CCM_CLUSTER.stop() + + try: + CCM_CLUSTER = CCMClusterFactory.load(path, cluster_name) + log.debug("Found existing CCM cluster, {0}; clearing.".format(cluster_name)) + CCM_CLUSTER.clear() + CCM_CLUSTER.set_install_dir(**ccm_options) + CCM_CLUSTER.set_configuration_options(configuration_options) + except Exception: + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + + ccm_options.update(cmd_line_args_to_dict('CCM_ARGS')) + + log.debug("Creating new CCM cluster, {0}, with args {1}".format(cluster_name, ccm_options)) + + # Make sure we cleanup old cluster dir if it exists + cluster_path = os.path.join(path, cluster_name) + if os.path.exists(cluster_path): + shutil.rmtree(cluster_path) + + if dse_cluster: + CCM_CLUSTER = DseCluster(path, cluster_name, **ccm_options) + CCM_CLUSTER.set_configuration_options({'start_native_transport': True}) + CCM_CLUSTER.set_configuration_options({'batch_size_warn_threshold_in_kb': 5}) + if dse_version >= Version('5.0'): + CCM_CLUSTER.set_configuration_options({'enable_user_defined_functions': True}) + CCM_CLUSTER.set_configuration_options({'enable_scripted_user_defined_functions': True}) + if 'spark' in workloads: + config_options = {"initial_spark_worker_resources": 0.1} + CCM_CLUSTER.set_dse_configuration_options(config_options) + common.switch_cluster(path, cluster_name) + CCM_CLUSTER.set_configuration_options(configuration_options) + CCM_CLUSTER.populate(nodes, ipformat=ipformat) + + CCM_CLUSTER.set_dse_configuration_options(dse_options) + else: + CCM_CLUSTER = CCMCluster(path, cluster_name, **ccm_options) + CCM_CLUSTER.set_configuration_options({'start_native_transport': True}) + if cassandra_version >= Version('2.2'): + CCM_CLUSTER.set_configuration_options({'enable_user_defined_functions': True}) + if cassandra_version >= Version('3.0'): + CCM_CLUSTER.set_configuration_options({'enable_scripted_user_defined_functions': True}) + common.switch_cluster(path, cluster_name) + CCM_CLUSTER.set_configuration_options(configuration_options) + CCM_CLUSTER.populate(nodes, ipformat=ipformat) + + try: + jvm_args = [] + # This will enable the Mirroring query handler which will echo our custom payload k,v pairs back + + if 'graph' not in workloads: + if PROTOCOL_VERSION >= 4: + jvm_args = [" -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"] + if len(workloads) > 0: + for node in CCM_CLUSTER.nodes.values(): + node.set_workloads(workloads) + if start: + log.debug("Starting CCM cluster: {0}".format(cluster_name)) + CCM_CLUSTER.start(wait_for_binary_proto=True, wait_other_notice=True, jvm_args=jvm_args) + # Added to wait for slow nodes to start up + for node in CCM_CLUSTER.nodes.values(): + wait_for_node_socket(node, 120) + if set_keyspace: + setup_keyspace(ipformat=ipformat) + except Exception: + log.exception("Failed to start CCM cluster; removing cluster.") + + if os.name == "nt": + if CCM_CLUSTER: + for node in six.itervalues(CCM_CLUSTER.nodes): + os.system("taskkill /F /PID " + str(node.pid)) + else: + call(["pkill", "-9", "-f", ".ccm"]) + remove_cluster() + raise + return CCM_CLUSTER + + +def teardown_package(): + if USE_CASS_EXTERNAL or KEEP_TEST_CLUSTER: + return + # when multiple modules are run explicitly, this runs between them + # need to make sure CCM_CLUSTER is properly cleared for that case + remove_cluster() + for cluster_name in [CLUSTER_NAME, MULTIDC_CLUSTER_NAME]: + try: + cluster = CCMClusterFactory.load(path, cluster_name) + try: + cluster.remove() + log.info('Removed cluster: %s' % cluster_name) + except Exception: + log.exception('Failed to remove cluster: %s' % cluster_name) + + except Exception: + log.warning('Did not find cluster: %s' % cluster_name) + + +def execute_until_pass(session, query): + tries = 0 + while tries < 100: + try: + return session.execute(query) + except (ConfigurationException, AlreadyExists, InvalidRequest): + log.warning("Received already exists from query {0} not exiting".format(query)) + # keyspace/table was already created/dropped + return + except (OperationTimedOut, ReadTimeout, ReadFailure, WriteTimeout, WriteFailure): + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + + +def execute_with_long_wait_retry(session, query, timeout=30): + tries = 0 + while tries < 10: + try: + return session.execute(query, timeout=timeout) + except (ConfigurationException, AlreadyExists): + log.warning("Received already exists from query {0} not exiting".format(query)) + # keyspace/table was already created/dropped + return + except (OperationTimedOut, ReadTimeout, ReadFailure, WriteTimeout, WriteFailure): + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + + +def execute_with_retry_tolerant(session, query, retry_exceptions, escape_exception): + # TODO refactor above methods into this one for code reuse + tries = 0 + while tries < 100: + try: + tries += 1 + rs = session.execute(query) + return rs + except escape_exception: + return + except retry_exceptions: + time.sleep(.1) + + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + + +def drop_keyspace_shutdown_cluster(keyspace_name, session, cluster): + try: + execute_with_long_wait_retry(session, "DROP KEYSPACE {0}".format(keyspace_name)) + except: + log.warning("Error encountered when droping keyspace {0}".format(keyspace_name)) + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + finally: + log.warning("Shutting down cluster") + cluster.shutdown() + + +def setup_keyspace(ipformat=None, wait=True): + # wait for nodes to startup + if wait: + time.sleep(10) + + if not ipformat: + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + else: + cluster = Cluster(contact_points=["::1"], protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + try: + for ksname in ('test1rf', 'test2rf', 'test3rf'): + if ksname in cluster.metadata.keyspaces: + execute_until_pass(session, "DROP KEYSPACE %s" % ksname) + + ddl = ''' + CREATE KEYSPACE test3rf + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}''' + execute_with_long_wait_retry(session, ddl) + + ddl = ''' + CREATE KEYSPACE test2rf + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '2'}''' + execute_with_long_wait_retry(session, ddl) + + ddl = ''' + CREATE KEYSPACE test1rf + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}''' + execute_with_long_wait_retry(session, ddl) + + ddl_3f = ''' + CREATE TABLE test3rf.test ( + k int PRIMARY KEY, + v int )''' + execute_with_long_wait_retry(session, ddl_3f) + + ddl_1f = ''' + CREATE TABLE test1rf.test ( + k int PRIMARY KEY, + v int )''' + execute_with_long_wait_retry(session, ddl_1f) + + except Exception: + traceback.print_exc() + raise + finally: + cluster.shutdown() + + +class UpDownWaiter(object): + + def __init__(self, host): + self.down_event = Event() + self.up_event = Event() + host.monitor.register(self) + + def on_up(self, host): + self.up_event.set() + + def on_down(self, host): + self.down_event.set() + + def wait_for_down(self): + self.down_event.wait() + + def wait_for_up(self): + self.up_event.wait() + + +class BasicKeyspaceUnitTestCase(unittest.TestCase): + """ + This is basic unit test case that provides various utility methods that can be leveraged for testcase setup and tear + down + """ + @property + def keyspace_name(self): + return self.ks_name + + @property + def class_table_name(self): + return self.ks_name + + @property + def function_table_name(self): + return self._testMethodName.lower() + + @property + def keyspace_table_name(self): + return "{0}.{1}".format(self.keyspace_name, self._testMethodName.lower()) + + @classmethod + def drop_keyspace(cls): + execute_with_long_wait_retry(cls.session, "DROP KEYSPACE {0}".format(cls.ks_name)) + + @classmethod + def create_keyspace(cls, rf): + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.ks_name, rf) + execute_with_long_wait_retry(cls.session, ddl) + + @classmethod + def common_setup(cls, rf, keyspace_creation=True, create_class_table=False, metrics=False): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION, metrics_enabled=metrics) + cls.session = cls.cluster.connect(wait_for_all_pools=True) + cls.ks_name = cls.__name__.lower() + if keyspace_creation: + cls.create_keyspace(rf) + cls.cass_version, cls.cql_version = get_server_versions() + + if create_class_table: + + ddl = ''' + CREATE TABLE {0}.{1} ( + k int PRIMARY KEY, + v int )'''.format(cls.ks_name, cls.ks_name) + execute_until_pass(cls.session, ddl) + + def create_function_table(self): + ddl = ''' + CREATE TABLE {0}.{1} ( + k int PRIMARY KEY, + v int )'''.format(self.keyspace_name, self.function_table_name) + execute_until_pass(self.session, ddl) + + def drop_function_table(self): + ddl = "DROP TABLE {0}.{1} ".format(self.keyspace_name, self.function_table_name) + execute_until_pass(self.session, ddl) + + +class MockLoggingHandler(logging.Handler): + """Mock logging handler to check for expected logs.""" + + def __init__(self, *args, **kwargs): + self.reset() + logging.Handler.__init__(self, *args, **kwargs) + + def emit(self, record): + self.messages[record.levelname.lower()].append(record.getMessage()) + + def reset(self): + self.messages = { + 'debug': [], + 'info': [], + 'warning': [], + 'error': [], + 'critical': [], + } + + def get_message_count(self, level, sub_string): + count = 0 + for msg in self.messages.get(level): + if sub_string in msg: + count+=1 + return count + + def set_module_name(self, module_name): + """ + This is intended to be used doing: + with MockLoggingHandler().set_module_name(connection.__name__) as mock_handler: + """ + self.module_name = module_name + return self + + def __enter__(self): + self.logger = logging.getLogger(self.module_name) + self.logger.addHandler(self) + return self + + def __exit__(self, *args): + pass + + +class BasicExistingKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): + """ + This is basic unit test defines class level teardown and setup methods. It assumes that keyspace is already defined, or created as part of the test. + """ + @classmethod + def setUpClass(cls): + cls.common_setup(1, keyspace_creation=False) + + @classmethod + def tearDownClass(cls): + cls.cluster.shutdown() + + +class BasicSharedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): + """ + This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. + creates a keyspace named after the testclass with a rf of 1. + """ + @classmethod + def setUpClass(cls): + cls.common_setup(1) + + @classmethod + def tearDownClass(cls): + drop_keyspace_shutdown_cluster(cls.ks_name, cls.session, cls.cluster) + + +class BasicSharedKeyspaceUnitTestCaseRF1(BasicSharedKeyspaceUnitTestCase): + """ + This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. + creates a keyspace named after the testclass with a rf of 1 + """ + @classmethod + def setUpClass(self): + self.common_setup(1, True) + + +class BasicSharedKeyspaceUnitTestCaseRF2(BasicSharedKeyspaceUnitTestCase): + """ + This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. + creates a keyspace named after the test class with a rf of 2, and a table named after the class + """ + @classmethod + def setUpClass(self): + self.common_setup(2) + + +class BasicSharedKeyspaceUnitTestCaseRF3(BasicSharedKeyspaceUnitTestCase): + """ + This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. + creates a keyspace named after the test class with a rf of 3 + """ + @classmethod + def setUpClass(self): + self.common_setup(3) + + +class BasicSharedKeyspaceUnitTestCaseRF3WM(BasicSharedKeyspaceUnitTestCase): + """ + This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. + creates a keyspace named after the test class with a rf of 3 with metrics enabled + """ + @classmethod + def setUpClass(self): + self.common_setup(3, True, True, True) + + +class BasicSharedKeyspaceUnitTestCaseWFunctionTable(BasicSharedKeyspaceUnitTestCase): + """" + This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. + creates a keyspace named after the test class with a rf of 3 and a table named after the class + the table is scoped to just the unit test and will be removed. + + """ + def setUp(self): + self.create_function_table() + + def tearDown(self): + self.drop_function_table() + + +class BasicSegregatedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): + """ + This unit test will create and teardown a keyspace for each individual unit tests. + It has overhead and should only be used with complex unit test were sharing a keyspace will + cause issues. + """ + def setUp(self): + self.common_setup(1) + + def tearDown(self): + drop_keyspace_shutdown_cluster(self.ks_name, self.session, self.cluster) + + +class BasicExistingSegregatedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): + """ + This unit test will create and teardown or each individual unit tests. It assumes that keyspace is existing + or created as part of a test. + It has some overhead and should only be used when sharing cluster/session is not feasible. + """ + def setUp(self): + self.common_setup(1, keyspace_creation=False) + + def tearDown(self): + self.cluster.shutdown() diff --git a/tests/integration/cqlengine/__init__.py b/tests/integration/cqlengine/__init__.py new file mode 100644 index 0000000..d098ea7 --- /dev/null +++ b/tests/integration/cqlengine/__init__.py @@ -0,0 +1,116 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa +from cassandra import ConsistencyLevel + +from cassandra.cqlengine import connection +from cassandra.cqlengine.management import create_keyspace_simple, drop_keyspace, CQLENG_ALLOW_SCHEMA_MANAGEMENT +import cassandra + +from tests.integration import get_server_versions, use_single_node, PROTOCOL_VERSION, CASSANDRA_IP, set_default_cass_ip +DEFAULT_KEYSPACE = 'cqlengine_test' + + +CQL_SKIP_EXECUTE = bool(os.getenv('CQL_SKIP_EXECUTE', False)) + + +def setup_package(): + warnings.simplefilter('always') # for testing warnings, make sure all are let through + os.environ[CQLENG_ALLOW_SCHEMA_MANAGEMENT] = '1' + + set_default_cass_ip() + use_single_node() + + setup_connection(DEFAULT_KEYSPACE) + create_keyspace_simple(DEFAULT_KEYSPACE, 1) + + +def teardown_package(): + connection.unregister_connection("default") + +def is_prepend_reversed(): + # do we have https://issues.apache.org/jira/browse/CASSANDRA-8733 ? + ver, _ = get_server_versions() + return not (ver >= (2, 0, 13) or ver >= (2, 1, 3)) + + +def setup_connection(keyspace_name): + connection.setup([CASSANDRA_IP], + consistency=ConsistencyLevel.ONE, + protocol_version=PROTOCOL_VERSION, + default_keyspace=keyspace_name) + + +class StatementCounter(object): + """ + Simple python object used to hold a count of the number of times + the wrapped method has been invoked + """ + def __init__(self, patched_func): + self.func = patched_func + self.counter = 0 + + def wrapped_execute(self, *args, **kwargs): + self.counter += 1 + return self.func(*args, **kwargs) + + def get_counter(self): + return self.counter + + +def execute_count(expected): + """ + A decorator used wrap cassandra.cqlengine.connection.execute. It counts the number of times this method is invoked + then compares it to the number expected. If they don't match it throws an assertion error. + This function can be disabled by running the test harness with the env variable CQL_SKIP_EXECUTE=1 set + """ + def innerCounter(fn): + def wrapped_function(*args, **kwargs): + # Create a counter monkey patch into cassandra.cqlengine.connection.execute + count = StatementCounter(cassandra.cqlengine.connection.execute) + original_function = cassandra.cqlengine.connection.execute + # Monkey patch in our StatementCounter wrapper + cassandra.cqlengine.connection.execute = count.wrapped_execute + # Invoked the underlying unit test + to_return = fn(*args, **kwargs) + # Get the count from our monkey patched counter + count.get_counter() + # DeMonkey Patch our code + cassandra.cqlengine.connection.execute = original_function + # Check to see if we have a pre-existing test case to work from. + if len(args) is 0: + test_case = unittest.TestCase("__init__") + else: + test_case = args[0] + # Check to see if the count is what you expect + test_case.assertEqual(count.get_counter(), expected, msg="Expected number of cassandra.cqlengine.connection.execute calls ({0}) doesn't match actual number invoked ({1})".format(expected, count.get_counter())) + return to_return + # Name of the wrapped function must match the original or unittest will error out. + wrapped_function.__name__ = fn.__name__ + wrapped_function.__doc__ = fn.__doc__ + # Escape hatch + if(CQL_SKIP_EXECUTE): + return fn + else: + return wrapped_function + + return innerCounter + + diff --git a/tests/integration/cqlengine/base.py b/tests/integration/cqlengine/base.py new file mode 100644 index 0000000..8a69033 --- /dev/null +++ b/tests/integration/cqlengine/base.py @@ -0,0 +1,54 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import sys + +from cassandra.cqlengine.connection import get_session +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns + +from uuid import uuid4 + +class TestQueryUpdateModel(Model): + + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.Integer(primary_key=True) + count = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + text_set = columns.Set(columns.Text, required=False) + text_list = columns.List(columns.Text, required=False) + text_map = columns.Map(columns.Text, columns.Text, required=False) + +class BaseCassEngTestCase(unittest.TestCase): + + session = None + + def setUp(self): + self.session = get_session() + + def assertHasAttr(self, obj, attr): + self.assertTrue(hasattr(obj, attr), + "{0} doesn't have attribute: {1}".format(obj, attr)) + + def assertNotHasAttr(self, obj, attr): + self.assertFalse(hasattr(obj, attr), + "{0} shouldn't have the attribute: {1}".format(obj, attr)) + + if sys.version_info > (3, 0): + def assertItemsEqual(self, first, second, msg=None): + return self.assertCountEqual(first, second, msg) diff --git a/tests/integration/cqlengine/columns/__init__.py b/tests/integration/cqlengine/columns/__init__.py new file mode 100644 index 0000000..386372e --- /dev/null +++ b/tests/integration/cqlengine/columns/__init__.py @@ -0,0 +1,14 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/integration/cqlengine/columns/test_container_columns.py b/tests/integration/cqlengine/columns/test_container_columns.py new file mode 100644 index 0000000..2acf364 --- /dev/null +++ b/tests/integration/cqlengine/columns/test_container_columns.py @@ -0,0 +1,971 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +import json +import logging +import six +import sys +import traceback +from uuid import uuid4 +from packaging.version import Version + +from cassandra import WriteTimeout, OperationTimedOut +import cassandra.cqlengine.columns as columns +from cassandra.cqlengine.functions import get_total_seconds +from cassandra.cqlengine.models import Model, ValidationError +from cassandra.cqlengine.management import sync_table, drop_table + +from tests.integration import CASSANDRA_IP +from tests.integration.cqlengine import is_prepend_reversed +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration import greaterthancass20, CASSANDRA_VERSION + +log = logging.getLogger(__name__) + + +class TestSetModel(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + int_set = columns.Set(columns.Integer, required=False) + text_set = columns.Set(columns.Text, required=False) + + +class JsonTestColumn(columns.Column): + + db_type = 'text' + + def to_python(self, value): + if value is None: + return + if isinstance(value, six.string_types): + return json.loads(value) + else: + return value + + def to_database(self, value): + if value is None: + return + return json.dumps(value) + + +class TestSetColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + drop_table(TestSetModel) + sync_table(TestSetModel) + + @classmethod + def tearDownClass(cls): + drop_table(TestSetModel) + + def test_add_none_fails(self): + self.assertRaises(ValidationError, TestSetModel.create, **{'int_set': set([None])}) + + def test_empty_set_initial(self): + """ + tests that sets are set() by default, should never be none + :return: + """ + m = TestSetModel.create() + m.int_set.add(5) + m.save() + + def test_deleting_last_item_should_succeed(self): + m = TestSetModel.create() + m.int_set.add(5) + m.save() + m.int_set.remove(5) + m.save() + + m = TestSetModel.get(partition=m.partition) + self.assertTrue(5 not in m.int_set) + + def test_blind_deleting_last_item_should_succeed(self): + m = TestSetModel.create() + m.int_set.add(5) + m.save() + + TestSetModel.objects(partition=m.partition).update(int_set=set()) + + m = TestSetModel.get(partition=m.partition) + self.assertTrue(5 not in m.int_set) + + def test_empty_set_retrieval(self): + m = TestSetModel.create() + m2 = TestSetModel.get(partition=m.partition) + m2.int_set.add(3) + + def test_io_success(self): + """ Tests that a basic usage works as expected """ + m1 = TestSetModel.create(int_set=set((1, 2)), text_set=set(('kai', 'andreas'))) + m2 = TestSetModel.get(partition=m1.partition) + + self.assertIsInstance(m2.int_set, set) + self.assertIsInstance(m2.text_set, set) + + self.assertIn(1, m2.int_set) + self.assertIn(2, m2.int_set) + + self.assertIn('kai', m2.text_set) + self.assertIn('andreas', m2.text_set) + + def test_type_validation(self): + """ + Tests that attempting to use the wrong types will raise an exception + """ + self.assertRaises(ValidationError, TestSetModel.create, **{'int_set': set(('string', True)), 'text_set': set((1, 3.0))}) + + def test_element_count_validation(self): + """ + Tests that big collections are detected and raise an exception. + """ + while True: + try: + TestSetModel.create(text_set=set(str(uuid4()) for i in range(65535))) + break + except WriteTimeout: + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + except OperationTimedOut: + #This will happen if the host is remote + self.assertFalse(CASSANDRA_IP.startswith("127.0.0.")) + self.assertRaises(ValidationError, TestSetModel.create, **{'text_set': set(str(uuid4()) for i in range(65536))}) + + def test_partial_updates(self): + """ Tests that partial udpates work as expected """ + m1 = TestSetModel.create(int_set=set((1, 2, 3, 4))) + + m1.int_set.add(5) + m1.int_set.remove(1) + self.assertEqual(m1.int_set, set((2, 3, 4, 5))) + + m1.save() + + m2 = TestSetModel.get(partition=m1.partition) + self.assertEqual(m2.int_set, set((2, 3, 4, 5))) + + def test_instantiation_with_column_class(self): + """ + Tests that columns instantiated with a column class work properly + and that the class is instantiated in the constructor + """ + column = columns.Set(columns.Text) + self.assertIsInstance(column.value_col, columns.Text) + + def test_instantiation_with_column_instance(self): + """ + Tests that columns instantiated with a column instance work properly + """ + column = columns.Set(columns.Text(min_length=100)) + self.assertIsInstance(column.value_col, columns.Text) + + def test_to_python(self): + """ Tests that to_python of value column is called """ + column = columns.Set(JsonTestColumn) + val = set((1, 2, 3)) + db_val = column.to_database(val) + self.assertEqual(db_val, set(json.dumps(v) for v in val)) + py_val = column.to_python(db_val) + self.assertEqual(py_val, val) + + def test_default_empty_container_saving(self): + """ tests that the default empty container is not saved if it hasn't been updated """ + pkey = uuid4() + # create a row with set data + TestSetModel.create(partition=pkey, int_set=set((3, 4))) + # create another with no set data + TestSetModel.create(partition=pkey) + + m = TestSetModel.get(partition=pkey) + self.assertEqual(m.int_set, set((3, 4))) + + +class TestListModel(Model): + + partition = columns.UUID(primary_key=True, default=uuid4) + int_list = columns.List(columns.Integer, required=False) + text_list = columns.List(columns.Text, required=False) + + +class TestListColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + drop_table(TestListModel) + sync_table(TestListModel) + + @classmethod + def tearDownClass(cls): + drop_table(TestListModel) + + def test_initial(self): + tmp = TestListModel.create() + tmp.int_list.append(1) + + def test_initial_retrieve(self): + tmp = TestListModel.create() + tmp2 = TestListModel.get(partition=tmp.partition) + tmp2.int_list.append(1) + + def test_io_success(self): + """ Tests that a basic usage works as expected """ + m1 = TestListModel.create(int_list=[1, 2], text_list=['kai', 'andreas']) + m2 = TestListModel.get(partition=m1.partition) + + self.assertIsInstance(m2.int_list, list) + self.assertIsInstance(m2.text_list, list) + + self.assertEqual(len(m2.int_list), 2) + self.assertEqual(len(m2.text_list), 2) + + self.assertEqual(m2.int_list[0], 1) + self.assertEqual(m2.int_list[1], 2) + + self.assertEqual(m2.text_list[0], 'kai') + self.assertEqual(m2.text_list[1], 'andreas') + + def test_type_validation(self): + """ + Tests that attempting to use the wrong types will raise an exception + """ + self.assertRaises(ValidationError, TestListModel.create, **{'int_list': ['string', True], 'text_list': [1, 3.0]}) + + def test_element_count_validation(self): + """ + Tests that big collections are detected and raise an exception. + """ + while True: + try: + TestListModel.create(text_list=[str(uuid4()) for i in range(65535)]) + break + except WriteTimeout: + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + self.assertRaises(ValidationError, TestListModel.create, **{'text_list': [str(uuid4()) for _ in range(65536)]}) + + def test_partial_updates(self): + """ Tests that partial udpates work as expected """ + full = list(range(10)) + initial = full[3:7] + + m1 = TestListModel.create(int_list=initial) + + m1.int_list = full + m1.save() + + if is_prepend_reversed(): + expected = full[2::-1] + full[3:] + else: + expected = full + + m2 = TestListModel.get(partition=m1.partition) + self.assertEqual(list(m2.int_list), expected) + + def test_instantiation_with_column_class(self): + """ + Tests that columns instantiated with a column class work properly + and that the class is instantiated in the constructor + """ + column = columns.List(columns.Text) + self.assertIsInstance(column.value_col, columns.Text) + + def test_instantiation_with_column_instance(self): + """ + Tests that columns instantiated with a column instance work properly + """ + column = columns.List(columns.Text(min_length=100)) + self.assertIsInstance(column.value_col, columns.Text) + + def test_to_python(self): + """ Tests that to_python of value column is called """ + column = columns.List(JsonTestColumn) + val = [1, 2, 3] + db_val = column.to_database(val) + self.assertEqual(db_val, [json.dumps(v) for v in val]) + py_val = column.to_python(db_val) + self.assertEqual(py_val, val) + + def test_default_empty_container_saving(self): + """ tests that the default empty container is not saved if it hasn't been updated """ + pkey = uuid4() + # create a row with list data + TestListModel.create(partition=pkey, int_list=[1, 2, 3, 4]) + # create another with no list data + TestListModel.create(partition=pkey) + + m = TestListModel.get(partition=pkey) + self.assertEqual(m.int_list, [1, 2, 3, 4]) + + def test_remove_entry_works(self): + pkey = uuid4() + tmp = TestListModel.create(partition=pkey, int_list=[1, 2]) + tmp.int_list.pop() + tmp.update() + tmp = TestListModel.get(partition=pkey) + self.assertEqual(tmp.int_list, [1]) + + def test_update_from_non_empty_to_empty(self): + pkey = uuid4() + tmp = TestListModel.create(partition=pkey, int_list=[1, 2]) + tmp.int_list = [] + tmp.update() + + tmp = TestListModel.get(partition=pkey) + self.assertEqual(tmp.int_list, []) + + def test_insert_none(self): + pkey = uuid4() + self.assertRaises(ValidationError, TestListModel.create, **{'partition': pkey, 'int_list': [None]}) + + def test_blind_list_updates_from_none(self): + """ Tests that updates from None work as expected """ + m = TestListModel.create(int_list=None) + expected = [1, 2] + m.int_list = expected + m.save() + + m2 = TestListModel.get(partition=m.partition) + self.assertEqual(m2.int_list, expected) + + TestListModel.objects(partition=m.partition).update(int_list=[]) + + m3 = TestListModel.get(partition=m.partition) + self.assertEqual(m3.int_list, []) + + +class TestMapModel(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + int_map = columns.Map(columns.Integer, columns.UUID, required=False) + text_map = columns.Map(columns.Text, columns.DateTime, required=False) + + +class TestMapColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + drop_table(TestMapModel) + sync_table(TestMapModel) + + @classmethod + def tearDownClass(cls): + drop_table(TestMapModel) + + def test_empty_default(self): + tmp = TestMapModel.create() + tmp.int_map['blah'] = 1 + + def test_add_none_as_map_key(self): + self.assertRaises(ValidationError, TestMapModel.create, **{'int_map': {None: uuid4()}}) + + def test_empty_retrieve(self): + tmp = TestMapModel.create() + tmp2 = TestMapModel.get(partition=tmp.partition) + tmp2.int_map['blah'] = 1 + + def test_remove_last_entry_works(self): + tmp = TestMapModel.create() + tmp.text_map["blah"] = datetime.now() + tmp.save() + del tmp.text_map["blah"] + tmp.save() + + tmp = TestMapModel.get(partition=tmp.partition) + self.assertTrue("blah" not in tmp.int_map) + + def test_io_success(self): + """ Tests that a basic usage works as expected """ + k1 = uuid4() + k2 = uuid4() + now = datetime.now() + then = now + timedelta(days=1) + m1 = TestMapModel.create(int_map={1: k1, 2: k2}, + text_map={'now': now, 'then': then}) + m2 = TestMapModel.get(partition=m1.partition) + + self.assertTrue(isinstance(m2.int_map, dict)) + self.assertTrue(isinstance(m2.text_map, dict)) + + self.assertTrue(1 in m2.int_map) + self.assertTrue(2 in m2.int_map) + self.assertEqual(m2.int_map[1], k1) + self.assertEqual(m2.int_map[2], k2) + + self.assertAlmostEqual(get_total_seconds(now - m2.text_map['now']), 0, 2) + self.assertAlmostEqual(get_total_seconds(then - m2.text_map['then']), 0, 2) + + def test_type_validation(self): + """ + Tests that attempting to use the wrong types will raise an exception + """ + self.assertRaises(ValidationError, TestMapModel.create, **{'int_map': {'key': 2, uuid4(): 'val'}, 'text_map': {2: 5}}) + + def test_element_count_validation(self): + """ + Tests that big collections are detected and raise an exception. + """ + while True: + try: + TestMapModel.create(text_map=dict((str(uuid4()), i) for i in range(65535))) + break + except WriteTimeout: + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + self.assertRaises(ValidationError, TestMapModel.create, **{'text_map': dict((str(uuid4()), i) for i in range(65536))}) + + def test_partial_updates(self): + """ Tests that partial udpates work as expected """ + now = datetime.now() + # derez it a bit + now = datetime(*now.timetuple()[:-3]) + early = now - timedelta(minutes=30) + earlier = early - timedelta(minutes=30) + later = now + timedelta(minutes=30) + + initial = {'now': now, 'early': earlier} + final = {'later': later, 'early': early} + + m1 = TestMapModel.create(text_map=initial) + + m1.text_map = final + m1.save() + + m2 = TestMapModel.get(partition=m1.partition) + self.assertEqual(m2.text_map, final) + + def test_updates_from_none(self): + """ Tests that updates from None work as expected """ + m = TestMapModel.create(int_map=None) + expected = {1: uuid4()} + m.int_map = expected + m.save() + + m2 = TestMapModel.get(partition=m.partition) + self.assertEqual(m2.int_map, expected) + + m2.int_map = None + m2.save() + m3 = TestMapModel.get(partition=m.partition) + self.assertNotEqual(m3.int_map, expected) + + def test_blind_updates_from_none(self): + """ Tests that updates from None work as expected """ + m = TestMapModel.create(int_map=None) + expected = {1: uuid4()} + m.int_map = expected + m.save() + + m2 = TestMapModel.get(partition=m.partition) + self.assertEqual(m2.int_map, expected) + + TestMapModel.objects(partition=m.partition).update(int_map={}) + + m3 = TestMapModel.get(partition=m.partition) + self.assertNotEqual(m3.int_map, expected) + + def test_updates_to_none(self): + """ Tests that setting the field to None works as expected """ + m = TestMapModel.create(int_map={1: uuid4()}) + m.int_map = None + m.save() + + m2 = TestMapModel.get(partition=m.partition) + self.assertEqual(m2.int_map, {}) + + def test_instantiation_with_column_class(self): + """ + Tests that columns instantiated with a column class work properly + and that the class is instantiated in the constructor + """ + column = columns.Map(columns.Text, columns.Integer) + self.assertIsInstance(column.key_col, columns.Text) + self.assertIsInstance(column.value_col, columns.Integer) + + def test_instantiation_with_column_instance(self): + """ + Tests that columns instantiated with a column instance work properly + """ + column = columns.Map(columns.Text(min_length=100), columns.Integer()) + self.assertIsInstance(column.key_col, columns.Text) + self.assertIsInstance(column.value_col, columns.Integer) + + def test_to_python(self): + """ Tests that to_python of value column is called """ + column = columns.Map(JsonTestColumn, JsonTestColumn) + val = {1: 2, 3: 4, 5: 6} + db_val = column.to_database(val) + self.assertEqual(db_val, dict((json.dumps(k), json.dumps(v)) for k, v in val.items())) + py_val = column.to_python(db_val) + self.assertEqual(py_val, val) + + def test_default_empty_container_saving(self): + """ tests that the default empty container is not saved if it hasn't been updated """ + pkey = uuid4() + tmap = {1: uuid4(), 2: uuid4()} + # create a row with set data + TestMapModel.create(partition=pkey, int_map=tmap) + # create another with no set data + TestMapModel.create(partition=pkey) + + m = TestMapModel.get(partition=pkey) + self.assertEqual(m.int_map, tmap) + + +class TestCamelMapModel(Model): + + partition = columns.UUID(primary_key=True, default=uuid4) + camelMap = columns.Map(columns.Text, columns.Integer, required=False) + + +class TestCamelMapColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + drop_table(TestCamelMapModel) + sync_table(TestCamelMapModel) + + @classmethod + def tearDownClass(cls): + drop_table(TestCamelMapModel) + + def test_camelcase_column(self): + TestCamelMapModel.create(camelMap={'blah': 1}) + + +class TestTupleModel(Model): + + partition = columns.UUID(primary_key=True, default=uuid4) + int_tuple = columns.Tuple(columns.Integer, columns.Integer, columns.Integer, required=False) + text_tuple = columns.Tuple(columns.Text, columns.Text, required=False) + mixed_tuple = columns.Tuple(columns.Text, columns.Integer, columns.Text, required=False) + + +@greaterthancass20 +class TestTupleColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + # Skip annotations don't seem to skip class level teradown and setup methods + if CASSANDRA_VERSION >= Version('2.1'): + drop_table(TestTupleModel) + sync_table(TestTupleModel) + + @classmethod + def tearDownClass(cls): + drop_table(TestTupleModel) + + def test_initial(self): + """ + Tests creation and insertion of tuple types with models + + @since 3.1 + @jira_ticket PYTHON-306 + @expected_result Model is successfully crated + + @test_category object_mapper + """ + tmp = TestTupleModel.create() + tmp.int_tuple = (1, 2, 3) + + def test_initial_retrieve(self): + """ + Tests creation and insertion of tuple types with models, + and their retrieval. + + @since 3.1 + @jira_ticket PYTHON-306 + @expected_result Model is successfully crated + + @test_category object_mapper + """ + + tmp = TestTupleModel.create() + tmp2 = tmp.get(partition=tmp.partition) + tmp2.int_tuple = (1, 2, 3) + + def test_io_success(self): + """ + Tests creation and insertion of various types with models, + and their retrieval. + + @since 3.1 + @jira_ticket PYTHON-306 + @expected_result Model is successfully created and fetched correctly + + @test_category object_mapper + """ + m1 = TestTupleModel.create(int_tuple=(1, 2, 3, 5, 6), text_tuple=('kai', 'andreas'), mixed_tuple=('first', 2, 'Third')) + m2 = TestTupleModel.get(partition=m1.partition) + + self.assertIsInstance(m2.int_tuple, tuple) + self.assertIsInstance(m2.text_tuple, tuple) + self.assertIsInstance(m2.mixed_tuple, tuple) + + self.assertEqual((1, 2, 3), m2.int_tuple) + self.assertEqual(('kai', 'andreas'), m2.text_tuple) + self.assertEqual(('first', 2, 'Third'), m2.mixed_tuple) + + def test_type_validation(self): + """ + Tests to make sure tuple type validation occurs + + @since 3.1 + @jira_ticket PYTHON-306 + @expected_result validation errors are raised + + @test_category object_mapper + """ + self.assertRaises(ValidationError, TestTupleModel.create, **{'int_tuple': ('string', True), 'text_tuple': ('test', 'test'), 'mixed_tuple': ('one', 2, 'three')}) + self.assertRaises(ValidationError, TestTupleModel.create, **{'int_tuple': ('string', 'string'), 'text_tuple': (1, 3.0), 'mixed_tuple': ('one', 2, 'three')}) + self.assertRaises(ValidationError, TestTupleModel.create, **{'int_tuple': ('string', 'string'), 'text_tuple': ('test', 'test'), 'mixed_tuple': (1, "two", 3)}) + + def test_instantiation_with_column_class(self): + """ + Tests that columns instantiated with a column class work properly + and that the class is instantiated in the constructor + + @since 3.1 + @jira_ticket PYTHON-306 + @expected_result types are instantiated correctly + + @test_category object_mapper + """ + mixed_tuple = columns.Tuple(columns.Text, columns.Integer, columns.Text, required=False) + self.assertIsInstance(mixed_tuple.types[0], columns.Text) + self.assertIsInstance(mixed_tuple.types[1], columns.Integer) + self.assertIsInstance(mixed_tuple.types[2], columns.Text) + self.assertEqual(len(mixed_tuple.types), 3) + + def test_default_empty_container_saving(self): + """ + Tests that the default empty container is not saved if it hasn't been updated + + @since 3.1 + @jira_ticket PYTHON-306 + @expected_result empty tuple is not upserted + + @test_category object_mapper + """ + pkey = uuid4() + # create a row with tuple data + TestTupleModel.create(partition=pkey, int_tuple=(1, 2, 3)) + # create another with no tuple data + TestTupleModel.create(partition=pkey) + + m = TestTupleModel.get(partition=pkey) + self.assertEqual(m.int_tuple, (1, 2, 3)) + + def test_updates(self): + """ + Tests that updates can be preformed on tuple containers + + @since 3.1 + @jira_ticket PYTHON-306 + @expected_result tuple is replaced + + @test_category object_mapper + """ + initial = (1, 2) + replacement = (1, 2, 3) + + m1 = TestTupleModel.create(int_tuple=initial) + m1.int_tuple = replacement + m1.save() + + m2 = TestTupleModel.get(partition=m1.partition) + self.assertEqual(tuple(m2.int_tuple), replacement) + + def test_update_from_non_empty_to_empty(self): + """ + Tests that explcit none updates are processed correctly on tuples + + @since 3.1 + @jira_ticket PYTHON-306 + @expected_result tuple is replaced with none + + @test_category object_mapper + """ + pkey = uuid4() + tmp = TestTupleModel.create(partition=pkey, int_tuple=(1, 2, 3)) + tmp.int_tuple = (None) + tmp.update() + + tmp = TestTupleModel.get(partition=pkey) + self.assertEqual(tmp.int_tuple, (None)) + + def test_insert_none(self): + """ + Tests that Tuples can be created as none + + @since 3.1 + @jira_ticket PYTHON-306 + @expected_result tuple is created as none + + @test_category object_mapper + """ + pkey = uuid4() + tmp = TestTupleModel.create(partition=pkey, int_tuple=(None)) + self.assertEqual((None), tmp.int_tuple) + + def test_blind_tuple_updates_from_none(self): + """ + Tests that Tuples can be updated from none + + @since 3.1 + @jira_ticket PYTHON-306 + @expected_result tuple is created as none, but upserted to contain values + + @test_category object_mapper + """ + + m = TestTupleModel.create(int_tuple=None) + expected = (1, 2, 3) + m.int_tuple = expected + m.save() + + m2 = TestTupleModel.get(partition=m.partition) + self.assertEqual(m2.int_tuple, expected) + + TestTupleModel.objects(partition=m.partition).update(int_tuple=None) + + m3 = TestTupleModel.get(partition=m.partition) + self.assertEqual(m3.int_tuple, None) + + +class TestNestedModel(Model): + + partition = columns.UUID(primary_key=True, default=uuid4) + list_list = columns.List(columns.List(columns.Integer), required=False) + map_list = columns.Map(columns.Text, columns.List(columns.Text), required=False) + set_tuple = columns.Set(columns.Tuple(columns.Integer, columns.Integer), required=False) + + +@greaterthancass20 +class TestNestedType(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + # Skip annotations don't seem to skip class level teradown and setup methods + if CASSANDRA_VERSION >= Version('2.1'): + drop_table(TestNestedModel) + sync_table(TestNestedModel) + + @classmethod + def tearDownClass(cls): + drop_table(TestNestedModel) + + def test_initial(self): + """ + Tests creation and insertion of nested collection types with models + + @since 3.1 + @jira_ticket PYTHON-478 + @expected_result Model is successfully created + + @test_category object_mapper + """ + tmp = TestNestedModel.create() + tmp.list_list = [[1, 2, 3], [2, 3, 4], [3, 4, 5]] + + def test_initial_retrieve(self): + """ + Tests creation and insertion of nested collection types with models, + and their retrieval. + + @since 3.1 + @jira_ticket PYTHON-478 + @expected_result Model is successfully crated + + @test_category object_mapper + """ + + tmp = TestNestedModel.create() + tmp2 = tmp.get(partition=tmp.partition) + tmp2.list_list = [[1, 2, 3], [2, 3, 4], [3, 4, 5]] + tmp2.map_list = {'key1': ["text1", "text2", "text3"], "key2": ["text1", "text2", "text3"], "key3": ["text1", "text2", "text3"]} + tmp2.set_tuple = set(((1, 2), (3, 5), (4, 5))) + + def test_io_success(self): + """ + Tests creation and insertion of various nested collection types with models, + and their retrieval. + + @since 3.1 + @jira_ticket PYTHON-378 + @expected_result Model is successfully created and fetched correctly + + @test_category object_mapper + """ + list_list_master = [[1, 2, 3], [2, 3, 4], [3, 4, 5]] + map_list_master = {'key1': ["text1", "text2", "text3"], "key2": ["text1", "text2", "text3"], "key3": ["text1", "text2", "text3"]} + set_tuple_master = set(((1, 2), (3, 5), (4, 5))) + + m1 = TestNestedModel.create(list_list=list_list_master, map_list=map_list_master, set_tuple=set_tuple_master) + m2 = TestNestedModel.get(partition=m1.partition) + + self.assertIsInstance(m2.list_list, list) + self.assertIsInstance(m2.list_list[0], list) + self.assertIsInstance(m2.map_list, dict) + self.assertIsInstance(m2.map_list.get("key2"), list) + + self.assertEqual(list_list_master, m2.list_list) + self.assertEqual(map_list_master, m2.map_list) + self.assertEqual(set_tuple_master, m2.set_tuple) + self.assertIsInstance(m2.set_tuple.pop(), tuple) + + def test_type_validation(self): + """ + Tests to make sure nested collection type validation occurs + + @since 3.1 + @jira_ticket PYTHON-478 + @expected_result validation errors are raised + + @test_category object_mapper + """ + list_list_bad_list_context = [['text', "text", "text"], ["text", "text", "text"], ["text", "text", "text"]] + list_list_no_list = ['text', "text", "text"] + + map_list_bad_value = {'key1': [1, 2, 3], "key2": [1, 2, 3], "key3": [1, 2, 3]} + map_list_bad_key = {1: ["text1", "text2", "text3"], 2: ["text1", "text2", "text3"], 3: ["text1", "text2", "text3"]} + + set_tuple_bad_tuple_value = set((("text", "text"), ("text", "text"), ("text", "text"))) + set_tuple_not_set = ['This', 'is', 'not', 'a', 'set'] + + self.assertRaises(ValidationError, TestNestedModel.create, **{'list_list': list_list_bad_list_context}) + self.assertRaises(ValidationError, TestNestedModel.create, **{'list_list': list_list_no_list}) + self.assertRaises(ValidationError, TestNestedModel.create, **{'map_list': map_list_bad_value}) + self.assertRaises(ValidationError, TestNestedModel.create, **{'map_list': map_list_bad_key}) + self.assertRaises(ValidationError, TestNestedModel.create, **{'set_tuple': set_tuple_bad_tuple_value}) + self.assertRaises(ValidationError, TestNestedModel.create, **{'set_tuple': set_tuple_not_set}) + + def test_instantiation_with_column_class(self): + """ + Tests that columns instantiated with a column class work properly + and that the class is instantiated in the constructor + + @since 3.1 + @jira_ticket PYTHON-478 + @expected_result types are instantiated correctly + + @test_category object_mapper + """ + list_list = columns.List(columns.List(columns.Integer), required=False) + map_list = columns.Map(columns.Text, columns.List(columns.Text), required=False) + set_tuple = columns.Set(columns.Tuple(columns.Integer, columns.Integer), required=False) + + self.assertIsInstance(list_list, columns.List) + self.assertIsInstance(list_list.types[0], columns.List) + self.assertIsInstance(map_list.types[0], columns.Text) + self.assertIsInstance(map_list.types[1], columns.List) + self.assertIsInstance(set_tuple.types[0], columns.Tuple) + + def test_default_empty_container_saving(self): + """ + Tests that the default empty nested collection container is not saved if it hasn't been updated + + @since 3.1 + @jira_ticket PYTHON-478 + @expected_result empty nested collection is not upserted + + @test_category object_mapper + """ + pkey = uuid4() + # create a row with tuple data + list_list_master = [[1, 2, 3], [2, 3, 4], [3, 4, 5]] + map_list_master = {'key1': ["text1", "text2", "text3"], "key2": ["text1", "text2", "text3"], "key3": ["text1", "text2", "text3"]} + set_tuple_master = set(((1, 2), (3, 5), (4, 5))) + + TestNestedModel.create(partition=pkey, list_list=list_list_master, map_list=map_list_master, set_tuple=set_tuple_master) + # create another with no tuple data + TestNestedModel.create(partition=pkey) + + m = TestNestedModel.get(partition=pkey) + self.assertEqual(m.list_list, list_list_master) + self.assertEqual(m.map_list, map_list_master) + self.assertEqual(m.set_tuple, set_tuple_master) + + def test_updates(self): + """ + Tests that updates can be preformed on nested collections + + @since 3.1 + @jira_ticket PYTHON-478 + @expected_result nested collection is replaced + + @test_category object_mapper + """ + list_list_initial = [[1, 2, 3], [2, 3, 4], [3, 4, 5]] + list_list_replacement = [[1, 2, 3], [3, 4, 5]] + set_tuple_initial = set(((1, 2), (3, 5), (4, 5))) + + map_list_initial = {'key1': ["text1", "text2", "text3"], "key2": ["text1", "text2", "text3"], "key3": ["text1", "text2", "text3"]} + map_list_replacement = {'key1': ["text1", "text2", "text3"], "key3": ["text1", "text2", "text3"]} + set_tuple_replacement = set(((7, 7), (7, 7), (4, 5))) + + m1 = TestNestedModel.create(list_list=list_list_initial, map_list=map_list_initial, set_tuple=set_tuple_initial) + m1.list_list = list_list_replacement + m1.map_list = map_list_replacement + m1.set_tuple = set_tuple_replacement + m1.save() + + m2 = TestNestedModel.get(partition=m1.partition) + self.assertEqual(m2.list_list, list_list_replacement) + self.assertEqual(m2.map_list, map_list_replacement) + self.assertEqual(m2.set_tuple, set_tuple_replacement) + + def test_update_from_non_empty_to_empty(self): + """ + Tests that explcit none updates are processed correctly on tuples + + @since 3.1 + @jira_ticket PYTHON-478 + @expected_result nested collection is replaced with none + + @test_category object_mapper + """ + list_list_initial = [[1, 2, 3], [2, 3, 4], [3, 4, 5]] + map_list_initial = {'key1': ["text1", "text2", "text3"], "key2": ["text1", "text2", "text3"], "key3": ["text1", "text2", "text3"]} + set_tuple_initial = set(((1, 2), (3, 5), (4, 5))) + tmp = TestNestedModel.create(list_list=list_list_initial, map_list=map_list_initial, set_tuple=set_tuple_initial) + tmp.list_list = [] + tmp.map_list = {} + tmp.set_tuple = set() + tmp.update() + + tmp = TestNestedModel.get(partition=tmp.partition) + self.assertEqual(tmp.list_list, []) + self.assertEqual(tmp.map_list, {}) + self.assertEqual(tmp.set_tuple, set()) + + def test_insert_none(self): + """ + Tests that Tuples can be created as none + + @since 3.1 + @jira_ticket PYTHON-478 + @expected_result nested collection is created as none + + @test_category object_mapper + """ + pkey = uuid4() + tmp = TestNestedModel.create(partition=pkey, list_list=(None), map_list=(None), set_tuple=(None)) + self.assertEqual([], tmp.list_list) + self.assertEqual({}, tmp.map_list) + self.assertEqual(set(), tmp.set_tuple) + + diff --git a/tests/integration/cqlengine/columns/test_counter_column.py b/tests/integration/cqlengine/columns/test_counter_column.py new file mode 100644 index 0000000..95792dd --- /dev/null +++ b/tests/integration/cqlengine/columns/test_counter_column.py @@ -0,0 +1,130 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import uuid4 + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.models import Model, ModelDefinitionException +from tests.integration.cqlengine.base import BaseCassEngTestCase + + +class TestCounterModel(Model): + + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.UUID(primary_key=True, default=uuid4) + counter = columns.Counter() + + +class TestClassConstruction(BaseCassEngTestCase): + + def test_defining_a_non_counter_column_fails(self): + """ Tests that defining a non counter column field in a model with a counter column fails """ + try: + class model(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + counter = columns.Counter() + text = columns.Text() + self.fail("did not raise expected ModelDefinitionException") + except ModelDefinitionException: + pass + + + def test_defining_a_primary_key_counter_column_fails(self): + """ Tests that defining primary keys on counter columns fails """ + try: + class model(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.Counter(primary_ley=True) + counter = columns.Counter() + self.fail("did not raise expected TypeError") + except TypeError: + pass + + # force it + try: + class model(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.Counter() + cluster.primary_key = True + counter = columns.Counter() + self.fail("did not raise expected ModelDefinitionException") + except ModelDefinitionException: + pass + + +class TestCounterColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + drop_table(TestCounterModel) + sync_table(TestCounterModel) + + @classmethod + def tearDownClass(cls): + drop_table(TestCounterModel) + + def test_updates(self): + """ Tests that counter updates work as intended """ + instance = TestCounterModel.create() + instance.counter += 5 + instance.save() + + actual = TestCounterModel.get(partition=instance.partition) + assert actual.counter == 5 + + def test_concurrent_updates(self): + """ Tests updates from multiple queries reaches the correct value """ + instance = TestCounterModel.create() + new1 = TestCounterModel.get(partition=instance.partition) + new2 = TestCounterModel.get(partition=instance.partition) + + new1.counter += 5 + new1.save() + new2.counter += 5 + new2.save() + + actual = TestCounterModel.get(partition=instance.partition) + assert actual.counter == 10 + + def test_update_from_none(self): + """ Tests that updating from None uses a create statement """ + instance = TestCounterModel() + instance.counter += 1 + instance.save() + + new = TestCounterModel.get(partition=instance.partition) + assert new.counter == 1 + + def test_new_instance_defaults_to_zero(self): + """ Tests that instantiating a new model instance will set the counter column to zero """ + instance = TestCounterModel() + assert instance.counter == 0 + + def test_save_after_no_update(self): + expected_value = 15 + instance = TestCounterModel.create() + instance.update(counter=expected_value) + + # read back + instance = TestCounterModel.get(partition=instance.partition) + self.assertEqual(instance.counter, expected_value) + + # save after doing nothing + instance.save() + self.assertEqual(instance.counter, expected_value) + + # make sure there was no increment + instance = TestCounterModel.get(partition=instance.partition) + self.assertEqual(instance.counter, expected_value) diff --git a/tests/integration/cqlengine/columns/test_static_column.py b/tests/integration/cqlengine/columns/test_static_column.py new file mode 100644 index 0000000..69e222d --- /dev/null +++ b/tests/integration/cqlengine/columns/test_static_column.py @@ -0,0 +1,94 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from uuid import uuid4 + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.models import Model + +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration import PROTOCOL_VERSION + +# TODO: is this really a protocol limitation, or is it just C* version? +# good enough proxy for now +STATIC_SUPPORTED = PROTOCOL_VERSION >= 2 + +class TestStaticModel(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.UUID(primary_key=True, default=uuid4) + static = columns.Text(static=True) + text = columns.Text() + + +class TestStaticColumn(BaseCassEngTestCase): + + def setUp(cls): + if not STATIC_SUPPORTED: + raise unittest.SkipTest("only runs against the cql3 protocol v2.0") + super(TestStaticColumn, cls).setUp() + + @classmethod + def setUpClass(cls): + drop_table(TestStaticModel) + if STATIC_SUPPORTED: # setup and teardown run regardless of skip + sync_table(TestStaticModel) + + @classmethod + def tearDownClass(cls): + drop_table(TestStaticModel) + + def test_mixed_updates(self): + """ Tests that updates on both static and non-static columns work as intended """ + instance = TestStaticModel.create() + instance.static = "it's shared" + instance.text = "some text" + instance.save() + + u = TestStaticModel.get(partition=instance.partition) + u.static = "it's still shared" + u.text = "another text" + u.update() + actual = TestStaticModel.get(partition=u.partition) + + assert actual.static == "it's still shared" + + def test_static_only_updates(self): + """ Tests that updates on static only column work as intended """ + instance = TestStaticModel.create() + instance.static = "it's shared" + instance.text = "some text" + instance.save() + + u = TestStaticModel.get(partition=instance.partition) + u.static = "it's still shared" + u.update() + actual = TestStaticModel.get(partition=u.partition) + assert actual.static == "it's still shared" + + def test_static_with_null_cluster_key(self): + """ Tests that save/update/delete works for static column works when clustering key is null""" + instance = TestStaticModel.create(cluster=None, static = "it's shared") + instance.save() + + u = TestStaticModel.get(partition=instance.partition) + u.static = "it's still shared" + u.update() + actual = TestStaticModel.get(partition=u.partition) + assert actual.static == "it's still shared" diff --git a/tests/integration/cqlengine/columns/test_validation.py b/tests/integration/cqlengine/columns/test_validation.py new file mode 100644 index 0000000..69682fd --- /dev/null +++ b/tests/integration/cqlengine/columns/test_validation.py @@ -0,0 +1,841 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import sys +from datetime import datetime, timedelta, date, tzinfo, time +from decimal import Decimal as D +from uuid import uuid4, uuid1 +from packaging.version import Version + +from cassandra import InvalidRequest +from cassandra.cqlengine.columns import (TimeUUID, Ascii, Text, Integer, BigInt, + VarInt, DateTime, Date, UUID, Boolean, + Decimal, Inet, Time, UserDefinedType, + Map, List, Set, Tuple, Double, Duration) +from cassandra.cqlengine.connection import execute +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.models import Model, ValidationError +from cassandra.cqlengine.usertype import UserType +from cassandra import util + +from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthanorequalcass30, greaterthanorequalcass3_11 +from tests.integration.cqlengine.base import BaseCassEngTestCase + + +class TestDatetime(BaseCassEngTestCase): + class DatetimeTest(Model): + + test_id = Integer(primary_key=True) + created_at = DateTime() + + @classmethod + def setUpClass(cls): + sync_table(cls.DatetimeTest) + + @classmethod + def tearDownClass(cls): + drop_table(cls.DatetimeTest) + + def test_datetime_io(self): + now = datetime.now() + self.DatetimeTest.objects.create(test_id=0, created_at=now) + dt2 = self.DatetimeTest.objects(test_id=0).first() + self.assertEqual(dt2.created_at.timetuple()[:6], now.timetuple()[:6]) + + def test_datetime_tzinfo_io(self): + class TZ(tzinfo): + def utcoffset(self, date_time): + return timedelta(hours=-1) + def dst(self, date_time): + return None + + now = datetime(1982, 1, 1, tzinfo=TZ()) + dt = self.DatetimeTest.objects.create(test_id=1, created_at=now) + dt2 = self.DatetimeTest.objects(test_id=1).first() + self.assertEqual(dt2.created_at.timetuple()[:6], (now + timedelta(hours=1)).timetuple()[:6]) + + @greaterthanorequalcass30 + def test_datetime_date_support(self): + today = date.today() + self.DatetimeTest.objects.create(test_id=2, created_at=today) + dt2 = self.DatetimeTest.objects(test_id=2).first() + self.assertEqual(dt2.created_at.isoformat(), datetime(today.year, today.month, today.day).isoformat()) + + result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2).first() + self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time())) + + result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2, created_at=today).first() + self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time())) + + def test_datetime_none(self): + dt = self.DatetimeTest.objects.create(test_id=3, created_at=None) + dt2 = self.DatetimeTest.objects(test_id=3).first() + self.assertIsNone(dt2.created_at) + + dts = self.DatetimeTest.objects.filter(test_id=3).values_list('created_at') + self.assertIsNone(dts[0][0]) + + def test_datetime_invalid(self): + dt_value= 'INVALID' + with self.assertRaises(TypeError): + self.DatetimeTest.objects.create(test_id=4, created_at=dt_value) + + def test_datetime_timestamp(self): + dt_value = 1454520554 + self.DatetimeTest.objects.create(test_id=5, created_at=dt_value) + dt2 = self.DatetimeTest.objects(test_id=5).first() + self.assertEqual(dt2.created_at, datetime.utcfromtimestamp(dt_value)) + + def test_datetime_large(self): + dt_value = datetime(2038, 12, 31, 10, 10, 10, 123000) + self.DatetimeTest.objects.create(test_id=6, created_at=dt_value) + dt2 = self.DatetimeTest.objects(test_id=6).first() + self.assertEqual(dt2.created_at, dt_value) + + def test_datetime_truncate_microseconds(self): + """ + Test to ensure that truncate microseconds works as expected. + This will be default behavior in the future and we will need to modify the tests to comply + with new behavior + + @since 3.2 + @jira_ticket PYTHON-273 + @expected_result microseconds should be to the nearest thousand when truncate is set. + + @test_category object_mapper + """ + DateTime.truncate_microseconds = True + try: + dt_value = datetime(2024, 12, 31, 10, 10, 10, 923567) + dt_truncated = datetime(2024, 12, 31, 10, 10, 10, 923000) + self.DatetimeTest.objects.create(test_id=6, created_at=dt_value) + dt2 = self.DatetimeTest.objects(test_id=6).first() + self.assertEqual(dt2.created_at,dt_truncated) + finally: + # We need to always return behavior to default + DateTime.truncate_microseconds = False + + +class TestBoolDefault(BaseCassEngTestCase): + class BoolDefaultValueTest(Model): + + test_id = Integer(primary_key=True) + stuff = Boolean(default=True) + + @classmethod + def setUpClass(cls): + sync_table(cls.BoolDefaultValueTest) + + def test_default_is_set(self): + tmp = self.BoolDefaultValueTest.create(test_id=1) + self.assertEqual(True, tmp.stuff) + tmp2 = self.BoolDefaultValueTest.get(test_id=1) + self.assertEqual(True, tmp2.stuff) + + +class TestBoolValidation(BaseCassEngTestCase): + class BoolValidationTest(Model): + + test_id = Integer(primary_key=True) + bool_column = Boolean() + + @classmethod + def setUpClass(cls): + sync_table(cls.BoolValidationTest) + + def test_validation_preserves_none(self): + test_obj = self.BoolValidationTest(test_id=1) + + test_obj.validate() + self.assertIsNone(test_obj.bool_column) + + +class TestVarInt(BaseCassEngTestCase): + class VarIntTest(Model): + + test_id = Integer(primary_key=True) + bignum = VarInt(primary_key=True) + + @classmethod + def setUpClass(cls): + sync_table(cls.VarIntTest) + + @classmethod + def tearDownClass(cls): + sync_table(cls.VarIntTest) + + def test_varint_io(self): + # TODO: this is a weird test. i changed the number from sys.maxint (which doesn't exist in python 3) + # to the giant number below and it broken between runs. + long_int = 92834902384092834092384028340283048239048203480234823048230482304820348239 + int1 = self.VarIntTest.objects.create(test_id=0, bignum=long_int) + int2 = self.VarIntTest.objects(test_id=0).first() + self.assertEqual(int1.bignum, int2.bignum) + + with self.assertRaises(ValidationError): + self.VarIntTest.objects.create(test_id=0, bignum="not_a_number") + + +class DataType(): + @classmethod + def setUpClass(cls): + if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): + return + + class DataTypeTest(Model): + test_id = Integer(primary_key=True) + class_param = cls.db_klass() + + cls.model_class = DataTypeTest + sync_table(cls.model_class) + + @classmethod + def tearDownClass(cls): + if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): + return + drop_table(cls.model_class) + + def setUp(self): + if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): + raise unittest.SkipTest("Protocol v4 datatypes " + "require native protocol 4+ and C* version >=3.0, " + "currently using protocol {0} and C* version {1}". + format(PROTOCOL_VERSION, CASSANDRA_VERSION)) + + def _check_value_is_correct_in_db(self, value): + """ + Check that different ways of reading the value + from the model class give the same expected result + """ + if value is None: + result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first() + self.assertIsNone(result.class_param) + + result = self.model_class.objects(test_id=0).first() + self.assertIsNone(result.class_param) + + else: + if not isinstance(value, self.python_klass): + value_to_compare = self.python_klass(value) + else: + value_to_compare = value + + result = self.model_class.objects(test_id=0).first() + self.assertIsInstance(result.class_param, self.python_klass) + self.assertEqual(result.class_param, value_to_compare) + + result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first() + self.assertIsInstance(result.class_param, self.python_klass) + self.assertEqual(result.class_param, value_to_compare) + + result = self.model_class.objects.all().allow_filtering().filter(test_id=0, class_param=value).first() + self.assertIsInstance(result.class_param, self.python_klass) + self.assertEqual(result.class_param, value_to_compare) + + return result + + def test_param_io(self): + first_value = self.first_value + second_value = self.second_value + third_value = self.third_value + + # Check value is correctly written/read from the DB + self.model_class.objects.create(test_id=0, class_param=first_value) + result = self._check_value_is_correct_in_db(first_value) + result.delete() + + # Check the previous value has been correctly deleted and write a new value + self.model_class.objects.create(test_id=0, class_param=second_value) + result = self._check_value_is_correct_in_db(second_value) + + # Check the value can be correctly updated from the Model class + result.update(class_param=third_value).save() + result = self._check_value_is_correct_in_db(third_value) + + # Check None is correctly written to the DB + result.update(class_param=None).save() + self._check_value_is_correct_in_db(None) + + def test_param_none(self): + """ + Test that None value is correctly written to the db + and then is correctly read + """ + self.model_class.objects.create(test_id=1, class_param=None) + dt2 = self.model_class.objects(test_id=1).first() + self.assertIsNone(dt2.class_param) + + dts = self.model_class.objects(test_id=1).values_list('class_param') + self.assertIsNone(dts[0][0]) + + +class TestDate(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = ( + Date, + util.Date + ) + + cls.first_value, cls.second_value, cls.third_value = ( + datetime.utcnow(), + util.Date(datetime(1, 1, 1)), + datetime(1, 1, 2) + ) + super(TestDate, cls).setUpClass() + + +class TestTime(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = ( + Time, + util.Time + ) + cls.first_value, cls.second_value, cls.third_value = ( + None, + util.Time(time(2, 12, 7, 49)), + time(2, 12, 7, 50) + ) + super(TestTime, cls).setUpClass() + + +class TestDateTime(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = ( + DateTime, + datetime + ) + cls.first_value, cls.second_value, cls.third_value = ( + datetime(2017, 4, 13, 18, 34, 24, 317000), + datetime(1, 1, 1), + datetime(1, 1, 2) + ) + super(TestDateTime, cls).setUpClass() + + +class TestBoolean(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = ( + Boolean, + bool + ) + cls.first_value, cls.second_value, cls.third_value = ( + None, + False, + True + ) + super(TestBoolean, cls).setUpClass() + +@greaterthanorequalcass3_11 +class TestDuration(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + # setUpClass is executed despite the whole class being skipped + if CASSANDRA_VERSION >= Version("3.10"): + cls.db_klass, cls.python_klass = ( + Duration, + util.Duration + ) + cls.first_value, cls.second_value, cls.third_value = ( + util.Duration(0, 0, 0), + util.Duration(1, 2, 3), + util.Duration(0, 0, 0) + ) + super(TestDuration, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + if CASSANDRA_VERSION >= Version("3.10"): + super(TestDuration, cls).tearDownClass() + + +class User(UserType): + # We use Date and Time to ensure to_python + # is called for these columns + age = Integer() + date_param = Date() + map_param = Map(Integer, Time) + list_param = List(Date) + set_param = Set(Date) + tuple_param = Tuple(Date, Decimal, Boolean, VarInt, Double, UUID) + + +class UserModel(Model): + test_id = Integer(primary_key=True) + class_param = UserDefinedType(User) + + +class TestUDT(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): + return + + cls.db_klass, cls.python_klass = UserDefinedType, User + cls.first_value = User( + age=1, + date_param=datetime.utcnow(), + map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))}, + list_param=[datetime(1, 1, 2), datetime(1, 1, 3)], + set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 1)))), + tuple_param=(datetime(1, 1, 3), 2, False, 1, 2.324, uuid4()) + ) + + cls.second_value = User( + age=1, + date_param=datetime.utcnow(), + map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))}, + list_param=[datetime(1, 1, 2), datetime(1, 2, 3)], + set_param=None, + tuple_param=(datetime(1, 1, 2), 2, False, 1, 2.324, uuid4()) + ) + + cls.third_value = User( + age=2, + date_param=None, + map_param={1: time(2, 12, 7, 51), 2: util.Time(time(2, 12, 7, 49))}, + list_param=[datetime(1, 1, 2), datetime(1, 1, 4)], + set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 2)))), + tuple_param=(None, 3, False, None, 2.3214, uuid4()) + ) + + cls.model_class = UserModel + sync_table(cls.model_class) + + +class TestDecimal(BaseCassEngTestCase): + class DecimalTest(Model): + + test_id = Integer(primary_key=True) + dec_val = Decimal() + + @classmethod + def setUpClass(cls): + sync_table(cls.DecimalTest) + + @classmethod + def tearDownClass(cls): + drop_table(cls.DecimalTest) + + def test_decimal_io(self): + dt = self.DecimalTest.objects.create(test_id=0, dec_val=D('0.00')) + dt2 = self.DecimalTest.objects(test_id=0).first() + assert dt2.dec_val == dt.dec_val + + dt = self.DecimalTest.objects.create(test_id=0, dec_val=5) + dt2 = self.DecimalTest.objects(test_id=0).first() + assert dt2.dec_val == D('5') + + +class TestUUID(BaseCassEngTestCase): + class UUIDTest(Model): + + test_id = Integer(primary_key=True) + a_uuid = UUID(default=uuid4()) + + @classmethod + def setUpClass(cls): + sync_table(cls.UUIDTest) + + @classmethod + def tearDownClass(cls): + drop_table(cls.UUIDTest) + + def test_uuid_str_with_dashes(self): + a_uuid = uuid4() + t0 = self.UUIDTest.create(test_id=0, a_uuid=str(a_uuid)) + t1 = self.UUIDTest.get(test_id=0) + assert a_uuid == t1.a_uuid + + def test_uuid_str_no_dashes(self): + a_uuid = uuid4() + t0 = self.UUIDTest.create(test_id=1, a_uuid=a_uuid.hex) + t1 = self.UUIDTest.get(test_id=1) + assert a_uuid == t1.a_uuid + + def test_uuid_with_upcase(self): + a_uuid = uuid4() + val = str(a_uuid).upper() + t0 = self.UUIDTest.create(test_id=0, a_uuid=val) + t1 = self.UUIDTest.get(test_id=0) + assert a_uuid == t1.a_uuid + + +class TestTimeUUID(BaseCassEngTestCase): + class TimeUUIDTest(Model): + + test_id = Integer(primary_key=True) + timeuuid = TimeUUID(default=uuid1()) + + @classmethod + def setUpClass(cls): + sync_table(cls.TimeUUIDTest) + + @classmethod + def tearDownClass(cls): + drop_table(cls.TimeUUIDTest) + + def test_timeuuid_io(self): + """ + ensures that + :return: + """ + t0 = self.TimeUUIDTest.create(test_id=0) + t1 = self.TimeUUIDTest.get(test_id=0) + + assert t1.timeuuid.time == t1.timeuuid.time + + +class TestInteger(BaseCassEngTestCase): + class IntegerTest(Model): + + test_id = UUID(primary_key=True, default=lambda:uuid4()) + value = Integer(default=0, required=True) + + def test_default_zero_fields_validate(self): + """ Tests that integer columns with a default value of 0 validate """ + it = self.IntegerTest() + it.validate() + + +class TestBigInt(BaseCassEngTestCase): + class BigIntTest(Model): + + test_id = UUID(primary_key=True, default=lambda:uuid4()) + value = BigInt(default=0, required=True) + + def test_default_zero_fields_validate(self): + """ Tests that bigint columns with a default value of 0 validate """ + it = self.BigIntTest() + it.validate() + + +class TestAscii(BaseCassEngTestCase): + def test_min_length(self): + """ Test arbitrary minimal lengths requirements. """ + + Ascii(min_length=0).validate('') + Ascii(min_length=0, required=True).validate('') + + Ascii(min_length=0).validate(None) + Ascii(min_length=0).validate('kevin') + + Ascii(min_length=1).validate('k') + + Ascii(min_length=5).validate('kevin') + Ascii(min_length=5).validate('kevintastic') + + with self.assertRaises(ValidationError): + Ascii(min_length=1).validate('') + + with self.assertRaises(ValidationError): + Ascii(min_length=1).validate(None) + + with self.assertRaises(ValidationError): + Ascii(min_length=6).validate('') + + with self.assertRaises(ValidationError): + Ascii(min_length=6).validate(None) + + with self.assertRaises(ValidationError): + Ascii(min_length=6).validate('kevin') + + with self.assertRaises(ValueError): + Ascii(min_length=-1) + + def test_max_length(self): + """ Test arbitrary maximal lengths requirements. """ + Ascii(max_length=0).validate('') + Ascii(max_length=0).validate(None) + + Ascii(max_length=1).validate('') + Ascii(max_length=1).validate(None) + Ascii(max_length=1).validate('b') + + Ascii(max_length=5).validate('') + Ascii(max_length=5).validate(None) + Ascii(max_length=5).validate('b') + Ascii(max_length=5).validate('blake') + + with self.assertRaises(ValidationError): + Ascii(max_length=0).validate('b') + + with self.assertRaises(ValidationError): + Ascii(max_length=5).validate('blaketastic') + + with self.assertRaises(ValueError): + Ascii(max_length=-1) + + def test_length_range(self): + Ascii(min_length=0, max_length=0) + Ascii(min_length=0, max_length=1) + Ascii(min_length=10, max_length=10) + Ascii(min_length=10, max_length=11) + + with self.assertRaises(ValueError): + Ascii(min_length=10, max_length=9) + + with self.assertRaises(ValueError): + Ascii(min_length=1, max_length=0) + + def test_type_checking(self): + Ascii().validate('string') + Ascii().validate(u'unicode') + Ascii().validate(bytearray('bytearray', encoding='ascii')) + + with self.assertRaises(ValidationError): + Ascii().validate(5) + + with self.assertRaises(ValidationError): + Ascii().validate(True) + + Ascii().validate("!#$%&\'()*+,-./") + + with self.assertRaises(ValidationError): + Ascii().validate('Beyonc' + chr(233)) + + if sys.version_info < (3, 1): + with self.assertRaises(ValidationError): + Ascii().validate(u'Beyonc' + unichr(233)) + + def test_unaltering_validation(self): + """ Test the validation step doesn't re-interpret values. """ + self.assertEqual(Ascii().validate(''), '') + self.assertEqual(Ascii().validate(None), None) + self.assertEqual(Ascii().validate('yo'), 'yo') + + def test_non_required_validation(self): + """ Tests that validation is ok on none and blank values if required is False. """ + Ascii().validate('') + Ascii().validate(None) + + def test_required_validation(self): + """ Tests that validation raise on none and blank values if value required. """ + Ascii(required=True).validate('k') + + with self.assertRaises(ValidationError): + Ascii(required=True).validate('') + + with self.assertRaises(ValidationError): + Ascii(required=True).validate(None) + + # With min_length set. + Ascii(required=True, min_length=0).validate('k') + Ascii(required=True, min_length=1).validate('k') + + with self.assertRaises(ValidationError): + Ascii(required=True, min_length=2).validate('k') + + # With max_length set. + Ascii(required=True, max_length=1).validate('k') + + with self.assertRaises(ValidationError): + Ascii(required=True, max_length=2).validate('kevin') + + with self.assertRaises(ValueError): + Ascii(required=True, max_length=0) + + +class TestText(BaseCassEngTestCase): + + def test_min_length(self): + """ Test arbitrary minimal lengths requirements. """ + + Text(min_length=0).validate('') + Text(min_length=0, required=True).validate('') + + Text(min_length=0).validate(None) + Text(min_length=0).validate('blake') + + Text(min_length=1).validate('b') + + Text(min_length=5).validate('blake') + Text(min_length=5).validate('blaketastic') + + with self.assertRaises(ValidationError): + Text(min_length=1).validate('') + + with self.assertRaises(ValidationError): + Text(min_length=1).validate(None) + + with self.assertRaises(ValidationError): + Text(min_length=6).validate('') + + with self.assertRaises(ValidationError): + Text(min_length=6).validate(None) + + with self.assertRaises(ValidationError): + Text(min_length=6).validate('blake') + + with self.assertRaises(ValueError): + Text(min_length=-1) + + def test_max_length(self): + """ Test arbitrary maximal lengths requirements. """ + Text(max_length=0).validate('') + Text(max_length=0).validate(None) + + Text(max_length=1).validate('') + Text(max_length=1).validate(None) + Text(max_length=1).validate('b') + + Text(max_length=5).validate('') + Text(max_length=5).validate(None) + Text(max_length=5).validate('b') + Text(max_length=5).validate('blake') + + with self.assertRaises(ValidationError): + Text(max_length=0).validate('b') + + with self.assertRaises(ValidationError): + Text(max_length=5).validate('blaketastic') + + with self.assertRaises(ValueError): + Text(max_length=-1) + + def test_length_range(self): + Text(min_length=0, max_length=0) + Text(min_length=0, max_length=1) + Text(min_length=10, max_length=10) + Text(min_length=10, max_length=11) + + with self.assertRaises(ValueError): + Text(min_length=10, max_length=9) + + with self.assertRaises(ValueError): + Text(min_length=1, max_length=0) + + def test_type_checking(self): + Text().validate('string') + Text().validate(u'unicode') + Text().validate(bytearray('bytearray', encoding='ascii')) + + with self.assertRaises(ValidationError): + Text().validate(5) + + with self.assertRaises(ValidationError): + Text().validate(True) + + Text().validate("!#$%&\'()*+,-./") + Text().validate('Beyonc' + chr(233)) + if sys.version_info < (3, 1): + Text().validate(u'Beyonc' + unichr(233)) + + def test_unaltering_validation(self): + """ Test the validation step doesn't re-interpret values. """ + self.assertEqual(Text().validate(''), '') + self.assertEqual(Text().validate(None), None) + self.assertEqual(Text().validate('yo'), 'yo') + + def test_non_required_validation(self): + """ Tests that validation is ok on none and blank values if required is False """ + Text().validate('') + Text().validate(None) + + def test_required_validation(self): + """ Tests that validation raise on none and blank values if value required. """ + Text(required=True).validate('b') + + with self.assertRaises(ValidationError): + Text(required=True).validate('') + + with self.assertRaises(ValidationError): + Text(required=True).validate(None) + + # With min_length set. + Text(required=True, min_length=0).validate('b') + Text(required=True, min_length=1).validate('b') + + with self.assertRaises(ValidationError): + Text(required=True, min_length=2).validate('b') + + # With max_length set. + Text(required=True, max_length=1).validate('b') + + with self.assertRaises(ValidationError): + Text(required=True, max_length=2).validate('blake') + + with self.assertRaises(ValueError): + Text(required=True, max_length=0) + + +class TestExtraFieldsRaiseException(BaseCassEngTestCase): + class TestModel(Model): + + id = UUID(primary_key=True, default=uuid4) + + def test_extra_field(self): + with self.assertRaises(ValidationError): + self.TestModel.create(bacon=5000) + + +class TestPythonDoesntDieWhenExtraFieldIsInCassandra(BaseCassEngTestCase): + class TestModel(Model): + + __table_name__ = 'alter_doesnt_break_running_app' + id = UUID(primary_key=True, default=uuid4) + + def test_extra_field(self): + drop_table(self.TestModel) + sync_table(self.TestModel) + self.TestModel.create() + execute("ALTER TABLE {0} add blah int".format(self.TestModel.column_family_name(include_keyspace=True))) + self.TestModel.objects.all() + + +class TestTimeUUIDFromDatetime(BaseCassEngTestCase): + def test_conversion_specific_date(self): + dt = datetime(1981, 7, 11, microsecond=555000) + + uuid = util.uuid_from_time(dt) + + from uuid import UUID + assert isinstance(uuid, UUID) + + ts = (uuid.time - 0x01b21dd213814000) / 1e7 # back to a timestamp + new_dt = datetime.utcfromtimestamp(ts) + + # checks that we created a UUID1 with the proper timestamp + assert new_dt == dt + + +class TestInet(BaseCassEngTestCase): + + class InetTestModel(Model): + id = UUID(primary_key=True, default=uuid4) + address = Inet() + + def setUp(self): + drop_table(self.InetTestModel) + sync_table(self.InetTestModel) + + def test_inet_saves(self): + tmp = self.InetTestModel.create(address="192.168.1.1") + + m = self.InetTestModel.get(id=tmp.id) + + assert m.address == "192.168.1.1" + + def test_non_address_fails(self): + # TODO: presently this only tests that the server blows it up. Is there supposed to be local validation? + with self.assertRaises(InvalidRequest): + self.InetTestModel.create(address="what is going on here?") diff --git a/tests/integration/cqlengine/columns/test_value_io.py b/tests/integration/cqlengine/columns/test_value_io.py new file mode 100644 index 0000000..243c2b0 --- /dev/null +++ b/tests/integration/cqlengine/columns/test_value_io.py @@ -0,0 +1,270 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from datetime import datetime, timedelta, time +from decimal import Decimal +from uuid import uuid1, uuid4, UUID +import six + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import sync_table +from cassandra.cqlengine.management import drop_table +from cassandra.cqlengine.models import Model + +from cassandra.util import Date, Time + +from tests.integration import PROTOCOL_VERSION +from tests.integration.cqlengine.base import BaseCassEngTestCase + + +class BaseColumnIOTest(BaseCassEngTestCase): + """ + Tests that values are come out of cassandra in the format we expect + + To test a column type, subclass this test, define the column, and the primary key + and data values you want to test + """ + + # The generated test model is assigned here + _generated_model = None + + # the column we want to test + column = None + + # the values we want to test against, you can + # use a single value, or multiple comma separated values + pkey_val = None + data_val = None + + @classmethod + def setUpClass(cls): + super(BaseColumnIOTest, cls).setUpClass() + + # if the test column hasn't been defined, bail out + if not cls.column: + return + + # create a table with the given column + class IOTestModel(Model): + pkey = cls.column(primary_key=True) + data = cls.column() + + cls._generated_model = IOTestModel + sync_table(cls._generated_model) + + # tupleify the tested values + if not isinstance(cls.pkey_val, tuple): + cls.pkey_val = cls.pkey_val, + if not isinstance(cls.data_val, tuple): + cls.data_val = cls.data_val, + + @classmethod + def tearDownClass(cls): + super(BaseColumnIOTest, cls).tearDownClass() + if not cls.column: + return + drop_table(cls._generated_model) + + def comparator_converter(self, val): + """ If you want to convert the original value used to compare the model vales """ + return val + + def test_column_io(self): + """ Tests the given models class creates and retrieves values as expected """ + if not self.column: + return + for pkey, data in zip(self.pkey_val, self.data_val): + # create + m1 = self._generated_model.create(pkey=pkey, data=data) + + # get + m2 = self._generated_model.get(pkey=pkey) + assert m1.pkey == m2.pkey == self.comparator_converter(pkey), self.column + assert m1.data == m2.data == self.comparator_converter(data), self.column + + # delete + self._generated_model.filter(pkey=pkey).delete() + + +class TestBlobIO(BaseColumnIOTest): + + column = columns.Blob + pkey_val = six.b('blake'), uuid4().bytes + data_val = six.b('eggleston'), uuid4().bytes + + +class TestBlobIO2(BaseColumnIOTest): + + column = columns.Blob + pkey_val = bytearray(six.b('blake')), uuid4().bytes + data_val = bytearray(six.b('eggleston')), uuid4().bytes + + +class TestTextIO(BaseColumnIOTest): + + column = columns.Text + pkey_val = 'bacon' + data_val = 'monkey' + + +class TestNonBinaryTextIO(BaseColumnIOTest): + + column = columns.Text + pkey_val = 'bacon' + data_val = '0xmonkey' + + +class TestInteger(BaseColumnIOTest): + + column = columns.Integer + pkey_val = 5 + data_val = 6 + + +class TestBigInt(BaseColumnIOTest): + + column = columns.BigInt + pkey_val = 6 + data_val = pow(2, 63) - 1 + + +class TestDateTime(BaseColumnIOTest): + + column = columns.DateTime + + now = datetime(*datetime.now().timetuple()[:6]) + pkey_val = now + data_val = now + timedelta(days=1) + + +class TestUUID(BaseColumnIOTest): + + column = columns.UUID + + pkey_val = str(uuid4()), uuid4() + data_val = str(uuid4()), uuid4() + + def comparator_converter(self, val): + return val if isinstance(val, UUID) else UUID(val) + + +class TestTimeUUID(BaseColumnIOTest): + + column = columns.TimeUUID + + pkey_val = str(uuid1()), uuid1() + data_val = str(uuid1()), uuid1() + + def comparator_converter(self, val): + return val if isinstance(val, UUID) else UUID(val) + + +class TestFloatIO(BaseColumnIOTest): + + column = columns.Float + + pkey_val = 4.75 + data_val = -1.5 + + +class TestDoubleIO(BaseColumnIOTest): + + column = columns.Double + + pkey_val = 3.14 + data_val = -1982.11 + + +class TestDecimalIO(BaseColumnIOTest): + + column = columns.Decimal + + pkey_val = Decimal('1.35'), 5, '2.4' + data_val = Decimal('0.005'), 3.5, '8' + + def comparator_converter(self, val): + return Decimal(repr(val) if isinstance(val, float) else val) + + +class ProtocolV4Test(BaseColumnIOTest): + + @classmethod + def setUpClass(cls): + if PROTOCOL_VERSION >= 4: + super(ProtocolV4Test, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + if PROTOCOL_VERSION >= 4: + super(ProtocolV4Test, cls).tearDownClass() + +class TestDate(ProtocolV4Test): + + def setUp(self): + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) + + super(TestDate, self).setUp() + + column = columns.Date + + now = Date(datetime.now().date()) + pkey_val = now + data_val = Date(now.days_from_epoch + 1) + + +class TestTime(ProtocolV4Test): + + def setUp(self): + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) + + super(TestTime, self).setUp() + + column = columns.Time + + pkey_val = Time(time(2, 12, 7, 48)) + data_val = Time(time(16, 47, 25, 7)) + + +class TestSmallInt(ProtocolV4Test): + + def setUp(self): + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) + + super(TestSmallInt, self).setUp() + + column = columns.SmallInt + + pkey_val = 16768 + data_val = 32523 + + +class TestTinyInt(ProtocolV4Test): + + def setUp(self): + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) + + super(TestTinyInt, self).setUp() + + column = columns.TinyInt + + pkey_val = 1 + data_val = 123 diff --git a/tests/integration/cqlengine/connections/__init__.py b/tests/integration/cqlengine/connections/__init__.py new file mode 100644 index 0000000..2c9ca17 --- /dev/null +++ b/tests/integration/cqlengine/connections/__init__.py @@ -0,0 +1,13 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/cqlengine/connections/test_connection.py b/tests/integration/cqlengine/connections/test_connection.py new file mode 100644 index 0000000..bbc0231 --- /dev/null +++ b/tests/integration/cqlengine/connections/test_connection.py @@ -0,0 +1,198 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +from cassandra import ConsistencyLevel +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns, connection, models +from cassandra.cqlengine.management import sync_table +from cassandra.cluster import Cluster, ExecutionProfile, _clusters_for_shutdown, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.policies import RoundRobinPolicy +from cassandra.query import dict_factory + +from tests.integration import CASSANDRA_IP, PROTOCOL_VERSION, execute_with_long_wait_retry, local +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine import DEFAULT_KEYSPACE, setup_connection + + +class TestConnectModel(Model): + + id = columns.Integer(primary_key=True) + keyspace = columns.Text() + + +class ConnectionTest(unittest.TestCase): + def tearDown(self): + connection.unregister_connection("default") + + @local + def test_connection_setup_with_setup(self): + connection.setup(hosts=None, default_keyspace=None) + self.assertIsNotNone(connection.get_connection("default").cluster.metadata.get_host("127.0.0.1")) + + @local + def test_connection_setup_with_default(self): + connection.default() + self.assertIsNotNone(connection.get_connection("default").cluster.metadata.get_host("127.0.0.1")) + + def test_only_one_connection_is_created(self): + """ + Test to ensure that only one new connection is created by + connection.register_connection + + @since 3.12 + @jira_ticket PYTHON-814 + @expected_result Only one connection is created + + @test_category object_mapper + """ + number_of_clusters_before = len(_clusters_for_shutdown) + connection.default() + number_of_clusters_after = len(_clusters_for_shutdown) + self.assertEqual(number_of_clusters_after - number_of_clusters_before, 1) + + +class SeveralConnectionsTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + connection.unregister_connection('default') + cls.keyspace1 = 'ctest1' + cls.keyspace2 = 'ctest2' + super(SeveralConnectionsTest, cls).setUpClass() + cls.setup_cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.setup_session = cls.setup_cluster.connect() + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace1, 1) + execute_with_long_wait_retry(cls.setup_session, ddl) + ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace2, 1) + execute_with_long_wait_retry(cls.setup_session, ddl) + + @classmethod + def tearDownClass(cls): + execute_with_long_wait_retry(cls.setup_session, "DROP KEYSPACE {0}".format(cls.keyspace1)) + execute_with_long_wait_retry(cls.setup_session, "DROP KEYSPACE {0}".format(cls.keyspace2)) + models.DEFAULT_KEYSPACE = DEFAULT_KEYSPACE + cls.setup_cluster.shutdown() + setup_connection(DEFAULT_KEYSPACE) + models.DEFAULT_KEYSPACE + + def setUp(self): + self.c = Cluster(protocol_version=PROTOCOL_VERSION) + self.session1 = self.c.connect(keyspace=self.keyspace1) + self.session1.row_factory = dict_factory + self.session2 = self.c.connect(keyspace=self.keyspace2) + self.session2.row_factory = dict_factory + + def tearDown(self): + self.c.shutdown() + + def test_connection_session_switch(self): + """ + Test to ensure that when the default keyspace is changed in a session and that session, + is set in the connection class, that the new defaul keyspace is honored. + + @since 3.1 + @jira_ticket PYTHON-486 + @expected_result CQLENGINE adopts whatever keyspace is passed in vai the set_session method as default + + @test_category object_mapper + """ + + connection.set_session(self.session1) + sync_table(TestConnectModel) + TCM1 = TestConnectModel.create(id=1, keyspace=self.keyspace1) + connection.set_session(self.session2) + sync_table(TestConnectModel) + TCM2 = TestConnectModel.create(id=1, keyspace=self.keyspace2) + connection.set_session(self.session1) + self.assertEqual(1, TestConnectModel.objects.count()) + self.assertEqual(TestConnectModel.objects.first(), TCM1) + connection.set_session(self.session2) + self.assertEqual(1, TestConnectModel.objects.count()) + self.assertEqual(TestConnectModel.objects.first(), TCM2) + + +class ConnectionModel(Model): + key = columns.Integer(primary_key=True) + some_data = columns.Text() + + +class ConnectionInitTest(unittest.TestCase): + def test_default_connection_uses_legacy(self): + connection.default() + conn = connection.get_connection() + self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + + def test_connection_with_legacy_settings(self): + connection.setup( + hosts=[CASSANDRA_IP], + default_keyspace=DEFAULT_KEYSPACE, + consistency=ConsistencyLevel.LOCAL_ONE + ) + conn = connection.get_connection() + self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + + def test_connection_from_session_with_execution_profile(self): + cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)}) + session = cluster.connect() + connection.default() + connection.set_session(session) + conn = connection.get_connection() + self.assertEqual(conn.cluster._config_mode, _ConfigMode.PROFILES) + + def test_connection_from_session_with_legacy_settings(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy()) + session = cluster.connect() + session.row_factory = dict_factory + connection.set_session(session) + conn = connection.get_connection() + self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + + def test_uncommitted_session_uses_legacy(self): + cluster = Cluster() + session = cluster.connect() + session.row_factory = dict_factory + connection.set_session(session) + conn = connection.get_connection() + self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + + def test_legacy_insert_query(self): + connection.setup( + hosts=[CASSANDRA_IP], + default_keyspace=DEFAULT_KEYSPACE, + consistency=ConsistencyLevel.LOCAL_ONE + ) + self.assertEqual(connection.get_connection().cluster._config_mode, _ConfigMode.LEGACY) + + sync_table(ConnectionModel) + ConnectionModel.objects.create(key=0, some_data='text0') + ConnectionModel.objects.create(key=1, some_data='text1') + self.assertEqual(ConnectionModel.objects(key=0)[0].some_data, 'text0') + + def test_execution_profile_insert_query(self): + cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)}) + session = cluster.connect() + connection.default() + connection.set_session(session) + self.assertEqual(connection.get_connection().cluster._config_mode, _ConfigMode.PROFILES) + + sync_table(ConnectionModel) + ConnectionModel.objects.create(key=0, some_data='text0') + ConnectionModel.objects.create(key=1, some_data='text1') + self.assertEqual(ConnectionModel.objects(key=0)[0].some_data, 'text0') diff --git a/tests/integration/cqlengine/management/__init__.py b/tests/integration/cqlengine/management/__init__.py new file mode 100644 index 0000000..386372e --- /dev/null +++ b/tests/integration/cqlengine/management/__init__.py @@ -0,0 +1,14 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/integration/cqlengine/management/test_compaction_settings.py b/tests/integration/cqlengine/management/test_compaction_settings.py new file mode 100644 index 0000000..d5dea12 --- /dev/null +++ b/tests/integration/cqlengine/management/test_compaction_settings.py @@ -0,0 +1,155 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from mock import patch +import six + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import drop_table, sync_table, _get_table_metadata, _update_options +from cassandra.cqlengine.models import Model + +from tests.integration.cqlengine.base import BaseCassEngTestCase + + +class LeveledCompactionTestTable(Model): + + __options__ = {'compaction': {'class': 'org.apache.cassandra.db.compaction.LeveledCompactionStrategy', + 'sstable_size_in_mb': '64'}} + + user_id = columns.UUID(primary_key=True) + name = columns.Text() + + +class AlterTableTest(BaseCassEngTestCase): + + def test_alter_is_called_table(self): + drop_table(LeveledCompactionTestTable) + sync_table(LeveledCompactionTestTable) + with patch('cassandra.cqlengine.management._update_options') as mock: + sync_table(LeveledCompactionTestTable) + assert mock.called == 1 + + def test_compaction_not_altered_without_changes_leveled(self): + + class LeveledCompactionChangesDetectionTest(Model): + + __options__ = {'compaction': {'class': 'org.apache.cassandra.db.compaction.LeveledCompactionStrategy', + 'sstable_size_in_mb': '160', + 'tombstone_threshold': '0.125', + 'tombstone_compaction_interval': '3600'}} + pk = columns.Integer(primary_key=True) + + drop_table(LeveledCompactionChangesDetectionTest) + sync_table(LeveledCompactionChangesDetectionTest) + + self.assertFalse(_update_options(LeveledCompactionChangesDetectionTest)) + + def test_compaction_not_altered_without_changes_sizetiered(self): + class SizeTieredCompactionChangesDetectionTest(Model): + + __options__ = {'compaction': {'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', + 'bucket_high': '20', + 'bucket_low': '10', + 'max_threshold': '200', + 'min_threshold': '100', + 'min_sstable_size': '1000', + 'tombstone_threshold': '0.125', + 'tombstone_compaction_interval': '3600'}} + pk = columns.Integer(primary_key=True) + + drop_table(SizeTieredCompactionChangesDetectionTest) + sync_table(SizeTieredCompactionChangesDetectionTest) + + self.assertFalse(_update_options(SizeTieredCompactionChangesDetectionTest)) + + def test_alter_actually_alters(self): + tmp = copy.deepcopy(LeveledCompactionTestTable) + drop_table(tmp) + sync_table(tmp) + tmp.__options__ = {'compaction': {'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy'}} + sync_table(tmp) + + table_meta = _get_table_metadata(tmp) + + self.assertRegexpMatches(table_meta.export_as_string(), '.*SizeTieredCompactionStrategy.*') + + def test_alter_options(self): + + class AlterTable(Model): + + __options__ = {'compaction': {'class': 'org.apache.cassandra.db.compaction.LeveledCompactionStrategy', + 'sstable_size_in_mb': '64'}} + user_id = columns.UUID(primary_key=True) + name = columns.Text() + + drop_table(AlterTable) + sync_table(AlterTable) + table_meta = _get_table_metadata(AlterTable) + self.assertRegexpMatches(table_meta.export_as_string(), ".*'sstable_size_in_mb': '64'.*") + AlterTable.__options__['compaction']['sstable_size_in_mb'] = '128' + sync_table(AlterTable) + table_meta = _get_table_metadata(AlterTable) + self.assertRegexpMatches(table_meta.export_as_string(), ".*'sstable_size_in_mb': '128'.*") + + +class OptionsTest(BaseCassEngTestCase): + + def _verify_options(self, table_meta, expected_options): + cql = table_meta.export_as_string() + + for name, value in expected_options.items(): + if isinstance(value, six.string_types): + self.assertIn("%s = '%s'" % (name, value), cql) + else: + start = cql.find("%s = {" % (name,)) + end = cql.find('}', start) + for subname, subvalue in value.items(): + attr = "'%s': '%s'" % (subname, subvalue) + found_at = cql.find(attr, start) + self.assertTrue(found_at > start) + self.assertTrue(found_at < end) + + def test_all_size_tiered_options(self): + class AllSizeTieredOptionsModel(Model): + __options__ = {'compaction': {'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', + 'bucket_low': '.3', + 'bucket_high': '2', + 'min_threshold': '2', + 'max_threshold': '64', + 'tombstone_compaction_interval': '86400'}} + + cid = columns.UUID(primary_key=True) + name = columns.Text() + + drop_table(AllSizeTieredOptionsModel) + sync_table(AllSizeTieredOptionsModel) + + table_meta = _get_table_metadata(AllSizeTieredOptionsModel) + self._verify_options(table_meta, AllSizeTieredOptionsModel.__options__) + + def test_all_leveled_options(self): + + class AllLeveledOptionsModel(Model): + __options__ = {'compaction': {'class': 'org.apache.cassandra.db.compaction.LeveledCompactionStrategy', + 'sstable_size_in_mb': '64'}} + + cid = columns.UUID(primary_key=True) + name = columns.Text() + + drop_table(AllLeveledOptionsModel) + sync_table(AllLeveledOptionsModel) + + table_meta = _get_table_metadata(AllLeveledOptionsModel) + self._verify_options(table_meta, AllLeveledOptionsModel.__options__) diff --git a/tests/integration/cqlengine/management/test_management.py b/tests/integration/cqlengine/management/test_management.py new file mode 100644 index 0000000..6b91760 --- /dev/null +++ b/tests/integration/cqlengine/management/test_management.py @@ -0,0 +1,480 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import mock +import logging +from packaging.version import Version + +from cassandra.cqlengine.connection import get_session, get_cluster +from cassandra.cqlengine import CQLEngineException +from cassandra.cqlengine import management +from cassandra.cqlengine.management import _get_table_metadata, sync_table, drop_table, sync_type +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns + +from tests.integration import PROTOCOL_VERSION, greaterthancass20, MockLoggingHandler, CASSANDRA_VERSION +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine.query.test_queryset import TestModel +from cassandra.cqlengine.usertype import UserType +from tests.integration.cqlengine import DEFAULT_KEYSPACE + + +class KeyspaceManagementTest(BaseCassEngTestCase): + def test_create_drop_succeeeds(self): + cluster = get_cluster() + + keyspace_ss = 'test_ks_ss' + self.assertNotIn(keyspace_ss, cluster.metadata.keyspaces) + management.create_keyspace_simple(keyspace_ss, 2) + self.assertIn(keyspace_ss, cluster.metadata.keyspaces) + + management.drop_keyspace(keyspace_ss) + self.assertNotIn(keyspace_ss, cluster.metadata.keyspaces) + + keyspace_nts = 'test_ks_nts' + self.assertNotIn(keyspace_nts, cluster.metadata.keyspaces) + management.create_keyspace_network_topology(keyspace_nts, {'dc1': 1}) + self.assertIn(keyspace_nts, cluster.metadata.keyspaces) + + management.drop_keyspace(keyspace_nts) + self.assertNotIn(keyspace_nts, cluster.metadata.keyspaces) + + +class DropTableTest(BaseCassEngTestCase): + + def test_multiple_deletes_dont_fail(self): + sync_table(TestModel) + + drop_table(TestModel) + drop_table(TestModel) + + +class LowercaseKeyModel(Model): + + first_key = columns.Integer(primary_key=True) + second_key = columns.Integer(primary_key=True) + some_data = columns.Text() + + +class CapitalizedKeyModel(Model): + + firstKey = columns.Integer(primary_key=True) + secondKey = columns.Integer(primary_key=True) + someData = columns.Text() + + +class PrimaryKeysOnlyModel(Model): + + __table_name__ = "primary_keys_only" + __options__ = {'compaction': {'class': 'LeveledCompactionStrategy'}} + + first_key = columns.Integer(primary_key=True) + second_key = columns.Integer(primary_key=True) + + +class PrimaryKeysModelChanged(Model): + + __table_name__ = "primary_keys_only" + __options__ = {'compaction': {'class': 'LeveledCompactionStrategy'}} + + new_first_key = columns.Integer(primary_key=True) + second_key = columns.Integer(primary_key=True) + + +class PrimaryKeysModelTypeChanged(Model): + + __table_name__ = "primary_keys_only" + __options__ = {'compaction': {'class': 'LeveledCompactionStrategy'}} + + first_key = columns.Float(primary_key=True) + second_key = columns.Integer(primary_key=True) + + +class PrimaryKeysRemovedPk(Model): + + __table_name__ = "primary_keys_only" + __options__ = {'compaction': {'class': 'LeveledCompactionStrategy'}} + + second_key = columns.Integer(primary_key=True) + + +class PrimaryKeysAddedClusteringKey(Model): + + __table_name__ = "primary_keys_only" + __options__ = {'compaction': {'class': 'LeveledCompactionStrategy'}} + + new_first_key = columns.Float(primary_key=True) + second_key = columns.Integer(primary_key=True) + + +class CapitalizedKeyTest(BaseCassEngTestCase): + + def test_table_definition(self): + """ Tests that creating a table with capitalized column names succeeds """ + sync_table(LowercaseKeyModel) + sync_table(CapitalizedKeyModel) + + drop_table(LowercaseKeyModel) + drop_table(CapitalizedKeyModel) + + +class FirstModel(Model): + + __table_name__ = 'first_model' + first_key = columns.UUID(primary_key=True) + second_key = columns.UUID() + third_key = columns.Text() + + +class SecondModel(Model): + + __table_name__ = 'first_model' + first_key = columns.UUID(primary_key=True) + second_key = columns.UUID() + third_key = columns.Text() + fourth_key = columns.Text() + + +class ThirdModel(Model): + + __table_name__ = 'first_model' + first_key = columns.UUID(primary_key=True) + second_key = columns.UUID() + third_key = columns.Text() + # removed fourth key, but it should stay in the DB + blah = columns.Map(columns.Text, columns.Text) + + +class FourthModel(Model): + + __table_name__ = 'first_model' + first_key = columns.UUID(primary_key=True) + second_key = columns.UUID() + third_key = columns.Text() + # renamed model field, but map to existing column + renamed = columns.Map(columns.Text, columns.Text, db_field='blah') + + +class AddColumnTest(BaseCassEngTestCase): + def setUp(self): + drop_table(FirstModel) + + def test_add_column(self): + sync_table(FirstModel) + meta_columns = _get_table_metadata(FirstModel).columns + self.assertEqual(set(meta_columns), set(FirstModel._columns)) + + sync_table(SecondModel) + meta_columns = _get_table_metadata(FirstModel).columns + self.assertEqual(set(meta_columns), set(SecondModel._columns)) + + sync_table(ThirdModel) + meta_columns = _get_table_metadata(FirstModel).columns + self.assertEqual(len(meta_columns), 5) + self.assertEqual(len(ThirdModel._columns), 4) + self.assertIn('fourth_key', meta_columns) + self.assertNotIn('fourth_key', ThirdModel._columns) + self.assertIn('blah', ThirdModel._columns) + self.assertIn('blah', meta_columns) + + sync_table(FourthModel) + meta_columns = _get_table_metadata(FirstModel).columns + self.assertEqual(len(meta_columns), 5) + self.assertEqual(len(ThirdModel._columns), 4) + self.assertIn('fourth_key', meta_columns) + self.assertNotIn('fourth_key', FourthModel._columns) + self.assertIn('renamed', FourthModel._columns) + self.assertNotIn('renamed', meta_columns) + self.assertIn('blah', meta_columns) + + +class ModelWithTableProperties(Model): + + __options__ = {'bloom_filter_fp_chance': '0.76328', + 'comment': 'TxfguvBdzwROQALmQBOziRMbkqVGFjqcJfVhwGR', + 'gc_grace_seconds': '2063', + 'read_repair_chance': '0.17985', + 'dclocal_read_repair_chance': '0.50811'} + + key = columns.UUID(primary_key=True) + + +class TablePropertiesTests(BaseCassEngTestCase): + + def setUp(self): + drop_table(ModelWithTableProperties) + + def test_set_table_properties(self): + + sync_table(ModelWithTableProperties) + expected = {'bloom_filter_fp_chance': 0.76328, + 'comment': 'TxfguvBdzwROQALmQBOziRMbkqVGFjqcJfVhwGR', + 'gc_grace_seconds': 2063, + 'read_repair_chance': 0.17985, + # For some reason 'dclocal_read_repair_chance' in CQL is called + # just 'local_read_repair_chance' in the schema table. + # Source: https://issues.apache.org/jira/browse/CASSANDRA-6717 + # TODO: due to a bug in the native driver i'm not seeing the local read repair chance show up + # 'local_read_repair_chance': 0.50811, + } + options = management._get_table_metadata(ModelWithTableProperties).options + self.assertEqual(dict([(k, options.get(k)) for k in expected.keys()]), + expected) + + def test_table_property_update(self): + ModelWithTableProperties.__options__['bloom_filter_fp_chance'] = 0.66778 + ModelWithTableProperties.__options__['comment'] = 'xirAkRWZVVvsmzRvXamiEcQkshkUIDINVJZgLYSdnGHweiBrAiJdLJkVohdRy' + ModelWithTableProperties.__options__['gc_grace_seconds'] = 96362 + + ModelWithTableProperties.__options__['read_repair_chance'] = 0.2989 + ModelWithTableProperties.__options__['dclocal_read_repair_chance'] = 0.12732 + + sync_table(ModelWithTableProperties) + + table_options = management._get_table_metadata(ModelWithTableProperties).options + + self.assertDictContainsSubset(ModelWithTableProperties.__options__, table_options) + + def test_bogus_option_update(self): + sync_table(ModelWithTableProperties) + option = 'no way will this ever be an option' + try: + ModelWithTableProperties.__options__[option] = 'what was I thinking?' + self.assertRaisesRegexp(KeyError, "Invalid table option.*%s.*" % option, sync_table, ModelWithTableProperties) + finally: + ModelWithTableProperties.__options__.pop(option, None) + + +class SyncTableTests(BaseCassEngTestCase): + + def setUp(self): + drop_table(PrimaryKeysOnlyModel) + + def test_sync_table_works_with_primary_keys_only_tables(self): + + sync_table(PrimaryKeysOnlyModel) + # blows up with DoesNotExist if table does not exist + table_meta = management._get_table_metadata(PrimaryKeysOnlyModel) + + self.assertIn('LeveledCompactionStrategy', table_meta.as_cql_query()) + + PrimaryKeysOnlyModel.__options__['compaction']['class'] = 'SizeTieredCompactionStrategy' + + sync_table(PrimaryKeysOnlyModel) + + table_meta = management._get_table_metadata(PrimaryKeysOnlyModel) + self.assertIn('SizeTieredCompactionStrategy', table_meta.as_cql_query()) + + def test_primary_key_validation(self): + """ + Test to ensure that changes to primary keys throw CQLEngineExceptions + + @since 3.2 + @jira_ticket PYTHON-532 + @expected_result Attempts to modify primary keys throw an exception + + @test_category object_mapper + """ + sync_table(PrimaryKeysOnlyModel) + self.assertRaises(CQLEngineException, sync_table, PrimaryKeysModelChanged) + self.assertRaises(CQLEngineException, sync_table, PrimaryKeysAddedClusteringKey) + self.assertRaises(CQLEngineException, sync_table, PrimaryKeysRemovedPk) + + +class IndexModel(Model): + + __table_name__ = 'index_model' + first_key = columns.UUID(primary_key=True) + second_key = columns.Text(index=True) + + +class IndexCaseSensitiveModel(Model): + + __table_name__ = 'IndexModel' + __table_name_case_sensitive__ = True + first_key = columns.UUID(primary_key=True) + second_key = columns.Text(index=True) + + +class BaseInconsistent(Model): + + __table_name__ = 'inconsistent' + first_key = columns.UUID(primary_key=True) + second_key = columns.Integer(index=True) + third_key = columns.Integer(index=True) + + +class ChangedInconsistent(Model): + + __table_name__ = 'inconsistent' + __table_name_case_sensitive__ = True + first_key = columns.UUID(primary_key=True) + second_key = columns.Text(index=True) + + +class BaseInconsistentType(UserType): + __type_name__ = 'type_inconsistent' + age = columns.Integer() + name = columns.Text() + + +class ChangedInconsistentType(UserType): + __type_name__ = 'type_inconsistent' + age = columns.Integer() + name = columns.Integer() + + +class InconsistentTable(BaseCassEngTestCase): + + def setUp(self): + drop_table(IndexModel) + + def test_sync_warnings(self): + """ + Test to insure when inconsistent changes are made to a table, or type as part of a sync call that the proper logging messages are surfaced + + @since 3.2 + @jira_ticket PYTHON-260 + @expected_result warnings are logged + + @test_category object_mapper + """ + mock_handler = MockLoggingHandler() + logger = logging.getLogger(management.__name__) + logger.addHandler(mock_handler) + sync_table(BaseInconsistent) + sync_table(ChangedInconsistent) + self.assertTrue('differing from the model type' in mock_handler.messages.get('warning')[0]) + if CASSANDRA_VERSION >= Version('2.1'): + sync_type(DEFAULT_KEYSPACE, BaseInconsistentType) + mock_handler.reset() + sync_type(DEFAULT_KEYSPACE, ChangedInconsistentType) + self.assertTrue('differing from the model user type' in mock_handler.messages.get('warning')[0]) + logger.removeHandler(mock_handler) + + +class TestIndexSetModel(Model): + partition = columns.UUID(primary_key=True) + int_set = columns.Set(columns.Integer, index=True) + int_list = columns.List(columns.Integer, index=True) + text_map = columns.Map(columns.Text, columns.DateTime, index=True) + mixed_tuple = columns.Tuple(columns.Text, columns.Integer, columns.Text, index=True) + + +class IndexTests(BaseCassEngTestCase): + + def setUp(self): + drop_table(IndexModel) + drop_table(IndexCaseSensitiveModel) + + def test_sync_index(self): + """ + Tests the default table creation, and ensures the table_name is created and surfaced correctly + in the table metadata + + @since 3.1 + @jira_ticket PYTHON-337 + @expected_result table_name is lower case + + @test_category object_mapper + """ + sync_table(IndexModel) + table_meta = management._get_table_metadata(IndexModel) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'second_key')) + + # index already exists + sync_table(IndexModel) + table_meta = management._get_table_metadata(IndexModel) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'second_key')) + + def test_sync_index_case_sensitive(self): + """ + Tests the default table creation, and ensures the table_name is created correctly and surfaced correctly + in table metadata + + @since 3.1 + @jira_ticket PYTHON-337 + @expected_result table_name is lower case + + @test_category object_mapper + """ + sync_table(IndexCaseSensitiveModel) + table_meta = management._get_table_metadata(IndexCaseSensitiveModel) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'second_key')) + + # index already exists + sync_table(IndexCaseSensitiveModel) + table_meta = management._get_table_metadata(IndexCaseSensitiveModel) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'second_key')) + + @greaterthancass20 + def test_sync_indexed_set(self): + """ + Tests that models that have container types with indices can be synced. + + @since 3.2 + @jira_ticket PYTHON-533 + @expected_result table_sync should complete without a server error. + + @test_category object_mapper + """ + sync_table(TestIndexSetModel) + table_meta = management._get_table_metadata(TestIndexSetModel) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'int_set')) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'int_list')) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'text_map')) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'mixed_tuple')) + + +class NonModelFailureTest(BaseCassEngTestCase): + class FakeModel(object): + pass + + def test_failure(self): + with self.assertRaises(CQLEngineException): + sync_table(self.FakeModel) + + +class StaticColumnTests(BaseCassEngTestCase): + def test_static_columns(self): + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest("Native protocol 2+ required, currently using: {0}".format(PROTOCOL_VERSION)) + + class StaticModel(Model): + id = columns.Integer(primary_key=True) + c = columns.Integer(primary_key=True) + name = columns.Text(static=True) + + drop_table(StaticModel) + + session = get_session() + + with mock.patch.object(session, "execute", wraps=session.execute) as m: + sync_table(StaticModel) + + self.assertGreater(m.call_count, 0) + statement = m.call_args[0][0].query_string + self.assertIn('"name" text static', statement) + + # if we sync again, we should not apply an alter w/ a static + sync_table(StaticModel) + + with mock.patch.object(session, "execute", wraps=session.execute) as m2: + sync_table(StaticModel) + + self.assertEqual(len(m2.call_args_list), 0) diff --git a/tests/integration/cqlengine/model/__init__.py b/tests/integration/cqlengine/model/__init__.py new file mode 100644 index 0000000..386372e --- /dev/null +++ b/tests/integration/cqlengine/model/__init__.py @@ -0,0 +1,14 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/integration/cqlengine/model/test_class_construction.py b/tests/integration/cqlengine/model/test_class_construction.py new file mode 100644 index 0000000..9c5afec --- /dev/null +++ b/tests/integration/cqlengine/model/test_class_construction.py @@ -0,0 +1,425 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import uuid4 +import warnings + +from cassandra.cqlengine import columns, CQLEngineException +from cassandra.cqlengine.models import Model, ModelException, ModelDefinitionException, ColumnQueryEvaluator +from cassandra.cqlengine.query import ModelQuerySet, DMLQuery + +from tests.integration.cqlengine.base import BaseCassEngTestCase + + +class TestModelClassFunction(BaseCassEngTestCase): + """ + Tests verifying the behavior of the Model metaclass + """ + + def test_column_attributes_handled_correctly(self): + """ + Tests that column attributes are moved to a _columns dict + and replaced with simple value attributes + """ + + class TestModel(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + text = columns.Text() + + # check class attibutes + self.assertHasAttr(TestModel, '_columns') + self.assertHasAttr(TestModel, 'id') + self.assertHasAttr(TestModel, 'text') + + # check instance attributes + inst = TestModel() + self.assertHasAttr(inst, 'id') + self.assertHasAttr(inst, 'text') + self.assertIsNotNone(inst.id) + self.assertIsNone(inst.text) + + def test_values_on_instantiation(self): + """ + Tests defaults and user-provided values on instantiation. + """ + + class TestPerson(Model): + first_name = columns.Text(primary_key=True, default='kevin') + last_name = columns.Text(default='deldycke') + + # Check that defaults are available at instantiation. + inst1 = TestPerson() + self.assertHasAttr(inst1, 'first_name') + self.assertHasAttr(inst1, 'last_name') + self.assertEqual(inst1.first_name, 'kevin') + self.assertEqual(inst1.last_name, 'deldycke') + + # Check that values on instantiation overrides defaults. + inst2 = TestPerson(first_name='bob', last_name='joe') + self.assertEqual(inst2.first_name, 'bob') + self.assertEqual(inst2.last_name, 'joe') + + def test_db_map(self): + """ + Tests that the db_map is properly defined + -the db_map allows columns + """ + class WildDBNames(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + content = columns.Text(db_field='words_and_whatnot') + numbers = columns.Integer(db_field='integers_etc') + + db_map = WildDBNames._db_map + self.assertEqual(db_map['words_and_whatnot'], 'content') + self.assertEqual(db_map['integers_etc'], 'numbers') + + def test_attempting_to_make_duplicate_column_names_fails(self): + """ + Tests that trying to create conflicting db column names will fail + """ + + with self.assertRaisesRegexp(ModelException, r".*more than once$"): + class BadNames(Model): + words = columns.Text(primary_key=True) + content = columns.Text(db_field='words') + + def test_column_ordering_is_preserved(self): + """ + Tests that the _columns dics retains the ordering of the class definition + """ + + class Stuff(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + words = columns.Text() + content = columns.Text() + numbers = columns.Integer() + + self.assertEqual([x for x in Stuff._columns.keys()], ['id', 'words', 'content', 'numbers']) + + def test_exception_raised_when_creating_class_without_pk(self): + with self.assertRaises(ModelDefinitionException): + class TestModel(Model): + + count = columns.Integer() + text = columns.Text(required=False) + + def test_value_managers_are_keeping_model_instances_isolated(self): + """ + Tests that instance value managers are isolated from other instances + """ + class Stuff(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + num = columns.Integer() + + inst1 = Stuff(num=5) + inst2 = Stuff(num=7) + + self.assertNotEqual(inst1.num, inst2.num) + self.assertEqual(inst1.num, 5) + self.assertEqual(inst2.num, 7) + + def test_superclass_fields_are_inherited(self): + """ + Tests that fields defined on the super class are inherited properly + """ + class TestModel(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + text = columns.Text() + + class InheritedModel(TestModel): + numbers = columns.Integer() + + assert 'text' in InheritedModel._columns + assert 'numbers' in InheritedModel._columns + + def test_column_family_name_generation(self): + """ Tests that auto column family name generation works as expected """ + class TestModel(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + text = columns.Text() + + assert TestModel.column_family_name(include_keyspace=False) == 'test_model' + + def test_partition_keys(self): + """ + Test compound partition key definition + """ + class ModelWithPartitionKeys(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + c1 = columns.Text(primary_key=True) + p1 = columns.Text(partition_key=True) + p2 = columns.Text(partition_key=True) + + cols = ModelWithPartitionKeys._columns + + self.assertTrue(cols['c1'].primary_key) + self.assertFalse(cols['c1'].partition_key) + + self.assertTrue(cols['p1'].primary_key) + self.assertTrue(cols['p1'].partition_key) + self.assertTrue(cols['p2'].primary_key) + self.assertTrue(cols['p2'].partition_key) + + obj = ModelWithPartitionKeys(p1='a', p2='b') + self.assertEqual(obj.pk, ('a', 'b')) + + def test_del_attribute_is_assigned_properly(self): + """ Tests that columns that can be deleted have the del attribute """ + class DelModel(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + key = columns.Integer(primary_key=True) + data = columns.Integer(required=False) + + model = DelModel(key=4, data=5) + del model.data + with self.assertRaises(AttributeError): + del model.key + + def test_does_not_exist_exceptions_are_not_shared_between_model(self): + """ Tests that DoesNotExist exceptions are not the same exception between models """ + + class Model1(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + + class Model2(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + + try: + raise Model1.DoesNotExist + except Model2.DoesNotExist: + assert False, "Model1 exception should not be caught by Model2" + except Model1.DoesNotExist: + # expected + pass + + def test_does_not_exist_inherits_from_superclass(self): + """ Tests that a DoesNotExist exception can be caught by it's parent class DoesNotExist """ + class Model1(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + + class Model2(Model1): + pass + + try: + raise Model2.DoesNotExist + except Model1.DoesNotExist: + # expected + pass + except Exception: + assert False, "Model2 exception should not be caught by Model1" + + def test_abstract_model_keyspace_warning_is_skipped(self): + with warnings.catch_warnings(record=True) as warn: + class NoKeyspace(Model): + __abstract__ = True + key = columns.UUID(primary_key=True) + + self.assertEqual(len(warn), 0) + + +class TestManualTableNaming(BaseCassEngTestCase): + + class RenamedTest(Model): + __keyspace__ = 'whatever' + __table_name__ = 'manual_name' + + id = columns.UUID(primary_key=True) + data = columns.Text() + + def test_proper_table_naming(self): + assert self.RenamedTest.column_family_name(include_keyspace=False) == 'manual_name' + assert self.RenamedTest.column_family_name(include_keyspace=True) == 'whatever.manual_name' + + +class TestManualTableNamingCaseSensitive(BaseCassEngTestCase): + + class RenamedCaseInsensitiveTest(Model): + __keyspace__ = 'whatever' + __table_name__ = 'Manual_Name' + + id = columns.UUID(primary_key=True) + + class RenamedCaseSensitiveTest(Model): + __keyspace__ = 'whatever' + __table_name__ = 'Manual_Name' + __table_name_case_sensitive__ = True + + id = columns.UUID(primary_key=True) + + def test_proper_table_naming_case_insensitive(self): + """ + Test to ensure case senstivity is not honored by default honored + + @since 3.1 + @jira_ticket PYTHON-337 + @expected_result table_names arel lowercase + + @test_category object_mapper + """ + self.assertEqual(self.RenamedCaseInsensitiveTest.column_family_name(include_keyspace=False), 'manual_name') + self.assertEqual(self.RenamedCaseInsensitiveTest.column_family_name(include_keyspace=True), 'whatever.manual_name') + + def test_proper_table_naming_case_sensitive(self): + """ + Test to ensure case is honored when the flag is correctly set. + + @since 3.1 + @jira_ticket PYTHON-337 + @expected_result table_name case is honored. + + @test_category object_mapper + """ + + self.assertEqual(self.RenamedCaseSensitiveTest.column_family_name(include_keyspace=False), '"Manual_Name"') + self.assertEqual(self.RenamedCaseSensitiveTest.column_family_name(include_keyspace=True), 'whatever."Manual_Name"') + + +class AbstractModel(Model): + __abstract__ = True + + +class ConcreteModel(AbstractModel): + pkey = columns.Integer(primary_key=True) + data = columns.Integer() + + +class AbstractModelWithCol(Model): + + __abstract__ = True + pkey = columns.Integer(primary_key=True) + + +class ConcreteModelWithCol(AbstractModelWithCol): + data = columns.Integer() + + +class AbstractModelWithFullCols(Model): + __abstract__ = True + + pkey = columns.Integer(primary_key=True) + data = columns.Integer() + + +class TestAbstractModelClasses(BaseCassEngTestCase): + + def test_id_field_is_not_created(self): + """ Tests that an id field is not automatically generated on abstract classes """ + assert not hasattr(AbstractModel, 'id') + assert not hasattr(AbstractModelWithCol, 'id') + + def test_id_field_is_not_created_on_subclass(self): + assert not hasattr(ConcreteModel, 'id') + + def test_abstract_attribute_is_not_inherited(self): + """ Tests that __abstract__ attribute is not inherited """ + assert not ConcreteModel.__abstract__ + assert not ConcreteModelWithCol.__abstract__ + + def test_attempting_to_save_abstract_model_fails(self): + """ Attempting to save a model from an abstract model should fail """ + with self.assertRaises(CQLEngineException): + AbstractModelWithFullCols.create(pkey=1, data=2) + + def test_attempting_to_create_abstract_table_fails(self): + """ Attempting to create a table from an abstract model should fail """ + from cassandra.cqlengine.management import sync_table + with self.assertRaises(CQLEngineException): + sync_table(AbstractModelWithFullCols) + + def test_attempting_query_on_abstract_model_fails(self): + """ Tests attempting to execute query with an abstract model fails """ + with self.assertRaises(CQLEngineException): + iter(AbstractModelWithFullCols.objects(pkey=5)).next() + + def test_abstract_columns_are_inherited(self): + """ Tests that columns defined in the abstract class are inherited into the concrete class """ + assert hasattr(ConcreteModelWithCol, 'pkey') + assert isinstance(ConcreteModelWithCol.pkey, ColumnQueryEvaluator) + assert isinstance(ConcreteModelWithCol._columns['pkey'], columns.Column) + + def test_concrete_class_table_creation_cycle(self): + """ Tests that models with inherited abstract classes can be created, and have io performed """ + from cassandra.cqlengine.management import sync_table, drop_table + sync_table(ConcreteModelWithCol) + + w1 = ConcreteModelWithCol.create(pkey=5, data=6) + w2 = ConcreteModelWithCol.create(pkey=6, data=7) + + r1 = ConcreteModelWithCol.get(pkey=5) + r2 = ConcreteModelWithCol.get(pkey=6) + + assert w1.pkey == r1.pkey + assert w1.data == r1.data + assert w2.pkey == r2.pkey + assert w2.data == r2.data + + drop_table(ConcreteModelWithCol) + + +class TestCustomQuerySet(BaseCassEngTestCase): + """ Tests overriding the default queryset class """ + + class TestException(Exception): pass + + def test_overriding_queryset(self): + + class QSet(ModelQuerySet): + def create(iself, **kwargs): + raise self.TestException + + class CQModel(Model): + __queryset__ = QSet + + part = columns.UUID(primary_key=True) + data = columns.Text() + + with self.assertRaises(self.TestException): + CQModel.create(part=uuid4(), data='s') + + def test_overriding_dmlqueryset(self): + + class DMLQ(DMLQuery): + def save(iself): + raise self.TestException + + class CDQModel(Model): + + __dmlquery__ = DMLQ + part = columns.UUID(primary_key=True) + data = columns.Text() + + with self.assertRaises(self.TestException): + CDQModel().save() + + +class TestCachedLengthIsNotCarriedToSubclasses(BaseCassEngTestCase): + def test_subclassing(self): + + length = len(ConcreteModelWithCol()) + + class AlreadyLoadedTest(ConcreteModelWithCol): + new_field = columns.Integer() + + self.assertGreater(len(AlreadyLoadedTest()), length) diff --git a/tests/integration/cqlengine/model/test_equality_operations.py b/tests/integration/cqlengine/model/test_equality_operations.py new file mode 100644 index 0000000..3b40ed4 --- /dev/null +++ b/tests/integration/cqlengine/model/test_equality_operations.py @@ -0,0 +1,68 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import uuid4 +from tests.integration.cqlengine.base import BaseCassEngTestCase + +from cassandra.cqlengine.management import sync_table +from cassandra.cqlengine.management import drop_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns + +class TestModel(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + +class TestEqualityOperators(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestEqualityOperators, cls).setUpClass() + sync_table(TestModel) + + def setUp(self): + super(TestEqualityOperators, self).setUp() + self.t0 = TestModel.create(count=5, text='words') + self.t1 = TestModel.create(count=5, text='words') + + @classmethod + def tearDownClass(cls): + super(TestEqualityOperators, cls).tearDownClass() + drop_table(TestModel) + + def test_an_instance_evaluates_as_equal_to_itself(self): + """ + """ + assert self.t0 == self.t0 + + def test_two_instances_referencing_the_same_rows_and_different_values_evaluate_not_equal(self): + """ + """ + t0 = TestModel.get(id=self.t0.id) + t0.text = 'bleh' + assert t0 != self.t0 + + def test_two_instances_referencing_the_same_rows_and_values_evaluate_equal(self): + """ + """ + t0 = TestModel.get(id=self.t0.id) + assert t0 == self.t0 + + def test_two_instances_referencing_different_rows_evaluate_to_not_equal(self): + """ + """ + assert self.t0 != self.t1 + diff --git a/tests/integration/cqlengine/model/test_model.py b/tests/integration/cqlengine/model/test_model.py new file mode 100644 index 0000000..81de0ea --- /dev/null +++ b/tests/integration/cqlengine/model/test_model.py @@ -0,0 +1,268 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from mock import patch + +from cassandra.cqlengine import columns, CQLEngineException +from cassandra.cqlengine.management import sync_table, drop_table, create_keyspace_simple, drop_keyspace +from cassandra.cqlengine import models +from cassandra.cqlengine.models import Model, ModelDefinitionException +from uuid import uuid1 +from tests.integration import pypy +from tests.integration.cqlengine.base import TestQueryUpdateModel + +class TestModel(unittest.TestCase): + """ Tests the non-io functionality of models """ + + def test_instance_equality(self): + """ tests the model equality functionality """ + class EqualityModel(Model): + + pk = columns.Integer(primary_key=True) + + m0 = EqualityModel(pk=0) + m1 = EqualityModel(pk=1) + + self.assertEqual(m0, m0) + self.assertNotEqual(m0, m1) + + def test_model_equality(self): + """ tests the model equality functionality """ + class EqualityModel0(Model): + + pk = columns.Integer(primary_key=True) + + class EqualityModel1(Model): + + kk = columns.Integer(primary_key=True) + + m0 = EqualityModel0(pk=0) + m1 = EqualityModel1(kk=1) + + self.assertEqual(m0, m0) + self.assertNotEqual(m0, m1) + + def test_keywords_as_names(self): + """ + Test for CQL keywords as names + + test_keywords_as_names tests that CQL keywords are properly and automatically quoted in cqlengine. It creates + a keyspace, keyspace, which should be automatically quoted to "keyspace" in CQL. It then creates a table, table, + which should also be automatically quoted to "table". It then verfies that operations can be done on the + "keyspace"."table" which has been created. It also verifies that table alternations work and operations can be + performed on the altered table. + + @since 2.6.0 + @jira_ticket PYTHON-244 + @expected_result Cqlengine should quote CQL keywords properly when creating keyspaces and tables. + + @test_category schema:generation + """ + + # If the keyspace exists, it will not be re-created + create_keyspace_simple('keyspace', 1) + + class table(Model): + __keyspace__ = 'keyspace' + select = columns.Integer(primary_key=True) + table = columns.Text() + + # In case the table already exists in keyspace + drop_table(table) + + # Create should work + sync_table(table) + + created = table.create(select=0, table='table') + selected = table.objects(select=0)[0] + self.assertEqual(created.select, selected.select) + self.assertEqual(created.table, selected.table) + + # Alter should work + class table(Model): + __keyspace__ = 'keyspace' + select = columns.Integer(primary_key=True) + table = columns.Text() + where = columns.Text() + + sync_table(table) + + created = table.create(select=1, table='table') + selected = table.objects(select=1)[0] + self.assertEqual(created.select, selected.select) + self.assertEqual(created.table, selected.table) + self.assertEqual(created.where, selected.where) + + drop_keyspace('keyspace') + + def test_column_family(self): + class TestModel(Model): + k = columns.Integer(primary_key=True) + + # no model keyspace uses default + self.assertEqual(TestModel.column_family_name(), "%s.test_model" % (models.DEFAULT_KEYSPACE,)) + + # model keyspace overrides + TestModel.__keyspace__ = "my_test_keyspace" + self.assertEqual(TestModel.column_family_name(), "%s.test_model" % (TestModel.__keyspace__,)) + + # neither set should raise CQLEngineException before failing or formatting an invalid name + del TestModel.__keyspace__ + with patch('cassandra.cqlengine.models.DEFAULT_KEYSPACE', None): + self.assertRaises(CQLEngineException, TestModel.column_family_name) + # .. but we can still get the bare CF name + self.assertEqual(TestModel.column_family_name(include_keyspace=False), "test_model") + + def test_column_family_case_sensitive(self): + """ + Test to ensure case sensitivity is honored when __table_name_case_sensitive__ flag is set + + @since 3.1 + @jira_ticket PYTHON-337 + @expected_result table_name case is respected + + @test_category object_mapper + """ + class TestModel(Model): + __table_name__ = 'TestModel' + __table_name_case_sensitive__ = True + + k = columns.Integer(primary_key=True) + + self.assertEqual(TestModel.column_family_name(), '%s."TestModel"' % (models.DEFAULT_KEYSPACE,)) + + TestModel.__keyspace__ = "my_test_keyspace" + self.assertEqual(TestModel.column_family_name(), '%s."TestModel"' % (TestModel.__keyspace__,)) + + del TestModel.__keyspace__ + with patch('cassandra.cqlengine.models.DEFAULT_KEYSPACE', None): + self.assertRaises(CQLEngineException, TestModel.column_family_name) + self.assertEqual(TestModel.column_family_name(include_keyspace=False), '"TestModel"') + + +class BuiltInAttributeConflictTest(unittest.TestCase): + """tests Model definitions that conflict with built-in attributes/methods""" + + def test_model_with_attribute_name_conflict(self): + """should raise exception when model defines column that conflicts with built-in attribute""" + with self.assertRaises(ModelDefinitionException): + class IllegalTimestampColumnModel(Model): + + my_primary_key = columns.Integer(primary_key=True) + timestamp = columns.BigInt() + + def test_model_with_method_name_conflict(self): + """should raise exception when model defines column that conflicts with built-in method""" + with self.assertRaises(ModelDefinitionException): + class IllegalFilterColumnModel(Model): + + my_primary_key = columns.Integer(primary_key=True) + filter = columns.Text() + + +class ModelOverWriteTest(unittest.TestCase): + + def test_model_over_write(self): + """ + Test to ensure overwriting of primary keys in model inheritance is allowed + + This is currently only an issue in PyPy. When PYTHON-504 is introduced this should + be updated error out and warn the user + + @since 3.6.0 + @jira_ticket PYTHON-576 + @expected_result primary keys can be overwritten via inheritance + + @test_category object_mapper + """ + class TimeModelBase(Model): + uuid = columns.TimeUUID(primary_key=True) + + class DerivedTimeModel(TimeModelBase): + __table_name__ = 'derived_time' + uuid = columns.TimeUUID(primary_key=True, partition_key=True) + value = columns.Text(required=False) + + # In case the table already exists in keyspace + drop_table(DerivedTimeModel) + + sync_table(DerivedTimeModel) + uuid_value = uuid1() + uuid_value2 = uuid1() + DerivedTimeModel.create(uuid=uuid_value, value="first") + DerivedTimeModel.create(uuid=uuid_value2, value="second") + DerivedTimeModel.objects.filter(uuid=uuid_value) + + +class TestColumnComparison(unittest.TestCase): + def test_comparison(self): + l = [TestQueryUpdateModel.partition.column, + TestQueryUpdateModel.cluster.column, + TestQueryUpdateModel.count.column, + TestQueryUpdateModel.text.column, + TestQueryUpdateModel.text_set.column, + TestQueryUpdateModel.text_list.column, + TestQueryUpdateModel.text_map.column] + + self.assertEqual(l, sorted(l)) + self.assertNotEqual(TestQueryUpdateModel.partition.column, TestQueryUpdateModel.cluster.column) + self.assertLessEqual(TestQueryUpdateModel.partition.column, TestQueryUpdateModel.cluster.column) + self.assertGreater(TestQueryUpdateModel.cluster.column, TestQueryUpdateModel.partition.column) + self.assertGreaterEqual(TestQueryUpdateModel.cluster.column, TestQueryUpdateModel.partition.column) + + +class TestDeprecationWarning(unittest.TestCase): + def test_deprecation_warnings(self): + """ + Test to some deprecation warning have been added. It tests warnings for + negative index, negative index slicing and table sensitive removal + + This test should be removed in 4.0, that's why the imports are in + this test, so it's easier to remove + + @since 3.13 + @jira_ticket PYTHON-877 + @expected_result the deprecation warnings are emitted + + @test_category logs + """ + import warnings + + class SensitiveModel(Model): + __table_name__ = 'SensitiveModel' + __table_name_case_sensitive__ = True + k = columns.Integer(primary_key=True) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + sync_table(SensitiveModel) + self.addCleanup(drop_table, SensitiveModel) + + SensitiveModel.create(k=0) + + rows = SensitiveModel.objects().all().allow_filtering() + rows[-1] + rows[-1:] + + self.assertEqual(len(w), 4) + self.assertIn("__table_name_case_sensitive__ will be removed in 4.0.", str(w[0].message)) + self.assertIn("__table_name_case_sensitive__ will be removed in 4.0.", str(w[1].message)) + self.assertIn("ModelQuerySet indexing with negative indices support will be removed in 4.0.", + str(w[2].message)) + self.assertIn("ModelQuerySet slicing with negative indices support will be removed in 4.0.", + str(w[3].message)) diff --git a/tests/integration/cqlengine/model/test_model_io.py b/tests/integration/cqlengine/model/test_model_io.py new file mode 100644 index 0000000..32ace53 --- /dev/null +++ b/tests/integration/cqlengine/model/test_model_io.py @@ -0,0 +1,914 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from uuid import uuid4, UUID +import random +from datetime import datetime, date, time +from decimal import Decimal +from operator import itemgetter + +import cassandra +from cassandra.cqlengine import columns +from cassandra.cqlengine import CQLEngineException +from cassandra.cqlengine.management import sync_table +from cassandra.cqlengine.management import drop_table +from cassandra.cqlengine.models import Model +from cassandra.query import SimpleStatement +from cassandra.util import Date, Time, Duration +from cassandra.cqlengine.statements import SelectStatement, DeleteStatement, WhereClause +from cassandra.cqlengine.operators import EqualsOperator + +from tests.integration import PROTOCOL_VERSION, greaterthanorequalcass3_10 +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine import DEFAULT_KEYSPACE + + +class TestModel(Model): + + id = columns.UUID(primary_key=True, default=lambda: uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + a_bool = columns.Boolean(default=False) + + +class TestModelSave(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.Integer(primary_key=True) + count = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + text_set = columns.Set(columns.Text, required=False) + text_list = columns.List(columns.Text, required=False) + text_map = columns.Map(columns.Text, columns.Text, required=False) + + +class TestModelIO(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestModelIO, cls).setUpClass() + sync_table(TestModel) + + @classmethod + def tearDownClass(cls): + super(TestModelIO, cls).tearDownClass() + drop_table(TestModel) + + def test_model_save_and_load(self): + """ + Tests that models can be saved and retrieved, using the create method. + """ + tm = TestModel.create(count=8, text='123456789') + self.assertIsInstance(tm, TestModel) + + tm2 = TestModel.objects(id=tm.pk).first() + self.assertIsInstance(tm2, TestModel) + + for cname in tm._columns.keys(): + self.assertEqual(getattr(tm, cname), getattr(tm2, cname)) + + def test_model_instantiation_save_and_load(self): + """ + Tests that models can be saved and retrieved, this time using the + natural model instantiation. + """ + tm = TestModel(count=8, text='123456789') + # Tests that values are available on instantiation. + self.assertIsNotNone(tm['id']) + self.assertEqual(tm.count, 8) + self.assertEqual(tm.text, '123456789') + tm.save() + tm2 = TestModel.objects(id=tm.id).first() + + for cname in tm._columns.keys(): + self.assertEqual(getattr(tm, cname), getattr(tm2, cname)) + + def test_model_read_as_dict(self): + """ + Tests that columns of an instance can be read as a dict. + """ + tm = TestModel.create(count=8, text='123456789', a_bool=True) + column_dict = { + 'id': tm.id, + 'count': tm.count, + 'text': tm.text, + 'a_bool': tm.a_bool, + } + self.assertEqual(sorted(tm.keys()), sorted(column_dict.keys())) + + self.assertSetEqual(set(tm.values()), set(column_dict.values())) + self.assertEqual( + sorted(tm.items(), key=itemgetter(0)), + sorted(column_dict.items(), key=itemgetter(0))) + self.assertEqual(len(tm), len(column_dict)) + for column_id in column_dict.keys(): + self.assertEqual(tm[column_id], column_dict[column_id]) + + tm['count'] = 6 + self.assertEqual(tm.count, 6) + + def test_model_updating_works_properly(self): + """ + Tests that subsequent saves after initial model creation work + """ + tm = TestModel.objects.create(count=8, text='123456789') + + tm.count = 100 + tm.a_bool = True + tm.save() + + tm2 = TestModel.objects(id=tm.pk).first() + self.assertEqual(tm.count, tm2.count) + self.assertEqual(tm.a_bool, tm2.a_bool) + + def test_model_deleting_works_properly(self): + """ + Tests that an instance's delete method deletes the instance + """ + tm = TestModel.create(count=8, text='123456789') + tm.delete() + tm2 = TestModel.objects(id=tm.pk).first() + self.assertIsNone(tm2) + + def test_column_deleting_works_properly(self): + """ + """ + tm = TestModel.create(count=8, text='123456789') + tm.text = None + tm.save() + + tm2 = TestModel.objects(id=tm.pk).first() + self.assertIsInstance(tm2, TestModel) + + self.assertTrue(tm2.text is None) + self.assertTrue(tm2._values['text'].previous_value is None) + + def test_a_sensical_error_is_raised_if_you_try_to_create_a_table_twice(self): + """ + """ + sync_table(TestModel) + sync_table(TestModel) + + @greaterthanorequalcass3_10 + def test_can_insert_model_with_all_column_types(self): + """ + Test for inserting all column types into a Model + + test_can_insert_model_with_all_column_types tests that each cqlengine column type can be inserted into a Model. + It first creates a Model that has each cqlengine column type. It then creates a Model instance where all the fields + have corresponding data, which performs the insert into the Cassandra table. + Finally, it verifies that each column read from the Model from Cassandra is the same as the input parameters. + + @since 2.6.0 + @jira_ticket PYTHON-246 + @expected_result The Model is inserted with each column type, and the resulting read yields proper data for each column. + + @test_category data_types:primitive + """ + + class AllDatatypesModel(Model): + id = columns.Integer(primary_key=True) + a = columns.Ascii() + b = columns.BigInt() + c = columns.Blob() + d = columns.Boolean() + e = columns.DateTime() + f = columns.Decimal() + g = columns.Double() + h = columns.Float() + i = columns.Inet() + j = columns.Integer() + k = columns.Text() + l = columns.TimeUUID() + m = columns.UUID() + n = columns.VarInt() + o = columns.Duration() + + sync_table(AllDatatypesModel) + + input = ['ascii', 2 ** 63 - 1, bytearray(b'hello world'), True, datetime.utcfromtimestamp(872835240), + Decimal('12.3E+7'), 2.39, 3.4028234663852886e+38, '123.123.123.123', 2147483647, 'text', + UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66'), UUID('067e6162-3b6f-4ae2-a171-2470b63dff00'), + int(str(2147483647) + '000')] + + AllDatatypesModel.create(id=0, a='ascii', b=2 ** 63 - 1, c=bytearray(b'hello world'), d=True, + e=datetime.utcfromtimestamp(872835240), f=Decimal('12.3E+7'), g=2.39, + h=3.4028234663852886e+38, i='123.123.123.123', j=2147483647, k='text', + l=UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66'), + m=UUID('067e6162-3b6f-4ae2-a171-2470b63dff00'), n=int(str(2147483647) + '000'), + o=Duration(2, 3, 4)) + + self.assertEqual(1, AllDatatypesModel.objects.count()) + output = AllDatatypesModel.objects.first() + + for i, i_char in enumerate(range(ord('a'), ord('a') + 14)): + self.assertEqual(input[i], output[chr(i_char)]) + + def test_can_specify_none_instead_of_default(self): + self.assertIsNotNone(TestModel.a_bool.column.default) + + # override default + inst = TestModel.create(a_bool=None) + self.assertIsNone(inst.a_bool) + queried = TestModel.objects(id=inst.id).first() + self.assertIsNone(queried.a_bool) + + # letting default be set + inst = TestModel.create() + self.assertEqual(inst.a_bool, TestModel.a_bool.column.default) + queried = TestModel.objects(id=inst.id).first() + self.assertEqual(queried.a_bool, TestModel.a_bool.column.default) + + def test_can_insert_model_with_all_protocol_v4_column_types(self): + """ + Test for inserting all protocol v4 column types into a Model + + test_can_insert_model_with_all_protocol_v4_column_types tests that each cqlengine protocol v4 column type can be + inserted into a Model. It first creates a Model that has each cqlengine protocol v4 column type. It then creates + a Model instance where all the fields have corresponding data, which performs the insert into the Cassandra table. + Finally, it verifies that each column read from the Model from Cassandra is the same as the input parameters. + + @since 2.6.0 + @jira_ticket PYTHON-245 + @expected_result The Model is inserted with each protocol v4 column type, and the resulting read yields proper data for each column. + + @test_category data_types:primitive + """ + + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) + + class v4DatatypesModel(Model): + id = columns.Integer(primary_key=True) + a = columns.Date() + b = columns.SmallInt() + c = columns.Time() + d = columns.TinyInt() + + sync_table(v4DatatypesModel) + + input = [Date(date(1970, 1, 1)), 32523, Time(time(16, 47, 25, 7)), 123] + + v4DatatypesModel.create(id=0, a=date(1970, 1, 1), b=32523, c=time(16, 47, 25, 7), d=123) + + self.assertEqual(1, v4DatatypesModel.objects.count()) + output = v4DatatypesModel.objects.first() + + for i, i_char in enumerate(range(ord('a'), ord('a') + 3)): + self.assertEqual(input[i], output[chr(i_char)]) + + def test_can_insert_double_and_float(self): + """ + Test for inserting single-precision and double-precision values into a Float and Double columns + + @since 2.6.0 + @changed 3.0.0 removed deprecated Float(double_precision) parameter + @jira_ticket PYTHON-246 + @expected_result Each floating point column type is able to hold their respective precision values. + + @test_category data_types:primitive + """ + + class FloatingPointModel(Model): + id = columns.Integer(primary_key=True) + f = columns.Float() + d = columns.Double() + + sync_table(FloatingPointModel) + + FloatingPointModel.create(id=0, f=2.39) + output = FloatingPointModel.objects.first() + self.assertEqual(2.390000104904175, output.f) # float loses precision + + FloatingPointModel.create(id=0, f=3.4028234663852886e+38, d=2.39) + output = FloatingPointModel.objects.first() + self.assertEqual(3.4028234663852886e+38, output.f) + self.assertEqual(2.39, output.d) # double retains precision + + FloatingPointModel.create(id=0, d=3.4028234663852886e+38) + output = FloatingPointModel.objects.first() + self.assertEqual(3.4028234663852886e+38, output.d) + + +class TestMultiKeyModel(Model): + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer(required=False) + text = columns.Text(required=False) + + +class TestDeleting(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestDeleting, cls).setUpClass() + drop_table(TestMultiKeyModel) + sync_table(TestMultiKeyModel) + + @classmethod + def tearDownClass(cls): + super(TestDeleting, cls).tearDownClass() + drop_table(TestMultiKeyModel) + + def test_deleting_only_deletes_one_object(self): + partition = random.randint(0, 1000) + for i in range(5): + TestMultiKeyModel.create(partition=partition, cluster=i, count=i, text=str(i)) + + self.assertTrue(TestMultiKeyModel.filter(partition=partition).count() == 5) + + TestMultiKeyModel.get(partition=partition, cluster=0).delete() + + self.assertTrue(TestMultiKeyModel.filter(partition=partition).count() == 4) + + TestMultiKeyModel.filter(partition=partition).delete() + + +class TestUpdating(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestUpdating, cls).setUpClass() + drop_table(TestModelSave) + drop_table(TestMultiKeyModel) + sync_table(TestModelSave) + sync_table(TestMultiKeyModel) + + @classmethod + def tearDownClass(cls): + super(TestUpdating, cls).tearDownClass() + drop_table(TestMultiKeyModel) + drop_table(TestModelSave) + + def setUp(self): + super(TestUpdating, self).setUp() + self.instance = TestMultiKeyModel.create( + partition=random.randint(0, 1000), + cluster=random.randint(0, 1000), + count=0, + text='happy' + ) + + def test_vanilla_update(self): + self.instance.count = 5 + self.instance.save() + + check = TestMultiKeyModel.get(partition=self.instance.partition, cluster=self.instance.cluster) + self.assertTrue(check.count == 5) + self.assertTrue(check.text == 'happy') + + def test_deleting_only(self): + self.instance.count = None + self.instance.text = None + self.instance.save() + + check = TestMultiKeyModel.get(partition=self.instance.partition, cluster=self.instance.cluster) + self.assertTrue(check.count is None) + self.assertTrue(check.text is None) + + def test_get_changed_columns(self): + self.assertTrue(self.instance.get_changed_columns() == []) + self.instance.count = 1 + changes = self.instance.get_changed_columns() + self.assertTrue(len(changes) == 1) + self.assertTrue(changes == ['count']) + self.instance.save() + self.assertTrue(self.instance.get_changed_columns() == []) + + def test_previous_value_tracking_of_persisted_instance(self): + # Check initial internal states. + self.assertTrue(self.instance.get_changed_columns() == []) + self.assertTrue(self.instance._values['count'].previous_value == 0) + + # Change value and check internal states. + self.instance.count = 1 + self.assertTrue(self.instance.get_changed_columns() == ['count']) + self.assertTrue(self.instance._values['count'].previous_value == 0) + + # Internal states should be updated on save. + self.instance.save() + self.assertTrue(self.instance.get_changed_columns() == []) + self.assertTrue(self.instance._values['count'].previous_value == 1) + + # Change value twice. + self.instance.count = 2 + self.assertTrue(self.instance.get_changed_columns() == ['count']) + self.assertTrue(self.instance._values['count'].previous_value == 1) + self.instance.count = 3 + self.assertTrue(self.instance.get_changed_columns() == ['count']) + self.assertTrue(self.instance._values['count'].previous_value == 1) + + # Internal states updated on save. + self.instance.save() + self.assertTrue(self.instance.get_changed_columns() == []) + self.assertTrue(self.instance._values['count'].previous_value == 3) + + # Change value and reset it. + self.instance.count = 2 + self.assertTrue(self.instance.get_changed_columns() == ['count']) + self.assertTrue(self.instance._values['count'].previous_value == 3) + self.instance.count = 3 + self.assertTrue(self.instance.get_changed_columns() == []) + self.assertTrue(self.instance._values['count'].previous_value == 3) + + # Nothing to save: values in initial conditions. + self.instance.save() + self.assertTrue(self.instance.get_changed_columns() == []) + self.assertTrue(self.instance._values['count'].previous_value == 3) + + # Change Multiple values + self.instance.count = 4 + self.instance.text = "changed" + self.assertTrue(len(self.instance.get_changed_columns()) == 2) + self.assertTrue('text' in self.instance.get_changed_columns()) + self.assertTrue('count' in self.instance.get_changed_columns()) + self.instance.save() + self.assertTrue(self.instance.get_changed_columns() == []) + + # Reset Multiple Values + self.instance.count = 5 + self.instance.text = "changed" + self.assertTrue(self.instance.get_changed_columns() == ['count']) + self.instance.text = "changed2" + self.assertTrue(len(self.instance.get_changed_columns()) == 2) + self.assertTrue('text' in self.instance.get_changed_columns()) + self.assertTrue('count' in self.instance.get_changed_columns()) + self.instance.count = 4 + self.instance.text = "changed" + self.assertTrue(self.instance.get_changed_columns() == []) + + def test_previous_value_tracking_on_instantiation(self): + self.instance = TestMultiKeyModel( + partition=random.randint(0, 1000), + cluster=random.randint(0, 1000), + count=0, + text='happy') + + # Columns of instances not persisted yet should be marked as changed. + self.assertTrue(set(self.instance.get_changed_columns()) == set([ + 'partition', 'cluster', 'count', 'text'])) + self.assertTrue(self.instance._values['partition'].previous_value is None) + self.assertTrue(self.instance._values['cluster'].previous_value is None) + self.assertTrue(self.instance._values['count'].previous_value is None) + self.assertTrue(self.instance._values['text'].previous_value is None) + + # Value changes doesn't affect internal states. + self.instance.count = 1 + self.assertTrue('count' in self.instance.get_changed_columns()) + self.assertTrue(self.instance._values['count'].previous_value is None) + self.instance.count = 2 + self.assertTrue('count' in self.instance.get_changed_columns()) + self.assertTrue(self.instance._values['count'].previous_value is None) + + # Value reset is properly tracked. + self.instance.count = None + self.assertTrue('count' not in self.instance.get_changed_columns()) + self.assertTrue(self.instance._values['count'].previous_value is None) + + self.instance.save() + self.assertTrue(self.instance.get_changed_columns() == []) + self.assertTrue(self.instance._values['count'].previous_value is None) + self.assertTrue(self.instance.count is None) + + def test_previous_value_tracking_on_instantiation_with_default(self): + + class TestDefaultValueTracking(Model): + id = columns.Integer(partition_key=True) + int1 = columns.Integer(default=123) + int2 = columns.Integer(default=456) + int3 = columns.Integer(default=lambda: random.randint(0, 1000)) + int4 = columns.Integer(default=lambda: random.randint(0, 1000)) + int5 = columns.Integer() + int6 = columns.Integer() + + instance = TestDefaultValueTracking( + id=1, + int1=9999, + int3=7777, + int5=5555) + + self.assertEqual(instance.id, 1) + self.assertEqual(instance.int1, 9999) + self.assertEqual(instance.int2, 456) + self.assertEqual(instance.int3, 7777) + self.assertIsNotNone(instance.int4) + self.assertIsInstance(instance.int4, int) + self.assertGreaterEqual(instance.int4, 0) + self.assertLessEqual(instance.int4, 1000) + self.assertEqual(instance.int5, 5555) + self.assertTrue(instance.int6 is None) + + # All previous values are unset as the object hasn't been persisted + # yet. + self.assertTrue(instance._values['id'].previous_value is None) + self.assertTrue(instance._values['int1'].previous_value is None) + self.assertTrue(instance._values['int2'].previous_value is None) + self.assertTrue(instance._values['int3'].previous_value is None) + self.assertTrue(instance._values['int4'].previous_value is None) + self.assertTrue(instance._values['int5'].previous_value is None) + self.assertTrue(instance._values['int6'].previous_value is None) + + # All explicitely set columns, and those with default values are + # flagged has changed. + self.assertTrue(set(instance.get_changed_columns()) == set([ + 'id', 'int1', 'int3', 'int5'])) + + def test_save_to_none(self): + """ + Test update of column value of None with save() function. + + Under specific scenarios calling save on a None value wouldn't update + previous values. This issue only manifests with a new instantiation of the model, + if existing model is modified and updated the issue will not occur. + + @since 3.0.0 + @jira_ticket PYTHON-475 + @expected_result column value should be updated to None + + @test_category object_mapper + """ + + partition = uuid4() + cluster = 1 + text = 'set' + text_list = ['set'] + text_set = set(("set",)) + text_map = {"set": 'set'} + initial = TestModelSave(partition=partition, cluster=cluster, text=text, text_list=text_list, + text_set=text_set, text_map=text_map) + initial.save() + current = TestModelSave.objects.get(partition=partition, cluster=cluster) + self.assertEqual(current.text, text) + self.assertEqual(current.text_list, text_list) + self.assertEqual(current.text_set, text_set) + self.assertEqual(current.text_map, text_map) + + next = TestModelSave(partition=partition, cluster=cluster, text=None, text_list=None, + text_set=None, text_map=None) + + next.save() + current = TestModelSave.objects.get(partition=partition, cluster=cluster) + self.assertEqual(current.text, None) + self.assertEqual(current.text_list, []) + self.assertEqual(current.text_set, set()) + self.assertEqual(current.text_map, {}) + + +def test_none_filter_fails(): + class NoneFilterModel(Model): + + pk = columns.Integer(primary_key=True) + v = columns.Integer() + sync_table(NoneFilterModel) + + try: + NoneFilterModel.objects(pk=None) + raise Exception("fail") + except CQLEngineException as e: + pass + + +class TestCanUpdate(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestCanUpdate, cls).setUpClass() + drop_table(TestModel) + sync_table(TestModel) + + @classmethod + def tearDownClass(cls): + super(TestCanUpdate, cls).tearDownClass() + drop_table(TestModel) + + def test_success_case(self): + tm = TestModel(count=8, text='123456789') + + # object hasn't been saved, + # shouldn't be able to update + self.assertTrue(not tm._is_persisted) + self.assertTrue(not tm._can_update()) + + tm.save() + + # object has been saved, + # should be able to update + self.assertTrue(tm._is_persisted) + self.assertTrue(tm._can_update()) + + tm.count = 200 + + # primary keys haven't changed, + # should still be able to update + self.assertTrue(tm._can_update()) + tm.save() + + tm.id = uuid4() + + # primary keys have changed, + # should not be able to update + self.assertTrue(not tm._can_update()) + + +class IndexDefinitionModel(Model): + + key = columns.UUID(primary_key=True) + val = columns.Text(index=True) + + +class TestIndexedColumnDefinition(BaseCassEngTestCase): + + def test_exception_isnt_raised_if_an_index_is_defined_more_than_once(self): + sync_table(IndexDefinitionModel) + sync_table(IndexDefinitionModel) + + +class ReservedWordModel(Model): + + token = columns.Text(primary_key=True) + insert = columns.Integer(index=True) + + +class TestQueryQuoting(BaseCassEngTestCase): + + def test_reserved_cql_words_can_be_used_as_column_names(self): + """ + """ + sync_table(ReservedWordModel) + + model1 = ReservedWordModel.create(token='1', insert=5) + + model2 = ReservedWordModel.filter(token='1') + + self.assertTrue(len(model2) == 1) + self.assertTrue(model1.token == model2[0].token) + self.assertTrue(model1.insert == model2[0].insert) + + +class TestQueryModel(Model): + + test_id = columns.UUID(primary_key=True, default=uuid4) + date = columns.Date(primary_key=True) + description = columns.Text() + + +class TestQuerying(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + if PROTOCOL_VERSION < 4: + return + + super(TestQuerying, cls).setUpClass() + drop_table(TestQueryModel) + sync_table(TestQueryModel) + + @classmethod + def tearDownClass(cls): + if PROTOCOL_VERSION < 4: + return + + super(TestQuerying, cls).tearDownClass() + drop_table(TestQueryModel) + + def setUp(self): + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Date query tests require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) + + def test_query_with_date(self): + uid = uuid4() + day = date(2013, 11, 26) + obj = TestQueryModel.create(test_id=uid, date=day, description=u'foo') + + self.assertEqual(obj.description, u'foo') + + inst = TestQueryModel.filter( + TestQueryModel.test_id == uid, + TestQueryModel.date == day).limit(1).first() + + self.assertTrue(inst.test_id == uid) + self.assertTrue(inst.date == day) + + +class BasicModelNoRouting(Model): + __table_name__ = 'basic_model_no_routing' + __compute_routing_key__ = False + k = columns.Integer(primary_key=True) + v = columns.Integer() + + +class BasicModel(Model): + __table_name__ = 'basic_model_routing' + k = columns.Integer(primary_key=True) + v = columns.Integer() + + +class BasicModelMulti(Model): + __table_name__ = 'basic_model_routing_multi' + k = columns.Integer(partition_key=True) + v = columns.Integer(partition_key=True) + + +class ComplexModelRouting(Model): + __table_name__ = 'complex_model_routing' + partition = columns.UUID(partition_key=True, default=uuid4) + cluster = columns.Integer(partition_key=True) + count = columns.Integer() + text = columns.Text(partition_key=True) + float = columns.Float(partition_key=True) + text_2 = columns.Text() + + +class TestModelRoutingKeys(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestModelRoutingKeys, cls).setUpClass() + sync_table(BasicModelNoRouting) + sync_table(BasicModel) + sync_table(BasicModelMulti) + sync_table(ComplexModelRouting) + + @classmethod + def tearDownClass(cls): + super(TestModelRoutingKeys, cls).tearDownClass() + drop_table(BasicModelNoRouting) + drop_table(BasicModel) + drop_table(BasicModelMulti) + drop_table(ComplexModelRouting) + + def test_routing_key_is_ignored(self): + """ + Compares the routing key generated by simple partition key using the model with the one generated by the equivalent + bound statement. It also verifies basic operations work with no routing key + @since 3.2 + @jira_ticket PYTHON-505 + @expected_result they shouldn't match + + @test_category object_mapper + """ + + prepared = self.session.prepare( + """ + INSERT INTO {0}.basic_model_no_routing (k, v) VALUES (?, ?) + """.format(DEFAULT_KEYSPACE)) + bound = prepared.bind((1, 2)) + + mrk = BasicModelNoRouting._routing_key_from_values([1], self.session.cluster.protocol_version) + simple = SimpleStatement("") + simple.routing_key = mrk + self.assertNotEqual(bound.routing_key, simple.routing_key) + + # Verify that basic create, update and delete work with no routing key + t = BasicModelNoRouting.create(k=2, v=3) + t.update(v=4).save() + f = BasicModelNoRouting.objects.filter(k=2).first() + self.assertEqual(t, f) + + t.delete() + self.assertEqual(BasicModelNoRouting.objects.count(), 0) + + + def test_routing_key_generation_basic(self): + """ + Compares the routing key generated by simple partition key using the model with the one generated by the equivalent + bound statement + @since 3.2 + @jira_ticket PYTHON-535 + @expected_result they should match + + @test_category object_mapper + """ + + prepared = self.session.prepare( + """ + INSERT INTO {0}.basic_model_routing (k, v) VALUES (?, ?) + """.format(DEFAULT_KEYSPACE)) + bound = prepared.bind((1, 2)) + + mrk = BasicModel._routing_key_from_values([1], self.session.cluster.protocol_version) + simple = SimpleStatement("") + simple.routing_key = mrk + self.assertEqual(bound.routing_key, simple.routing_key) + + def test_routing_key_generation_multi(self): + """ + Compares the routing key generated by composite partition key using the model with the one generated by the equivalent + bound statement + @since 3.2 + @jira_ticket PYTHON-535 + @expected_result they should match + + @test_category object_mapper + """ + + prepared = self.session.prepare( + """ + INSERT INTO {0}.basic_model_routing_multi (k, v) VALUES (?, ?) + """.format(DEFAULT_KEYSPACE)) + bound = prepared.bind((1, 2)) + mrk = BasicModelMulti._routing_key_from_values([1, 2], self.session.cluster.protocol_version) + simple = SimpleStatement("") + simple.routing_key = mrk + self.assertEqual(bound.routing_key, simple.routing_key) + + def test_routing_key_generation_complex(self): + """ + Compares the routing key generated by complex composite partition key using the model with the one generated by the equivalent + bound statement + @since 3.2 + @jira_ticket PYTHON-535 + @expected_result they should match + + @test_category object_mapper + """ + prepared = self.session.prepare( + """ + INSERT INTO {0}.complex_model_routing (partition, cluster, count, text, float, text_2) VALUES (?, ?, ?, ?, ?, ?) + """.format(DEFAULT_KEYSPACE)) + partition = uuid4() + cluster = 1 + count = 2 + text = "text" + float = 1.2 + text_2 = "text_2" + bound = prepared.bind((partition, cluster, count, text, float, text_2)) + mrk = ComplexModelRouting._routing_key_from_values([partition, cluster, text, float], self.session.cluster.protocol_version) + simple = SimpleStatement("") + simple.routing_key = mrk + self.assertEqual(bound.routing_key, simple.routing_key) + + def test_partition_key_index(self): + """ + Test to ensure that statement partition key generation is in the correct order + @since 3.2 + @jira_ticket PYTHON-535 + @expected_result . + + @test_category object_mapper + """ + self._check_partition_value_generation(BasicModel, SelectStatement(BasicModel.__table_name__)) + self._check_partition_value_generation(BasicModel, DeleteStatement(BasicModel.__table_name__)) + self._check_partition_value_generation(BasicModelMulti, SelectStatement(BasicModelMulti.__table_name__)) + self._check_partition_value_generation(BasicModelMulti, DeleteStatement(BasicModelMulti.__table_name__)) + self._check_partition_value_generation(ComplexModelRouting, SelectStatement(ComplexModelRouting.__table_name__)) + self._check_partition_value_generation(ComplexModelRouting, DeleteStatement(ComplexModelRouting.__table_name__)) + self._check_partition_value_generation(BasicModel, SelectStatement(BasicModel.__table_name__), reverse=True) + self._check_partition_value_generation(BasicModel, DeleteStatement(BasicModel.__table_name__), reverse=True) + self._check_partition_value_generation(BasicModelMulti, SelectStatement(BasicModelMulti.__table_name__), reverse=True) + self._check_partition_value_generation(BasicModelMulti, DeleteStatement(BasicModelMulti.__table_name__), reverse=True) + self._check_partition_value_generation(ComplexModelRouting, SelectStatement(ComplexModelRouting.__table_name__), reverse=True) + self._check_partition_value_generation(ComplexModelRouting, DeleteStatement(ComplexModelRouting.__table_name__), reverse=True) + + def _check_partition_value_generation(self, model, state, reverse=False): + """ + This generates a some statements based on the partition_key_index of the model. + It then validates that order of the partition key values in the statement matches the index + specified in the models partition_key_index + """ + # Setup some unique values for statement generation + uuid = uuid4() + values = {'k': 5, 'v': 3, 'partition': uuid, 'cluster': 6, 'count': 42, 'text': 'text', 'float': 3.1415, 'text_2': 'text_2'} + res = dict((v, k) for k, v in values.items()) + items = list(model._partition_key_index.items()) + if(reverse): + items.reverse() + # Add where clauses for each partition key + for partition_key, position in items: + wc = WhereClause(partition_key, EqualsOperator(), values.get(partition_key)) + state._add_where_clause(wc) + + # Iterate over the partition key values check to see that their index matches + # Those specified in the models partition field + for indx, value in enumerate(state.partition_key_values(model._partition_key_index)): + name = res.get(value) + self.assertEqual(indx, model._partition_key_index.get(name)) + + +def test_none_filter_fails(): + class NoneFilterModel(Model): + + pk = columns.Integer(primary_key=True) + v = columns.Integer() + sync_table(NoneFilterModel) + + try: + NoneFilterModel.objects(pk=None) + raise Exception("fail") + except CQLEngineException as e: + pass diff --git a/tests/integration/cqlengine/model/test_polymorphism.py b/tests/integration/cqlengine/model/test_polymorphism.py new file mode 100644 index 0000000..e78fef4 --- /dev/null +++ b/tests/integration/cqlengine/model/test_polymorphism.py @@ -0,0 +1,255 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +import mock + +from cassandra.cqlengine import columns +from cassandra.cqlengine import models +from cassandra.cqlengine.connection import get_session +from tests.integration.cqlengine.base import BaseCassEngTestCase +from cassandra.cqlengine import management + + +class TestInheritanceClassConstruction(BaseCassEngTestCase): + + def test_multiple_discriminator_value_failure(self): + """ Tests that defining a model with more than one discriminator column fails """ + with self.assertRaises(models.ModelDefinitionException): + class M(models.Model): + partition = columns.Integer(primary_key=True) + type1 = columns.Integer(discriminator_column=True) + type2 = columns.Integer(discriminator_column=True) + + def test_no_discriminator_column_failure(self): + with self.assertRaises(models.ModelDefinitionException): + class M(models.Model): + __discriminator_value__ = 1 + + def test_discriminator_value_inheritance(self): + """ Tests that discriminator_column attribute is not inherited """ + class Base(models.Model): + + partition = columns.Integer(primary_key=True) + type1 = columns.Integer(discriminator_column=True) + + class M1(Base): + __discriminator_value__ = 1 + + class M2(M1): + pass + + assert M2.__discriminator_value__ is None + + def test_inheritance_metaclass(self): + """ Tests that the model meta class configures inherited models properly """ + class Base(models.Model): + + partition = columns.Integer(primary_key=True) + type1 = columns.Integer(discriminator_column=True) + + class M1(Base): + __discriminator_value__ = 1 + + assert Base._is_polymorphic + assert M1._is_polymorphic + + assert Base._is_polymorphic_base + assert not M1._is_polymorphic_base + + assert Base._discriminator_column is Base._columns['type1'] + assert M1._discriminator_column is M1._columns['type1'] + + assert Base._discriminator_column_name == 'type1' + assert M1._discriminator_column_name == 'type1' + + def test_table_names_are_inherited_from_base(self): + class Base(models.Model): + + partition = columns.Integer(primary_key=True) + type1 = columns.Integer(discriminator_column=True) + + class M1(Base): + __discriminator_value__ = 1 + + assert Base.column_family_name() == M1.column_family_name() + + def test_collection_columns_cant_be_discriminator_column(self): + with self.assertRaises(models.ModelDefinitionException): + class Base(models.Model): + + partition = columns.Integer(primary_key=True) + type1 = columns.Set(columns.Integer, discriminator_column=True) + + +class InheritBase(models.Model): + + partition = columns.UUID(primary_key=True, default=uuid.uuid4) + row_type = columns.Integer(discriminator_column=True) + + +class Inherit1(InheritBase): + __discriminator_value__ = 1 + data1 = columns.Text() + + +class Inherit2(InheritBase): + __discriminator_value__ = 2 + data2 = columns.Text() + + +class TestInheritanceModel(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestInheritanceModel, cls).setUpClass() + management.sync_table(Inherit1) + management.sync_table(Inherit2) + + @classmethod + def tearDownClass(cls): + super(TestInheritanceModel, cls).tearDownClass() + management.drop_table(Inherit1) + management.drop_table(Inherit2) + + def test_saving_base_model_fails(self): + with self.assertRaises(models.PolymorphicModelException): + InheritBase.create() + + def test_saving_subclass_saves_disc_value(self): + p1 = Inherit1.create(data1='pickle') + p2 = Inherit2.create(data2='bacon') + + assert p1.row_type == Inherit1.__discriminator_value__ + assert p2.row_type == Inherit2.__discriminator_value__ + + def test_query_deserialization(self): + p1 = Inherit1.create(data1='pickle') + p2 = Inherit2.create(data2='bacon') + + p1r = InheritBase.get(partition=p1.partition) + p2r = InheritBase.get(partition=p2.partition) + + assert isinstance(p1r, Inherit1) + assert isinstance(p2r, Inherit2) + + def test_delete_on_subclass_does_not_include_disc_value(self): + p1 = Inherit1.create() + session = get_session() + with mock.patch.object(session, 'execute') as m: + Inherit1.objects(partition=p1.partition).delete() + + # make sure our discriminator value isn't in the CQL + # not sure how we would even get here if it was in there + # since the CQL would fail. + + self.assertNotIn("row_type", m.call_args[0][0].query_string) + + +class UnindexedInheritBase(models.Model): + + partition = columns.UUID(primary_key=True, default=uuid.uuid4) + cluster = columns.UUID(primary_key=True, default=uuid.uuid4) + row_type = columns.Integer(discriminator_column=True) + + +class UnindexedInherit1(UnindexedInheritBase): + __discriminator_value__ = 1 + data1 = columns.Text() + + +class UnindexedInherit2(UnindexedInheritBase): + __discriminator_value__ = 2 + data2 = columns.Text() + + +class UnindexedInherit3(UnindexedInherit2): + __discriminator_value__ = 3 + data3 = columns.Text() + + +class TestUnindexedInheritanceQuery(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestUnindexedInheritanceQuery, cls).setUpClass() + management.sync_table(UnindexedInherit1) + management.sync_table(UnindexedInherit2) + management.sync_table(UnindexedInherit3) + + cls.p1 = UnindexedInherit1.create(data1='pickle') + cls.p2 = UnindexedInherit2.create(partition=cls.p1.partition, data2='bacon') + cls.p3 = UnindexedInherit3.create(partition=cls.p1.partition, data3='turkey') + + @classmethod + def tearDownClass(cls): + super(TestUnindexedInheritanceQuery, cls).tearDownClass() + management.drop_table(UnindexedInherit1) + management.drop_table(UnindexedInherit2) + management.drop_table(UnindexedInherit3) + + def test_non_conflicting_type_results_work(self): + p1, p2, p3 = self.p1, self.p2, self.p3 + assert len(list(UnindexedInherit1.objects(partition=p1.partition, cluster=p1.cluster))) == 1 + assert len(list(UnindexedInherit2.objects(partition=p1.partition, cluster=p2.cluster))) == 1 + assert len(list(UnindexedInherit3.objects(partition=p1.partition, cluster=p3.cluster))) == 1 + + def test_subclassed_model_results_work_properly(self): + p1, p2, p3 = self.p1, self.p2, self.p3 + assert len(list(UnindexedInherit2.objects(partition=p1.partition, cluster__in=[p2.cluster, p3.cluster]))) == 2 + + def test_conflicting_type_results(self): + with self.assertRaises(models.PolymorphicModelException): + list(UnindexedInherit1.objects(partition=self.p1.partition)) + with self.assertRaises(models.PolymorphicModelException): + list(UnindexedInherit2.objects(partition=self.p1.partition)) + + +class IndexedInheritBase(models.Model): + + partition = columns.UUID(primary_key=True, default=uuid.uuid4) + cluster = columns.UUID(primary_key=True, default=uuid.uuid4) + row_type = columns.Integer(discriminator_column=True, index=True) + + +class IndexedInherit1(IndexedInheritBase): + __discriminator_value__ = 1 + data1 = columns.Text() + + +class IndexedInherit2(IndexedInheritBase): + __discriminator_value__ = 2 + data2 = columns.Text() + + +class TestIndexedInheritanceQuery(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestIndexedInheritanceQuery, cls).setUpClass() + management.sync_table(IndexedInherit1) + management.sync_table(IndexedInherit2) + + cls.p1 = IndexedInherit1.create(data1='pickle') + cls.p2 = IndexedInherit2.create(partition=cls.p1.partition, data2='bacon') + + @classmethod + def tearDownClass(cls): + super(TestIndexedInheritanceQuery, cls).tearDownClass() + management.drop_table(IndexedInherit1) + management.drop_table(IndexedInherit2) + + def test_success_case(self): + self.assertEqual(len(list(IndexedInherit1.objects(partition=self.p1.partition))), 1) + self.assertEqual(len(list(IndexedInherit2.objects(partition=self.p1.partition))), 1) diff --git a/tests/integration/cqlengine/model/test_udts.py b/tests/integration/cqlengine/model/test_udts.py new file mode 100644 index 0000000..8297343 --- /dev/null +++ b/tests/integration/cqlengine/model/test_udts.py @@ -0,0 +1,591 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from datetime import datetime, date, time +from decimal import Decimal +from mock import Mock +from uuid import UUID, uuid4 + +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.usertype import UserType, UserTypeDefinitionException +from cassandra.cqlengine import columns, connection +from cassandra.cqlengine.management import sync_table, drop_table, sync_type, create_keyspace_simple, drop_keyspace +from cassandra.cqlengine import ValidationError +from cassandra.util import Date, Time + +from tests.integration import PROTOCOL_VERSION +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine import DEFAULT_KEYSPACE + + +class User(UserType): + age = columns.Integer() + name = columns.Text() + + +class UserModel(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(User) + + +class AllDatatypes(UserType): + a = columns.Ascii() + b = columns.BigInt() + c = columns.Blob() + d = columns.Boolean() + e = columns.DateTime() + f = columns.Decimal() + g = columns.Double() + h = columns.Float() + i = columns.Inet() + j = columns.Integer() + k = columns.Text() + l = columns.TimeUUID() + m = columns.UUID() + n = columns.VarInt() + + +class AllDatatypesModel(Model): + id = columns.Integer(primary_key=True) + data = columns.UserDefinedType(AllDatatypes) + + +class UserDefinedTypeTests(BaseCassEngTestCase): + + def setUp(self): + if PROTOCOL_VERSION < 3: + raise unittest.SkipTest("UDTs require native protocol 3+, currently using: {0}".format(PROTOCOL_VERSION)) + + def test_can_create_udts(self): + class User(UserType): + age = columns.Integer() + name = columns.Text() + + sync_type(DEFAULT_KEYSPACE, User) + user = User(age=42, name="John") + self.assertEqual(42, user.age) + self.assertEqual("John", user.name) + + # Add a field + class User(UserType): + age = columns.Integer() + name = columns.Text() + gender = columns.Text() + + sync_type(DEFAULT_KEYSPACE, User) + user = User(age=42) + user["name"] = "John" + user["gender"] = "male" + self.assertEqual(42, user.age) + self.assertEqual("John", user.name) + self.assertEqual("male", user.gender) + + # Remove a field + class User(UserType): + age = columns.Integer() + name = columns.Text() + + sync_type(DEFAULT_KEYSPACE, User) + user = User(age=42, name="John", gender="male") + with self.assertRaises(AttributeError): + user.gender + + def test_can_insert_udts(self): + + sync_table(UserModel) + self.addCleanup(drop_table, UserModel) + + user = User(age=42, name="John") + UserModel.create(id=0, info=user) + + self.assertEqual(1, UserModel.objects.count()) + + john = UserModel.objects.first() + self.assertEqual(0, john.id) + self.assertTrue(type(john.info) is User) + self.assertEqual(42, john.info.age) + self.assertEqual("John", john.info.name) + + def test_can_update_udts(self): + sync_table(UserModel) + self.addCleanup(drop_table, UserModel) + + user = User(age=42, name="John") + created_user = UserModel.create(id=0, info=user) + + john_info = UserModel.objects.first().info + self.assertEqual(42, john_info.age) + self.assertEqual("John", john_info.name) + + created_user.info = User(age=22, name="Mary") + created_user.update() + + mary_info = UserModel.objects.first().info + self.assertEqual(22, mary_info["age"]) + self.assertEqual("Mary", mary_info["name"]) + + def test_can_update_udts_with_nones(self): + sync_table(UserModel) + self.addCleanup(drop_table, UserModel) + + user = User(age=42, name="John") + created_user = UserModel.create(id=0, info=user) + + john_info = UserModel.objects.first().info + self.assertEqual(42, john_info.age) + self.assertEqual("John", john_info.name) + + created_user.info = None + created_user.update() + + john_info = UserModel.objects.first().info + self.assertIsNone(john_info) + + def test_can_create_same_udt_different_keyspaces(self): + sync_type(DEFAULT_KEYSPACE, User) + + create_keyspace_simple("simplex", 1) + sync_type("simplex", User) + drop_keyspace("simplex") + + def test_can_insert_partial_udts(self): + class UserGender(UserType): + age = columns.Integer() + name = columns.Text() + gender = columns.Text() + + class UserModelGender(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(UserGender) + + sync_table(UserModelGender) + self.addCleanup(drop_table, UserModelGender) + + user = UserGender(age=42, name="John") + UserModelGender.create(id=0, info=user) + + john_info = UserModelGender.objects.first().info + self.assertEqual(42, john_info.age) + self.assertEqual("John", john_info.name) + self.assertIsNone(john_info.gender) + + user = UserGender(age=42) + UserModelGender.create(id=0, info=user) + + john_info = UserModelGender.objects.first().info + self.assertEqual(42, john_info.age) + self.assertIsNone(john_info.name) + self.assertIsNone(john_info.gender) + + def test_can_insert_nested_udts(self): + class Depth_0(UserType): + age = columns.Integer() + name = columns.Text() + + class Depth_1(UserType): + value = columns.UserDefinedType(Depth_0) + + class Depth_2(UserType): + value = columns.UserDefinedType(Depth_1) + + class Depth_3(UserType): + value = columns.UserDefinedType(Depth_2) + + class DepthModel(Model): + id = columns.Integer(primary_key=True) + v_0 = columns.UserDefinedType(Depth_0) + v_1 = columns.UserDefinedType(Depth_1) + v_2 = columns.UserDefinedType(Depth_2) + v_3 = columns.UserDefinedType(Depth_3) + + sync_table(DepthModel) + self.addCleanup(drop_table, DepthModel) + + udts = [Depth_0(age=42, name="John")] + udts.append(Depth_1(value=udts[0])) + udts.append(Depth_2(value=udts[1])) + udts.append(Depth_3(value=udts[2])) + + DepthModel.create(id=0, v_0=udts[0], v_1=udts[1], v_2=udts[2], v_3=udts[3]) + output = DepthModel.objects.first() + + self.assertEqual(udts[0], output.v_0) + self.assertEqual(udts[1], output.v_1) + self.assertEqual(udts[2], output.v_2) + self.assertEqual(udts[3], output.v_3) + + def test_can_insert_udts_with_nones(self): + """ + Test for inserting all column types as empty into a UserType as None's + + test_can_insert_udts_with_nones tests that each cqlengine column type can be inserted into a UserType as None's. + It first creates a UserType that has each cqlengine column type, and a corresponding table/Model. It then creates + a UserType instance where all the fields are None's and inserts the UserType as an instance of the Model. Finally, + it verifies that each column read from the UserType from Cassandra is None. + + @since 2.5.0 + @jira_ticket PYTHON-251 + @expected_result The UserType is inserted with each column type, and the resulting read yields None's for each column. + + @test_category data_types:udt + """ + sync_table(AllDatatypesModel) + self.addCleanup(drop_table, AllDatatypesModel) + + input = AllDatatypes(a=None, b=None, c=None, d=None, e=None, f=None, g=None, h=None, i=None, j=None, k=None, + l=None, m=None, n=None) + AllDatatypesModel.create(id=0, data=input) + + self.assertEqual(1, AllDatatypesModel.objects.count()) + + output = AllDatatypesModel.objects.first().data + self.assertEqual(input, output) + + def test_can_insert_udts_with_all_datatypes(self): + """ + Test for inserting all column types into a UserType + + test_can_insert_udts_with_all_datatypes tests that each cqlengine column type can be inserted into a UserType. + It first creates a UserType that has each cqlengine column type, and a corresponding table/Model. It then creates + a UserType instance where all the fields have corresponding data, and inserts the UserType as an instance of the Model. + Finally, it verifies that each column read from the UserType from Cassandra is the same as the input parameters. + + @since 2.5.0 + @jira_ticket PYTHON-251 + @expected_result The UserType is inserted with each column type, and the resulting read yields proper data for each column. + + @test_category data_types:udt + """ + sync_table(AllDatatypesModel) + self.addCleanup(drop_table, AllDatatypesModel) + + input = AllDatatypes(a='ascii', b=2 ** 63 - 1, c=bytearray(b'hello world'), d=True, + e=datetime.utcfromtimestamp(872835240), f=Decimal('12.3E+7'), g=2.39, + h=3.4028234663852886e+38, i='123.123.123.123', j=2147483647, k='text', + l=UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66'), + m=UUID('067e6162-3b6f-4ae2-a171-2470b63dff00'), n=int(str(2147483647) + '000')) + AllDatatypesModel.create(id=0, data=input) + + self.assertEqual(1, AllDatatypesModel.objects.count()) + output = AllDatatypesModel.objects.first().data + + for i in range(ord('a'), ord('a') + 14): + self.assertEqual(input[chr(i)], output[chr(i)]) + + def test_can_insert_udts_protocol_v4_datatypes(self): + """ + Test for inserting all protocol v4 column types into a UserType + + test_can_insert_udts_protocol_v4_datatypes tests that each protocol v4 cqlengine column type can be inserted + into a UserType. It first creates a UserType that has each protocol v4 cqlengine column type, and a corresponding + table/Model. It then creates a UserType instance where all the fields have corresponding data, and inserts the + UserType as an instance of the Model. Finally, it verifies that each column read from the UserType from Cassandra + is the same as the input parameters. + + @since 2.6.0 + @jira_ticket PYTHON-245 + @expected_result The UserType is inserted with each protocol v4 column type, and the resulting read yields proper data for each column. + + @test_category data_types:udt + """ + + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Protocol v4 datatypes in UDTs require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) + + class Allv4Datatypes(UserType): + a = columns.Date() + b = columns.SmallInt() + c = columns.Time() + d = columns.TinyInt() + + class Allv4DatatypesModel(Model): + id = columns.Integer(primary_key=True) + data = columns.UserDefinedType(Allv4Datatypes) + + sync_table(Allv4DatatypesModel) + self.addCleanup(drop_table, Allv4DatatypesModel) + + input = Allv4Datatypes(a=Date(date(1970, 1, 1)), b=32523, c=Time(time(16, 47, 25, 7)), d=123) + Allv4DatatypesModel.create(id=0, data=input) + + self.assertEqual(1, Allv4DatatypesModel.objects.count()) + output = Allv4DatatypesModel.objects.first().data + + for i in range(ord('a'), ord('a') + 3): + self.assertEqual(input[chr(i)], output[chr(i)]) + + def test_nested_udts_inserts(self): + """ + Test for inserting collections of user types using cql engine. + + test_nested_udts_inserts Constructs a model that contains a list of usertypes. It will then attempt to insert + them. The expectation is that no exception is thrown during insert. For sanity sake we also validate that our + input and output values match. This combination of model, and UT produces a syntax error in 2.5.1 due to + improper quoting around the names collection. + + @since 2.6.0 + @jira_ticket PYTHON-311 + @expected_result No syntax exception thrown + + @test_category data_types:udt + """ + + class Name(UserType): + type_name__ = "header" + + name = columns.Text() + value = columns.Text() + + class Container(Model): + id = columns.UUID(primary_key=True, default=uuid4) + names = columns.List(columns.UserDefinedType(Name)) + + # Construct the objects and insert them + names = [] + for i in range(0, 10): + names.append(Name(name="name{0}".format(i), value="value{0}".format(i))) + + # Create table, insert data + sync_table(Container) + self.addCleanup(drop_table, Container) + + Container.create(id=UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66'), names=names) + + # Validate input and output matches + self.assertEqual(1, Container.objects.count()) + names_output = Container.objects.first().names + self.assertEqual(names_output, names) + + def test_udts_with_unicode(self): + """ + Test for inserting models with unicode and udt columns. + + test_udts_with_unicode constructs a model with a user defined type. It then attempts to insert that model with + a unicode primary key. It will also attempt to upsert a udt that contains unicode text. + + @since 3.0.0 + @jira_ticket PYTHON-353 + @expected_result No exceptions thrown + + @test_category data_types:udt + """ + ascii_name = 'normal name' + unicode_name = u'Fran\u00E7ois' + + class UserModelText(Model): + id = columns.Text(primary_key=True) + info = columns.UserDefinedType(User) + + sync_table(UserModelText) + self.addCleanup(drop_table, UserModelText) + + # Two udt instances one with a unicode one with ascii + user_template_ascii = User(age=25, name=ascii_name) + user_template_unicode = User(age=25, name=unicode_name) + + UserModelText.create(id=ascii_name, info=user_template_unicode) + UserModelText.create(id=unicode_name, info=user_template_ascii) + UserModelText.create(id=unicode_name, info=user_template_unicode) + + def test_register_default_keyspace(self): + + from cassandra.cqlengine import models + from cassandra.cqlengine import connection + + # None emulating no model and no default keyspace before connecting + connection.udt_by_keyspace.clear() + User.register_for_keyspace(None) + self.assertEqual(len(connection.udt_by_keyspace), 1) + self.assertIn(None, connection.udt_by_keyspace) + + # register should be with default keyspace, not None + cluster = Mock() + connection._register_known_types(cluster) + cluster.register_user_type.assert_called_with(models.DEFAULT_KEYSPACE, User.type_name(), User) + + def test_db_field_override(self): + """ + Tests for db_field override + + Tests to ensure that udt's in models can specify db_field for a particular field and that it will be honored. + + @since 3.1.0 + @jira_ticket PYTHON-346 + @expected_result The actual cassandra column will use the db_field specified. + + @test_category data_types:udt + """ + class db_field_different(UserType): + age = columns.Integer(db_field='a') + name = columns.Text(db_field='n') + + class TheModel(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(db_field_different) + + sync_table(TheModel) + self.addCleanup(drop_table, TheModel) + + cluster = connection.get_cluster() + type_meta = cluster.metadata.keyspaces[TheModel._get_keyspace()].user_types[db_field_different.type_name()] + + type_fields = (db_field_different.age.column, db_field_different.name.column) + + self.assertEqual(len(type_meta.field_names), len(type_fields)) + for f in type_fields: + self.assertIn(f.db_field_name, type_meta.field_names) + + id = 0 + age = 42 + name = 'John' + info = db_field_different(age=age, name=name) + TheModel.create(id=id, info=info) + + self.assertEqual(1, TheModel.objects.count()) + + john = TheModel.objects.first() + self.assertEqual(john.id, id) + info = john.info + self.assertIsInstance(info, db_field_different) + self.assertEqual(info.age, age) + self.assertEqual(info.name, name) + # also excercise the db_Field mapping + self.assertEqual(info.a, age) + self.assertEqual(info.n, name) + + def test_db_field_overload(self): + """ + Tests for db_field UserTypeDefinitionException + + Test so that when we override a model's default field witha db_field that it errors appropriately + + @since 3.1.0 + @jira_ticket PYTHON-346 + @expected_result Setting a db_field to an existing field causes an exception to occur. + + @test_category data_types:udt + """ + + with self.assertRaises(UserTypeDefinitionException): + class something_silly(UserType): + first_col = columns.Integer() + second_col = columns.Text(db_field='first_col') + + with self.assertRaises(UserTypeDefinitionException): + class something_silly_2(UserType): + first_col = columns.Integer(db_field="second_col") + second_col = columns.Text() + + def test_set_udt_fields(self): + # PYTHON-502 + + u = User() + u.age = 20 + self.assertEqual(20, u.age) + + def test_default_values(self): + """ + Test that default types are set on object creation for UDTs + + @since 3.7.0 + @jira_ticket PYTHON-606 + @expected_result Default values should be set. + + @test_category data_types:udt + """ + + class NestedUdt(UserType): + + test_id = columns.UUID(default=uuid4) + something = columns.Text() + default_text = columns.Text(default="default text") + + class OuterModel(Model): + + name = columns.Text(primary_key=True) + first_name = columns.Text() + nested = columns.List(columns.UserDefinedType(NestedUdt)) + simple = columns.UserDefinedType(NestedUdt) + + sync_table(OuterModel) + self.addCleanup(drop_table, OuterModel) + + t = OuterModel.create(name='test1') + t.nested = [NestedUdt(something='test')] + t.simple = NestedUdt(something="") + t.save() + self.assertIsNotNone(t.nested[0].test_id) + self.assertEqual(t.nested[0].default_text, "default text") + self.assertIsNotNone(t.simple.test_id) + self.assertEqual(t.simple.default_text, "default text") + + def test_udt_validate(self): + """ + Test to verify restrictions are honored and that validate is called + for each member of the UDT when an updated is attempted + + @since 3.10 + @jira_ticket PYTHON-505 + @expected_result a validation error is arisen due to the name being + too long + + @test_category data_types:object_mapper + """ + class UserValidate(UserType): + age = columns.Integer() + name = columns.Text(max_length=2) + + class UserModelValidate(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(UserValidate) + + sync_table(UserModelValidate) + self.addCleanup(drop_table, UserModelValidate) + + user = UserValidate(age=1, name="Robert") + item = UserModelValidate(id=1, info=user) + with self.assertRaises(ValidationError): + item.save() + + def test_udt_validate_with_default(self): + """ + Test to verify restrictions are honored and that validate is called + on the default value + + @since 3.10 + @jira_ticket PYTHON-505 + @expected_result a validation error is arisen due to the name being + too long + + @test_category data_types:object_mapper + """ + class UserValidateDefault(UserType): + age = columns.Integer() + name = columns.Text(max_length=2, default="Robert") + + class UserModelValidateDefault(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(UserValidateDefault) + + sync_table(UserModelValidateDefault) + self.addCleanup(drop_table, UserModelValidateDefault) + + user = UserValidateDefault(age=1) + item = UserModelValidateDefault(id=1, info=user) + with self.assertRaises(ValidationError): + item.save() diff --git a/tests/integration/cqlengine/model/test_updates.py b/tests/integration/cqlengine/model/test_updates.py new file mode 100644 index 0000000..17eed8d --- /dev/null +++ b/tests/integration/cqlengine/model/test_updates.py @@ -0,0 +1,373 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import uuid4 + +from mock import patch +from cassandra.cqlengine import ValidationError + +from tests.integration import greaterthancass21 +from tests.integration.cqlengine.base import BaseCassEngTestCase +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.usertype import UserType + +class TestUpdateModel(Model): + + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.UUID(primary_key=True, default=uuid4) + count = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + + +class ModelUpdateTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(ModelUpdateTests, cls).setUpClass() + sync_table(TestUpdateModel) + + @classmethod + def tearDownClass(cls): + super(ModelUpdateTests, cls).tearDownClass() + drop_table(TestUpdateModel) + + def test_update_model(self): + """ tests calling udpate on models with no values passed in """ + m0 = TestUpdateModel.create(count=5, text='monkey') + + # independently save over a new count value, unknown to original instance + m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + m1.count = 6 + m1.save() + + # update the text, and call update + m0.text = 'monkey land' + m0.update() + + # database should reflect both updates + m2 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + self.assertEqual(m2.count, m1.count) + self.assertEqual(m2.text, m0.text) + + #This shouldn't raise a Validation error as the PR is not changing + m0.update(partition=m0.partition, cluster=m0.cluster) + + #Assert a ValidationError is risen if the PR changes + with self.assertRaises(ValidationError): + m0.update(partition=m0.partition, cluster=20) + + # Assert a ValidationError is risen if the columns doesn't exist + with self.assertRaises(ValidationError): + m0.update(invalid_column=20) + + def test_update_values(self): + """ tests calling update on models with values passed in """ + m0 = TestUpdateModel.create(count=5, text='monkey') + + # independently save over a new count value, unknown to original instance + m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + m1.count = 6 + m1.save() + + # update the text, and call update + m0.update(text='monkey land') + self.assertEqual(m0.text, 'monkey land') + + # database should reflect both updates + m2 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + self.assertEqual(m2.count, m1.count) + self.assertEqual(m2.text, m0.text) + + def test_noop_model_direct_update(self): + """ Tests that calling update on a model with no changes will do nothing. """ + m0 = TestUpdateModel.create(count=5, text='monkey') + + with patch.object(self.session, 'execute') as execute: + m0.update() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m0.update(count=5) + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m0.update(partition=m0.partition) + + with patch.object(self.session, 'execute') as execute: + m0.update(cluster=m0.cluster) + + def test_noop_model_assignation_update(self): + """ Tests that assigning the same value on a model will do nothing. """ + # Create object and fetch it back to eliminate any hidden variable + # cache effect. + m0 = TestUpdateModel.create(count=5, text='monkey') + m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + + with patch.object(self.session, 'execute') as execute: + m1.save() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m1.count = 5 + m1.save() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m1.partition = m0.partition + m1.save() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m1.cluster = m0.cluster + m1.save() + assert execute.call_count == 0 + + def test_invalid_update_kwarg(self): + """ tests that passing in a kwarg to the update method that isn't a column will fail """ + m0 = TestUpdateModel.create(count=5, text='monkey') + with self.assertRaises(ValidationError): + m0.update(numbers=20) + + def test_primary_key_update_failure(self): + """ tests that attempting to update the value of a primary key will fail """ + m0 = TestUpdateModel.create(count=5, text='monkey') + with self.assertRaises(ValidationError): + m0.update(partition=uuid4()) + + +class UDT(UserType): + age = columns.Integer() + mf = columns.Map(columns.Integer, columns.Integer) + dummy_udt = columns.Integer(default=42) + time_col = columns.Time() + + +class ModelWithDefault(Model): + id = columns.Integer(primary_key=True) + mf = columns.Map(columns.Integer, columns.Integer) + dummy = columns.Integer(default=42) + udt = columns.UserDefinedType(UDT) + udt_default = columns.UserDefinedType(UDT, default=UDT(age=1, mf={2:2})) + + +class UDTWithDefault(UserType): + age = columns.Integer() + mf = columns.Map(columns.Integer, columns.Integer, default={2:2}) + dummy_udt = columns.Integer(default=42) + + +class ModelWithDefaultCollection(Model): + id = columns.Integer(primary_key=True) + mf = columns.Map(columns.Integer, columns.Integer, default={2:2}) + dummy = columns.Integer(default=42) + udt = columns.UserDefinedType(UDT) + udt_default = columns.UserDefinedType(UDT, default=UDT(age=1, mf={2: 2})) + +@greaterthancass21 +class ModelWithDefaultTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + cls.udt_default = UDT(age=1, mf={2:2}, dummy_udt=42) + + def setUp(self): + sync_table(ModelWithDefault) + sync_table(ModelWithDefaultCollection) + + def tearDown(self): + drop_table(ModelWithDefault) + drop_table(ModelWithDefaultCollection) + + def test_value_override_with_default(self): + """ + Updating a row with a new Model instance shouldn't set columns to defaults + + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result column value should not change + + @test_category object_mapper + """ + first_udt = UDT(age=1, mf={2:2}, dummy_udt=0) + initial = ModelWithDefault(id=1, mf={0: 0}, dummy=0, udt=first_udt, udt_default=first_udt) + initial.save() + + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': 0, 'mf': {0: 0}, "udt": first_udt, "udt_default": first_udt}) + + second_udt = UDT(age=1, mf={3: 3}, dummy_udt=12) + second = ModelWithDefault(id=1) + second.update(mf={0: 1}, udt=second_udt) + + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': 0, 'mf': {0: 1}, "udt": second_udt, "udt_default": first_udt}) + + def test_value_is_written_if_is_default(self): + """ + Check if the we try to update with the default value, the update + happens correctly + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result column value should be updated + :return: + """ + initial = ModelWithDefault(id=1) + initial.mf = {0: 0} + initial.dummy = 42 + initial.udt_default = self.udt_default + initial.update() + + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': 42, 'mf': {0: 0}, "udt": None, "udt_default": self.udt_default}) + + def test_null_update_is_respected(self): + """ + Check if the we try to update with None under particular + circumstances, it works correctly + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result column value should be updated to None + + @test_category object_mapper + :return: + """ + ModelWithDefault.create(id=1, mf={0: 0}).save() + + q = ModelWithDefault.objects.all().allow_filtering() + obj = q.filter(id=1).get() + + updated_udt = UDT(age=1, mf={2:2}, dummy_udt=None) + obj.update(dummy=None, udt_default=updated_udt) + + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': None, 'mf': {0: 0}, "udt": None, "udt_default": updated_udt}) + + def test_only_set_values_is_updated(self): + """ + Test the updates work as expected when an object is deleted + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result the non updated column is None and the + updated column has the set value + + @test_category object_mapper + """ + + ModelWithDefault.create(id=1, mf={1: 1}, dummy=1).save() + + item = ModelWithDefault.filter(id=1).first() + ModelWithDefault.objects(id=1).delete() + item.mf = {1: 2} + udt, udt_default = UDT(age=1, mf={2:3}), UDT(age=1, mf={2:3}) + item.udt, item.udt_default = udt, udt_default + item.save() + + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': None, 'mf': {1: 2}, "udt": udt, "udt_default": udt_default}) + + def test_collections(self): + """ + Test the updates work as expected on Map objects + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result the row is updated when the Map object is + reduced + + @test_category object_mapper + """ + udt, udt_default = UDT(age=1, mf={1: 1, 2: 1}), UDT(age=1, mf={1: 1, 2: 1}) + + ModelWithDefault.create(id=1, mf={1: 1, 2: 1}, dummy=1, udt=udt, udt_default=udt_default).save() + item = ModelWithDefault.filter(id=1).first() + + udt, udt_default = UDT(age=1, mf={2: 1}), UDT(age=1, mf={2: 1}) + item.update(mf={2:1}, udt=udt, udt_default=udt_default) + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': 1, 'mf': {2: 1}, "udt": udt, "udt_default": udt_default}) + + def test_collection_with_default(self): + """ + Test the updates work as expected when an object is deleted + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result the non updated column is None and the + updated column has the set value + + @test_category object_mapper + """ + sync_table(ModelWithDefaultCollection) + + udt, udt_default = UDT(age=1, mf={6: 6}), UDT(age=1, mf={6: 6}) + + item = ModelWithDefaultCollection.create(id=1, mf={1: 1}, dummy=1, udt=udt, udt_default=udt_default).save() + self.assertEqual(ModelWithDefaultCollection.objects.get(id=1)._as_dict(), + {'id': 1, 'dummy': 1, 'mf': {1: 1}, "udt": udt, "udt_default": udt_default}) + + udt, udt_default = UDT(age=1, mf={5: 5}), UDT(age=1, mf={5: 5}) + item.update(mf={2: 2}, udt=udt, udt_default=udt_default) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=1)._as_dict(), + {'id': 1, 'dummy': 1, 'mf': {2: 2}, "udt": udt, "udt_default": udt_default}) + + udt, udt_default = UDT(age=1, mf=None), UDT(age=1, mf=None) + expected_udt, expected_udt_default = UDT(age=1, mf={}), UDT(age=1, mf={}) + item.update(mf=None, udt=udt, udt_default=udt_default) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=1)._as_dict(), + {'id': 1, 'dummy': 1, 'mf': {}, "udt": expected_udt, "udt_default": expected_udt_default}) + + udt_default = UDT(age=1, mf={2:2}, dummy_udt=42) + item = ModelWithDefaultCollection.create(id=2, dummy=2) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), + {'id': 2, 'dummy': 2, 'mf': {2: 2}, "udt": None, "udt_default": udt_default}) + + udt, udt_default = UDT(age=1, mf={1: 1, 6: 6}), UDT(age=1, mf={1: 1, 6: 6}) + item.update(mf={1: 1, 4: 4}, udt=udt, udt_default=udt_default) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), + {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default}) + + item.update(udt_default=None) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), + {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": None}) + + udt_default = UDT(age=1, mf={2:2}) + item.update(udt_default=udt_default) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), + {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default}) + + + def test_udt_to_python(self): + """ + Test the to_python and to_database are correctly called on UDTs + @since 3.10 + @jira_ticket PYTHON-743 + @expected_result the int value is correctly converted to utils.Time + and written to C* + + @test_category object_mapper + """ + item = ModelWithDefault(id=1) + item.save() + + # We update time_col this way because we want to hit + # the to_python method from UserDefinedType, otherwise to_python + # would be called in UDT.__init__ + user_to_update = UDT() + user_to_update.time_col = 10 + + item.update(udt=user_to_update) + + udt, udt_default = UDT(time_col=10), UDT(age=1, mf={2:2}) + self.assertEqual(ModelWithDefault.objects.get(id=1)._as_dict(), + {'id': 1, 'dummy': 42, 'mf': {}, "udt": udt, "udt_default": udt_default}) diff --git a/tests/integration/cqlengine/model/test_value_lists.py b/tests/integration/cqlengine/model/test_value_lists.py new file mode 100644 index 0000000..0c91315 --- /dev/null +++ b/tests/integration/cqlengine/model/test_value_lists.py @@ -0,0 +1,75 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from tests.integration.cqlengine.base import BaseCassEngTestCase + +from cassandra.cqlengine.management import sync_table +from cassandra.cqlengine.management import drop_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns + + +class TestModel(Model): + + id = columns.Integer(primary_key=True) + clustering_key = columns.Integer(primary_key=True, clustering_order='desc') + +class TestClusteringComplexModel(Model): + + id = columns.Integer(primary_key=True) + clustering_key = columns.Integer(primary_key=True, clustering_order='desc') + some_value = columns.Integer() + +class TestClusteringOrder(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestClusteringOrder, cls).setUpClass() + sync_table(TestModel) + + @classmethod + def tearDownClass(cls): + super(TestClusteringOrder, cls).tearDownClass() + drop_table(TestModel) + + def test_clustering_order(self): + """ + Tests that models can be saved and retrieved + """ + items = list(range(20)) + random.shuffle(items) + for i in items: + TestModel.create(id=1, clustering_key=i) + + values = list(TestModel.objects.values_list('clustering_key', flat=True)) + # [19L, 18L, 17L, 16L, 15L, 14L, 13L, 12L, 11L, 10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L, 0L] + self.assertEqual(values, sorted(items, reverse=True)) + + def test_clustering_order_more_complex(self): + """ + Tests that models can be saved and retrieved + """ + sync_table(TestClusteringComplexModel) + + items = list(range(20)) + random.shuffle(items) + for i in items: + TestClusteringComplexModel.create(id=1, clustering_key=i, some_value=2) + + values = list(TestClusteringComplexModel.objects.values_list('some_value', flat=True)) + + self.assertEqual([2] * 20, values) + drop_table(TestClusteringComplexModel) + diff --git a/tests/integration/cqlengine/operators/__init__.py b/tests/integration/cqlengine/operators/__init__.py new file mode 100644 index 0000000..05a41c4 --- /dev/null +++ b/tests/integration/cqlengine/operators/__init__.py @@ -0,0 +1,20 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cqlengine.operators import BaseWhereOperator + + +def check_lookup(test_case, symbol, expected): + op = BaseWhereOperator.get_operator(symbol) + test_case.assertEqual(op, expected) diff --git a/tests/integration/cqlengine/operators/test_where_operators.py b/tests/integration/cqlengine/operators/test_where_operators.py new file mode 100644 index 0000000..fdfce1f --- /dev/null +++ b/tests/integration/cqlengine/operators/test_where_operators.py @@ -0,0 +1,114 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.cqlengine.operators import * + +from uuid import uuid4 + +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.operators import IsNotNullOperator +from cassandra.cqlengine.statements import IsNotNull +from cassandra import InvalidRequest + +from tests.integration.cqlengine.base import TestQueryUpdateModel, BaseCassEngTestCase +from tests.integration.cqlengine.operators import check_lookup +from tests.integration import greaterthanorequalcass30 + +import six + + +class TestWhereOperators(unittest.TestCase): + + def test_symbol_lookup(self): + """ tests where symbols are looked up properly """ + + check_lookup(self, 'EQ', EqualsOperator) + check_lookup(self, 'NE', NotEqualsOperator) + check_lookup(self, 'IN', InOperator) + check_lookup(self, 'GT', GreaterThanOperator) + check_lookup(self, 'GTE', GreaterThanOrEqualOperator) + check_lookup(self, 'LT', LessThanOperator) + check_lookup(self, 'LTE', LessThanOrEqualOperator) + check_lookup(self, 'CONTAINS', ContainsOperator) + check_lookup(self, 'LIKE', LikeOperator) + + def test_operator_rendering(self): + """ tests symbols are rendered properly """ + self.assertEqual("=", six.text_type(EqualsOperator())) + self.assertEqual("!=", six.text_type(NotEqualsOperator())) + self.assertEqual("IN", six.text_type(InOperator())) + self.assertEqual(">", six.text_type(GreaterThanOperator())) + self.assertEqual(">=", six.text_type(GreaterThanOrEqualOperator())) + self.assertEqual("<", six.text_type(LessThanOperator())) + self.assertEqual("<=", six.text_type(LessThanOrEqualOperator())) + self.assertEqual("CONTAINS", six.text_type(ContainsOperator())) + self.assertEqual("LIKE", six.text_type(LikeOperator())) + + +class TestIsNotNull(BaseCassEngTestCase): + def test_is_not_null_to_cql(self): + """ + Verify that IsNotNull is converted correctly to CQL + + @since 2.5 + @jira_ticket PYTHON-968 + @expected_result the strings match + + @test_category cqlengine + """ + + check_lookup(self, 'IS NOT NULL', IsNotNullOperator) + + # The * is not expanded because there are no referred fields + self.assertEqual( + str(TestQueryUpdateModel.filter(IsNotNull("text")).limit(2)), + 'SELECT * FROM cqlengine_test.test_query_update_model WHERE "text" IS NOT NULL LIMIT 2' + ) + + # We already know partition so cqlengine doesn't query for it + self.assertEqual( + str(TestQueryUpdateModel.filter(IsNotNull("text"), partition=uuid4())), + ('SELECT "cluster", "count", "text", "text_set", ' + '"text_list", "text_map" FROM cqlengine_test.test_query_update_model ' + 'WHERE "text" IS NOT NULL AND "partition" = %(0)s LIMIT 10000') + ) + + @greaterthanorequalcass30 + def test_is_not_null_execution(self): + """ + Verify that CQL statements have correct syntax when executed + If we wanted them to return something meaningful and not a InvalidRequest + we'd have to create an index in search for the column we are using + IsNotNull + + @since 2.5 + @jira_ticket PYTHON-968 + @expected_result InvalidRequest is arisen + + @test_category cqlengine + """ + sync_table(TestQueryUpdateModel) + self.addCleanup(drop_table, TestQueryUpdateModel) + + # Raises InvalidRequest instead of dse.protocol.SyntaxException + with self.assertRaises(InvalidRequest): + list(TestQueryUpdateModel.filter(IsNotNull("text"))) + + with self.assertRaises(InvalidRequest): + list(TestQueryUpdateModel.filter(IsNotNull("text"), partition=uuid4())) diff --git a/tests/integration/cqlengine/query/__init__.py b/tests/integration/cqlengine/query/__init__.py new file mode 100644 index 0000000..386372e --- /dev/null +++ b/tests/integration/cqlengine/query/__init__.py @@ -0,0 +1,14 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/integration/cqlengine/query/test_batch_query.py b/tests/integration/cqlengine/query/test_batch_query.py new file mode 100644 index 0000000..f0c9c43 --- /dev/null +++ b/tests/integration/cqlengine/query/test_batch_query.py @@ -0,0 +1,291 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock + +from cassandra.cqlengine import columns +from cassandra.cqlengine.connection import NOT_SET +from cassandra.cqlengine.management import drop_table, sync_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import BatchQuery, DMLQuery +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine import execute_count +from cassandra.cluster import Session +from cassandra.query import BatchType as cassandra_BatchType +from cassandra.cqlengine.query import BatchType as cqlengine_BatchType + + +class TestMultiKeyModel(Model): + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer(required=False) + text = columns.Text(required=False) + +class BatchQueryLogModel(Model): + + # simple k/v table + k = columns.Integer(primary_key=True) + v = columns.Integer() + + +class CounterBatchQueryModel(Model): + k = columns.Integer(primary_key=True) + v = columns.Counter() + +class BatchQueryTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BatchQueryTests, cls).setUpClass() + drop_table(TestMultiKeyModel) + sync_table(TestMultiKeyModel) + + @classmethod + def tearDownClass(cls): + super(BatchQueryTests, cls).tearDownClass() + drop_table(TestMultiKeyModel) + + def setUp(self): + super(BatchQueryTests, self).setUp() + self.pkey = 1 + for obj in TestMultiKeyModel.filter(partition=self.pkey): + obj.delete() + + @execute_count(3) + def test_insert_success_case(self): + + b = BatchQuery() + inst = TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=2, count=3, text='4') + + with self.assertRaises(TestMultiKeyModel.DoesNotExist): + TestMultiKeyModel.get(partition=self.pkey, cluster=2) + + b.execute() + + TestMultiKeyModel.get(partition=self.pkey, cluster=2) + + @execute_count(4) + def test_update_success_case(self): + + inst = TestMultiKeyModel.create(partition=self.pkey, cluster=2, count=3, text='4') + + b = BatchQuery() + + inst.count = 4 + inst.batch(b).save() + + inst2 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) + self.assertEqual(inst2.count, 3) + + b.execute() + + inst3 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) + self.assertEqual(inst3.count, 4) + + @execute_count(4) + def test_delete_success_case(self): + + inst = TestMultiKeyModel.create(partition=self.pkey, cluster=2, count=3, text='4') + + b = BatchQuery() + + inst.batch(b).delete() + + TestMultiKeyModel.get(partition=self.pkey, cluster=2) + + b.execute() + + with self.assertRaises(TestMultiKeyModel.DoesNotExist): + TestMultiKeyModel.get(partition=self.pkey, cluster=2) + + @execute_count(11) + def test_context_manager(self): + + with BatchQuery() as b: + for i in range(5): + TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=i, count=3, text='4') + + for i in range(5): + with self.assertRaises(TestMultiKeyModel.DoesNotExist): + TestMultiKeyModel.get(partition=self.pkey, cluster=i) + + for i in range(5): + TestMultiKeyModel.get(partition=self.pkey, cluster=i) + + @execute_count(9) + def test_bulk_delete_success_case(self): + + for i in range(1): + for j in range(5): + TestMultiKeyModel.create(partition=i, cluster=j, count=i*j, text='{0}:{1}'.format(i,j)) + + with BatchQuery() as b: + TestMultiKeyModel.objects.batch(b).filter(partition=0).delete() + self.assertEqual(TestMultiKeyModel.filter(partition=0).count(), 5) + + self.assertEqual(TestMultiKeyModel.filter(partition=0).count(), 0) + #cleanup + for m in TestMultiKeyModel.all(): + m.delete() + + @execute_count(0) + def test_none_success_case(self): + """ Tests that passing None into the batch call clears any batch object """ + b = BatchQuery() + + q = TestMultiKeyModel.objects.batch(b) + self.assertEqual(q._batch, b) + + q = q.batch(None) + self.assertIsNone(q._batch) + + @execute_count(0) + def test_dml_none_success_case(self): + """ Tests that passing None into the batch call clears any batch object """ + b = BatchQuery() + + q = DMLQuery(TestMultiKeyModel, batch=b) + self.assertEqual(q._batch, b) + + q.batch(None) + self.assertIsNone(q._batch) + + @execute_count(3) + def test_batch_execute_on_exception_succeeds(self): + # makes sure if execute_on_exception == True we still apply the batch + drop_table(BatchQueryLogModel) + sync_table(BatchQueryLogModel) + + obj = BatchQueryLogModel.objects(k=1) + self.assertEqual(0, len(obj)) + + try: + with BatchQuery(execute_on_exception=True) as b: + BatchQueryLogModel.batch(b).create(k=1, v=1) + raise Exception("Blah") + except: + pass + + obj = BatchQueryLogModel.objects(k=1) + # should be 1 because the batch should execute + self.assertEqual(1, len(obj)) + + @execute_count(2) + def test_batch_execute_on_exception_skips_if_not_specified(self): + # makes sure if execute_on_exception == True we still apply the batch + drop_table(BatchQueryLogModel) + sync_table(BatchQueryLogModel) + + obj = BatchQueryLogModel.objects(k=2) + self.assertEqual(0, len(obj)) + + try: + with BatchQuery() as b: + BatchQueryLogModel.batch(b).create(k=2, v=2) + raise Exception("Blah") + except: + pass + + obj = BatchQueryLogModel.objects(k=2) + + # should be 0 because the batch should not execute + self.assertEqual(0, len(obj)) + + @execute_count(1) + def test_batch_execute_timeout(self): + with mock.patch.object(Session, 'execute') as mock_execute: + with BatchQuery(timeout=1) as b: + BatchQueryLogModel.batch(b).create(k=2, v=2) + self.assertEqual(mock_execute.call_args[-1]['timeout'], 1) + + @execute_count(1) + def test_batch_execute_no_timeout(self): + with mock.patch.object(Session, 'execute') as mock_execute: + with BatchQuery() as b: + BatchQueryLogModel.batch(b).create(k=2, v=2) + self.assertEqual(mock_execute.call_args[-1]['timeout'], NOT_SET) + + +class BatchTypeQueryTests(BaseCassEngTestCase): + def setUp(self): + sync_table(TestMultiKeyModel) + sync_table(CounterBatchQueryModel) + + def tearDown(self): + drop_table(TestMultiKeyModel) + drop_table(CounterBatchQueryModel) + + @execute_count(6) + def test_cassandra_batch_type(self): + """ + Tests the different types of `class: cassandra.query.BatchType` + + @since 3.13 + @jira_ticket PYTHON-88 + @expected_result batch query succeeds and the results + are correctly readen + + @test_category query + """ + with BatchQuery(batch_type=cassandra_BatchType.UNLOGGED) as b: + TestMultiKeyModel.batch(b).create(partition=1, cluster=1) + TestMultiKeyModel.batch(b).create(partition=1, cluster=2) + + obj = TestMultiKeyModel.objects(partition=1) + self.assertEqual(2, len(obj)) + + with BatchQuery(batch_type=cassandra_BatchType.COUNTER) as b: + CounterBatchQueryModel.batch(b).create(k=1, v=1) + CounterBatchQueryModel.batch(b).create(k=1, v=2) + CounterBatchQueryModel.batch(b).create(k=1, v=10) + + obj = CounterBatchQueryModel.objects(k=1) + self.assertEqual(1, len(obj)) + self.assertEqual(obj[0].v, 13) + + with BatchQuery(batch_type=cassandra_BatchType.LOGGED) as b: + TestMultiKeyModel.batch(b).create(partition=1, cluster=1) + TestMultiKeyModel.batch(b).create(partition=1, cluster=2) + + obj = TestMultiKeyModel.objects(partition=1) + self.assertEqual(2, len(obj)) + + @execute_count(4) + def test_cqlengine_batch_type(self): + """ + Tests the different types of `class: cassandra.cqlengine.query.BatchType` + + @since 3.13 + @jira_ticket PYTHON-88 + @expected_result batch query succeeds and the results + are correctly readen + + @test_category query + """ + with BatchQuery(batch_type=cqlengine_BatchType.Unlogged) as b: + TestMultiKeyModel.batch(b).create(partition=1, cluster=1) + TestMultiKeyModel.batch(b).create(partition=1, cluster=2) + + obj = TestMultiKeyModel.objects(partition=1) + self.assertEqual(2, len(obj)) + + with BatchQuery(batch_type=cqlengine_BatchType.Counter) as b: + CounterBatchQueryModel.batch(b).create(k=1, v=1) + CounterBatchQueryModel.batch(b).create(k=1, v=2) + CounterBatchQueryModel.batch(b).create(k=1, v=10) + + obj = CounterBatchQueryModel.objects(k=1) + self.assertEqual(1, len(obj)) + self.assertEqual(obj[0].v, 13) diff --git a/tests/integration/cqlengine/query/test_datetime_queries.py b/tests/integration/cqlengine/query/test_datetime_queries.py new file mode 100644 index 0000000..ba1c90b --- /dev/null +++ b/tests/integration/cqlengine/query/test_datetime_queries.py @@ -0,0 +1,75 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +from uuid import uuid4 +from cassandra.cqlengine.functions import get_total_seconds + +from tests.integration.cqlengine.base import BaseCassEngTestCase + +from cassandra.cqlengine.management import sync_table +from cassandra.cqlengine.management import drop_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns +from tests.integration.cqlengine import execute_count + + +class DateTimeQueryTestModel(Model): + + user = columns.Integer(primary_key=True) + day = columns.DateTime(primary_key=True) + data = columns.Text() + + +class TestDateTimeQueries(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestDateTimeQueries, cls).setUpClass() + sync_table(DateTimeQueryTestModel) + + cls.base_date = datetime.now() - timedelta(days=10) + for x in range(7): + for y in range(10): + DateTimeQueryTestModel.create( + user=x, + day=(cls.base_date+timedelta(days=y)), + data=str(uuid4()) + ) + + @classmethod + def tearDownClass(cls): + super(TestDateTimeQueries, cls).tearDownClass() + drop_table(DateTimeQueryTestModel) + + @execute_count(1) + def test_range_query(self): + """ Tests that loading from a range of dates works properly """ + start = datetime(*self.base_date.timetuple()[:3]) + end = start + timedelta(days=3) + + results = DateTimeQueryTestModel.filter(user=0, day__gte=start, day__lt=end) + assert len(results) == 3 + + @execute_count(3) + def test_datetime_precision(self): + """ Tests that millisecond resolution is preserved when saving datetime objects """ + now = datetime.now() + pk = 1000 + obj = DateTimeQueryTestModel.create(user=pk, day=now, data='energy cheese') + load = DateTimeQueryTestModel.get(user=pk) + + self.assertAlmostEqual(get_total_seconds(now - load.day), 0, 2) + obj.delete() + diff --git a/tests/integration/cqlengine/query/test_named.py b/tests/integration/cqlengine/query/test_named.py new file mode 100644 index 0000000..4907c26 --- /dev/null +++ b/tests/integration/cqlengine/query/test_named.py @@ -0,0 +1,387 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra import ConsistencyLevel +from cassandra.cqlengine import operators +from cassandra.cqlengine.named import NamedKeyspace +from cassandra.cqlengine.operators import EqualsOperator, GreaterThanOrEqualOperator +from cassandra.cqlengine.query import ResultObject +from cassandra.concurrent import execute_concurrent_with_args +from cassandra.cqlengine import models + +from tests.integration.cqlengine import setup_connection, execute_count +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine.query.test_queryset import BaseQuerySetUsage + + +from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthanorequalcass30 + + +class TestQuerySetOperation(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestQuerySetOperation, cls).setUpClass() + cls.keyspace = NamedKeyspace('cqlengine_test') + cls.table = cls.keyspace.table('test_model') + + def test_query_filter_parsing(self): + """ + Tests the queryset filter method parses it's kwargs properly + """ + query1 = self.table.objects(test_id=5) + assert len(query1._where) == 1 + + op = query1._where[0] + assert isinstance(op.operator, operators.EqualsOperator) + assert op.value == 5 + + query2 = query1.filter(expected_result__gte=1) + assert len(query2._where) == 2 + + op = query2._where[1] + assert isinstance(op.operator, operators.GreaterThanOrEqualOperator) + assert op.value == 1 + + def test_query_expression_parsing(self): + """ Tests that query experessions are evaluated properly """ + query1 = self.table.filter(self.table.column('test_id') == 5) + assert len(query1._where) == 1 + + op = query1._where[0] + assert isinstance(op.operator, operators.EqualsOperator) + assert op.value == 5 + + query2 = query1.filter(self.table.column('expected_result') >= 1) + assert len(query2._where) == 2 + + op = query2._where[1] + assert isinstance(op.operator, operators.GreaterThanOrEqualOperator) + assert op.value == 1 + + def test_filter_method_where_clause_generation(self): + """ + Tests the where clause creation + """ + query1 = self.table.objects(test_id=5) + self.assertEqual(len(query1._where), 1) + where = query1._where[0] + self.assertEqual(where.field, 'test_id') + self.assertEqual(where.value, 5) + + query2 = query1.filter(expected_result__gte=1) + self.assertEqual(len(query2._where), 2) + + where = query2._where[0] + self.assertEqual(where.field, 'test_id') + self.assertIsInstance(where.operator, EqualsOperator) + self.assertEqual(where.value, 5) + + where = query2._where[1] + self.assertEqual(where.field, 'expected_result') + self.assertIsInstance(where.operator, GreaterThanOrEqualOperator) + self.assertEqual(where.value, 1) + + def test_query_expression_where_clause_generation(self): + """ + Tests the where clause creation + """ + query1 = self.table.objects(self.table.column('test_id') == 5) + self.assertEqual(len(query1._where), 1) + where = query1._where[0] + self.assertEqual(where.field, 'test_id') + self.assertEqual(where.value, 5) + + query2 = query1.filter(self.table.column('expected_result') >= 1) + self.assertEqual(len(query2._where), 2) + + where = query2._where[0] + self.assertEqual(where.field, 'test_id') + self.assertIsInstance(where.operator, EqualsOperator) + self.assertEqual(where.value, 5) + + where = query2._where[1] + self.assertEqual(where.field, 'expected_result') + self.assertIsInstance(where.operator, GreaterThanOrEqualOperator) + self.assertEqual(where.value, 1) + + +class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): + + @classmethod + def setUpClass(cls): + super(TestQuerySetCountSelectionAndIteration, cls).setUpClass() + + from tests.integration.cqlengine.query.test_queryset import TestModel + + ks, tn = TestModel.column_family_name().split('.') + cls.keyspace = NamedKeyspace(ks) + cls.table = cls.keyspace.table(tn) + + @execute_count(2) + def test_count(self): + """ Tests that adding filtering statements affects the count query as expected """ + assert self.table.objects.count() == 12 + + q = self.table.objects(test_id=0) + assert q.count() == 4 + + @execute_count(2) + def test_query_expression_count(self): + """ Tests that adding query statements affects the count query as expected """ + assert self.table.objects.count() == 12 + + q = self.table.objects(self.table.column('test_id') == 0) + assert q.count() == 4 + + @execute_count(3) + def test_iteration(self): + """ Tests that iterating over a query set pulls back all of the expected results """ + q = self.table.objects(test_id=0) + # tuple of expected attempt_id, expected_result values + compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + # test with regular filtering + q = self.table.objects(attempt_id=3).allow_filtering() + assert len(q) == 3 + # tuple of expected test_id, expected_result values + compare_set = set([(0, 20), (1, 20), (2, 75)]) + for t in q: + val = t.test_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + # test with query method + q = self.table.objects(self.table.column('attempt_id') == 3).allow_filtering() + assert len(q) == 3 + # tuple of expected test_id, expected_result values + compare_set = set([(0, 20), (1, 20), (2, 75)]) + for t in q: + val = t.test_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + @execute_count(2) + def test_multiple_iterations_work_properly(self): + """ Tests that iterating over a query set more than once works """ + # test with both the filtering method and the query method + for q in (self.table.objects(test_id=0), self.table.objects(self.table.column('test_id') == 0)): + # tuple of expected attempt_id, expected_result values + compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + # try it again + compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + @execute_count(2) + def test_multiple_iterators_are_isolated(self): + """ + tests that the use of one iterator does not affect the behavior of another + """ + for q in (self.table.objects(test_id=0), self.table.objects(self.table.column('test_id') == 0)): + q = q.order_by('attempt_id') + expected_order = [0, 1, 2, 3] + iter1 = iter(q) + iter2 = iter(q) + for attempt_id in expected_order: + assert next(iter1).attempt_id == attempt_id + assert next(iter2).attempt_id == attempt_id + + @execute_count(3) + def test_get_success_case(self): + """ + Tests that the .get() method works on new and existing querysets + """ + m = self.table.objects.get(test_id=0, attempt_id=0) + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = self.table.objects(test_id=0, attempt_id=0) + m = q.get() + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = self.table.objects(test_id=0) + m = q.get(attempt_id=0) + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + @execute_count(3) + def test_query_expression_get_success_case(self): + """ + Tests that the .get() method works on new and existing querysets + """ + m = self.table.get(self.table.column('test_id') == 0, self.table.column('attempt_id') == 0) + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = self.table.objects(self.table.column('test_id') == 0, self.table.column('attempt_id') == 0) + m = q.get() + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = self.table.objects(self.table.column('test_id') == 0) + m = q.get(self.table.column('attempt_id') == 0) + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + @execute_count(1) + def test_get_doesnotexist_exception(self): + """ + Tests that get calls that don't return a result raises a DoesNotExist error + """ + with self.assertRaises(self.table.DoesNotExist): + self.table.objects.get(test_id=100) + + @execute_count(1) + def test_get_multipleobjects_exception(self): + """ + Tests that get calls that return multiple results raise a MultipleObjectsReturned error + """ + with self.assertRaises(self.table.MultipleObjectsReturned): + self.table.objects.get(test_id=1) + + +class TestNamedWithMV(BasicSharedKeyspaceUnitTestCase): + + @classmethod + def setUpClass(cls): + super(TestNamedWithMV, cls).setUpClass() + cls.default_keyspace = models.DEFAULT_KEYSPACE + models.DEFAULT_KEYSPACE = cls.ks_name + + @classmethod + def tearDownClass(cls): + models.DEFAULT_KEYSPACE = cls.default_keyspace + super(TestNamedWithMV, cls).tearDownClass() + + @greaterthanorequalcass30 + @execute_count(5) + def test_named_table_with_mv(self): + """ + Test NamedTable access to materialized views + + Creates some materialized views using Traditional CQL. Then ensures we can access those materialized view using + the NamedKeyspace, and NamedTable interfaces. Tests basic filtering as well. + + @since 3.0.0 + @jira_ticket PYTHON-406 + @expected_result Named Tables should have access to materialized views + + @test_category materialized_view + """ + ks = models.DEFAULT_KEYSPACE + self.session.execute("DROP MATERIALIZED VIEW IF EXISTS {0}.alltimehigh".format(ks)) + self.session.execute("DROP MATERIALIZED VIEW IF EXISTS {0}.monthlyhigh".format(ks)) + self.session.execute("DROP TABLE IF EXISTS {0}.scores".format(ks)) + create_table = """CREATE TABLE {0}.scores( + user TEXT, + game TEXT, + year INT, + month INT, + day INT, + score INT, + PRIMARY KEY (user, game, year, month, day) + )""".format(ks) + + self.session.execute(create_table) + create_mv = """CREATE MATERIALIZED VIEW {0}.monthlyhigh AS + SELECT game, year, month, score, user, day FROM {0}.scores + WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL + PRIMARY KEY ((game, year, month), score, user, day) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(ks) + + self.session.execute(create_mv) + + create_mv_alltime = """CREATE MATERIALIZED VIEW {0}.alltimehigh AS + SELECT * FROM {0}.scores + WHERE game IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL + PRIMARY KEY (game, score, user, year, month, day) + WITH CLUSTERING ORDER BY (score DESC)""".format(ks) + + self.session.execute(create_mv_alltime) + + # Populate the base table with data + prepared_insert = self.session.prepare("""INSERT INTO {0}.scores (user, game, year, month, day, score) VALUES (?, ?, ? ,? ,?, ?)""".format(ks)) + parameters = (('pcmanus', 'Coup', 2015, 5, 1, 4000), + ('jbellis', 'Coup', 2015, 5, 3, 1750), + ('yukim', 'Coup', 2015, 5, 3, 2250), + ('tjake', 'Coup', 2015, 5, 3, 500), + ('iamaleksey', 'Coup', 2015, 6, 1, 2500), + ('tjake', 'Coup', 2015, 6, 2, 1000), + ('pcmanus', 'Coup', 2015, 6, 2, 2000), + ('jmckenzie', 'Coup', 2015, 6, 9, 2700), + ('jbellis', 'Coup', 2015, 6, 20, 3500), + ('jbellis', 'Checkers', 2015, 6, 20, 1200), + ('jbellis', 'Chess', 2015, 6, 21, 3500), + ('pcmanus', 'Chess', 2015, 1, 25, 3200)) + prepared_insert.consistency_level = ConsistencyLevel.ALL + execute_concurrent_with_args(self.session, prepared_insert, parameters) + + # Attempt to query the data using Named Table interface + # Also test filtering on mv's + key_space = NamedKeyspace(ks) + mv_monthly = key_space.table("monthlyhigh") + mv_all_time = key_space.table("alltimehigh") + self.assertTrue(self.check_table_size("scores", key_space, len(parameters))) + self.assertTrue(self.check_table_size("monthlyhigh", key_space, len(parameters))) + self.assertTrue(self.check_table_size("alltimehigh", key_space, len(parameters))) + + filtered_mv_monthly_objects = mv_monthly.objects.filter(game='Chess', year=2015, month=6) + self.assertEqual(len(filtered_mv_monthly_objects), 1) + self.assertEqual(filtered_mv_monthly_objects[0]['score'], 3500) + self.assertEqual(filtered_mv_monthly_objects[0]['user'], 'jbellis') + filtered_mv_alltime_objects = mv_all_time.objects.filter(game='Chess') + self.assertEqual(len(filtered_mv_alltime_objects), 2) + self.assertEqual(filtered_mv_alltime_objects[0]['score'], 3500) + + def check_table_size(self, table_name, key_space, expected_size): + table = key_space.table(table_name) + attempts = 0 + while attempts < 10: + attempts += 1 + table_size = len(table.objects.all()) + if(table_size is not expected_size): + print("Table {0} size was {1} and was expected to be {2}".format(table_name, table_size, expected_size)) + else: + return True + + return False diff --git a/tests/integration/cqlengine/query/test_queryoperators.py b/tests/integration/cqlengine/query/test_queryoperators.py new file mode 100644 index 0000000..fd148ba --- /dev/null +++ b/tests/integration/cqlengine/query/test_queryoperators.py @@ -0,0 +1,159 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +from cassandra.cqlengine import columns +from cassandra.cqlengine import functions +from cassandra.cqlengine import query +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.named import NamedTable +from cassandra.cqlengine.operators import EqualsOperator +from cassandra.cqlengine.statements import WhereClause +from tests.integration.cqlengine import DEFAULT_KEYSPACE +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine import execute_count + + +class TestQuerySetOperation(BaseCassEngTestCase): + + def test_maxtimeuuid_function(self): + """ + Tests that queries with helper functions are generated properly + """ + now = datetime.now() + where = WhereClause('time', EqualsOperator(), functions.MaxTimeUUID(now)) + where.set_context_id(5) + + self.assertEqual(str(where), '"time" = MaxTimeUUID(%(5)s)') + ctx = {} + where.update_context(ctx) + self.assertEqual(ctx, {'5': columns.DateTime().to_database(now)}) + + def test_mintimeuuid_function(self): + """ + Tests that queries with helper functions are generated properly + """ + now = datetime.now() + where = WhereClause('time', EqualsOperator(), functions.MinTimeUUID(now)) + where.set_context_id(5) + + self.assertEqual(str(where), '"time" = MinTimeUUID(%(5)s)') + ctx = {} + where.update_context(ctx) + self.assertEqual(ctx, {'5': columns.DateTime().to_database(now)}) + + +class TokenTestModel(Model): + __table_name__ = "token_test_model" + key = columns.Integer(primary_key=True) + val = columns.Integer() + + +class TestTokenFunction(BaseCassEngTestCase): + + def setUp(self): + super(TestTokenFunction, self).setUp() + sync_table(TokenTestModel) + + def tearDown(self): + super(TestTokenFunction, self).tearDown() + drop_table(TokenTestModel) + + @execute_count(15) + def test_token_function(self): + """ Tests that token functions work properly """ + assert TokenTestModel.objects.count() == 0 + for i in range(10): + TokenTestModel.create(key=i, val=i) + assert TokenTestModel.objects.count() == 10 + seen_keys = set() + last_token = None + for instance in TokenTestModel.objects().limit(5): + last_token = instance.key + seen_keys.add(last_token) + assert len(seen_keys) == 5 + for instance in TokenTestModel.objects(pk__token__gt=functions.Token(last_token)): + seen_keys.add(instance.key) + + assert len(seen_keys) == 10 + assert all([i in seen_keys for i in range(10)]) + + # pk__token equality + r = TokenTestModel.objects(pk__token=functions.Token(last_token)) + self.assertEqual(len(r), 1) + r.all() # Attempt to obtain queryset for results. This has thrown an exception in the past + + def test_compound_pk_token_function(self): + + class TestModel(Model): + + p1 = columns.Text(partition_key=True) + p2 = columns.Text(partition_key=True) + + func = functions.Token('a', 'b') + + q = TestModel.objects.filter(pk__token__gt=func) + where = q._where[0] + where.set_context_id(1) + self.assertEqual(str(where), 'token("p1", "p2") > token(%({0})s, %({1})s)'.format(1, 2)) + + # Verify that a SELECT query can be successfully generated + str(q._select_query()) + + # Token(tuple()) is also possible for convenience + # it (allows for Token(obj.pk) syntax) + func = functions.Token(('a', 'b')) + + q = TestModel.objects.filter(pk__token__gt=func) + where = q._where[0] + where.set_context_id(1) + self.assertEqual(str(where), 'token("p1", "p2") > token(%({0})s, %({1})s)'.format(1, 2)) + str(q._select_query()) + + # The 'pk__token' virtual column may only be compared to a Token + self.assertRaises(query.QueryException, TestModel.objects.filter, pk__token__gt=10) + + # A Token may only be compared to the `pk__token' virtual column + func = functions.Token('a', 'b') + self.assertRaises(query.QueryException, TestModel.objects.filter, p1__gt=func) + + # The # of arguments to Token must match the # of partition keys + func = functions.Token('a') + self.assertRaises(query.QueryException, TestModel.objects.filter, pk__token__gt=func) + + @execute_count(7) + def test_named_table_pk_token_function(self): + """ + Test to ensure that token function work with named tables. + + @since 3.2 + @jira_ticket PYTHON-272 + @expected_result partition key token functions should all for pagination. Prior to Python-272 + this would fail with an AttributeError + + @test_category object_mapper + """ + + for i in range(5): + TokenTestModel.create(key=i, val=i) + named = NamedTable(DEFAULT_KEYSPACE, TokenTestModel.__table_name__) + + query = named.all().limit(1) + first_page = list(query) + last = first_page[-1] + self.assertTrue(len(first_page) is 1) + next_page = list(query.filter(pk__token__gt=functions.Token(last.key))) + self.assertTrue(len(next_page) is 1) diff --git a/tests/integration/cqlengine/query/test_queryset.py b/tests/integration/cqlengine/query/test_queryset.py new file mode 100644 index 0000000..e5a15b7 --- /dev/null +++ b/tests/integration/cqlengine/query/test_queryset.py @@ -0,0 +1,1456 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from datetime import datetime +from uuid import uuid4 +from packaging.version import Version +import uuid + +from cassandra.cluster import Cluster, Session +from cassandra import InvalidRequest +from tests.integration.cqlengine.base import BaseCassEngTestCase +from cassandra.cqlengine.connection import NOT_SET +import mock +from cassandra.cqlengine import functions +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns +from cassandra.cqlengine import query +from cassandra.cqlengine.query import QueryException, BatchQuery +from datetime import timedelta +from datetime import tzinfo + +from cassandra.cqlengine import statements +from cassandra.cqlengine import operators +from cassandra.util import uuid_from_time +from cassandra.cqlengine.connection import get_session +from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthancass20, greaterthancass21, \ + greaterthanorequalcass30 +from tests.integration.cqlengine import execute_count, DEFAULT_KEYSPACE + + +class TzOffset(tzinfo): + """Minimal implementation of a timezone offset to help testing with timezone + aware datetimes. + """ + + def __init__(self, offset): + self._offset = timedelta(hours=offset) + + def utcoffset(self, dt): + return self._offset + + def tzname(self, dt): + return 'TzOffset: {}'.format(self._offset.hours) + + def dst(self, dt): + return timedelta(0) + + +class TestModel(Model): + + test_id = columns.Integer(primary_key=True) + attempt_id = columns.Integer(primary_key=True) + description = columns.Text() + expected_result = columns.Integer() + test_result = columns.Integer() + + +class IndexedTestModel(Model): + + test_id = columns.Integer(primary_key=True) + attempt_id = columns.Integer(index=True) + description = columns.Text() + expected_result = columns.Integer() + test_result = columns.Integer(index=True) + + +class CustomIndexedTestModel(Model): + + test_id = columns.Integer(primary_key=True) + description = columns.Text(custom_index=True) + indexed = columns.Text(index=True) + data = columns.Text() + + +class IndexedCollectionsTestModel(Model): + + test_id = columns.Integer(primary_key=True) + attempt_id = columns.Integer(index=True) + description = columns.Text() + expected_result = columns.Integer() + test_result = columns.Integer(index=True) + test_list = columns.List(columns.Integer, index=True) + test_set = columns.Set(columns.Integer, index=True) + test_map = columns.Map(columns.Text, columns.Integer, index=True) + + test_list_no_index = columns.List(columns.Integer, index=False) + test_set_no_index = columns.Set(columns.Integer, index=False) + test_map_no_index = columns.Map(columns.Text, columns.Integer, index=False) + + +class TestMultiClusteringModel(Model): + + one = columns.Integer(primary_key=True) + two = columns.Integer(primary_key=True) + three = columns.Integer(primary_key=True) + + +class TestQuerySetOperation(BaseCassEngTestCase): + + def test_query_filter_parsing(self): + """ + Tests the queryset filter method parses it's kwargs properly + """ + query1 = TestModel.objects(test_id=5) + assert len(query1._where) == 1 + + op = query1._where[0] + + assert isinstance(op, statements.WhereClause) + assert isinstance(op.operator, operators.EqualsOperator) + assert op.value == 5 + + query2 = query1.filter(expected_result__gte=1) + assert len(query2._where) == 2 + + op = query2._where[1] + self.assertIsInstance(op, statements.WhereClause) + self.assertIsInstance(op.operator, operators.GreaterThanOrEqualOperator) + assert op.value == 1 + + def test_query_expression_parsing(self): + """ Tests that query experessions are evaluated properly """ + query1 = TestModel.filter(TestModel.test_id == 5) + assert len(query1._where) == 1 + + op = query1._where[0] + assert isinstance(op, statements.WhereClause) + assert isinstance(op.operator, operators.EqualsOperator) + assert op.value == 5 + + query2 = query1.filter(TestModel.expected_result >= 1) + assert len(query2._where) == 2 + + op = query2._where[1] + self.assertIsInstance(op, statements.WhereClause) + self.assertIsInstance(op.operator, operators.GreaterThanOrEqualOperator) + assert op.value == 1 + + def test_using_invalid_column_names_in_filter_kwargs_raises_error(self): + """ + Tests that using invalid or nonexistant column names for filter args raises an error + """ + with self.assertRaises(query.QueryException): + TestModel.objects(nonsense=5) + + def test_using_nonexistant_column_names_in_query_args_raises_error(self): + """ + Tests that using invalid or nonexistant columns for query args raises an error + """ + with self.assertRaises(AttributeError): + TestModel.objects(TestModel.nonsense == 5) + + def test_using_non_query_operators_in_query_args_raises_error(self): + """ + Tests that providing query args that are not query operator instances raises an error + """ + with self.assertRaises(query.QueryException): + TestModel.objects(5) + + def test_queryset_is_immutable(self): + """ + Tests that calling a queryset function that changes it's state returns a new queryset + """ + query1 = TestModel.objects(test_id=5) + assert len(query1._where) == 1 + + query2 = query1.filter(expected_result__gte=1) + assert len(query2._where) == 2 + assert len(query1._where) == 1 + + def test_queryset_limit_immutability(self): + """ + Tests that calling a queryset function that changes it's state returns a new queryset with same limit + """ + query1 = TestModel.objects(test_id=5).limit(1) + assert query1._limit == 1 + + query2 = query1.filter(expected_result__gte=1) + assert query2._limit == 1 + + query3 = query1.filter(expected_result__gte=1).limit(2) + assert query1._limit == 1 + assert query3._limit == 2 + + def test_the_all_method_duplicates_queryset(self): + """ + Tests that calling all on a queryset with previously defined filters duplicates queryset + """ + query1 = TestModel.objects(test_id=5) + assert len(query1._where) == 1 + + query2 = query1.filter(expected_result__gte=1) + assert len(query2._where) == 2 + + query3 = query2.all() + assert query3 == query2 + + def test_queryset_with_distinct(self): + """ + Tests that calling distinct on a queryset w/without parameter are evaluated properly. + """ + + query1 = TestModel.objects.distinct() + self.assertEqual(len(query1._distinct_fields), 1) + + query2 = TestModel.objects.distinct(['test_id']) + self.assertEqual(len(query2._distinct_fields), 1) + + query3 = TestModel.objects.distinct(['test_id', 'attempt_id']) + self.assertEqual(len(query3._distinct_fields), 2) + + def test_defining_only_fields(self): + """ + Tests defining only fields + + @since 3.5 + @jira_ticket PYTHON-560 + @expected_result deferred fields should not be returned + + @test_category object_mapper + """ + # simple only definition + q = TestModel.objects.only(['attempt_id', 'description']) + self.assertEqual(q._select_fields(), ['attempt_id', 'description']) + + with self.assertRaises(query.QueryException): + TestModel.objects.only(['nonexistent_field']) + + # Cannot define more than once only fields + with self.assertRaises(query.QueryException): + TestModel.objects.only(['description']).only(['attempt_id']) + + # only with defer fields + q = TestModel.objects.only(['attempt_id', 'description']) + q = q.defer(['description']) + self.assertEqual(q._select_fields(), ['attempt_id']) + + # Eliminate all results confirm exception is thrown + q = TestModel.objects.only(['description']) + q = q.defer(['description']) + with self.assertRaises(query.QueryException): + q._select_fields() + + q = TestModel.objects.filter(test_id=0).only(['test_id', 'attempt_id', 'description']) + self.assertEqual(q._select_fields(), ['attempt_id', 'description']) + + # no fields to select + with self.assertRaises(query.QueryException): + q = TestModel.objects.only(['test_id']).defer(['test_id']) + q._select_fields() + + with self.assertRaises(query.QueryException): + q = TestModel.objects.filter(test_id=0).only(['test_id']) + q._select_fields() + + def test_defining_defer_fields(self): + """ + Tests defining defer fields + + @since 3.5 + @jira_ticket PYTHON-560 + @jira_ticket PYTHON-599 + @expected_result deferred fields should not be returned + + @test_category object_mapper + """ + + # simple defer definition + q = TestModel.objects.defer(['attempt_id', 'description']) + self.assertEqual(q._select_fields(), ['test_id', 'expected_result', 'test_result']) + + with self.assertRaises(query.QueryException): + TestModel.objects.defer(['nonexistent_field']) + + # defer more than one + q = TestModel.objects.defer(['attempt_id', 'description']) + q = q.defer(['expected_result']) + self.assertEqual(q._select_fields(), ['test_id', 'test_result']) + + # defer with only + q = TestModel.objects.defer(['description', 'attempt_id']) + q = q.only(['description', 'test_id']) + self.assertEqual(q._select_fields(), ['test_id']) + + # Eliminate all results confirm exception is thrown + q = TestModel.objects.defer(['description', 'attempt_id']) + q = q.only(['description']) + with self.assertRaises(query.QueryException): + q._select_fields() + + # implicit defer + q = TestModel.objects.filter(test_id=0) + self.assertEqual(q._select_fields(), ['attempt_id', 'description', 'expected_result', 'test_result']) + + # when all fields are defered, it fallbacks select the partition keys + q = TestModel.objects.defer(['test_id', 'attempt_id', 'description', 'expected_result', 'test_result']) + self.assertEqual(q._select_fields(), ['test_id']) + + +class BaseQuerySetUsage(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseQuerySetUsage, cls).setUpClass() + drop_table(TestModel) + drop_table(IndexedTestModel) + drop_table(CustomIndexedTestModel) + + sync_table(TestModel) + sync_table(IndexedTestModel) + sync_table(CustomIndexedTestModel) + sync_table(TestMultiClusteringModel) + + TestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30) + TestModel.objects.create(test_id=0, attempt_id=1, description='try2', expected_result=10, test_result=30) + TestModel.objects.create(test_id=0, attempt_id=2, description='try3', expected_result=15, test_result=30) + TestModel.objects.create(test_id=0, attempt_id=3, description='try4', expected_result=20, test_result=25) + + TestModel.objects.create(test_id=1, attempt_id=0, description='try5', expected_result=5, test_result=25) + TestModel.objects.create(test_id=1, attempt_id=1, description='try6', expected_result=10, test_result=25) + TestModel.objects.create(test_id=1, attempt_id=2, description='try7', expected_result=15, test_result=25) + TestModel.objects.create(test_id=1, attempt_id=3, description='try8', expected_result=20, test_result=20) + + TestModel.objects.create(test_id=2, attempt_id=0, description='try9', expected_result=50, test_result=40) + TestModel.objects.create(test_id=2, attempt_id=1, description='try10', expected_result=60, test_result=40) + TestModel.objects.create(test_id=2, attempt_id=2, description='try11', expected_result=70, test_result=45) + TestModel.objects.create(test_id=2, attempt_id=3, description='try12', expected_result=75, test_result=45) + + IndexedTestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30) + IndexedTestModel.objects.create(test_id=1, attempt_id=1, description='try2', expected_result=10, test_result=30) + IndexedTestModel.objects.create(test_id=2, attempt_id=2, description='try3', expected_result=15, test_result=30) + IndexedTestModel.objects.create(test_id=3, attempt_id=3, description='try4', expected_result=20, test_result=25) + + IndexedTestModel.objects.create(test_id=4, attempt_id=0, description='try5', expected_result=5, test_result=25) + IndexedTestModel.objects.create(test_id=5, attempt_id=1, description='try6', expected_result=10, test_result=25) + IndexedTestModel.objects.create(test_id=6, attempt_id=2, description='try7', expected_result=15, test_result=25) + IndexedTestModel.objects.create(test_id=7, attempt_id=3, description='try8', expected_result=20, test_result=20) + + IndexedTestModel.objects.create(test_id=8, attempt_id=0, description='try9', expected_result=50, test_result=40) + IndexedTestModel.objects.create(test_id=9, attempt_id=1, description='try10', expected_result=60, + test_result=40) + IndexedTestModel.objects.create(test_id=10, attempt_id=2, description='try11', expected_result=70, + test_result=45) + IndexedTestModel.objects.create(test_id=11, attempt_id=3, description='try12', expected_result=75, + test_result=45) + + if CASSANDRA_VERSION >= Version('2.1'): + drop_table(IndexedCollectionsTestModel) + sync_table(IndexedCollectionsTestModel) + IndexedCollectionsTestModel.objects.create(test_id=12, attempt_id=3, description='list12', expected_result=75, + test_result=45, test_list=[1, 2, 42], test_set=set([1, 2, 3]), + test_map={'1': 1, '2': 2, '3': 3}) + IndexedCollectionsTestModel.objects.create(test_id=13, attempt_id=3, description='list13', expected_result=75, + test_result=45, test_list=[3, 4, 5], test_set=set([4, 5, 42]), + test_map={'1': 5, '2': 6, '3': 7}) + IndexedCollectionsTestModel.objects.create(test_id=14, attempt_id=3, description='list14', expected_result=75, + test_result=45, test_list=[1, 2, 3], test_set=set([1, 2, 3]), + test_map={'1': 1, '2': 2, '3': 42}) + + IndexedCollectionsTestModel.objects.create(test_id=15, attempt_id=4, description='list14', expected_result=75, + test_result=45, test_list_no_index=[1, 2, 3], test_set_no_index=set([1, 2, 3]), + test_map_no_index={'1': 1, '2': 2, '3': 42}) + + @classmethod + def tearDownClass(cls): + super(BaseQuerySetUsage, cls).tearDownClass() + drop_table(TestModel) + drop_table(IndexedTestModel) + drop_table(CustomIndexedTestModel) + drop_table(TestMultiClusteringModel) + + +class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): + + @execute_count(2) + def test_count(self): + """ Tests that adding filtering statements affects the count query as expected """ + assert TestModel.objects.count() == 12 + + q = TestModel.objects(test_id=0) + assert q.count() == 4 + + @execute_count(2) + def test_query_expression_count(self): + """ Tests that adding query statements affects the count query as expected """ + assert TestModel.objects.count() == 12 + + q = TestModel.objects(TestModel.test_id == 0) + assert q.count() == 4 + + @execute_count(3) + def test_iteration(self): + """ Tests that iterating over a query set pulls back all of the expected results """ + q = TestModel.objects(test_id=0) + # tuple of expected attempt_id, expected_result values + compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + # test with regular filtering + q = TestModel.objects(attempt_id=3).allow_filtering() + assert len(q) == 3 + # tuple of expected test_id, expected_result values + compare_set = set([(0, 20), (1, 20), (2, 75)]) + for t in q: + val = t.test_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + # test with query method + q = TestModel.objects(TestModel.attempt_id == 3).allow_filtering() + assert len(q) == 3 + # tuple of expected test_id, expected_result values + compare_set = set([(0, 20), (1, 20), (2, 75)]) + for t in q: + val = t.test_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + @execute_count(2) + def test_multiple_iterations_work_properly(self): + """ Tests that iterating over a query set more than once works """ + # test with both the filtering method and the query method + for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)): + # tuple of expected attempt_id, expected_result values + compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + # try it again + compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + @execute_count(2) + def test_multiple_iterators_are_isolated(self): + """ + tests that the use of one iterator does not affect the behavior of another + """ + for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)): + q = q.order_by('attempt_id') + expected_order = [0, 1, 2, 3] + iter1 = iter(q) + iter2 = iter(q) + for attempt_id in expected_order: + assert next(iter1).attempt_id == attempt_id + assert next(iter2).attempt_id == attempt_id + + @execute_count(3) + def test_get_success_case(self): + """ + Tests that the .get() method works on new and existing querysets + """ + m = TestModel.objects.get(test_id=0, attempt_id=0) + assert isinstance(m, TestModel) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = TestModel.objects(test_id=0, attempt_id=0) + m = q.get() + assert isinstance(m, TestModel) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = TestModel.objects(test_id=0) + m = q.get(attempt_id=0) + assert isinstance(m, TestModel) + assert m.test_id == 0 + assert m.attempt_id == 0 + + @execute_count(3) + def test_query_expression_get_success_case(self): + """ + Tests that the .get() method works on new and existing querysets + """ + m = TestModel.get(TestModel.test_id == 0, TestModel.attempt_id == 0) + assert isinstance(m, TestModel) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = TestModel.objects(TestModel.test_id == 0, TestModel.attempt_id == 0) + m = q.get() + assert isinstance(m, TestModel) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = TestModel.objects(TestModel.test_id == 0) + m = q.get(TestModel.attempt_id == 0) + assert isinstance(m, TestModel) + assert m.test_id == 0 + assert m.attempt_id == 0 + + @execute_count(1) + def test_get_doesnotexist_exception(self): + """ + Tests that get calls that don't return a result raises a DoesNotExist error + """ + with self.assertRaises(TestModel.DoesNotExist): + TestModel.objects.get(test_id=100) + + @execute_count(1) + def test_get_multipleobjects_exception(self): + """ + Tests that get calls that return multiple results raise a MultipleObjectsReturned error + """ + with self.assertRaises(TestModel.MultipleObjectsReturned): + TestModel.objects.get(test_id=1) + + def test_allow_filtering_flag(self): + """ + """ + +@execute_count(4) +def test_non_quality_filtering(): + class NonEqualityFilteringModel(Model): + + example_id = columns.UUID(primary_key=True, default=uuid.uuid4) + sequence_id = columns.Integer(primary_key=True) # sequence_id is a clustering key + example_type = columns.Integer(index=True) + created_at = columns.DateTime() + + drop_table(NonEqualityFilteringModel) + sync_table(NonEqualityFilteringModel) + + # setup table, etc. + + NonEqualityFilteringModel.create(sequence_id=1, example_type=0, created_at=datetime.now()) + NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now()) + NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now()) + + qa = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering() + num = qa.count() + assert num == 1, num + + +class TestQuerySetDistinct(BaseQuerySetUsage): + + @execute_count(1) + def test_distinct_without_parameter(self): + q = TestModel.objects.distinct() + self.assertEqual(len(q), 3) + + @execute_count(1) + def test_distinct_with_parameter(self): + q = TestModel.objects.distinct(['test_id']) + self.assertEqual(len(q), 3) + + @execute_count(1) + def test_distinct_with_filter(self): + q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2]) + self.assertEqual(len(q), 2) + + @execute_count(1) + def test_distinct_with_non_partition(self): + with self.assertRaises(InvalidRequest): + q = TestModel.objects.distinct(['description']).filter(test_id__in=[1, 2]) + len(q) + + @execute_count(1) + def test_zero_result(self): + q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[52]) + self.assertEqual(len(q), 0) + + @greaterthancass21 + @execute_count(2) + def test_distinct_with_explicit_count(self): + q = TestModel.objects.distinct(['test_id']) + self.assertEqual(q.count(), 3) + + q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2]) + self.assertEqual(q.count(), 2) + + +class TestQuerySetOrdering(BaseQuerySetUsage): + @execute_count(2) + def test_order_by_success_case(self): + q = TestModel.objects(test_id=0).order_by('attempt_id') + expected_order = [0, 1, 2, 3] + for model, expect in zip(q, expected_order): + assert model.attempt_id == expect + + q = q.order_by('-attempt_id') + expected_order.reverse() + for model, expect in zip(q, expected_order): + assert model.attempt_id == expect + + def test_ordering_by_non_second_primary_keys_fail(self): + # kwarg filtering + with self.assertRaises(query.QueryException): + TestModel.objects(test_id=0).order_by('test_id') + + # kwarg filtering + with self.assertRaises(query.QueryException): + TestModel.objects(TestModel.test_id == 0).order_by('test_id') + + def test_ordering_by_non_primary_keys_fails(self): + with self.assertRaises(query.QueryException): + TestModel.objects(test_id=0).order_by('description') + + def test_ordering_on_indexed_columns_fails(self): + with self.assertRaises(query.QueryException): + IndexedTestModel.objects(test_id=0).order_by('attempt_id') + + @execute_count(8) + def test_ordering_on_multiple_clustering_columns(self): + TestMultiClusteringModel.create(one=1, two=1, three=4) + TestMultiClusteringModel.create(one=1, two=1, three=2) + TestMultiClusteringModel.create(one=1, two=1, three=5) + TestMultiClusteringModel.create(one=1, two=1, three=1) + TestMultiClusteringModel.create(one=1, two=1, three=3) + + results = TestMultiClusteringModel.objects.filter(one=1, two=1).order_by('-two', '-three') + assert [r.three for r in results] == [5, 4, 3, 2, 1] + + results = TestMultiClusteringModel.objects.filter(one=1, two=1).order_by('two', 'three') + assert [r.three for r in results] == [1, 2, 3, 4, 5] + + results = TestMultiClusteringModel.objects.filter(one=1, two=1).order_by('two').order_by('three') + assert [r.three for r in results] == [1, 2, 3, 4, 5] + + +class TestQuerySetSlicing(BaseQuerySetUsage): + + @execute_count(1) + def test_out_of_range_index_raises_error(self): + q = TestModel.objects(test_id=0).order_by('attempt_id') + with self.assertRaises(IndexError): + q[10] + + @execute_count(1) + def test_array_indexing_works_properly(self): + q = TestModel.objects(test_id=0).order_by('attempt_id') + expected_order = [0, 1, 2, 3] + for i in range(len(q)): + assert q[i].attempt_id == expected_order[i] + + @execute_count(1) + def test_negative_indexing_works_properly(self): + q = TestModel.objects(test_id=0).order_by('attempt_id') + expected_order = [0, 1, 2, 3] + assert q[-1].attempt_id == expected_order[-1] + assert q[-2].attempt_id == expected_order[-2] + + @execute_count(1) + def test_slicing_works_properly(self): + q = TestModel.objects(test_id=0).order_by('attempt_id') + expected_order = [0, 1, 2, 3] + + for model, expect in zip(q[1:3], expected_order[1:3]): + self.assertEqual(model.attempt_id, expect) + + for model, expect in zip(q[0:3:2], expected_order[0:3:2]): + self.assertEqual(model.attempt_id, expect) + + @execute_count(1) + def test_negative_slicing(self): + q = TestModel.objects(test_id=0).order_by('attempt_id') + expected_order = [0, 1, 2, 3] + + for model, expect in zip(q[-3:], expected_order[-3:]): + self.assertEqual(model.attempt_id, expect) + + for model, expect in zip(q[:-1], expected_order[:-1]): + self.assertEqual(model.attempt_id, expect) + + for model, expect in zip(q[1:-1], expected_order[1:-1]): + self.assertEqual(model.attempt_id, expect) + + for model, expect in zip(q[-3:-1], expected_order[-3:-1]): + self.assertEqual(model.attempt_id, expect) + + for model, expect in zip(q[-3:-1:2], expected_order[-3:-1:2]): + self.assertEqual(model.attempt_id, expect) + + +class TestQuerySetValidation(BaseQuerySetUsage): + + def test_primary_key_or_index_must_be_specified(self): + """ + Tests that queries that don't have an equals relation to a primary key or indexed field fail + """ + with self.assertRaises(query.QueryException): + q = TestModel.objects(test_result=25) + list([i for i in q]) + + def test_primary_key_or_index_must_have_equal_relation_filter(self): + """ + Tests that queries that don't have non equal (>,<, etc) relation to a primary key or indexed field fail + """ + with self.assertRaises(query.QueryException): + q = TestModel.objects(test_id__gt=0) + list([i for i in q]) + + @greaterthancass20 + @execute_count(7) + def test_indexed_field_can_be_queried(self): + """ + Tests that queries on an indexed field will work without any primary key relations specified + """ + q = IndexedTestModel.objects(test_result=25) + self.assertEqual(q.count(), 4) + + q = IndexedCollectionsTestModel.objects.filter(test_list__contains=42) + self.assertEqual(q.count(), 1) + + q = IndexedCollectionsTestModel.objects.filter(test_list__contains=13) + self.assertEqual(q.count(), 0) + + q = IndexedCollectionsTestModel.objects.filter(test_set__contains=42) + self.assertEqual(q.count(), 1) + + q = IndexedCollectionsTestModel.objects.filter(test_set__contains=13) + self.assertEqual(q.count(), 0) + + q = IndexedCollectionsTestModel.objects.filter(test_map__contains=42) + self.assertEqual(q.count(), 1) + + q = IndexedCollectionsTestModel.objects.filter(test_map__contains=13) + self.assertEqual(q.count(), 0) + + def test_custom_indexed_field_can_be_queried(self): + """ + Tests that queries on an custom indexed field will work without any primary key relations specified + """ + + with self.assertRaises(query.QueryException): + list(CustomIndexedTestModel.objects.filter(data='test')) # not custom indexed + + # It should return InvalidRequest if target an indexed columns + with self.assertRaises(InvalidRequest): + list(CustomIndexedTestModel.objects.filter(indexed='test', data='test')) + + # It should return InvalidRequest if target an indexed columns + with self.assertRaises(InvalidRequest): + list(CustomIndexedTestModel.objects.filter(description='test', data='test')) + + # equals operator, server error since there is no real index, but it passes + with self.assertRaises(InvalidRequest): + list(CustomIndexedTestModel.objects.filter(description='test')) + + with self.assertRaises(InvalidRequest): + list(CustomIndexedTestModel.objects.filter(test_id=1, description='test')) + + # gte operator, server error since there is no real index, but it passes + # this can't work with a secondary index + with self.assertRaises(InvalidRequest): + list(CustomIndexedTestModel.objects.filter(description__gte='test')) + + with Cluster().connect() as session: + session.execute("CREATE INDEX custom_index_cqlengine ON {}.{} (description)". + format(DEFAULT_KEYSPACE, CustomIndexedTestModel._table_name)) + + list(CustomIndexedTestModel.objects.filter(description='test')) + list(CustomIndexedTestModel.objects.filter(test_id=1, description='test')) + + +class TestQuerySetDelete(BaseQuerySetUsage): + + @execute_count(9) + def test_delete(self): + TestModel.objects.create(test_id=3, attempt_id=0, description='try9', expected_result=50, test_result=40) + TestModel.objects.create(test_id=3, attempt_id=1, description='try10', expected_result=60, test_result=40) + TestModel.objects.create(test_id=3, attempt_id=2, description='try11', expected_result=70, test_result=45) + TestModel.objects.create(test_id=3, attempt_id=3, description='try12', expected_result=75, test_result=45) + + assert TestModel.objects.count() == 16 + assert TestModel.objects(test_id=3).count() == 4 + + TestModel.objects(test_id=3).delete() + + assert TestModel.objects.count() == 12 + assert TestModel.objects(test_id=3).count() == 0 + + def test_delete_without_partition_key(self): + """ Tests that attempting to delete a model without defining a partition key fails """ + with self.assertRaises(query.QueryException): + TestModel.objects(attempt_id=0).delete() + + def test_delete_without_any_where_args(self): + """ Tests that attempting to delete a whole table without any arguments will fail """ + with self.assertRaises(query.QueryException): + TestModel.objects(attempt_id=0).delete() + + @greaterthanorequalcass30 + @execute_count(18) + def test_range_deletion(self): + """ + Tests that range deletion work as expected + """ + + for i in range(10): + TestMultiClusteringModel.objects().create(one=1, two=i, three=i) + + TestMultiClusteringModel.objects(one=1, two__gte=0, two__lte=3).delete() + self.assertEqual(6, len(TestMultiClusteringModel.objects.all())) + + TestMultiClusteringModel.objects(one=1, two__gt=3, two__lt=5).delete() + self.assertEqual(5, len(TestMultiClusteringModel.objects.all())) + + TestMultiClusteringModel.objects(one=1, two__in=[8, 9]).delete() + self.assertEqual(3, len(TestMultiClusteringModel.objects.all())) + + TestMultiClusteringModel.objects(one__in=[1], two__gte=0).delete() + self.assertEqual(0, len(TestMultiClusteringModel.objects.all())) + + +class TimeUUIDQueryModel(Model): + + partition = columns.UUID(primary_key=True) + time = columns.TimeUUID(primary_key=True) + data = columns.Text(required=False) + + +class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + super(TestMinMaxTimeUUIDFunctions, cls).setUpClass() + sync_table(TimeUUIDQueryModel) + + @classmethod + def tearDownClass(cls): + super(TestMinMaxTimeUUIDFunctions, cls).tearDownClass() + drop_table(TimeUUIDQueryModel) + + @execute_count(7) + def test_tzaware_datetime_support(self): + """Test that using timezone aware datetime instances works with the + MinTimeUUID/MaxTimeUUID functions. + """ + pk = uuid4() + midpoint_utc = datetime.utcnow().replace(tzinfo=TzOffset(0)) + midpoint_helsinki = midpoint_utc.astimezone(TzOffset(3)) + + # Assert pre-condition that we have the same logical point in time + assert midpoint_utc.utctimetuple() == midpoint_helsinki.utctimetuple() + assert midpoint_utc.timetuple() != midpoint_helsinki.timetuple() + + TimeUUIDQueryModel.create( + partition=pk, + time=uuid_from_time(midpoint_utc - timedelta(minutes=1)), + data='1') + + TimeUUIDQueryModel.create( + partition=pk, + time=uuid_from_time(midpoint_utc), + data='2') + + TimeUUIDQueryModel.create( + partition=pk, + time=uuid_from_time(midpoint_utc + timedelta(minutes=1)), + data='3') + + assert ['1', '2'] == [o.data for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_utc))] + + assert ['1', '2'] == [o.data for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_helsinki))] + + assert ['2', '3'] == [o.data for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_utc))] + + assert ['2', '3'] == [o.data for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_helsinki))] + + @execute_count(8) + def test_success_case(self): + """ Test that the min and max time uuid functions work as expected """ + pk = uuid4() + startpoint = datetime.utcnow() + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=1)), data='1') + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=2)), data='2') + midpoint = startpoint + timedelta(seconds=3) + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=4)), data='3') + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=5)), data='4') + + # test kwarg filtering + q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) + q = [d for d in q] + self.assertEqual(len(q), 2, msg="Got: %s" % q) + datas = [d.data for d in q] + assert '1' in datas + assert '2' in datas + + q = TimeUUIDQueryModel.filter(partition=pk, time__gte=functions.MinTimeUUID(midpoint)) + assert len(q) == 2 + datas = [d.data for d in q] + assert '3' in datas + assert '4' in datas + + # test query expression filtering + q = TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint) + ) + q = [d for d in q] + assert len(q) == 2 + datas = [d.data for d in q] + assert '1' in datas + assert '2' in datas + + q = TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint) + ) + assert len(q) == 2 + datas = [d.data for d in q] + assert '3' in datas + assert '4' in datas + + +class TestInOperator(BaseQuerySetUsage): + @execute_count(1) + def test_kwarg_success_case(self): + """ Tests the in operator works with the kwarg query method """ + q = TestModel.filter(test_id__in=[0, 1]) + assert q.count() == 8 + + @execute_count(1) + def test_query_expression_success_case(self): + """ Tests the in operator works with the query expression query method """ + q = TestModel.filter(TestModel.test_id.in_([0, 1])) + assert q.count() == 8 + + @execute_count(5) + def test_bool(self): + """ + Adding coverage to cqlengine for bool types. + + @since 3.6 + @jira_ticket PYTHON-596 + @expected_result bool results should be filtered appropriately + + @test_category object_mapper + """ + class bool_model(Model): + k = columns.Integer(primary_key=True) + b = columns.Boolean(primary_key=True) + v = columns.Integer(default=3) + sync_table(bool_model) + + bool_model.create(k=0, b=True) + bool_model.create(k=0, b=False) + self.assertEqual(len(bool_model.objects.all()), 2) + self.assertEqual(len(bool_model.objects.filter(k=0, b=True)), 1) + self.assertEqual(len(bool_model.objects.filter(k=0, b=False)), 1) + + @execute_count(3) + def test_bool_filter(self): + """ + Test to ensure that we don't translate boolean objects to String unnecessarily in filter clauses + + @since 3.6 + @jira_ticket PYTHON-596 + @expected_result We should not receive a server error + + @test_category object_mapper + """ + class bool_model2(Model): + k = columns.Boolean(primary_key=True) + b = columns.Integer(primary_key=True) + v = columns.Text() + drop_table(bool_model2) + sync_table(bool_model2) + + bool_model2.create(k=True, b=1, v='a') + bool_model2.create(k=False, b=1, v='b') + self.assertEqual(len(list(bool_model2.objects(k__in=(True, False)))), 2) + + +@greaterthancass20 +class TestContainsOperator(BaseQuerySetUsage): + + @execute_count(6) + def test_kwarg_success_case(self): + """ Tests the CONTAINS operator works with the kwarg query method """ + q = IndexedCollectionsTestModel.filter(test_list__contains=1) + self.assertEqual(q.count(), 2) + + q = IndexedCollectionsTestModel.filter(test_list__contains=13) + self.assertEqual(q.count(), 0) + + q = IndexedCollectionsTestModel.filter(test_set__contains=3) + self.assertEqual(q.count(), 2) + + q = IndexedCollectionsTestModel.filter(test_set__contains=13) + self.assertEqual(q.count(), 0) + + q = IndexedCollectionsTestModel.filter(test_map__contains=42) + self.assertEqual(q.count(), 1) + + q = IndexedCollectionsTestModel.filter(test_map__contains=13) + self.assertEqual(q.count(), 0) + + with self.assertRaises(QueryException): + q = IndexedCollectionsTestModel.filter(test_list_no_index__contains=1) + self.assertEqual(q.count(), 0) + with self.assertRaises(QueryException): + q = IndexedCollectionsTestModel.filter(test_set_no_index__contains=1) + self.assertEqual(q.count(), 0) + with self.assertRaises(QueryException): + q = IndexedCollectionsTestModel.filter(test_map_no_index__contains=1) + self.assertEqual(q.count(), 0) + + @execute_count(6) + def test_query_expression_success_case(self): + """ Tests the CONTAINS operator works with the query expression query method """ + q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_list.contains_(1)) + self.assertEqual(q.count(), 2) + + q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_list.contains_(13)) + self.assertEqual(q.count(), 0) + + q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_set.contains_(3)) + self.assertEqual(q.count(), 2) + + q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_set.contains_(13)) + self.assertEqual(q.count(), 0) + + q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map.contains_(42)) + self.assertEqual(q.count(), 1) + + q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map.contains_(13)) + self.assertEqual(q.count(), 0) + + with self.assertRaises(QueryException): + q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map_no_index.contains_(1)) + self.assertEqual(q.count(), 0) + with self.assertRaises(QueryException): + q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map_no_index.contains_(1)) + self.assertEqual(q.count(), 0) + with self.assertRaises(QueryException): + q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map_no_index.contains_(1)) + self.assertEqual(q.count(), 0) + + +class TestValuesList(BaseQuerySetUsage): + + @execute_count(2) + def test_values_list(self): + q = TestModel.objects.filter(test_id=0, attempt_id=1) + item = q.values_list('test_id', 'attempt_id', 'description', 'expected_result', 'test_result').first() + assert item == [0, 1, 'try2', 10, 30] + + item = q.values_list('expected_result', flat=True).first() + assert item == 10 + + +class TestObjectsProperty(BaseQuerySetUsage): + @execute_count(1) + def test_objects_property_returns_fresh_queryset(self): + assert TestModel.objects._result_cache is None + len(TestModel.objects) # evaluate queryset + assert TestModel.objects._result_cache is None + + +class PageQueryTests(BaseCassEngTestCase): + @execute_count(3) + def test_paged_result_handling(self): + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest("Paging requires native protocol 2+, currently using: {0}".format(PROTOCOL_VERSION)) + + # addresses #225 + class PagingTest(Model): + id = columns.Integer(primary_key=True) + val = columns.Integer() + sync_table(PagingTest) + + PagingTest.create(id=1, val=1) + PagingTest.create(id=2, val=2) + + session = get_session() + with mock.patch.object(session, 'default_fetch_size', 1): + results = PagingTest.objects()[:] + + assert len(results) == 2 + + +class ModelQuerySetTimeoutTestCase(BaseQuerySetUsage): + def test_default_timeout(self): + with mock.patch.object(Session, 'execute') as mock_execute: + list(TestModel.objects()) + self.assertEqual(mock_execute.call_args[-1]['timeout'], NOT_SET) + + def test_float_timeout(self): + with mock.patch.object(Session, 'execute') as mock_execute: + list(TestModel.objects().timeout(0.5)) + self.assertEqual(mock_execute.call_args[-1]['timeout'], 0.5) + + def test_none_timeout(self): + with mock.patch.object(Session, 'execute') as mock_execute: + list(TestModel.objects().timeout(None)) + self.assertEqual(mock_execute.call_args[-1]['timeout'], None) + + +class DMLQueryTimeoutTestCase(BaseQuerySetUsage): + def setUp(self): + self.model = TestModel(test_id=1, attempt_id=1, description='timeout test') + super(DMLQueryTimeoutTestCase, self).setUp() + + def test_default_timeout(self): + with mock.patch.object(Session, 'execute') as mock_execute: + self.model.save() + self.assertEqual(mock_execute.call_args[-1]['timeout'], NOT_SET) + + def test_float_timeout(self): + with mock.patch.object(Session, 'execute') as mock_execute: + self.model.timeout(0.5).save() + self.assertEqual(mock_execute.call_args[-1]['timeout'], 0.5) + + def test_none_timeout(self): + with mock.patch.object(Session, 'execute') as mock_execute: + self.model.timeout(None).save() + self.assertEqual(mock_execute.call_args[-1]['timeout'], None) + + def test_timeout_then_batch(self): + b = query.BatchQuery() + m = self.model.timeout(None) + with self.assertRaises(AssertionError): + m.batch(b) + + def test_batch_then_timeout(self): + b = query.BatchQuery() + m = self.model.batch(b) + with self.assertRaises(AssertionError): + m.timeout(0.5) + + +class DBFieldModel(Model): + k0 = columns.Integer(partition_key=True, db_field='a') + k1 = columns.Integer(partition_key=True, db_field='b') + c0 = columns.Integer(primary_key=True, db_field='c') + v0 = columns.Integer(db_field='d') + v1 = columns.Integer(db_field='e', index=True) + + +class DBFieldModelMixed1(Model): + k0 = columns.Integer(partition_key=True, db_field='a') + k1 = columns.Integer(partition_key=True) + c0 = columns.Integer(primary_key=True, db_field='c') + v0 = columns.Integer(db_field='d') + v1 = columns.Integer(index=True) + + +class DBFieldModelMixed2(Model): + k0 = columns.Integer(partition_key=True) + k1 = columns.Integer(partition_key=True, db_field='b') + c0 = columns.Integer(primary_key=True) + v0 = columns.Integer(db_field='d') + v1 = columns.Integer(index=True, db_field='e') + + +class TestModelQueryWithDBField(BaseCassEngTestCase): + + def setUp(cls): + super(TestModelQueryWithDBField, cls).setUpClass() + cls.model_list = [DBFieldModel, DBFieldModelMixed1, DBFieldModelMixed2] + for model in cls.model_list: + sync_table(model) + + def tearDown(cls): + super(TestModelQueryWithDBField, cls).tearDownClass() + for model in cls.model_list: + drop_table(model) + + @execute_count(33) + def test_basic_crud(self): + """ + Tests creation update and delete of object model queries that are using db_field mappings. + + @since 3.1 + @jira_ticket PYTHON-351 + @expected_result results are properly retrieved without errors + + @test_category object_mapper + """ + for model in self.model_list: + values = {'k0': 1, 'k1': 2, 'c0': 3, 'v0': 4, 'v1': 5} + # create + i = model.create(**values) + i = model.objects(k0=i.k0, k1=i.k1).first() + self.assertEqual(i, model(**values)) + + # create + values['v0'] = 101 + i.update(v0=values['v0']) + i = model.objects(k0=i.k0, k1=i.k1).first() + self.assertEqual(i, model(**values)) + + # delete + model.objects(k0=i.k0, k1=i.k1).delete() + i = model.objects(k0=i.k0, k1=i.k1).first() + self.assertIsNone(i) + + i = model.create(**values) + i = model.objects(k0=i.k0, k1=i.k1).first() + self.assertEqual(i, model(**values)) + i.delete() + model.objects(k0=i.k0, k1=i.k1).delete() + i = model.objects(k0=i.k0, k1=i.k1).first() + self.assertIsNone(i) + + @execute_count(21) + def test_slice(self): + """ + Tests slice queries for object models that are using db_field mapping + + @since 3.1 + @jira_ticket PYTHON-351 + @expected_result results are properly retrieved without errors + + @test_category object_mapper + """ + for model in self.model_list: + values = {'k0': 1, 'k1': 3, 'c0': 3, 'v0': 4, 'v1': 5} + clustering_values = range(3) + for c in clustering_values: + values['c0'] = c + i = model.create(**values) + + self.assertEqual(model.objects(k0=i.k0, k1=i.k1).count(), len(clustering_values)) + self.assertEqual(model.objects(k0=i.k0, k1=i.k1, c0=i.c0).count(), 1) + self.assertEqual(model.objects(k0=i.k0, k1=i.k1, c0__lt=i.c0).count(), len(clustering_values[:-1])) + self.assertEqual(model.objects(k0=i.k0, k1=i.k1, c0__gt=0).count(), len(clustering_values[1:])) + + @execute_count(15) + def test_order(self): + """ + Tests order by queries for object models that are using db_field mapping + + @since 3.1 + @jira_ticket PYTHON-351 + @expected_result results are properly retrieved without errors + + @test_category object_mapper + """ + for model in self.model_list: + values = {'k0': 1, 'k1': 4, 'c0': 3, 'v0': 4, 'v1': 5} + clustering_values = range(3) + for c in clustering_values: + values['c0'] = c + i = model.create(**values) + self.assertEqual(model.objects(k0=i.k0, k1=i.k1).order_by('c0').first().c0, clustering_values[0]) + self.assertEqual(model.objects(k0=i.k0, k1=i.k1).order_by('-c0').first().c0, clustering_values[-1]) + + @execute_count(15) + def test_index(self): + """ + Tests queries using index fields for object models using db_field mapping + + @since 3.1 + @jira_ticket PYTHON-351 + @expected_result results are properly retrieved without errors + + @test_category object_mapper + """ + for model in self.model_list: + values = {'k0': 1, 'k1': 5, 'c0': 3, 'v0': 4, 'v1': 5} + clustering_values = range(3) + for c in clustering_values: + values['c0'] = c + values['v1'] = c + i = model.create(**values) + self.assertEqual(model.objects(k0=i.k0, k1=i.k1).count(), len(clustering_values)) + self.assertEqual(model.objects(k0=i.k0, k1=i.k1, v1=0).count(), 1) + + @execute_count(1) + def test_db_field_names_used(self): + """ + Tests to ensure that with generated cql update statements correctly utilize the db_field values. + + @since 3.2 + @jira_ticket PYTHON-530 + @expected_result resulting cql_statements will use the db_field values + + @test_category object_mapper + """ + + values = ('k0', 'k1', 'c0', 'v0', 'v1') + # Test QuerySet Path + b = BatchQuery() + DBFieldModel.objects(k0=1).batch(b).update( + v0=0, + v1=9, + ) + for value in values: + self.assertTrue(value not in str(b.queries[0])) + + # Test DML path + b2 = BatchQuery() + dml_field_model = DBFieldModel.create(k0=1, k1=5, c0=3, v0=4, v1=5) + dml_field_model.batch(b2).update( + v0=0, + v1=9, + ) + for value in values: + self.assertTrue(value not in str(b2.queries[0])) + + def test_db_field_value_list(self): + DBFieldModel.create(k0=0, k1=0, c0=0, v0=4, v1=5) + + self.assertEqual(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')._defer_fields, + {'a', 'c', 'b'}) + self.assertEqual(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')._only_fields, + ['c', 'd']) + + list(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')) + +class TestModelSmall(Model): + + test_id = columns.Integer(primary_key=True) + + +class TestModelQueryWithFetchSize(BaseCassEngTestCase): + """ + Test FetchSize, and ensure that results are returned correctly + regardless of the paging size + + @since 3.1 + @jira_ticket PYTHON-324 + @expected_result results are properly retrieved and the correct size + + @test_category object_mapper + """ + + @classmethod + def setUpClass(cls): + super(TestModelQueryWithFetchSize, cls).setUpClass() + sync_table(TestModelSmall) + + @classmethod + def tearDownClass(cls): + super(TestModelQueryWithFetchSize, cls).tearDownClass() + drop_table(TestModelSmall) + + @execute_count(9) + def test_defaultFetchSize(self): + with BatchQuery() as b: + for i in range(5100): + TestModelSmall.batch(b).create(test_id=i) + self.assertEqual(len(TestModelSmall.objects.fetch_size(1)), 5100) + self.assertEqual(len(TestModelSmall.objects.fetch_size(500)), 5100) + self.assertEqual(len(TestModelSmall.objects.fetch_size(4999)), 5100) + self.assertEqual(len(TestModelSmall.objects.fetch_size(5000)), 5100) + self.assertEqual(len(TestModelSmall.objects.fetch_size(5001)), 5100) + self.assertEqual(len(TestModelSmall.objects.fetch_size(5100)), 5100) + self.assertEqual(len(TestModelSmall.objects.fetch_size(5101)), 5100) + self.assertEqual(len(TestModelSmall.objects.fetch_size(1)), 5100) + + with self.assertRaises(QueryException): + TestModelSmall.objects.fetch_size(0) + with self.assertRaises(QueryException): + TestModelSmall.objects.fetch_size(-1) + + +class People(Model): + __table_name__ = "people" + last_name = columns.Text(primary_key=True, partition_key=True) + first_name = columns.Text(primary_key=True) + birthday = columns.DateTime() + + +class People2(Model): + __table_name__ = "people" + last_name = columns.Text(primary_key=True, partition_key=True) + first_name = columns.Text(primary_key=True) + middle_name = columns.Text() + birthday = columns.DateTime() + + +class TestModelQueryWithDifferedFeld(BaseCassEngTestCase): + """ + Tests that selects with filter will deffer population of known values until after the results are returned. + I.E. Instead of generating SELECT * FROM People WHERE last_name="Smith" It will generate + SELECT first_name, birthday FROM People WHERE last_name="Smith" + Where last_name 'smith' will populated post query + + @since 3.2 + @jira_ticket PYTHON-520 + @expected_result only needed fields are included in the query + + @test_category object_mapper + """ + @classmethod + def setUpClass(cls): + super(TestModelQueryWithDifferedFeld, cls).setUpClass() + sync_table(People) + + @classmethod + def tearDownClass(cls): + super(TestModelQueryWithDifferedFeld, cls).tearDownClass() + drop_table(People) + + @execute_count(8) + def test_defaultFetchSize(self): + # Populate Table + People.objects.create(last_name="Smith", first_name="John", birthday=datetime.now()) + People.objects.create(last_name="Bestwater", first_name="Alan", birthday=datetime.now()) + People.objects.create(last_name="Smith", first_name="Greg", birthday=datetime.now()) + People.objects.create(last_name="Smith", first_name="Adam", birthday=datetime.now()) + + # Check query constructions + expected_fields = ['first_name', 'birthday'] + self.assertEqual(People.filter(last_name="Smith")._select_fields(), expected_fields) + # Validate correct fields are fetched + smiths = list(People.filter(last_name="Smith")) + self.assertEqual(len(smiths), 3) + self.assertTrue(smiths[0].last_name is not None) + + # Modify table with new value + sync_table(People2) + + # populate new format + People2.objects.create(last_name="Smith", first_name="Chris", middle_name="Raymond", birthday=datetime.now()) + People2.objects.create(last_name="Smith", first_name="Andrew", middle_name="Micheal", birthday=datetime.now()) + + # validate query construction + expected_fields = ['first_name', 'middle_name', 'birthday'] + self.assertEqual(People2.filter(last_name="Smith")._select_fields(), expected_fields) + + # validate correct items are returneds + smiths = list(People2.filter(last_name="Smith")) + self.assertEqual(len(smiths), 5) + self.assertTrue(smiths[0].last_name is not None) diff --git a/tests/integration/cqlengine/query/test_updates.py b/tests/integration/cqlengine/query/test_updates.py new file mode 100644 index 0000000..fb6082b --- /dev/null +++ b/tests/integration/cqlengine/query/test_updates.py @@ -0,0 +1,347 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import uuid4 +from cassandra.cqlengine import ValidationError + +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine import columns + +from tests.integration.cqlengine import is_prepend_reversed +from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel +from tests.integration.cqlengine import execute_count +from tests.integration import greaterthancass20 + + +class QueryUpdateTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(QueryUpdateTests, cls).setUpClass() + sync_table(TestQueryUpdateModel) + + @classmethod + def tearDownClass(cls): + super(QueryUpdateTests, cls).tearDownClass() + drop_table(TestQueryUpdateModel) + + @execute_count(8) + def test_update_values(self): + """ tests calling udpate on a queryset """ + partition = uuid4() + for i in range(5): + TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i)) + + # sanity check + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, i) + self.assertEqual(row.text, str(i)) + + # perform update + TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6) + + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, 6 if i == 3 else i) + self.assertEqual(row.text, str(i)) + + @execute_count(6) + def test_update_values_validation(self): + """ tests calling udpate on models with values passed in """ + partition = uuid4() + for i in range(5): + TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i)) + + # sanity check + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, i) + self.assertEqual(row.text, str(i)) + + # perform update + with self.assertRaises(ValidationError): + TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count='asdf') + + def test_invalid_update_kwarg(self): + """ tests that passing in a kwarg to the update method that isn't a column will fail """ + with self.assertRaises(ValidationError): + TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(bacon=5000) + + def test_primary_key_update_failure(self): + """ tests that attempting to update the value of a primary key will fail """ + with self.assertRaises(ValidationError): + TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(cluster=5000) + + @execute_count(8) + def test_null_update_deletes_column(self): + """ setting a field to null in the update should issue a delete statement """ + partition = uuid4() + for i in range(5): + TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i)) + + # sanity check + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, i) + self.assertEqual(row.text, str(i)) + + # perform update + TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None) + + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, i) + self.assertEqual(row.text, None if i == 3 else str(i)) + + @execute_count(9) + def test_mixed_value_and_null_update(self): + """ tests that updating a columns value, and removing another works properly """ + partition = uuid4() + for i in range(5): + TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i)) + + # sanity check + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, i) + self.assertEqual(row.text, str(i)) + + # perform update + TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None) + + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, 6 if i == 3 else i) + self.assertEqual(row.text, None if i == 3 else str(i)) + + @execute_count(3) + def test_set_add_updates(self): + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, text_set=set(("foo",))) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update(text_set__add=set(('bar',))) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_set, set(("foo", "bar"))) + + @execute_count(2) + def test_set_add_updates_new_record(self): + """ If the key doesn't exist yet, an update creates the record + """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update(text_set__add=set(('bar',))) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_set, set(("bar",))) + + @execute_count(3) + def test_set_remove_updates(self): + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, text_set=set(("foo", "baz"))) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_set__remove=set(('foo',))) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_set, set(("baz",))) + + @execute_count(3) + def test_set_remove_new_record(self): + """ Removing something not in the set should silently do nothing + """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, text_set=set(("foo",))) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_set__remove=set(('afsd',))) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_set, set(("foo",))) + + @execute_count(3) + def test_list_append_updates(self): + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, text_list=["foo"]) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_list__append=['bar']) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_list, ["foo", "bar"]) + + @execute_count(3) + def test_list_prepend_updates(self): + """ Prepend two things since order is reversed by default by CQL """ + partition = uuid4() + cluster = 1 + original = ["foo"] + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, text_list=original) + prepended = ['bar', 'baz'] + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_list__prepend=prepended) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + expected = (prepended[::-1] if is_prepend_reversed() else prepended) + original + self.assertEqual(obj.text_list, expected) + + @execute_count(3) + def test_map_update_updates(self): + """ Merge a dictionary into existing value """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, + text_map={"foo": '1', "bar": '2'}) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_map__update={"bar": '3', "baz": '4'}) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_map, {"foo": '1', "bar": '3', "baz": '4'}) + + @execute_count(3) + def test_map_update_none_deletes_key(self): + """ The CQL behavior is if you set a key in a map to null it deletes + that key from the map. Test that this works with __update. + """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, + text_map={"foo": '1', "bar": '2'}) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_map__update={"bar": None}) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_map, {"foo": '1'}) + + @greaterthancass20 + @execute_count(5) + def test_map_update_remove(self): + """ + Test that map item removal with update(__remove=...) works + + @jira_ticket PYTHON-688 + """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, + cluster=cluster, + text_map={"foo": '1', "bar": '2'} + ) + TestQueryUpdateModel.objects(partition=partition, cluster=cluster).update( + text_map__remove={"bar"}, + text_map__update={"foz": '4', "foo": '2'} + ) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_map, {"foo": '2', "foz": '4'}) + + TestQueryUpdateModel.objects(partition=partition, cluster=cluster).update( + text_map__remove={"foo", "foz"} + ) + self.assertEqual( + TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster).text_map, + {} + ) + + def test_map_remove_rejects_non_sets(self): + """ + Map item removal requires a set to match the CQL API + + @jira_ticket PYTHON-688 + """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, + cluster=cluster, + text_map={"foo": '1', "bar": '2'} + ) + with self.assertRaises(ValidationError): + TestQueryUpdateModel.objects(partition=partition, cluster=cluster).update( + text_map__remove=["bar"] + ) + + @execute_count(3) + def test_an_extra_delete_is_not_sent(self): + """ + Test to ensure that an extra DELETE is not sent if an object is read + from the DB with a None value + + @since 3.9 + @jira_ticket PYTHON-719 + @expected_result only three queries are executed, the first one for + inserting the object, the second one for reading it, and the third + one for updating it + + @test_category object_mapper + """ + partition = uuid4() + cluster = 1 + + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster) + + obj = TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).first() + + self.assertFalse({k: v for (k, v) in obj._values.items() if v.deleted}) + + obj.text = 'foo' + obj.save() + #execute_count will check the execution count and + #assert no more calls than necessary where made + +class StaticDeleteModel(Model): + example_id = columns.Integer(partition_key=True, primary_key=True, default=uuid4) + example_static1 = columns.Integer(static=True) + example_static2 = columns.Integer(static=True) + example_clust = columns.Integer(primary_key=True) + + +class StaticDeleteTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(StaticDeleteTests, cls).setUpClass() + sync_table(StaticDeleteModel) + + @classmethod + def tearDownClass(cls): + super(StaticDeleteTests, cls).tearDownClass() + drop_table(StaticDeleteModel) + + def test_static_deletion(self): + """ + Test to ensure that cluster keys are not included when removing only static columns + + @since 3.6 + @jira_ticket PYTHON-608 + @expected_result Server should not throw an exception, and the static column should be deleted + + @test_category object_mapper + """ + StaticDeleteModel.create(example_id=5, example_clust=5, example_static2=1) + sdm = StaticDeleteModel.filter(example_id=5).first() + self.assertEqual(1, sdm.example_static2) + sdm.update(example_static2=None) + self.assertIsNone(sdm.example_static2) diff --git a/tests/integration/cqlengine/statements/__init__.py b/tests/integration/cqlengine/statements/__init__.py new file mode 100644 index 0000000..2c9ca17 --- /dev/null +++ b/tests/integration/cqlengine/statements/__init__.py @@ -0,0 +1,13 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/cqlengine/statements/test_assignment_clauses.py b/tests/integration/cqlengine/statements/test_assignment_clauses.py new file mode 100644 index 0000000..594224d --- /dev/null +++ b/tests/integration/cqlengine/statements/test_assignment_clauses.py @@ -0,0 +1,360 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.cqlengine.statements import AssignmentClause, SetUpdateClause, ListUpdateClause, MapUpdateClause, MapDeleteClause, FieldDeleteClause, CounterUpdateClause + + +class AssignmentClauseTests(unittest.TestCase): + + def test_rendering(self): + pass + + def test_insert_tuple(self): + ac = AssignmentClause('a', 'b') + ac.set_context_id(10) + self.assertEqual(ac.insert_tuple(), ('a', 10)) + + +class SetUpdateClauseTests(unittest.TestCase): + + def test_update_from_none(self): + c = SetUpdateClause('s', set((1, 2)), previous=None) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, set((1, 2))) + self.assertIsNone(c._additions) + self.assertIsNone(c._removals) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': set((1, 2))}) + + def test_null_update(self): + """ tests setting a set to None creates an empty update statement """ + c = SetUpdateClause('s', None, previous=set((1, 2))) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertIsNone(c._additions) + self.assertIsNone(c._removals) + + self.assertEqual(c.get_context_size(), 0) + self.assertEqual(str(c), '') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {}) + + def test_no_update(self): + """ tests an unchanged value creates an empty update statement """ + c = SetUpdateClause('s', set((1, 2)), previous=set((1, 2))) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertIsNone(c._additions) + self.assertIsNone(c._removals) + + self.assertEqual(c.get_context_size(), 0) + self.assertEqual(str(c), '') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {}) + + def test_update_empty_set(self): + """tests assigning a set to an empty set creates a nonempty + update statement and nonzero context size.""" + c = SetUpdateClause(field='s', value=set()) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, set()) + self.assertIsNone(c._additions) + self.assertIsNone(c._removals) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': set()}) + + def test_additions(self): + c = SetUpdateClause('s', set((1, 2, 3)), previous=set((1, 2))) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertEqual(c._additions, set((3,))) + self.assertIsNone(c._removals) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = "s" + %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': set((3,))}) + + def test_removals(self): + c = SetUpdateClause('s', set((1, 2)), previous=set((1, 2, 3))) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertIsNone(c._additions) + self.assertEqual(c._removals, set((3,))) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = "s" - %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': set((3,))}) + + def test_additions_and_removals(self): + c = SetUpdateClause('s', set((2, 3)), previous=set((1, 2))) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertEqual(c._additions, set((3,))) + self.assertEqual(c._removals, set((1,))) + + self.assertEqual(c.get_context_size(), 2) + self.assertEqual(str(c), '"s" = "s" + %(0)s, "s" = "s" - %(1)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': set((3,)), '1': set((1,))}) + + +class ListUpdateClauseTests(unittest.TestCase): + + def test_update_from_none(self): + c = ListUpdateClause('s', [1, 2, 3]) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, [1, 2, 3]) + self.assertIsNone(c._append) + self.assertIsNone(c._prepend) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [1, 2, 3]}) + + def test_update_from_empty(self): + c = ListUpdateClause('s', [1, 2, 3], previous=[]) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, [1, 2, 3]) + self.assertIsNone(c._append) + self.assertIsNone(c._prepend) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [1, 2, 3]}) + + def test_update_from_different_list(self): + c = ListUpdateClause('s', [1, 2, 3], previous=[3, 2, 1]) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, [1, 2, 3]) + self.assertIsNone(c._append) + self.assertIsNone(c._prepend) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [1, 2, 3]}) + + def test_append(self): + c = ListUpdateClause('s', [1, 2, 3, 4], previous=[1, 2]) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertEqual(c._append, [3, 4]) + self.assertIsNone(c._prepend) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = "s" + %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [3, 4]}) + + def test_prepend(self): + c = ListUpdateClause('s', [1, 2, 3, 4], previous=[3, 4]) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertIsNone(c._append) + self.assertEqual(c._prepend, [1, 2]) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s + "s"') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [1, 2]}) + + def test_append_and_prepend(self): + c = ListUpdateClause('s', [1, 2, 3, 4, 5, 6], previous=[3, 4]) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertEqual(c._append, [5, 6]) + self.assertEqual(c._prepend, [1, 2]) + + self.assertEqual(c.get_context_size(), 2) + self.assertEqual(str(c), '"s" = %(0)s + "s", "s" = "s" + %(1)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [1, 2], '1': [5, 6]}) + + def test_shrinking_list_update(self): + """ tests that updating to a smaller list results in an insert statement """ + c = ListUpdateClause('s', [1, 2, 3], previous=[1, 2, 3, 4]) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, [1, 2, 3]) + self.assertIsNone(c._append) + self.assertIsNone(c._prepend) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [1, 2, 3]}) + + +class MapUpdateTests(unittest.TestCase): + + def test_update(self): + c = MapUpdateClause('s', {3: 0, 5: 6}, previous={5: 0, 3: 4}) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._updates, [3, 5]) + self.assertEqual(c.get_context_size(), 4) + self.assertEqual(str(c), '"s"[%(0)s] = %(1)s, "s"[%(2)s] = %(3)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': 3, "1": 0, '2': 5, '3': 6}) + + def test_update_from_null(self): + c = MapUpdateClause('s', {3: 0, 5: 6}) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._updates, [3, 5]) + self.assertEqual(c.get_context_size(), 4) + self.assertEqual(str(c), '"s"[%(0)s] = %(1)s, "s"[%(2)s] = %(3)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': 3, "1": 0, '2': 5, '3': 6}) + + def test_nulled_columns_arent_included(self): + c = MapUpdateClause('s', {3: 0}, {1: 2, 3: 4}) + c._analyze() + c.set_context_id(0) + + self.assertNotIn(1, c._updates) + + +class CounterUpdateTests(unittest.TestCase): + + def test_positive_update(self): + c = CounterUpdateClause('a', 5, 3) + c.set_context_id(5) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"a" = "a" + %(5)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'5': 2}) + + def test_negative_update(self): + c = CounterUpdateClause('a', 4, 7) + c.set_context_id(3) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"a" = "a" - %(3)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'3': 3}) + + def noop_update(self): + c = CounterUpdateClause('a', 5, 5) + c.set_context_id(5) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"a" = "a" + %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'5': 0}) + + +class MapDeleteTests(unittest.TestCase): + + def test_update(self): + c = MapDeleteClause('s', {3: 0}, {1: 2, 3: 4, 5: 6}) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._removals, [1, 5]) + self.assertEqual(c.get_context_size(), 2) + self.assertEqual(str(c), '"s"[%(0)s], "s"[%(1)s]') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': 1, '1': 5}) + + +class FieldDeleteTests(unittest.TestCase): + + def test_str(self): + f = FieldDeleteClause("blake") + assert str(f) == '"blake"' diff --git a/tests/integration/cqlengine/statements/test_base_clause.py b/tests/integration/cqlengine/statements/test_base_clause.py new file mode 100644 index 0000000..3519838 --- /dev/null +++ b/tests/integration/cqlengine/statements/test_base_clause.py @@ -0,0 +1,30 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import TestCase +from cassandra.cqlengine.statements import BaseClause + + +class BaseClauseTests(TestCase): + + def test_context_updating(self): + ss = BaseClause('a', 'b') + assert ss.get_context_size() == 1 + + ctx = {} + ss.set_context_id(10) + ss.update_context(ctx) + assert ctx == {'10': 'b'} + + diff --git a/tests/integration/cqlengine/statements/test_base_statement.py b/tests/integration/cqlengine/statements/test_base_statement.py new file mode 100644 index 0000000..db7d1eb --- /dev/null +++ b/tests/integration/cqlengine/statements/test_base_statement.py @@ -0,0 +1,159 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from uuid import uuid4 +import six + +from cassandra.query import FETCH_SIZE_UNSET +from cassandra.cluster import Cluster, ConsistencyLevel +from cassandra.cqlengine.statements import BaseCQLStatement +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.statements import InsertStatement, UpdateStatement, SelectStatement, DeleteStatement, \ + WhereClause +from cassandra.cqlengine.operators import EqualsOperator, LikeOperator +from cassandra.cqlengine.columns import Column + +from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel +from tests.integration.cqlengine import DEFAULT_KEYSPACE +from tests.integration import greaterthanorequalcass3_10 + +from cassandra.cqlengine.connection import execute + + +class BaseStatementTest(unittest.TestCase): + + def test_fetch_size(self): + """ tests that fetch_size is correctly set """ + stmt = BaseCQLStatement('table', None, fetch_size=1000) + self.assertEqual(stmt.fetch_size, 1000) + + stmt = BaseCQLStatement('table', None, fetch_size=None) + self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET) + + stmt = BaseCQLStatement('table', None) + self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET) + + +class ExecuteStatementTest(BaseCassEngTestCase): + text = "text_for_db" + + @classmethod + def setUpClass(cls): + super(ExecuteStatementTest, cls).setUpClass() + sync_table(TestQueryUpdateModel) + cls.table_name = '{0}.test_query_update_model'.format(DEFAULT_KEYSPACE) + + @classmethod + def tearDownClass(cls): + super(ExecuteStatementTest, cls).tearDownClass() + drop_table(TestQueryUpdateModel) + + def _verify_statement(self, original): + st = SelectStatement(self.table_name) + result = execute(st) + response = result[0] + + for assignment in original.assignments: + self.assertEqual(response[assignment.field], assignment.value) + self.assertEqual(len(response), 7) + + def test_insert_statement_execute(self): + """ + Test to verify the execution of BaseCQLStatements using connection.execute + + @since 3.10 + @jira_ticket PYTHON-505 + @expected_result inserts a row in C*, updates the rows and then deletes + all the rows using BaseCQLStatements + + @test_category data_types:object_mapper + """ + partition = uuid4() + cluster = 1 + self._insert_statement(partition, cluster) + + # Verifying update statement + where = [WhereClause('partition', EqualsOperator(), partition), + WhereClause('cluster', EqualsOperator(), cluster)] + + st = UpdateStatement(self.table_name, where=where) + st.add_assignment(Column(db_field='count'), 2) + st.add_assignment(Column(db_field='text'), "text_for_db_update") + st.add_assignment(Column(db_field='text_set'), set(("foo_update", "bar_update"))) + st.add_assignment(Column(db_field='text_list'), ["foo_update", "bar_update"]) + st.add_assignment(Column(db_field='text_map'), {"foo": '3', "bar": '4'}) + + execute(st) + self._verify_statement(st) + + # Verifying delete statement + execute(DeleteStatement(self.table_name, where=where)) + self.assertEqual(TestQueryUpdateModel.objects.count(), 0) + + @greaterthanorequalcass3_10 + def test_like_operator(self): + """ + Test to verify the like operator works appropriately + + @since 3.13 + @jira_ticket PYTHON-512 + @expected_result the expected row is read using LIKE + + @test_category data_types:object_mapper + """ + cluster = Cluster() + session = cluster.connect() + self.addCleanup(cluster.shutdown) + + session.execute("""CREATE CUSTOM INDEX text_index ON {} (text) + USING 'org.apache.cassandra.index.sasi.SASIIndex';""".format(self.table_name)) + self.addCleanup(session.execute, "DROP INDEX {}.text_index".format(DEFAULT_KEYSPACE)) + + partition = uuid4() + cluster = 1 + self._insert_statement(partition, cluster) + + ss = SelectStatement(self.table_name) + like_clause = "text_for_%" + ss.add_where(Column(db_field='text'), LikeOperator(), like_clause) + self.assertEqual(six.text_type(ss), + 'SELECT * FROM {} WHERE "text" LIKE %(0)s'.format(self.table_name)) + + result = execute(ss) + self.assertEqual(result[0]["text"], self.text) + + q = TestQueryUpdateModel.objects.filter(text__like=like_clause).allow_filtering() + self.assertEqual(q[0].text, self.text) + + q = TestQueryUpdateModel.objects.filter(text__like=like_clause) + self.assertEqual(q[0].text, self.text) + + def _insert_statement(self, partition, cluster): + # Verifying insert statement + st = InsertStatement(self.table_name) + st.add_assignment(Column(db_field='partition'), partition) + st.add_assignment(Column(db_field='cluster'), cluster) + + st.add_assignment(Column(db_field='count'), 1) + st.add_assignment(Column(db_field='text'), self.text) + st.add_assignment(Column(db_field='text_set'), set(("foo", "bar"))) + st.add_assignment(Column(db_field='text_list'), ["foo", "bar"]) + st.add_assignment(Column(db_field='text_map'), {"foo": '1', "bar": '2'}) + + execute(st) + self._verify_statement(st) diff --git a/tests/integration/cqlengine/statements/test_delete_statement.py b/tests/integration/cqlengine/statements/test_delete_statement.py new file mode 100644 index 0000000..5e2894a --- /dev/null +++ b/tests/integration/cqlengine/statements/test_delete_statement.py @@ -0,0 +1,91 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import TestCase + +from cassandra.cqlengine.columns import Column +from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause +from cassandra.cqlengine.operators import * +import six + + +class DeleteStatementTests(TestCase): + + def test_single_field_is_listified(self): + """ tests that passing a string field into the constructor puts it into a list """ + ds = DeleteStatement('table', 'field') + self.assertEqual(len(ds.fields), 1) + self.assertEqual(ds.fields[0].field, 'field') + + def test_field_rendering(self): + """ tests that fields are properly added to the select statement """ + ds = DeleteStatement('table', ['f1', 'f2']) + self.assertTrue(six.text_type(ds).startswith('DELETE "f1", "f2"'), six.text_type(ds)) + self.assertTrue(str(ds).startswith('DELETE "f1", "f2"'), str(ds)) + + def test_none_fields_rendering(self): + """ tests that a '*' is added if no fields are passed in """ + ds = DeleteStatement('table', None) + self.assertTrue(six.text_type(ds).startswith('DELETE FROM'), six.text_type(ds)) + self.assertTrue(str(ds).startswith('DELETE FROM'), str(ds)) + + def test_table_rendering(self): + ds = DeleteStatement('table', None) + self.assertTrue(six.text_type(ds).startswith('DELETE FROM table'), six.text_type(ds)) + self.assertTrue(str(ds).startswith('DELETE FROM table'), str(ds)) + + def test_where_clause_rendering(self): + ds = DeleteStatement('table', None) + ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') + self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s', six.text_type(ds)) + + def test_context_update(self): + ds = DeleteStatement('table', None) + ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4})) + ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') + + ds.update_context_id(7) + self.assertEqual(six.text_type(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s') + self.assertEqual(ds.get_context(), {'7': 'b', '8': 3}) + + def test_context(self): + ds = DeleteStatement('table', None) + ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') + self.assertEqual(ds.get_context(), {'0': 'b'}) + + def test_range_deletion_rendering(self): + ds = DeleteStatement('table', None) + ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') + ds.add_where(Column(db_field='created_at'), GreaterThanOrEqualOperator(), '0') + ds.add_where(Column(db_field='created_at'), LessThanOrEqualOperator(), '10') + self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', six.text_type(ds)) + + ds = DeleteStatement('table', None) + ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') + ds.add_where(Column(db_field='created_at'), InOperator(), ['0', '10', '20']) + self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds)) + + ds = DeleteStatement('table', None) + ds.add_where(Column(db_field='a'), NotEqualsOperator(), 'b') + self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" != %(0)s', six.text_type(ds)) + + def test_delete_conditional(self): + where = [WhereClause('id', EqualsOperator(), 1)] + conditionals = [ConditionalClause('f0', 'value0'), ConditionalClause('f1', 'value1')] + ds = DeleteStatement('table', where=where, conditionals=conditionals) + self.assertEqual(len(ds.conditionals), len(conditionals)) + self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds)) + fields = ['one', 'two'] + ds = DeleteStatement('table', fields=fields, where=where, conditionals=conditionals) + self.assertEqual(six.text_type(ds), 'DELETE "one", "two" FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds)) diff --git a/tests/integration/cqlengine/statements/test_insert_statement.py b/tests/integration/cqlengine/statements/test_insert_statement.py new file mode 100644 index 0000000..3bf90ec --- /dev/null +++ b/tests/integration/cqlengine/statements/test_insert_statement.py @@ -0,0 +1,54 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import six + +from cassandra.cqlengine.columns import Column +from cassandra.cqlengine.statements import InsertStatement + + +class InsertStatementTests(unittest.TestCase): + + def test_statement(self): + ist = InsertStatement('table', None) + ist.add_assignment(Column(db_field='a'), 'b') + ist.add_assignment(Column(db_field='c'), 'd') + + self.assertEqual( + six.text_type(ist), + 'INSERT INTO table ("a", "c") VALUES (%(0)s, %(1)s)' + ) + + def test_context_update(self): + ist = InsertStatement('table', None) + ist.add_assignment(Column(db_field='a'), 'b') + ist.add_assignment(Column(db_field='c'), 'd') + + ist.update_context_id(4) + self.assertEqual( + six.text_type(ist), + 'INSERT INTO table ("a", "c") VALUES (%(4)s, %(5)s)' + ) + ctx = ist.get_context() + self.assertEqual(ctx, {'4': 'b', '5': 'd'}) + + def test_additional_rendering(self): + ist = InsertStatement('table', ttl=60) + ist.add_assignment(Column(db_field='a'), 'b') + ist.add_assignment(Column(db_field='c'), 'd') + self.assertIn('USING TTL 60', six.text_type(ist)) diff --git a/tests/integration/cqlengine/statements/test_select_statement.py b/tests/integration/cqlengine/statements/test_select_statement.py new file mode 100644 index 0000000..90c14bc --- /dev/null +++ b/tests/integration/cqlengine/statements/test_select_statement.py @@ -0,0 +1,111 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.cqlengine.columns import Column +from cassandra.cqlengine.statements import SelectStatement, WhereClause +from cassandra.cqlengine.operators import * +import six + +class SelectStatementTests(unittest.TestCase): + + def test_single_field_is_listified(self): + """ tests that passing a string field into the constructor puts it into a list """ + ss = SelectStatement('table', 'field') + self.assertEqual(ss.fields, ['field']) + + def test_field_rendering(self): + """ tests that fields are properly added to the select statement """ + ss = SelectStatement('table', ['f1', 'f2']) + self.assertTrue(six.text_type(ss).startswith('SELECT "f1", "f2"'), six.text_type(ss)) + self.assertTrue(str(ss).startswith('SELECT "f1", "f2"'), str(ss)) + + def test_none_fields_rendering(self): + """ tests that a '*' is added if no fields are passed in """ + ss = SelectStatement('table') + self.assertTrue(six.text_type(ss).startswith('SELECT *'), six.text_type(ss)) + self.assertTrue(str(ss).startswith('SELECT *'), str(ss)) + + def test_table_rendering(self): + ss = SelectStatement('table') + self.assertTrue(six.text_type(ss).startswith('SELECT * FROM table'), six.text_type(ss)) + self.assertTrue(str(ss).startswith('SELECT * FROM table'), str(ss)) + + def test_where_clause_rendering(self): + ss = SelectStatement('table') + ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') + self.assertEqual(six.text_type(ss), 'SELECT * FROM table WHERE "a" = %(0)s', six.text_type(ss)) + + def test_count(self): + ss = SelectStatement('table', count=True, limit=10, order_by='d') + ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') + self.assertEqual(six.text_type(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', six.text_type(ss)) + self.assertIn('LIMIT', six.text_type(ss)) + self.assertNotIn('ORDER', six.text_type(ss)) + + def test_distinct(self): + ss = SelectStatement('table', distinct_fields=['field2']) + ss.add_where(Column(db_field='field1'), EqualsOperator(), 'b') + self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', six.text_type(ss)) + + ss = SelectStatement('table', distinct_fields=['field1', 'field2']) + self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field1", "field2" FROM table') + + ss = SelectStatement('table', distinct_fields=['field1'], count=True) + self.assertEqual(six.text_type(ss), 'SELECT DISTINCT COUNT("field1") FROM table') + + def test_context(self): + ss = SelectStatement('table') + ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') + self.assertEqual(ss.get_context(), {'0': 'b'}) + + def test_context_id_update(self): + """ tests that the right things happen the the context id """ + ss = SelectStatement('table') + ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') + self.assertEqual(ss.get_context(), {'0': 'b'}) + self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s') + + ss.update_context_id(5) + self.assertEqual(ss.get_context(), {'5': 'b'}) + self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(5)s') + + def test_additional_rendering(self): + ss = SelectStatement( + 'table', + None, + order_by=['x', 'y'], + limit=15, + allow_filtering=True + ) + qstr = six.text_type(ss) + self.assertIn('LIMIT 15', qstr) + self.assertIn('ORDER BY x, y', qstr) + self.assertIn('ALLOW FILTERING', qstr) + + def test_limit_rendering(self): + ss = SelectStatement('table', None, limit=10) + qstr = six.text_type(ss) + self.assertIn('LIMIT 10', qstr) + + ss = SelectStatement('table', None, limit=0) + qstr = six.text_type(ss) + self.assertNotIn('LIMIT', qstr) + + ss = SelectStatement('table', None, limit=None) + qstr = six.text_type(ss) + self.assertNotIn('LIMIT', qstr) diff --git a/tests/integration/cqlengine/statements/test_update_statement.py b/tests/integration/cqlengine/statements/test_update_statement.py new file mode 100644 index 0000000..c6ed228 --- /dev/null +++ b/tests/integration/cqlengine/statements/test_update_statement.py @@ -0,0 +1,90 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.cqlengine.columns import Column, Set, List, Text +from cassandra.cqlengine.operators import * +from cassandra.cqlengine.statements import (UpdateStatement, WhereClause, + AssignmentClause, SetUpdateClause, + ListUpdateClause) +import six + + +class UpdateStatementTests(unittest.TestCase): + + def test_table_rendering(self): + """ tests that fields are properly added to the select statement """ + us = UpdateStatement('table') + self.assertTrue(six.text_type(us).startswith('UPDATE table SET'), six.text_type(us)) + self.assertTrue(str(us).startswith('UPDATE table SET'), str(us)) + + def test_rendering(self): + us = UpdateStatement('table') + us.add_assignment(Column(db_field='a'), 'b') + us.add_assignment(Column(db_field='c'), 'd') + us.add_where(Column(db_field='a'), EqualsOperator(), 'x') + self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', six.text_type(us)) + + us.add_where(Column(db_field='a'), NotEqualsOperator(), 'y') + self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s AND "a" != %(3)s', six.text_type(us)) + + def test_context(self): + us = UpdateStatement('table') + us.add_assignment(Column(db_field='a'), 'b') + us.add_assignment(Column(db_field='c'), 'd') + us.add_where(Column(db_field='a'), EqualsOperator(), 'x') + self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'}) + + def test_context_update(self): + us = UpdateStatement('table') + us.add_assignment(Column(db_field='a'), 'b') + us.add_assignment(Column(db_field='c'), 'd') + us.add_where(Column(db_field='a'), EqualsOperator(), 'x') + us.update_context_id(3) + self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s') + self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'}) + + def test_additional_rendering(self): + us = UpdateStatement('table', ttl=60) + us.add_assignment(Column(db_field='a'), 'b') + us.add_where(Column(db_field='a'), EqualsOperator(), 'x') + self.assertIn('USING TTL 60', six.text_type(us)) + + def test_update_set_add(self): + us = UpdateStatement('table') + us.add_update(Set(Text, db_field='a'), set((1,)), 'add') + self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') + + def test_update_empty_set_add_does_not_assign(self): + us = UpdateStatement('table') + us.add_update(Set(Text, db_field='a'), set(), 'add') + self.assertFalse(us.assignments) + + def test_update_empty_set_removal_does_not_assign(self): + us = UpdateStatement('table') + us.add_update(Set(Text, db_field='a'), set(), 'remove') + self.assertFalse(us.assignments) + + def test_update_list_prepend_with_empty_list(self): + us = UpdateStatement('table') + us.add_update(List(Text, db_field='a'), [], 'prepend') + self.assertFalse(us.assignments) + + def test_update_list_append_with_empty_list(self): + us = UpdateStatement('table') + us.add_update(List(Text, db_field='a'), [], 'append') + self.assertFalse(us.assignments) diff --git a/tests/integration/cqlengine/statements/test_where_clause.py b/tests/integration/cqlengine/statements/test_where_clause.py new file mode 100644 index 0000000..3173320 --- /dev/null +++ b/tests/integration/cqlengine/statements/test_where_clause.py @@ -0,0 +1,43 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import six +from cassandra.cqlengine.operators import EqualsOperator +from cassandra.cqlengine.statements import StatementException, WhereClause + + +class TestWhereClause(unittest.TestCase): + + def test_operator_check(self): + """ tests that creating a where statement with a non BaseWhereOperator object fails """ + with self.assertRaises(StatementException): + WhereClause('a', 'b', 'c') + + def test_where_clause_rendering(self): + """ tests that where clauses are rendered properly """ + wc = WhereClause('a', EqualsOperator(), 'c') + wc.set_context_id(5) + + self.assertEqual('"a" = %(5)s', six.text_type(wc), six.text_type(wc)) + self.assertEqual('"a" = %(5)s', str(wc), type(wc)) + + def test_equality_method(self): + """ tests that 2 identical where clauses evaluate as == """ + wc1 = WhereClause('a', EqualsOperator(), 'c') + wc2 = WhereClause('a', EqualsOperator(), 'c') + assert wc1 == wc2 diff --git a/tests/integration/cqlengine/test_batch_query.py b/tests/integration/cqlengine/test_batch_query.py new file mode 100644 index 0000000..7b78fa9 --- /dev/null +++ b/tests/integration/cqlengine/test_batch_query.py @@ -0,0 +1,250 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings + +import sure + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import drop_table, sync_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import BatchQuery +from tests.integration.cqlengine.base import BaseCassEngTestCase + +from mock import patch + +class TestMultiKeyModel(Model): + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer(required=False) + text = columns.Text(required=False) + + +class BatchQueryTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BatchQueryTests, cls).setUpClass() + drop_table(TestMultiKeyModel) + sync_table(TestMultiKeyModel) + + @classmethod + def tearDownClass(cls): + super(BatchQueryTests, cls).tearDownClass() + drop_table(TestMultiKeyModel) + + def setUp(self): + super(BatchQueryTests, self).setUp() + self.pkey = 1 + for obj in TestMultiKeyModel.filter(partition=self.pkey): + obj.delete() + + def test_insert_success_case(self): + + b = BatchQuery() + TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=2, count=3, text='4') + + with self.assertRaises(TestMultiKeyModel.DoesNotExist): + TestMultiKeyModel.get(partition=self.pkey, cluster=2) + + b.execute() + + TestMultiKeyModel.get(partition=self.pkey, cluster=2) + + def test_update_success_case(self): + + inst = TestMultiKeyModel.create(partition=self.pkey, cluster=2, count=3, text='4') + + b = BatchQuery() + + inst.count = 4 + inst.batch(b).save() + + inst2 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) + self.assertEqual(inst2.count, 3) + + b.execute() + + inst3 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) + self.assertEqual(inst3.count, 4) + + def test_delete_success_case(self): + + inst = TestMultiKeyModel.create(partition=self.pkey, cluster=2, count=3, text='4') + + b = BatchQuery() + + inst.batch(b).delete() + + TestMultiKeyModel.get(partition=self.pkey, cluster=2) + + b.execute() + + with self.assertRaises(TestMultiKeyModel.DoesNotExist): + TestMultiKeyModel.get(partition=self.pkey, cluster=2) + + def test_context_manager(self): + + with BatchQuery() as b: + for i in range(5): + TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=i, count=3, text='4') + + for i in range(5): + with self.assertRaises(TestMultiKeyModel.DoesNotExist): + TestMultiKeyModel.get(partition=self.pkey, cluster=i) + + for i in range(5): + TestMultiKeyModel.get(partition=self.pkey, cluster=i) + + def test_bulk_delete_success_case(self): + + for i in range(1): + for j in range(5): + TestMultiKeyModel.create(partition=i, cluster=j, count=i*j, text='{0}:{1}'.format(i,j)) + + with BatchQuery() as b: + TestMultiKeyModel.objects.batch(b).filter(partition=0).delete() + self.assertEqual(TestMultiKeyModel.filter(partition=0).count(), 5) + + self.assertEqual(TestMultiKeyModel.filter(partition=0).count(), 0) + #cleanup + for m in TestMultiKeyModel.all(): + m.delete() + + def test_empty_batch(self): + b = BatchQuery() + b.execute() + + with BatchQuery() as b: + pass + +class BatchQueryCallbacksTests(BaseCassEngTestCase): + + def test_API_managing_callbacks(self): + + # Callbacks can be added at init and after + + def my_callback(*args, **kwargs): + pass + + # adding on init: + batch = BatchQuery() + + batch.add_callback(my_callback) + batch.add_callback(my_callback, 2, named_arg='value') + batch.add_callback(my_callback, 1, 3) + + self.assertEqual(batch._callbacks, [ + (my_callback, (), {}), + (my_callback, (2,), {'named_arg':'value'}), + (my_callback, (1, 3), {}) + ]) + + def test_callbacks_properly_execute_callables_and_tuples(self): + + call_history = [] + def my_callback(*args, **kwargs): + call_history.append(args) + + # adding on init: + batch = BatchQuery() + + batch.add_callback(my_callback) + batch.add_callback(my_callback, 'more', 'args') + + batch.execute() + + self.assertEqual(len(call_history), 2) + self.assertEqual([(), ('more', 'args')], call_history) + + def test_callbacks_tied_to_execute(self): + """Batch callbacks should NOT fire if batch is not executed in context manager mode""" + + call_history = [] + def my_callback(*args, **kwargs): + call_history.append(args) + + with BatchQuery() as batch: + batch.add_callback(my_callback) + + self.assertEqual(len(call_history), 1) + + class SomeError(Exception): + pass + + with self.assertRaises(SomeError): + with BatchQuery() as batch: + batch.add_callback(my_callback) + # this error bubbling up through context manager + # should prevent callback runs (along with b.execute()) + raise SomeError + + # still same call history. Nothing added + self.assertEqual(len(call_history), 1) + + # but if execute ran, even with an error bubbling through + # the callbacks also would have fired + with self.assertRaises(SomeError): + with BatchQuery(execute_on_exception=True) as batch: + batch.add_callback(my_callback) + raise SomeError + + # updated call history + self.assertEqual(len(call_history), 2) + + def test_callbacks_work_multiple_times(self): + """ + Tests that multiple executions of execute on a batch statement + logs a warning, and that we don't encounter an attribute error. + @since 3.1 + @jira_ticket PYTHON-445 + @expected_result warning message is logged + + @test_category object_mapper + """ + call_history = [] + + def my_callback(*args, **kwargs): + call_history.append(args) + + with warnings.catch_warnings(record=True) as w: + with BatchQuery() as batch: + batch.add_callback(my_callback) + batch.execute() + batch.execute() + self.assertEqual(len(w), 2) # package filter setup to warn always + self.assertRegexpMatches(str(w[0].message), r"^Batch.*multiple.*") + + def test_disable_multiple_callback_warning(self): + """ + Tests that multiple executions of a batch statement + don't log a warning when warn_multiple_exec flag is set, and + that we don't encounter an attribute error. + @since 3.1 + @jira_ticket PYTHON-445 + @expected_result warning message is logged + + @test_category object_mapper + """ + call_history = [] + + def my_callback(*args, **kwargs): + call_history.append(args) + + with patch('cassandra.cqlengine.query.BatchQuery.warn_multiple_exec', False): + with warnings.catch_warnings(record=True) as w: + with BatchQuery() as batch: + batch.add_callback(my_callback) + batch.execute() + batch.execute() + self.assertFalse(w) diff --git a/tests/integration/cqlengine/test_connections.py b/tests/integration/cqlengine/test_connections.py new file mode 100644 index 0000000..10dee66 --- /dev/null +++ b/tests/integration/cqlengine/test_connections.py @@ -0,0 +1,671 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra import InvalidRequest +from cassandra.cluster import Cluster +from cassandra.cluster import NoHostAvailable +from cassandra.cqlengine import columns, CQLEngineException +from cassandra.cqlengine import connection as conn +from cassandra.cqlengine.management import drop_keyspace, sync_table, drop_table, create_keyspace_simple +from cassandra.cqlengine.models import Model, QuerySetDescriptor +from cassandra.cqlengine.query import ContextQuery, BatchQuery, ModelQuerySet +from tests.integration.cqlengine import setup_connection, DEFAULT_KEYSPACE +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine.query import test_queryset +from tests.integration import local, CASSANDRA_IP + + +class TestModel(Model): + + __keyspace__ = 'ks1' + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer() + text = columns.Text() + + +class AnotherTestModel(Model): + + __keyspace__ = 'ks1' + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer() + text = columns.Text() + +class ContextQueryConnectionTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(ContextQueryConnectionTests, cls).setUpClass() + create_keyspace_simple('ks1', 1) + + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['1.2.3.4'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', [CASSANDRA_IP]) + + with ContextQuery(TestModel, connection='cluster') as tm: + sync_table(tm) + + @classmethod + def tearDownClass(cls): + super(ContextQueryConnectionTests, cls).tearDownClass() + + with ContextQuery(TestModel, connection='cluster') as tm: + drop_table(tm) + drop_keyspace('ks1', connections=['cluster']) + + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def test_context_connection_priority(self): + """ + Tests to ensure the proper connection priority is honored. + + Explicit connection should have higest priority, + Followed by context query connection + Default connection should be honored last. + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result priorities should be honored + + @test_category object_mapper + """ + # model keyspace write/read + + # Set the default connection on the Model + TestModel.__connection__ = 'cluster' + with ContextQuery(TestModel) as tm: + tm.objects.create(partition=1, cluster=1) + + # ContextQuery connection should have priority over default one + with ContextQuery(TestModel, connection='fake_cluster') as tm: + with self.assertRaises(NoHostAvailable): + tm.objects.create(partition=1, cluster=1) + + # Explicit connection should have priority over ContextQuery one + with ContextQuery(TestModel, connection='fake_cluster') as tm: + tm.objects.using(connection='cluster').create(partition=1, cluster=1) + + # Reset the default conn of the model + TestModel.__connection__ = None + + # No model connection and an invalid default connection + with ContextQuery(TestModel) as tm: + with self.assertRaises(NoHostAvailable): + tm.objects.create(partition=1, cluster=1) + + def test_context_connection_with_keyspace(self): + """ + Tests to ensure keyspace param is honored + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result Invalid request is thrown + + @test_category object_mapper + """ + + # ks2 doesn't exist + with ContextQuery(TestModel, connection='cluster', keyspace='ks2') as tm: + with self.assertRaises(InvalidRequest): + tm.objects.create(partition=1, cluster=1) + + +class ManagementConnectionTests(BaseCassEngTestCase): + + keyspaces = ['ks1', 'ks2'] + conns = ['cluster'] + + @classmethod + def setUpClass(cls): + super(ManagementConnectionTests, cls).setUpClass() + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['127.0.0.100'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', [CASSANDRA_IP]) + + @classmethod + def tearDownClass(cls): + super(ManagementConnectionTests, cls).tearDownClass() + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def test_create_drop_keyspace(self): + """ + Tests drop and create keyspace with connections explicitly set + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result keyspaces should be created and dropped + + @test_category object_mapper + """ + + # No connection (default is fake) + with self.assertRaises(NoHostAvailable): + create_keyspace_simple(self.keyspaces[0], 1) + + # Explicit connections + for ks in self.keyspaces: + create_keyspace_simple(ks, 1, connections=self.conns) + + for ks in self.keyspaces: + drop_keyspace(ks, connections=self.conns) + + def test_create_drop_table(self): + """ + Tests drop and create Table with connections explicitly set + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result Tables should be created and dropped + + @test_category object_mapper + """ + for ks in self.keyspaces: + create_keyspace_simple(ks, 1, connections=self.conns) + + # No connection (default is fake) + with self.assertRaises(NoHostAvailable): + sync_table(TestModel) + + # Explicit connections + sync_table(TestModel, connections=self.conns) + + # Explicit drop + drop_table(TestModel, connections=self.conns) + + # Model connection + TestModel.__connection__ = 'cluster' + sync_table(TestModel) + TestModel.__connection__ = None + + # No connection (default is fake) + with self.assertRaises(NoHostAvailable): + drop_table(TestModel) + + # Model connection + TestModel.__connection__ = 'cluster' + drop_table(TestModel) + TestModel.__connection__ = None + + # Model connection + for ks in self.keyspaces: + drop_keyspace(ks, connections=self.conns) + + def test_connection_creation_from_session(self): + """ + Test to ensure that you can register a connection from a session + @since 3.8 + @jira_ticket PYTHON-649 + @expected_result queries should execute appropriately + + @test_category object_mapper + """ + cluster = Cluster([CASSANDRA_IP]) + session = cluster.connect() + connection_name = 'from_session' + conn.register_connection(connection_name, session=session) + self.assertIsNotNone(conn.get_connection(connection_name).cluster.metadata.get_host(CASSANDRA_IP)) + self.addCleanup(conn.unregister_connection, connection_name) + cluster.shutdown() + + def test_connection_from_hosts(self): + """ + Test to ensure that you can register a connection from a list of hosts + @since 3.8 + @jira_ticket PYTHON-692 + @expected_result queries should execute appropriately + + @test_category object_mapper + """ + connection_name = 'from_hosts' + conn.register_connection(connection_name, hosts=[CASSANDRA_IP]) + self.assertIsNotNone(conn.get_connection(connection_name).cluster.metadata.get_host(CASSANDRA_IP)) + self.addCleanup(conn.unregister_connection, connection_name) + + def test_connection_param_validation(self): + """ + Test to validate that invalid parameter combinations for registering connections via session are not tolerated + @since 3.8 + @jira_ticket PYTHON-649 + @expected_result queries should execute appropriately + + @test_category object_mapper + """ + cluster = Cluster([CASSANDRA_IP]) + session = cluster.connect() + with self.assertRaises(CQLEngineException): + conn.register_connection("bad_coonection1", session=session, consistency="not_null") + with self.assertRaises(CQLEngineException): + conn.register_connection("bad_coonection2", session=session, lazy_connect="not_null") + with self.assertRaises(CQLEngineException): + conn.register_connection("bad_coonection3", session=session, retry_connect="not_null") + with self.assertRaises(CQLEngineException): + conn.register_connection("bad_coonection4", session=session, cluster_options="not_null") + with self.assertRaises(CQLEngineException): + conn.register_connection("bad_coonection5", hosts="not_null", session=session) + cluster.shutdown() + + cluster.shutdown() + + + cluster.shutdown() + +class BatchQueryConnectionTests(BaseCassEngTestCase): + + conns = ['cluster'] + + @classmethod + def setUpClass(cls): + super(BatchQueryConnectionTests, cls).setUpClass() + + create_keyspace_simple('ks1', 1) + sync_table(TestModel) + sync_table(AnotherTestModel) + + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['127.0.0.100'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', [CASSANDRA_IP]) + + @classmethod + def tearDownClass(cls): + super(BatchQueryConnectionTests, cls).tearDownClass() + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + drop_keyspace('ks1') + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def test_basic_batch_query(self): + """ + Test Batch queries with connections explicitly set + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result queries should execute appropriately + + @test_category object_mapper + """ + + # No connection with a QuerySet (default is a fake one) + with self.assertRaises(NoHostAvailable): + with BatchQuery() as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + + # Explicit connection with a QuerySet + with BatchQuery(connection='cluster') as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + + # Get an object from the BD + with ContextQuery(TestModel, connection='cluster') as tm: + obj = tm.objects.get(partition=1, cluster=1) + obj.__connection__ = None + + # No connection with a model (default is a fake one) + with self.assertRaises(NoHostAvailable): + with BatchQuery() as b: + obj.count = 2 + obj.batch(b).save() + + # Explicit connection with a model + with BatchQuery(connection='cluster') as b: + obj.count = 2 + obj.batch(b).save() + + def test_batch_query_different_connection(self): + """ + Test BatchQuery with Models that have a different connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result queries should execute appropriately + + @test_category object_mapper + """ + + # Testing on a model class + TestModel.__connection__ = 'cluster' + AnotherTestModel.__connection__ = 'cluster2' + + with self.assertRaises(CQLEngineException): + with BatchQuery() as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + AnotherTestModel.objects.batch(b).create(partition=1, cluster=1) + + TestModel.__connection__ = None + AnotherTestModel.__connection__ = None + + with BatchQuery(connection='cluster') as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + AnotherTestModel.objects.batch(b).create(partition=1, cluster=1) + + # Testing on a model instance + with ContextQuery(TestModel, AnotherTestModel, connection='cluster') as (tm, atm): + obj1 = tm.objects.get(partition=1, cluster=1) + obj2 = atm.objects.get(partition=1, cluster=1) + + obj1.__connection__ = 'cluster' + obj2.__connection__ = 'cluster2' + + obj1.count = 4 + obj2.count = 4 + + with self.assertRaises(CQLEngineException): + with BatchQuery() as b: + obj1.batch(b).save() + obj2.batch(b).save() + + def test_batch_query_connection_override(self): + """ + Test that we cannot override a BatchQuery connection per model + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result Proper exceptions should be raised + + @test_category object_mapper + """ + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + TestModel.batch(b).using(connection='test').save() + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + TestModel.using(connection='test').batch(b).save() + + with ContextQuery(TestModel, AnotherTestModel, connection='cluster') as (tm, atm): + obj1 = tm.objects.get(partition=1, cluster=1) + obj1.__connection__ = None + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + obj1.using(connection='test').batch(b).save() + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + obj1.batch(b).using(connection='test').save() + +class UsingDescriptorTests(BaseCassEngTestCase): + + conns = ['cluster'] + keyspaces = ['ks1', 'ks2'] + + @classmethod + def setUpClass(cls): + super(UsingDescriptorTests, cls).setUpClass() + + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['127.0.0.100'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', [CASSANDRA_IP]) + + @classmethod + def tearDownClass(cls): + super(UsingDescriptorTests, cls).tearDownClass() + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + for ks in cls.keyspaces: + drop_keyspace(ks) + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def _reset_data(self): + + for ks in self.keyspaces: + drop_keyspace(ks, connections=self.conns) + for ks in self.keyspaces: + create_keyspace_simple(ks, 1, connections=self.conns) + sync_table(TestModel, keyspaces=self.keyspaces, connections=self.conns) + + def test_keyspace(self): + """ + Test keyspace segregation when same connection is used + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result Keyspace segration is honored + + @test_category object_mapper + """ + self._reset_data() + + with ContextQuery(TestModel, connection='cluster') as tm: + + # keyspace Model class + tm.objects.using(keyspace='ks2').create(partition=1, cluster=1) + tm.objects.using(keyspace='ks2').create(partition=2, cluster=2) + + with self.assertRaises(TestModel.DoesNotExist): + tm.objects.get(partition=1, cluster=1) # default keyspace ks1 + obj1 = tm.objects.using(keyspace='ks2').get(partition=1, cluster=1) + + obj1.count = 2 + obj1.save() + + with self.assertRaises(NoHostAvailable): + TestModel.objects.using(keyspace='ks2').get(partition=1, cluster=1) + + obj2 = TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=1, cluster=1) + self.assertEqual(obj2.count, 2) + + # Update test + TestModel.objects(partition=2, cluster=2).using(connection='cluster', keyspace='ks2').update(count=5) + obj3 = TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=2, cluster=2) + self.assertEqual(obj3.count, 5) + + TestModel.objects(partition=2, cluster=2).using(connection='cluster', keyspace='ks2').delete() + with self.assertRaises(TestModel.DoesNotExist): + TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=2, cluster=2) + + def test_connection(self): + """ + Test basic connection functionality + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + self._reset_data() + + # Model class + with self.assertRaises(NoHostAvailable): + TestModel.objects.create(partition=1, cluster=1) + + TestModel.objects.using(connection='cluster').create(partition=1, cluster=1) + TestModel.objects(partition=1, cluster=1).using(connection='cluster').update(count=2) + obj1 = TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) + self.assertEqual(obj1.count, 2) + + obj1.using(connection='cluster').update(count=5) + obj1 = TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) + self.assertEqual(obj1.count, 5) + + obj1.using(connection='cluster').delete() + with self.assertRaises(TestModel.DoesNotExist): + TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) + + +class ModelQuerySetNew(ModelQuerySet): + def __init__(self, *args, **kwargs): + super(ModelQuerySetNew, self).__init__(*args, **kwargs) + self._connection = "cluster" + +class BaseConnectionTestNoDefault(object): + conns = ['cluster'] + + @classmethod + def setUpClass(cls): + conn.register_connection('cluster', [CASSANDRA_IP]) + test_queryset.TestModel.__queryset__ = ModelQuerySetNew + test_queryset.IndexedTestModel.__queryset__ = ModelQuerySetNew + test_queryset.CustomIndexedTestModel.__queryset__ = ModelQuerySetNew + test_queryset.IndexedCollectionsTestModel.__queryset__ = ModelQuerySetNew + test_queryset.TestMultiClusteringModel.__queryset__ = ModelQuerySetNew + + super(BaseConnectionTestNoDefault, cls).setUpClass() + conn.unregister_connection('default') + + @classmethod + def tearDownClass(cls): + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + super(BaseConnectionTestNoDefault, cls).tearDownClass() + # reset the default connection + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + +class TestQuerySetOperationConnection(BaseConnectionTestNoDefault, test_queryset.TestQuerySetOperation): + """ + Execute test_queryset.TestQuerySetOperation using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetDistinctNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetDistinct): + """ + Execute test_queryset.TestQuerySetDistinct using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetOrderingNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetOrdering): + """ + Execute test_queryset.TestQuerySetOrdering using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetCountSelectionAndIterationNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetCountSelectionAndIteration): + """ + Execute test_queryset.TestQuerySetOrdering using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetSlicingNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetSlicing): + """ + Execute test_queryset.TestQuerySetOrdering using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetValidationNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetValidation): + """ + Execute test_queryset.TestQuerySetOrdering using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetDeleteNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetDelete): + """ + Execute test_queryset.TestQuerySetDelete using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestValuesListNoDefault(BaseConnectionTestNoDefault, test_queryset.TestValuesList): + """ + Execute test_queryset.TestValuesList using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestObjectsPropertyNoDefault(BaseConnectionTestNoDefault, test_queryset.TestObjectsProperty): + """ + Execute test_queryset.TestObjectsProperty using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass diff --git a/tests/integration/cqlengine/test_consistency.py b/tests/integration/cqlengine/test_consistency.py new file mode 100644 index 0000000..dc0aa32 --- /dev/null +++ b/tests/integration/cqlengine/test_consistency.py @@ -0,0 +1,121 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +from uuid import uuid4 + +from cassandra import ConsistencyLevel as CL, ConsistencyLevel +from cassandra.cluster import Session +from cassandra.cqlengine import columns +from cassandra.cqlengine import connection +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import BatchQuery + +from tests.integration.cqlengine.base import BaseCassEngTestCase + +class TestConsistencyModel(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + +class BaseConsistencyTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseConsistencyTest, cls).setUpClass() + sync_table(TestConsistencyModel) + + @classmethod + def tearDownClass(cls): + super(BaseConsistencyTest, cls).tearDownClass() + drop_table(TestConsistencyModel) + + +class TestConsistency(BaseConsistencyTest): + def test_create_uses_consistency(self): + + qs = TestConsistencyModel.consistency(CL.ALL) + with mock.patch.object(self.session, 'execute') as m: + qs.create(text="i am not fault tolerant this way") + + args = m.call_args + self.assertEqual(CL.ALL, args[0][0].consistency_level) + + def test_queryset_is_returned_on_create(self): + qs = TestConsistencyModel.consistency(CL.ALL) + self.assertTrue(isinstance(qs, TestConsistencyModel.__queryset__), type(qs)) + + def test_update_uses_consistency(self): + t = TestConsistencyModel.create(text="bacon and eggs") + t.text = "ham sandwich" + + with mock.patch.object(self.session, 'execute') as m: + t.consistency(CL.ALL).save() + + args = m.call_args + self.assertEqual(CL.ALL, args[0][0].consistency_level) + + def test_batch_consistency(self): + + with mock.patch.object(self.session, 'execute') as m: + with BatchQuery(consistency=CL.ALL) as b: + TestConsistencyModel.batch(b).create(text="monkey") + + args = m.call_args + + self.assertEqual(CL.ALL, args[0][0].consistency_level) + + with mock.patch.object(self.session, 'execute') as m: + with BatchQuery() as b: + TestConsistencyModel.batch(b).create(text="monkey") + + args = m.call_args + self.assertNotEqual(CL.ALL, args[0][0].consistency_level) + + def test_blind_update(self): + t = TestConsistencyModel.create(text="bacon and eggs") + t.text = "ham sandwich" + uid = t.id + + with mock.patch.object(self.session, 'execute') as m: + TestConsistencyModel.objects(id=uid).consistency(CL.ALL).update(text="grilled cheese") + + args = m.call_args + self.assertEqual(CL.ALL, args[0][0].consistency_level) + + def test_delete(self): + # ensures we always carry consistency through on delete statements + t = TestConsistencyModel.create(text="bacon and eggs") + t.text = "ham and cheese sandwich" + uid = t.id + + with mock.patch.object(self.session, 'execute') as m: + t.consistency(CL.ALL).delete() + + with mock.patch.object(self.session, 'execute') as m: + TestConsistencyModel.objects(id=uid).consistency(CL.ALL).delete() + + args = m.call_args + self.assertEqual(CL.ALL, args[0][0].consistency_level) + + def test_default_consistency(self): + # verify global assumed default + self.assertEqual(Session._default_consistency_level, ConsistencyLevel.LOCAL_ONE) + + # verify that this session default is set according to connection.setup + # assumes tests/cqlengine/__init__ setup uses CL.ONE + session = connection.get_session() + self.assertEqual(session.default_consistency_level, ConsistencyLevel.ONE) diff --git a/tests/integration/cqlengine/test_context_query.py b/tests/integration/cqlengine/test_context_query.py new file mode 100644 index 0000000..6f2a161 --- /dev/null +++ b/tests/integration/cqlengine/test_context_query.py @@ -0,0 +1,175 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import drop_keyspace, sync_table, create_keyspace_simple +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import ContextQuery +from tests.integration.cqlengine.base import BaseCassEngTestCase + + +class TestModel(Model): + + __keyspace__ = 'ks1' + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer() + text = columns.Text() + + +class ContextQueryTests(BaseCassEngTestCase): + + KEYSPACES = ('ks1', 'ks2', 'ks3', 'ks4') + + @classmethod + def setUpClass(cls): + super(ContextQueryTests, cls).setUpClass() + for ks in cls.KEYSPACES: + create_keyspace_simple(ks, 1) + sync_table(TestModel, keyspaces=cls.KEYSPACES) + + @classmethod + def tearDownClass(cls): + super(ContextQueryTests, cls).tearDownClass() + for ks in cls.KEYSPACES: + drop_keyspace(ks) + + + def setUp(self): + super(ContextQueryTests, self).setUp() + for ks in self.KEYSPACES: + with ContextQuery(TestModel, keyspace=ks) as tm: + for obj in tm.all(): + obj.delete() + + def test_context_manager(self): + """ + Validates that when a context query is constructed that the + keyspace of the returned model is toggled appropriately + + @since 3.6 + @jira_ticket PYTHON-598 + @expected_result default keyspace should be used + + @test_category query + """ + # model keyspace write/read + for ks in self.KEYSPACES: + with ContextQuery(TestModel, keyspace=ks) as tm: + self.assertEqual(tm.__keyspace__, ks) + + self.assertEqual(TestModel._get_keyspace(), 'ks1') + + def test_default_keyspace(self): + """ + Tests the use of context queries with the default model keyspsace + + @since 3.6 + @jira_ticket PYTHON-598 + @expected_result default keyspace should be used + + @test_category query + """ + # model keyspace write/read + for i in range(5): + TestModel.objects.create(partition=i, cluster=i) + + with ContextQuery(TestModel) as tm: + self.assertEqual(5, len(tm.objects.all())) + + with ContextQuery(TestModel, keyspace='ks1') as tm: + self.assertEqual(5, len(tm.objects.all())) + + for ks in self.KEYSPACES[1:]: + with ContextQuery(TestModel, keyspace=ks) as tm: + self.assertEqual(0, len(tm.objects.all())) + + def test_context_keyspace(self): + """ + Tests the use of context queries with non default keyspaces + + @since 3.6 + @jira_ticket PYTHON-598 + @expected_result queries should be routed to appropriate keyspaces + + @test_category query + """ + for i in range(5): + with ContextQuery(TestModel, keyspace='ks4') as tm: + tm.objects.create(partition=i, cluster=i) + + with ContextQuery(TestModel, keyspace='ks4') as tm: + self.assertEqual(5, len(tm.objects.all())) + + self.assertEqual(0, len(TestModel.objects.all())) + + for ks in self.KEYSPACES[:2]: + with ContextQuery(TestModel, keyspace=ks) as tm: + self.assertEqual(0, len(tm.objects.all())) + + # simple data update + with ContextQuery(TestModel, keyspace='ks4') as tm: + obj = tm.objects.get(partition=1) + obj.update(count=42) + + self.assertEqual(42, tm.objects.get(partition=1).count) + + def test_context_multiple_models(self): + """ + Tests the use of multiple models with the context manager + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result all models are properly updated with the context + + @test_category query + """ + + with ContextQuery(TestModel, TestModel, keyspace='ks4') as (tm1, tm2): + + self.assertNotEqual(tm1, tm2) + self.assertEqual(tm1.__keyspace__, 'ks4') + self.assertEqual(tm2.__keyspace__, 'ks4') + + def test_context_invalid_parameters(self): + """ + Tests that invalid parameters are raised by the context manager + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result a ValueError is raised when passing invalid parameters + + @test_category query + """ + + with self.assertRaises(ValueError): + with ContextQuery(keyspace='ks2'): + pass + + with self.assertRaises(ValueError): + with ContextQuery(42) as tm: + pass + + with self.assertRaises(ValueError): + with ContextQuery(TestModel, 42): + pass + + with self.assertRaises(ValueError): + with ContextQuery(TestModel, unknown_param=42): + pass + + with self.assertRaises(ValueError): + with ContextQuery(TestModel, keyspace='ks2', unknown_param=42): + pass \ No newline at end of file diff --git a/tests/integration/cqlengine/test_ifexists.py b/tests/integration/cqlengine/test_ifexists.py new file mode 100644 index 0000000..2797edd --- /dev/null +++ b/tests/integration/cqlengine/test_ifexists.py @@ -0,0 +1,316 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import mock +from uuid import uuid4 + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import BatchQuery, BatchType, LWTException, IfExistsWithCounterColumn + +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration import PROTOCOL_VERSION + + +class TestIfExistsModel(Model): + + id = columns.UUID(primary_key=True, default=lambda: uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + + +class TestIfExistsModel2(Model): + + id = columns.Integer(primary_key=True) + count = columns.Integer(primary_key=True, required=False) + text = columns.Text(required=False) + + +class TestIfExistsWithCounterModel(Model): + + id = columns.UUID(primary_key=True, default=lambda: uuid4()) + likes = columns.Counter() + + +class BaseIfExistsTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseIfExistsTest, cls).setUpClass() + sync_table(TestIfExistsModel) + sync_table(TestIfExistsModel2) + + @classmethod + def tearDownClass(cls): + super(BaseIfExistsTest, cls).tearDownClass() + drop_table(TestIfExistsModel) + drop_table(TestIfExistsModel2) + + +class BaseIfExistsWithCounterTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseIfExistsWithCounterTest, cls).setUpClass() + sync_table(TestIfExistsWithCounterModel) + + @classmethod + def tearDownClass(cls): + super(BaseIfExistsWithCounterTest, cls).tearDownClass() + drop_table(TestIfExistsWithCounterModel) + + +class IfExistsUpdateTests(BaseIfExistsTest): + + @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") + def test_update_if_exists(self): + """ + Tests that update with if_exists work as expected + + @since 3.1 + @jira_ticket PYTHON-432 + @expected_result updates to be applied when primary key exists, otherwise LWT exception to be thrown + + @test_category object_mapper + """ + + id = uuid4() + + m = TestIfExistsModel.create(id=id, count=8, text='123456789') + m.text = 'changed' + m.if_exists().update() + m = TestIfExistsModel.get(id=id) + self.assertEqual(m.text, 'changed') + + # save() + m.text = 'changed_again' + m.if_exists().save() + m = TestIfExistsModel.get(id=id) + self.assertEqual(m.text, 'changed_again') + + m = TestIfExistsModel(id=uuid4(), count=44) # do not exists + with self.assertRaises(LWTException) as assertion: + m.if_exists().update() + + self.assertEqual(assertion.exception.existing, { + '[applied]': False, + }) + + # queryset update + with self.assertRaises(LWTException) as assertion: + TestIfExistsModel.objects(id=uuid4()).if_exists().update(count=8) + + self.assertEqual(assertion.exception.existing, { + '[applied]': False, + }) + + @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") + def test_batch_update_if_exists_success(self): + """ + Tests that batch update with if_exists work as expected + + @since 3.1 + @jira_ticket PYTHON-432 + @expected_result + + @test_category object_mapper + """ + + id = uuid4() + + m = TestIfExistsModel.create(id=id, count=8, text='123456789') + + with BatchQuery() as b: + m.text = '111111111' + m.batch(b).if_exists().update() + + with self.assertRaises(LWTException) as assertion: + with BatchQuery() as b: + m = TestIfExistsModel(id=uuid4(), count=42) # Doesn't exist + m.batch(b).if_exists().update() + + self.assertEqual(assertion.exception.existing, { + '[applied]': False, + }) + + q = TestIfExistsModel.objects(id=id) + self.assertEqual(len(q), 1) + + tm = q.first() + self.assertEqual(tm.count, 8) + self.assertEqual(tm.text, '111111111') + + @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") + def test_batch_mixed_update_if_exists_success(self): + """ + Tests that batch update with with one bad query will still fail with LWTException + + @since 3.1 + @jira_ticket PYTHON-432 + @expected_result + + @test_category object_mapper + """ + + m = TestIfExistsModel2.create(id=1, count=8, text='123456789') + with self.assertRaises(LWTException) as assertion: + with BatchQuery() as b: + m.text = '111111112' + m.batch(b).if_exists().update() # Does exist + n = TestIfExistsModel2(id=1, count=10, text="Failure") # Doesn't exist + n.batch(b).if_exists().update() + + self.assertEqual(assertion.exception.existing.get('[applied]'), False) + + @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") + def test_delete_if_exists(self): + """ + Tests that delete with if_exists work, and throw proper LWT exception when they are are not applied + + @since 3.1 + @jira_ticket PYTHON-432 + @expected_result Deletes will be preformed if they exist, otherwise throw LWT exception + + @test_category object_mapper + """ + + id = uuid4() + + m = TestIfExistsModel.create(id=id, count=8, text='123456789') + m.if_exists().delete() + q = TestIfExistsModel.objects(id=id) + self.assertEqual(len(q), 0) + + m = TestIfExistsModel(id=uuid4(), count=44) # do not exists + with self.assertRaises(LWTException) as assertion: + m.if_exists().delete() + + self.assertEqual(assertion.exception.existing, { + '[applied]': False, + }) + + # queryset delete + with self.assertRaises(LWTException) as assertion: + TestIfExistsModel.objects(id=uuid4()).if_exists().delete() + + self.assertEqual(assertion.exception.existing, { + '[applied]': False, + }) + + @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") + def test_batch_delete_if_exists_success(self): + """ + Tests that batch deletes with if_exists work, and throw proper LWTException when they are are not applied + + @since 3.1 + @jira_ticket PYTHON-432 + @expected_result Deletes will be preformed if they exist, otherwise throw LWTException + + @test_category object_mapper + """ + + id = uuid4() + + m = TestIfExistsModel.create(id=id, count=8, text='123456789') + + with BatchQuery() as b: + m.batch(b).if_exists().delete() + + q = TestIfExistsModel.objects(id=id) + self.assertEqual(len(q), 0) + + with self.assertRaises(LWTException) as assertion: + with BatchQuery() as b: + m = TestIfExistsModel(id=uuid4(), count=42) # Doesn't exist + m.batch(b).if_exists().delete() + + self.assertEqual(assertion.exception.existing, { + '[applied]': False, + }) + + @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") + def test_batch_delete_mixed(self): + """ + Tests that batch deletes with multiple queries and throw proper LWTException when they are are not all applicable + + @since 3.1 + @jira_ticket PYTHON-432 + @expected_result If one delete clause doesn't exist all should fail. + + @test_category object_mapper + """ + + m = TestIfExistsModel2.create(id=3, count=8, text='123456789') + + with self.assertRaises(LWTException) as assertion: + with BatchQuery() as b: + m.batch(b).if_exists().delete() # Does exist + n = TestIfExistsModel2(id=3, count=42, text='1111111') # Doesn't exist + n.batch(b).if_exists().delete() + + self.assertEqual(assertion.exception.existing.get('[applied]'), False) + q = TestIfExistsModel2.objects(id=3, count=8) + self.assertEqual(len(q), 1) + + +class IfExistsQueryTest(BaseIfExistsTest): + + def test_if_exists_included_on_queryset_update(self): + + with mock.patch.object(self.session, 'execute') as m: + TestIfExistsModel.objects(id=uuid4()).if_exists().update(count=42) + + query = m.call_args[0][0].query_string + self.assertIn("IF EXISTS", query) + + def test_if_exists_included_on_update(self): + """ tests that if_exists on models update works as expected """ + + with mock.patch.object(self.session, 'execute') as m: + TestIfExistsModel(id=uuid4()).if_exists().update(count=8) + + query = m.call_args[0][0].query_string + self.assertIn("IF EXISTS", query) + + def test_if_exists_included_on_delete(self): + """ tests that if_exists on models delete works as expected """ + + with mock.patch.object(self.session, 'execute') as m: + TestIfExistsModel(id=uuid4()).if_exists().delete() + + query = m.call_args[0][0].query_string + self.assertIn("IF EXISTS", query) + + +class IfExistWithCounterTest(BaseIfExistsWithCounterTest): + + def test_instance_raise_exception(self): + """ + Tests if exists is used with a counter column model that exception are thrown + + @since 3.1 + @jira_ticket PYTHON-432 + @expected_result Deletes will be preformed if they exist, otherwise throw LWTException + + @test_category object_mapper + """ + id = uuid4() + with self.assertRaises(IfExistsWithCounterColumn): + TestIfExistsWithCounterModel.if_exists() + diff --git a/tests/integration/cqlengine/test_ifnotexists.py b/tests/integration/cqlengine/test_ifnotexists.py new file mode 100644 index 0000000..206101f --- /dev/null +++ b/tests/integration/cqlengine/test_ifnotexists.py @@ -0,0 +1,205 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import mock +from uuid import uuid4 + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import BatchQuery, LWTException, IfNotExistsWithCounterColumn + +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration import PROTOCOL_VERSION + +class TestIfNotExistsModel(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + + +class TestIfNotExistsWithCounterModel(Model): + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + likes = columns.Counter() + + +class BaseIfNotExistsTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseIfNotExistsTest, cls).setUpClass() + """ + when receiving an insert statement with 'if not exist', cassandra would + perform a read with QUORUM level. Unittest would be failed if replica_factor + is 3 and one node only. Therefore I have create a new keyspace with + replica_factor:1. + """ + sync_table(TestIfNotExistsModel) + + @classmethod + def tearDownClass(cls): + super(BaseIfNotExistsTest, cls).tearDownClass() + drop_table(TestIfNotExistsModel) + + +class BaseIfNotExistsWithCounterTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseIfNotExistsWithCounterTest, cls).setUpClass() + sync_table(TestIfNotExistsWithCounterModel) + + @classmethod + def tearDownClass(cls): + super(BaseIfNotExistsWithCounterTest, cls).tearDownClass() + drop_table(TestIfNotExistsWithCounterModel) + + +class IfNotExistsInsertTests(BaseIfNotExistsTest): + + @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") + def test_insert_if_not_exists(self): + """ tests that insertion with if_not_exists work as expected """ + + id = uuid4() + + TestIfNotExistsModel.create(id=id, count=8, text='123456789') + + with self.assertRaises(LWTException) as assertion: + TestIfNotExistsModel.if_not_exists().create(id=id, count=9, text='111111111111') + + with self.assertRaises(LWTException) as assertion: + TestIfNotExistsModel.objects(count=9, text='111111111111').if_not_exists().create(id=id) + + self.assertEqual(assertion.exception.existing, { + 'count': 8, + 'id': id, + 'text': '123456789', + '[applied]': False, + }) + + q = TestIfNotExistsModel.objects(id=id) + self.assertEqual(len(q), 1) + + tm = q.first() + self.assertEqual(tm.count, 8) + self.assertEqual(tm.text, '123456789') + + @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") + def test_batch_insert_if_not_exists(self): + """ tests that batch insertion with if_not_exists work as expected """ + + id = uuid4() + + with BatchQuery() as b: + TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=8, text='123456789') + + b = BatchQuery() + TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=9, text='111111111111') + with self.assertRaises(LWTException) as assertion: + b.execute() + + self.assertEqual(assertion.exception.existing, { + 'count': 8, + 'id': id, + 'text': '123456789', + '[applied]': False, + }) + + q = TestIfNotExistsModel.objects(id=id) + self.assertEqual(len(q), 1) + + tm = q.first() + self.assertEqual(tm.count, 8) + self.assertEqual(tm.text, '123456789') + + +class IfNotExistsModelTest(BaseIfNotExistsTest): + + def test_if_not_exists_included_on_create(self): + """ tests that if_not_exists on models works as expected """ + + with mock.patch.object(self.session, 'execute') as m: + TestIfNotExistsModel.if_not_exists().create(count=8) + + query = m.call_args[0][0].query_string + self.assertIn("IF NOT EXISTS", query) + + def test_if_not_exists_included_on_save(self): + """ tests if we correctly put 'IF NOT EXISTS' for insert statement """ + + with mock.patch.object(self.session, 'execute') as m: + tm = TestIfNotExistsModel(count=8) + tm.if_not_exists().save() + + query = m.call_args[0][0].query_string + self.assertIn("IF NOT EXISTS", query) + + def test_queryset_is_returned_on_class(self): + """ ensure we get a queryset description back """ + qs = TestIfNotExistsModel.if_not_exists() + self.assertTrue(isinstance(qs, TestIfNotExistsModel.__queryset__), type(qs)) + + def test_batch_if_not_exists(self): + """ ensure 'IF NOT EXISTS' exists in statement when in batch """ + with mock.patch.object(self.session, 'execute') as m: + with BatchQuery() as b: + TestIfNotExistsModel.batch(b).if_not_exists().create(count=8) + + self.assertIn("IF NOT EXISTS", m.call_args[0][0].query_string) + + +class IfNotExistsInstanceTest(BaseIfNotExistsTest): + + def test_instance_is_returned(self): + """ + ensures that we properly handle the instance.if_not_exists().save() + scenario + """ + o = TestIfNotExistsModel.create(text="whatever") + o.text = "new stuff" + o = o.if_not_exists() + self.assertEqual(True, o._if_not_exists) + + def test_if_not_exists_is_not_include_with_query_on_update(self): + """ + make sure we don't put 'IF NOT EXIST' in update statements + """ + o = TestIfNotExistsModel.create(text="whatever") + o.text = "new stuff" + o = o.if_not_exists() + + with mock.patch.object(self.session, 'execute') as m: + o.save() + + query = m.call_args[0][0].query_string + self.assertNotIn("IF NOT EXIST", query) + + +class IfNotExistWithCounterTest(BaseIfNotExistsWithCounterTest): + + def test_instance_raise_exception(self): + """ make sure exception is raised when calling + if_not_exists on table with counter column + """ + id = uuid4() + with self.assertRaises(IfNotExistsWithCounterColumn): + TestIfNotExistsWithCounterModel.if_not_exists() + diff --git a/tests/integration/cqlengine/test_lwt_conditional.py b/tests/integration/cqlengine/test_lwt_conditional.py new file mode 100644 index 0000000..1c418ae --- /dev/null +++ b/tests/integration/cqlengine/test_lwt_conditional.py @@ -0,0 +1,299 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import mock +import six +from uuid import uuid4 + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import BatchQuery, LWTException +from cassandra.cqlengine.statements import ConditionalClause + +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration import greaterthancass20 + + +class TestConditionalModel(Model): + id = columns.UUID(primary_key=True, default=uuid4) + count = columns.Integer() + text = columns.Text(required=False) + + +class TestUpdateModel(Model): + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + value = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + + +@greaterthancass20 +class TestConditional(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestConditional, cls).setUpClass() + sync_table(TestConditionalModel) + + @classmethod + def tearDownClass(cls): + super(TestConditional, cls).tearDownClass() + drop_table(TestConditionalModel) + + def test_update_using_conditional(self): + t = TestConditionalModel.if_not_exists().create(text='blah blah') + t.text = 'new blah' + with mock.patch.object(self.session, 'execute') as m: + t.iff(text='blah blah').save() + + args = m.call_args + self.assertIn('IF "text" = %(0)s', args[0][0].query_string) + + def test_update_conditional_success(self): + t = TestConditionalModel.if_not_exists().create(text='blah blah', count=5) + id = t.id + t.text = 'new blah' + t.iff(text='blah blah').save() + + updated = TestConditionalModel.objects(id=id).first() + self.assertEqual(updated.count, 5) + self.assertEqual(updated.text, 'new blah') + + def test_update_failure(self): + t = TestConditionalModel.if_not_exists().create(text='blah blah') + t.text = 'new blah' + t = t.iff(text='something wrong') + + with self.assertRaises(LWTException) as assertion: + t.save() + + self.assertEqual(assertion.exception.existing, { + 'text': 'blah blah', + '[applied]': False, + }) + + def test_blind_update(self): + t = TestConditionalModel.if_not_exists().create(text='blah blah') + t.text = 'something else' + uid = t.id + + with mock.patch.object(self.session, 'execute') as m: + TestConditionalModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der') + + args = m.call_args + self.assertIn('IF "text" = %(1)s', args[0][0].query_string) + + def test_blind_update_fail(self): + t = TestConditionalModel.if_not_exists().create(text='blah blah') + t.text = 'something else' + uid = t.id + qs = TestConditionalModel.objects(id=uid).iff(text='Not dis!') + with self.assertRaises(LWTException) as assertion: + qs.update(text='this will never work') + + self.assertEqual(assertion.exception.existing, { + 'text': 'blah blah', + '[applied]': False, + }) + + def test_conditional_clause(self): + tc = ConditionalClause('some_value', 23) + tc.set_context_id(3) + + self.assertEqual('"some_value" = %(3)s', six.text_type(tc)) + self.assertEqual('"some_value" = %(3)s', str(tc)) + + def test_batch_update_conditional(self): + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + id = t.id + with BatchQuery() as b: + t.batch(b).iff(count=5).update(text='something else') + + updated = TestConditionalModel.objects(id=id).first() + self.assertEqual(updated.text, 'something else') + + b = BatchQuery() + updated.batch(b).iff(count=6).update(text='and another thing') + with self.assertRaises(LWTException) as assertion: + b.execute() + + self.assertEqual(assertion.exception.existing, { + 'id': id, + 'count': 5, + '[applied]': False, + }) + + updated = TestConditionalModel.objects(id=id).first() + self.assertEqual(updated.text, 'something else') + + @unittest.skip("Skipping until PYTHON-943 is resolved") + def test_batch_update_conditional_several_rows(self): + sync_table(TestUpdateModel) + self.addCleanup(drop_table, TestUpdateModel) + + first_row = TestUpdateModel.create(partition=1, cluster=1, value=5, text="something") + second_row = TestUpdateModel.create(partition=1, cluster=2, value=5, text="something") + + b = BatchQuery() + TestUpdateModel.batch(b).if_not_exists().create(partition=1, cluster=1, value=5, text='something else') + TestUpdateModel.batch(b).if_not_exists().create(partition=1, cluster=2, value=5, text='something else') + TestUpdateModel.batch(b).if_not_exists().create(partition=1, cluster=3, value=5, text='something else') + + # The response will be more than two rows because two of the inserts will fail + with self.assertRaises(LWTException): + b.execute() + + first_row.delete() + second_row.delete() + b.execute() + + + def test_delete_conditional(self): + # DML path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + t.iff(count=9999).delete() + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + t.iff(count=5).delete() + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + + # QuerySet path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + TestConditionalModel.objects(id=t.id).iff(count=9999).delete() + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + TestConditionalModel.objects(id=t.id).iff(count=5).delete() + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + + def test_delete_lwt_ne(self): + """ + Test to ensure that deletes using IF and not equals are honored correctly + + @since 3.2 + @jira_ticket PYTHON-328 + @expected_result Delete conditional with NE should be honored + + @test_category object_mapper + """ + + # DML path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + t.iff(count__ne=5).delete() + t.iff(count__ne=2).delete() + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + + # QuerySet path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + TestConditionalModel.objects(id=t.id).iff(count__ne=5).delete() + TestConditionalModel.objects(id=t.id).iff(count__ne=2).delete() + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + + def test_update_lwt_ne(self): + """ + Test to ensure that update using IF and not equals are honored correctly + + @since 3.2 + @jira_ticket PYTHON-328 + @expected_result update conditional with NE should be honored + + @test_category object_mapper + """ + + # DML path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + t.iff(count__ne=5).update(text='nothing') + t.iff(count__ne=2).update(text='nothing') + self.assertEqual(TestConditionalModel.objects(id=t.id).first().text, 'nothing') + t.delete() + + # QuerySet path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + TestConditionalModel.objects(id=t.id).iff(count__ne=5).update(text='nothing') + TestConditionalModel.objects(id=t.id).iff(count__ne=2).update(text='nothing') + self.assertEqual(TestConditionalModel.objects(id=t.id).first().text, 'nothing') + t.delete() + + def test_update_to_none(self): + # This test is done because updates to none are split into deletes + # for old versions of cassandra. Can be removed when we drop that code + # https://github.com/datastax/python-driver/blob/3.1.1/cassandra/cqlengine/query.py#L1197-L1200 + + # DML path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + t.iff(count=9999).update(text=None) + self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text) + t.iff(count=5).update(text=None) + self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + + # QuerySet path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + TestConditionalModel.objects(id=t.id).iff(count=9999).update(text=None) + self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text) + TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None) + self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + + def test_column_delete_after_update(self): + # DML path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + t.iff(count=5).update(text=None, count=6) + + self.assertIsNone(t.text) + self.assertEqual(t.count, 6) + + # QuerySet path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None, count=6) + + self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + self.assertEqual(TestConditionalModel.objects(id=t.id).first().count, 6) + + def test_conditional_without_instance(self): + """ + Test to ensure that the iff method is honored if it's called + directly from the Model class + + @jira_ticket PYTHON-505 + @expected_result the value is updated + + @test_category object_mapper + """ + uuid = uuid4() + TestConditionalModel.if_not_exists().create(id=uuid, text='test_for_cassandra', count=5) + + # This uses the iff method directly from the model class without + # an instance having been created + TestConditionalModel.iff(count=5).filter(id=uuid).update(text=None, count=6) + + t = TestConditionalModel.filter(id=uuid).first() + self.assertIsNone(t.text) + self.assertEqual(t.count, 6) diff --git a/tests/integration/cqlengine/test_timestamp.py b/tests/integration/cqlengine/test_timestamp.py new file mode 100644 index 0000000..abf751e --- /dev/null +++ b/tests/integration/cqlengine/test_timestamp.py @@ -0,0 +1,207 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import timedelta, datetime +import mock +import sure +from uuid import uuid4 + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import sync_table +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import BatchQuery +from tests.integration.cqlengine.base import BaseCassEngTestCase + + +class TestTimestampModel(Model): + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + + +class BaseTimestampTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseTimestampTest, cls).setUpClass() + sync_table(TestTimestampModel) + + +class BatchTest(BaseTimestampTest): + + def test_batch_is_included(self): + with mock.patch.object(self.session, "execute") as m: + with BatchQuery(timestamp=timedelta(seconds=30)) as b: + TestTimestampModel.batch(b).create(count=1) + + "USING TIMESTAMP".should.be.within(m.call_args[0][0].query_string) + + +class CreateWithTimestampTest(BaseTimestampTest): + + def test_batch(self): + with mock.patch.object(self.session, "execute") as m: + with BatchQuery() as b: + TestTimestampModel.timestamp(timedelta(seconds=10)).batch(b).create(count=1) + + query = m.call_args[0][0].query_string + + query.should.match(r"INSERT.*USING TIMESTAMP") + query.should_not.match(r"TIMESTAMP.*INSERT") + + def test_timestamp_not_included_on_normal_create(self): + with mock.patch.object(self.session, "execute") as m: + TestTimestampModel.create(count=2) + + "USING TIMESTAMP".shouldnt.be.within(m.call_args[0][0].query_string) + + def test_timestamp_is_set_on_model_queryset(self): + delta = timedelta(seconds=30) + tmp = TestTimestampModel.timestamp(delta) + tmp._timestamp.should.equal(delta) + + def test_non_batch_syntax_integration(self): + tmp = TestTimestampModel.timestamp(timedelta(seconds=30)).create(count=1) + tmp.should.be.ok + + def test_non_batch_syntax_with_tll_integration(self): + tmp = TestTimestampModel.timestamp(timedelta(seconds=30)).ttl(30).create(count=1) + tmp.should.be.ok + + def test_non_batch_syntax_unit(self): + + with mock.patch.object(self.session, "execute") as m: + TestTimestampModel.timestamp(timedelta(seconds=30)).create(count=1) + + query = m.call_args[0][0].query_string + + "USING TIMESTAMP".should.be.within(query) + + def test_non_batch_syntax_with_ttl_unit(self): + + with mock.patch.object(self.session, "execute") as m: + TestTimestampModel.timestamp(timedelta(seconds=30)).ttl(30).create( + count=1) + + query = m.call_args[0][0].query_string + + query.should.match(r"USING TTL \d* AND TIMESTAMP") + + +class UpdateWithTimestampTest(BaseTimestampTest): + def setUp(self): + self.instance = TestTimestampModel.create(count=1) + super(UpdateWithTimestampTest, self).setUp() + + def test_instance_update_includes_timestamp_in_query(self): + # not a batch + + with mock.patch.object(self.session, "execute") as m: + self.instance.timestamp(timedelta(seconds=30)).update(count=2) + + "USING TIMESTAMP".should.be.within(m.call_args[0][0].query_string) + + def test_instance_update_in_batch(self): + with mock.patch.object(self.session, "execute") as m: + with BatchQuery() as b: + self.instance.batch(b).timestamp(timedelta(seconds=30)).update(count=2) + + query = m.call_args[0][0].query_string + "USING TIMESTAMP".should.be.within(query) + + +class DeleteWithTimestampTest(BaseTimestampTest): + + def test_non_batch(self): + """ + we don't expect the model to come back at the end because the deletion timestamp should be in the future + """ + uid = uuid4() + tmp = TestTimestampModel.create(id=uid, count=1) + + TestTimestampModel.get(id=uid).should.be.ok + + tmp.timestamp(timedelta(seconds=5)).delete() + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + tmp = TestTimestampModel.create(id=uid, count=1) + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + # calling .timestamp sets the TS on the model + tmp.timestamp(timedelta(seconds=5)) + tmp._timestamp.should.be.ok + + # calling save clears the set timestamp + tmp.save() + tmp._timestamp.shouldnt.be.ok + + tmp.timestamp(timedelta(seconds=5)) + tmp.update() + tmp._timestamp.shouldnt.be.ok + + def test_blind_delete(self): + """ + we don't expect the model to come back at the end because the deletion timestamp should be in the future + """ + uid = uuid4() + tmp = TestTimestampModel.create(id=uid, count=1) + + TestTimestampModel.get(id=uid).should.be.ok + + TestTimestampModel.objects(id=uid).timestamp(timedelta(seconds=5)).delete() + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + tmp = TestTimestampModel.create(id=uid, count=1) + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + def test_blind_delete_with_datetime(self): + """ + we don't expect the model to come back at the end because the deletion timestamp should be in the future + """ + uid = uuid4() + tmp = TestTimestampModel.create(id=uid, count=1) + + TestTimestampModel.get(id=uid).should.be.ok + + plus_five_seconds = datetime.now() + timedelta(seconds=5) + + TestTimestampModel.objects(id=uid).timestamp(plus_five_seconds).delete() + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + tmp = TestTimestampModel.create(id=uid, count=1) + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + def test_delete_in_the_past(self): + uid = uuid4() + tmp = TestTimestampModel.create(id=uid, count=1) + + TestTimestampModel.get(id=uid).should.be.ok + + # delete the in past, should not affect the object created above + TestTimestampModel.objects(id=uid).timestamp(timedelta(seconds=-60)).delete() + + TestTimestampModel.get(id=uid) + + diff --git a/tests/integration/cqlengine/test_ttl.py b/tests/integration/cqlengine/test_ttl.py new file mode 100644 index 0000000..a9aa32d --- /dev/null +++ b/tests/integration/cqlengine/test_ttl.py @@ -0,0 +1,237 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from packaging.version import Version + +from cassandra import InvalidRequest +from cassandra.cqlengine.management import sync_table, drop_table +from tests.integration.cqlengine.base import BaseCassEngTestCase +from cassandra.cqlengine.models import Model +from uuid import uuid4 +from cassandra.cqlengine import columns +import mock +from cassandra.cqlengine.connection import get_session +from tests.integration import CASSANDRA_VERSION, greaterthancass20 + + +class TestTTLModel(Model): + id = columns.UUID(primary_key=True, default=lambda: uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + + +class BaseTTLTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseTTLTest, cls).setUpClass() + sync_table(TestTTLModel) + + @classmethod + def tearDownClass(cls): + super(BaseTTLTest, cls).tearDownClass() + drop_table(TestTTLModel) + + +class TestDefaultTTLModel(Model): + __options__ = {'default_time_to_live': 20} + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + + +class BaseDefaultTTLTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + if CASSANDRA_VERSION >= Version('2.0'): + super(BaseDefaultTTLTest, cls).setUpClass() + sync_table(TestDefaultTTLModel) + sync_table(TestTTLModel) + + @classmethod + def tearDownClass(cls): + if CASSANDRA_VERSION >= Version('2.0'): + super(BaseDefaultTTLTest, cls).tearDownClass() + drop_table(TestDefaultTTLModel) + drop_table(TestTTLModel) + + +class TTLQueryTests(BaseTTLTest): + + def test_update_queryset_ttl_success_case(self): + """ tests that ttls on querysets work as expected """ + + def test_select_ttl_failure(self): + """ tests that ttls on select queries raise an exception """ + + +class TTLModelTests(BaseTTLTest): + + def test_ttl_included_on_create(self): + """ tests that ttls on models work as expected """ + session = get_session() + + with mock.patch.object(session, 'execute') as m: + TestTTLModel.ttl(60).create(text="hello blake") + + query = m.call_args[0][0].query_string + self.assertIn("USING TTL", query) + + def test_queryset_is_returned_on_class(self): + """ + ensures we get a queryset descriptor back + """ + qs = TestTTLModel.ttl(60) + self.assertTrue(isinstance(qs, TestTTLModel.__queryset__), type(qs)) + + +class TTLInstanceUpdateTest(BaseTTLTest): + def test_update_includes_ttl(self): + session = get_session() + + model = TestTTLModel.create(text="goodbye blake") + with mock.patch.object(session, 'execute') as m: + model.ttl(60).update(text="goodbye forever") + + query = m.call_args[0][0].query_string + self.assertIn("USING TTL", query) + + def test_update_syntax_valid(self): + # sanity test that ensures the TTL syntax is accepted by cassandra + model = TestTTLModel.create(text="goodbye blake") + model.ttl(60).update(text="goodbye forever") + + +class TTLInstanceTest(BaseTTLTest): + def test_instance_is_returned(self): + """ + ensures that we properly handle the instance.ttl(60).save() scenario + :return: + """ + o = TestTTLModel.create(text="whatever") + o.text = "new stuff" + o = o.ttl(60) + self.assertEqual(60, o._ttl) + + def test_ttl_is_include_with_query_on_update(self): + session = get_session() + + o = TestTTLModel.create(text="whatever") + o.text = "new stuff" + o = o.ttl(60) + + with mock.patch.object(session, 'execute') as m: + o.save() + + query = m.call_args[0][0].query_string + self.assertIn("USING TTL", query) + + +class TTLBlindUpdateTest(BaseTTLTest): + def test_ttl_included_with_blind_update(self): + session = get_session() + + o = TestTTLModel.create(text="whatever") + tid = o.id + + with mock.patch.object(session, 'execute') as m: + TestTTLModel.objects(id=tid).ttl(60).update(text="bacon") + + query = m.call_args[0][0].query_string + self.assertIn("USING TTL", query) + + +class TTLDefaultTest(BaseDefaultTTLTest): + def get_default_ttl(self, table_name): + session = get_session() + try: + default_ttl = session.execute("SELECT default_time_to_live FROM system_schema.tables " + "WHERE keyspace_name = 'cqlengine_test' AND table_name = '{0}'".format(table_name)) + except InvalidRequest: + default_ttl = session.execute("SELECT default_time_to_live FROM system.schema_columnfamilies " + "WHERE keyspace_name = 'cqlengine_test' AND columnfamily_name = '{0}'".format(table_name)) + return default_ttl[0]['default_time_to_live'] + + def test_default_ttl_not_set(self): + session = get_session() + + o = TestTTLModel.create(text="some text") + tid = o.id + + self.assertIsNone(o._ttl) + + default_ttl = self.get_default_ttl('test_ttlmodel') + self.assertEqual(default_ttl, 0) + + with mock.patch.object(session, 'execute') as m: + TestTTLModel.objects(id=tid).update(text="aligators") + + query = m.call_args[0][0].query_string + self.assertNotIn("USING TTL", query) + + def test_default_ttl_set(self): + session = get_session() + + o = TestDefaultTTLModel.create(text="some text on ttl") + tid = o.id + + # Should not be set, it's handled by Cassandra + self.assertIsNone(o._ttl) + + default_ttl = self.get_default_ttl('test_default_ttlmodel') + self.assertEqual(default_ttl, 20) + + with mock.patch.object(session, 'execute') as m: + TestTTLModel.objects(id=tid).update(text="aligators expired") + + # Should not be set either + query = m.call_args[0][0].query_string + self.assertNotIn("USING TTL", query) + + def test_default_ttl_modify(self): + session = get_session() + + default_ttl = self.get_default_ttl('test_default_ttlmodel') + self.assertEqual(default_ttl, 20) + + TestDefaultTTLModel.__options__ = {'default_time_to_live': 10} + sync_table(TestDefaultTTLModel) + + default_ttl = self.get_default_ttl('test_default_ttlmodel') + self.assertEqual(default_ttl, 10) + + # Restore default TTL + TestDefaultTTLModel.__options__ = {'default_time_to_live': 20} + sync_table(TestDefaultTTLModel) + + def test_override_default_ttl(self): + session = get_session() + o = TestDefaultTTLModel.create(text="some text on ttl") + tid = o.id + + o.ttl(3600) + self.assertEqual(o._ttl, 3600) + + with mock.patch.object(session, 'execute') as m: + TestDefaultTTLModel.objects(id=tid).ttl(None).update(text="aligators expired") + + query = m.call_args[0][0].query_string + self.assertNotIn("USING TTL", query) diff --git a/tests/integration/datatype_utils.py b/tests/integration/datatype_utils.py new file mode 100644 index 0000000..8a1c813 --- /dev/null +++ b/tests/integration/datatype_utils.py @@ -0,0 +1,176 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from decimal import Decimal +from datetime import datetime, date, time +from uuid import uuid1, uuid4 +import six + +from cassandra.util import OrderedMap, Date, Time, sortedset, Duration + +from tests.integration import get_server_versions + + +PRIMITIVE_DATATYPES = sortedset([ + 'ascii', + 'bigint', + 'blob', + 'boolean', + 'decimal', + 'double', + 'float', + 'inet', + 'int', + 'text', + 'timestamp', + 'timeuuid', + 'uuid', + 'varchar', + 'varint', +]) + +PRIMITIVE_DATATYPES_KEYS = PRIMITIVE_DATATYPES.copy() + +COLLECTION_TYPES = sortedset([ + 'list', + 'set', + 'map', +]) + + +def update_datatypes(): + _cass_version, _cql_version = get_server_versions() + + if _cass_version >= (2, 1, 0): + COLLECTION_TYPES.add('tuple') + + if _cass_version >= (2, 2, 0): + PRIMITIVE_DATATYPES.update(['date', 'time', 'smallint', 'tinyint']) + PRIMITIVE_DATATYPES_KEYS.update(['date', 'time', 'smallint', 'tinyint']) + if _cass_version >= (3, 10): + PRIMITIVE_DATATYPES.add('duration') + + global SAMPLE_DATA + SAMPLE_DATA = get_sample_data() + + +def get_sample_data(): + sample_data = {} + + for datatype in PRIMITIVE_DATATYPES: + if datatype == 'ascii': + sample_data[datatype] = 'ascii' + + elif datatype == 'bigint': + sample_data[datatype] = 2 ** 63 - 1 + + elif datatype == 'blob': + sample_data[datatype] = bytearray(b'hello world') + + elif datatype == 'boolean': + sample_data[datatype] = True + + elif datatype == 'decimal': + sample_data[datatype] = Decimal('12.3E+7') + + elif datatype == 'double': + sample_data[datatype] = 1.23E+8 + + elif datatype == 'float': + sample_data[datatype] = 3.4028234663852886e+38 + + elif datatype == 'inet': + sample_data[datatype] = ('123.123.123.123', '2001:db8:85a3:8d3:1319:8a2e:370:7348') + if six.PY3: + import ipaddress + sample_data[datatype] += (ipaddress.IPv4Address("123.123.123.123"), + ipaddress.IPv6Address('2001:db8:85a3:8d3:1319:8a2e:370:7348')) + + elif datatype == 'int': + sample_data[datatype] = 2147483647 + + elif datatype == 'text': + sample_data[datatype] = 'text' + + elif datatype == 'timestamp': + sample_data[datatype] = datetime(2013, 12, 31, 23, 59, 59, 999000) + + elif datatype == 'timeuuid': + sample_data[datatype] = uuid1() + + elif datatype == 'uuid': + sample_data[datatype] = uuid4() + + elif datatype == 'varchar': + sample_data[datatype] = 'varchar' + + elif datatype == 'varint': + sample_data[datatype] = int(str(2147483647) + '000') + + elif datatype == 'date': + sample_data[datatype] = Date(date(2015, 1, 15)) + + elif datatype == 'time': + sample_data[datatype] = Time(time(16, 47, 25, 7)) + + elif datatype == 'tinyint': + sample_data[datatype] = 123 + + elif datatype == 'smallint': + sample_data[datatype] = 32523 + + elif datatype == 'duration': + sample_data[datatype] = Duration(months=2, days=12, nanoseconds=21231) + + else: + raise Exception("Missing handling of {0}".format(datatype)) + + return sample_data + +SAMPLE_DATA = get_sample_data() + + +def get_sample(datatype): + """ + Helper method to access created sample data for primitive types + """ + if isinstance(SAMPLE_DATA[datatype], tuple): + return SAMPLE_DATA[datatype][0] + return SAMPLE_DATA[datatype] + + +def get_all_samples(datatype): + """ + Helper method to access created sample data for primitive types + """ + if isinstance(SAMPLE_DATA[datatype], tuple): + return SAMPLE_DATA[datatype] + return SAMPLE_DATA[datatype], + + +def get_collection_sample(collection_type, datatype): + """ + Helper method to access created sample data for collection types + """ + + if collection_type == 'list': + return [get_sample(datatype), get_sample(datatype)] + elif collection_type == 'set': + return sortedset([get_sample(datatype)]) + elif collection_type == 'map': + return OrderedMap([(get_sample(datatype), get_sample(datatype))]) + elif collection_type == 'tuple': + return (get_sample(datatype),) + else: + raise Exception('Missing handling of non-primitive type {0}.'.format(collection_type)) diff --git a/tests/integration/long/__init__.py b/tests/integration/long/__init__.py new file mode 100644 index 0000000..447f488 --- /dev/null +++ b/tests/integration/long/__init__.py @@ -0,0 +1,22 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +try: + from ccmlib import common +except ImportError as e: + raise unittest.SkipTest('ccm is a dependency for integration tests:', e) diff --git a/tests/integration/long/test_consistency.py b/tests/integration/long/test_consistency.py new file mode 100644 index 0000000..bb6828a --- /dev/null +++ b/tests/integration/long/test_consistency.py @@ -0,0 +1,367 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import struct, time, traceback, sys, logging + +from random import randint +from cassandra import ConsistencyLevel, OperationTimedOut, ReadTimeout, WriteTimeout, Unavailable +from cassandra.cluster import Cluster +from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy +from cassandra.query import SimpleStatement +from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass + +from tests.integration.long.utils import (force_stop, create_schema, wait_for_down, wait_for_up, + start, CoordinatorStats) + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +ALL_CONSISTENCY_LEVELS = set([ + ConsistencyLevel.ANY, ConsistencyLevel.ONE, ConsistencyLevel.TWO, + ConsistencyLevel.QUORUM, ConsistencyLevel.THREE, + ConsistencyLevel.ALL, ConsistencyLevel.LOCAL_QUORUM, + ConsistencyLevel.EACH_QUORUM]) + +MULTI_DC_CONSISTENCY_LEVELS = set([ + ConsistencyLevel.LOCAL_QUORUM, ConsistencyLevel.EACH_QUORUM]) + +SINGLE_DC_CONSISTENCY_LEVELS = ALL_CONSISTENCY_LEVELS - MULTI_DC_CONSISTENCY_LEVELS + +log = logging.getLogger(__name__) + + +def setup_module(): + use_singledc() + + +class ConsistencyTests(unittest.TestCase): + + def setUp(self): + self.coordinator_stats = CoordinatorStats() + + def _cl_failure(self, consistency_level, e): + self.fail('Instead of success, saw %s for CL.%s:\n\n%s' % ( + e, ConsistencyLevel.value_to_name[consistency_level], + traceback.format_exc())) + + def _cl_expected_failure(self, cl): + self.fail('Test passed at ConsistencyLevel.%s:\n\n%s' % ( + ConsistencyLevel.value_to_name[cl], traceback.format_exc())) + + def _insert(self, session, keyspace, count, consistency_level=ConsistencyLevel.ONE): + session.execute('USE %s' % keyspace) + for i in range(count): + ss = SimpleStatement('INSERT INTO cf(k, i) VALUES (0, 0)', + consistency_level=consistency_level) + execute_until_pass(session, ss) + + def _query(self, session, keyspace, count, consistency_level=ConsistencyLevel.ONE): + routing_key = struct.pack('>i', 0) + for i in range(count): + ss = SimpleStatement('SELECT * FROM cf WHERE k = 0', + consistency_level=consistency_level, + routing_key=routing_key) + tries = 0 + while True: + if tries > 100: + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(ss)) + try: + self.coordinator_stats.add_coordinator(session.execute_async(ss)) + break + except (OperationTimedOut, ReadTimeout): + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + time.sleep(1) + + def _assert_writes_succeed(self, session, keyspace, consistency_levels): + for cl in consistency_levels: + self.coordinator_stats.reset_counts() + try: + self._insert(session, keyspace, 1, cl) + except Exception as e: + self._cl_failure(cl, e) + + def _assert_reads_succeed(self, session, keyspace, consistency_levels, expected_reader=3): + for cl in consistency_levels: + self.coordinator_stats.reset_counts() + try: + self._query(session, keyspace, 1, cl) + for i in range(3): + if i == expected_reader: + self.coordinator_stats.assert_query_count_equals(self, i, 1) + else: + self.coordinator_stats.assert_query_count_equals(self, i, 0) + except Exception as e: + self._cl_failure(cl, e) + + def _assert_writes_fail(self, session, keyspace, consistency_levels): + for cl in consistency_levels: + self.coordinator_stats.reset_counts() + try: + self._insert(session, keyspace, 1, cl) + self._cl_expected_failure(cl) + except (Unavailable, WriteTimeout): + pass + + def _assert_reads_fail(self, session, keyspace, consistency_levels): + for cl in consistency_levels: + self.coordinator_stats.reset_counts() + try: + self._query(session, keyspace, 1, cl) + self._cl_expected_failure(cl) + except (Unavailable, ReadTimeout): + pass + + def _test_tokenaware_one_node_down(self, keyspace, rf, accepted): + cluster = Cluster( + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + protocol_version=PROTOCOL_VERSION) + session = cluster.connect(wait_for_all_pools=True) + wait_for_up(cluster, 1) + wait_for_up(cluster, 2) + + create_schema(cluster, session, keyspace, replication_factor=rf) + self._insert(session, keyspace, count=1) + self._query(session, keyspace, count=1) + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 1) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + + try: + force_stop(2) + wait_for_down(cluster, 2) + + self._assert_writes_succeed(session, keyspace, accepted) + self._assert_reads_succeed(session, keyspace, + accepted - set([ConsistencyLevel.ANY])) + self._assert_writes_fail(session, keyspace, + SINGLE_DC_CONSISTENCY_LEVELS - accepted) + self._assert_reads_fail(session, keyspace, + SINGLE_DC_CONSISTENCY_LEVELS - accepted) + finally: + start(2) + wait_for_up(cluster, 2) + + cluster.shutdown() + + def test_rfone_tokenaware_one_node_down(self): + self._test_tokenaware_one_node_down( + keyspace='test_rfone_tokenaware', + rf=1, + accepted=set([ConsistencyLevel.ANY])) + + def test_rftwo_tokenaware_one_node_down(self): + self._test_tokenaware_one_node_down( + keyspace='test_rftwo_tokenaware', + rf=2, + accepted=set([ConsistencyLevel.ANY, ConsistencyLevel.ONE])) + + def test_rfthree_tokenaware_one_node_down(self): + self._test_tokenaware_one_node_down( + keyspace='test_rfthree_tokenaware', + rf=3, + accepted=set([ConsistencyLevel.ANY, ConsistencyLevel.ONE, + ConsistencyLevel.TWO, ConsistencyLevel.QUORUM])) + + def test_rfthree_tokenaware_none_down(self): + keyspace = 'test_rfthree_tokenaware_none_down' + cluster = Cluster( + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + protocol_version=PROTOCOL_VERSION) + session = cluster.connect(wait_for_all_pools=True) + wait_for_up(cluster, 1) + wait_for_up(cluster, 2) + + create_schema(cluster, session, keyspace, replication_factor=3) + self._insert(session, keyspace, count=1) + self._query(session, keyspace, count=1) + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 1) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + + self.coordinator_stats.reset_counts() + + self._assert_writes_succeed(session, keyspace, SINGLE_DC_CONSISTENCY_LEVELS) + self._assert_reads_succeed(session, keyspace, + SINGLE_DC_CONSISTENCY_LEVELS - set([ConsistencyLevel.ANY]), + expected_reader=2) + + cluster.shutdown() + + def _test_downgrading_cl(self, keyspace, rf, accepted): + cluster = Cluster( + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + default_retry_policy=DowngradingConsistencyRetryPolicy(), + protocol_version=PROTOCOL_VERSION) + session = cluster.connect(wait_for_all_pools=True) + + create_schema(cluster, session, keyspace, replication_factor=rf) + self._insert(session, keyspace, 1) + self._query(session, keyspace, 1) + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 1) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + + try: + force_stop(2) + wait_for_down(cluster, 2) + + self._assert_writes_succeed(session, keyspace, accepted) + self._assert_reads_succeed(session, keyspace, + accepted - set([ConsistencyLevel.ANY])) + self._assert_writes_fail(session, keyspace, + SINGLE_DC_CONSISTENCY_LEVELS - accepted) + self._assert_reads_fail(session, keyspace, + SINGLE_DC_CONSISTENCY_LEVELS - accepted) + finally: + start(2) + wait_for_up(cluster, 2) + + cluster.shutdown() + + def test_rfone_downgradingcl(self): + self._test_downgrading_cl( + keyspace='test_rfone_downgradingcl', + rf=1, + accepted=set([ConsistencyLevel.ANY])) + + def test_rftwo_downgradingcl(self): + self._test_downgrading_cl( + keyspace='test_rftwo_downgradingcl', + rf=2, + accepted=SINGLE_DC_CONSISTENCY_LEVELS) + + def test_rfthree_roundrobin_downgradingcl(self): + keyspace = 'test_rfthree_roundrobin_downgradingcl' + cluster = Cluster( + load_balancing_policy=RoundRobinPolicy(), + default_retry_policy=DowngradingConsistencyRetryPolicy(), + protocol_version=PROTOCOL_VERSION) + self.rfthree_downgradingcl(cluster, keyspace, True) + + def test_rfthree_tokenaware_downgradingcl(self): + keyspace = 'test_rfthree_tokenaware_downgradingcl' + cluster = Cluster( + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + default_retry_policy=DowngradingConsistencyRetryPolicy(), + protocol_version=PROTOCOL_VERSION) + self.rfthree_downgradingcl(cluster, keyspace, False) + + def rfthree_downgradingcl(self, cluster, keyspace, roundrobin): + session = cluster.connect(wait_for_all_pools=True) + + create_schema(cluster, session, keyspace, replication_factor=2) + self._insert(session, keyspace, count=12) + self._query(session, keyspace, count=12) + + if roundrobin: + self.coordinator_stats.assert_query_count_equals(self, 1, 4) + self.coordinator_stats.assert_query_count_equals(self, 2, 4) + self.coordinator_stats.assert_query_count_equals(self, 3, 4) + else: + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 12) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + + try: + self.coordinator_stats.reset_counts() + force_stop(2) + wait_for_down(cluster, 2) + + self._assert_writes_succeed(session, keyspace, SINGLE_DC_CONSISTENCY_LEVELS) + + # Test reads that expected to complete successfully + for cl in SINGLE_DC_CONSISTENCY_LEVELS - set([ConsistencyLevel.ANY]): + self.coordinator_stats.reset_counts() + self._query(session, keyspace, 12, consistency_level=cl) + if roundrobin: + self.coordinator_stats.assert_query_count_equals(self, 1, 6) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(self, 3, 6) + else: + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(self, 3, 12) + finally: + start(2) + wait_for_up(cluster, 2) + + session.cluster.shutdown() + + # TODO: can't be done in this class since we reuse the ccm cluster + # instead we should create these elsewhere + # def test_rfthree_downgradingcl_twodcs(self): + # def test_rfthree_downgradingcl_twodcs_dcaware(self): + + +class ConnectivityTest(unittest.TestCase): + + def get_node_not_x(self, node_to_stop): + nodes = [1, 2, 3] + for num in nodes: + if num is not node_to_stop: + return num + + def test_pool_with_host_down(self): + """ + Test to ensure that cluster.connect() doesn't return prior to pools being initialized. + + This test will figure out which host our pool logic will connect to first. It then shuts that server down. + Previously the cluster.connect() would return prior to the pools being initialized, and the first queries would + return a no host exception + + @since 3.7.0 + @jira_ticket PYTHON-617 + @expected_result query should complete successfully + + @test_category connection + """ + + # find the first node, we will try create connections to, shut it down. + + # We will be shuting down a random house, so we need a complete contact list + all_contact_points = ["127.0.0.1", "127.0.0.2", "127.0.0.3"] + + # Connect up and find out which host will bet queries routed to to first + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster.connect(wait_for_all_pools=True) + hosts = cluster.metadata.all_hosts() + address = hosts[0].address + node_to_stop = int(address.split('.')[-1:][0]) + cluster.shutdown() + + # We now register a cluster that has it's Control Connection NOT on the node that we are shutting down. + # We do this so we don't miss the event + contact_point = '127.0.0.{0}'.format(self.get_node_not_x(node_to_stop)) + cluster = Cluster(contact_points=[contact_point], protocol_version=PROTOCOL_VERSION) + cluster.connect(wait_for_all_pools=True) + try: + force_stop(node_to_stop) + wait_for_down(cluster, node_to_stop) + # Attempt a query against that node. It should complete + cluster2 = Cluster(contact_points=all_contact_points, protocol_version=PROTOCOL_VERSION) + session2 = cluster2.connect() + session2.execute("SELECT * FROM system.local") + finally: + cluster2.shutdown() + start(node_to_stop) + wait_for_up(cluster, node_to_stop) + cluster.shutdown() + + + diff --git a/tests/integration/long/test_failure_types.py b/tests/integration/long/test_failure_types.py new file mode 100644 index 0000000..a67c05a --- /dev/null +++ b/tests/integration/long/test_failure_types.py @@ -0,0 +1,395 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys,logging, traceback, time, re + +from cassandra import (ConsistencyLevel, OperationTimedOut, ReadTimeout, WriteTimeout, ReadFailure, WriteFailure, + FunctionFailure, ProtocolVersion) +from cassandra.cluster import Cluster, NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy +from cassandra.concurrent import execute_concurrent_with_args +from cassandra.query import SimpleStatement +from tests.integration import use_singledc, PROTOCOL_VERSION, get_cluster, setup_keyspace, remove_cluster, get_node, \ + requiresmallclockgranularity +from mock import Mock + +try: + import unittest2 as unittest +except ImportError: + import unittest + +log = logging.getLogger(__name__) + + +def setup_module(): + """ + We need some custom setup for this module. All unit tests in this module + require protocol >=4. We won't bother going through the setup required unless that is the + protocol version we are using. + """ + + # If we aren't at protocol v 4 or greater don't waste time setting anything up, all tests will be skipped + if PROTOCOL_VERSION >= 4: + use_singledc(start=False) + ccm_cluster = get_cluster() + ccm_cluster.stop() + config_options = {'tombstone_failure_threshold': 2000, 'tombstone_warn_threshold': 1000} + ccm_cluster.set_configuration_options(config_options) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + setup_keyspace() + + +def teardown_module(): + """ + The rest of the tests don't need custom tombstones + remove the cluster so as to not interfere with other tests. + """ + if PROTOCOL_VERSION >= 4: + remove_cluster() + + +class ClientExceptionTests(unittest.TestCase): + + def setUp(self): + """ + Test is skipped if run with native protocol version <4 + """ + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest( + "Native protocol 4,0+ is required for custom payloads, currently using %r" + % (PROTOCOL_VERSION,)) + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.session = self.cluster.connect() + self.nodes_currently_failing = [] + self.node1, self.node2, self.node3 = get_cluster().nodes.values() + + def tearDown(self): + + self.cluster.shutdown() + failing_nodes = [] + + # Restart the nodes to fully functional again + self.setFailingNodes(failing_nodes, "testksfail") + + def execute_helper(self, session, query): + tries = 0 + while tries < 100: + try: + return session.execute(query) + except OperationTimedOut: + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + + def execute_concurrent_args_helper(self, session, query, params): + tries = 0 + while tries < 100: + try: + return execute_concurrent_with_args(session, query, params, concurrency=50) + except (ReadTimeout, WriteTimeout, OperationTimedOut, ReadFailure, WriteFailure): + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + + def setFailingNodes(self, failing_nodes, keyspace): + """ + This method will take in a set of failing nodes, and toggle all of the nodes in the provided list to fail + writes. + @param failing_nodes A definitive list of nodes that should fail writes + @param keyspace The keyspace to enable failures on + + """ + + # Ensure all of the nodes on the list have failures enabled + for node in failing_nodes: + if node not in self.nodes_currently_failing: + node.stop(wait_other_notice=True, gently=False) + node.start(jvm_args=[" -Dcassandra.test.fail_writes_ks=" + keyspace], wait_for_binary_proto=True, + wait_other_notice=True) + self.nodes_currently_failing.append(node) + + # Ensure all nodes not on the list, but that are currently set to failing are enabled + for node in self.nodes_currently_failing: + if node not in failing_nodes: + node.stop(wait_other_notice=True, gently=True) + node.start(wait_for_binary_proto=True, wait_other_notice=True) + self.nodes_currently_failing.remove(node) + + def _perform_cql_statement(self, text, consistency_level, expected_exception, session=None): + """ + Simple helper method to preform cql statements and check for expected exception + @param text CQl statement to execute + @param consistency_level Consistency level at which it is to be executed + @param expected_exception Exception expected to be throw or none + """ + if session is None: + session = self.session + statement = SimpleStatement(text) + statement.consistency_level = consistency_level + + if expected_exception is None: + self.execute_helper(session, statement) + else: + with self.assertRaises(expected_exception) as cm: + self.execute_helper(session, statement) + if ProtocolVersion.uses_error_code_map(PROTOCOL_VERSION): + if isinstance(cm.exception, ReadFailure): + self.assertEqual(list(cm.exception.error_code_map.values())[0], 1) + if isinstance(cm.exception, WriteFailure): + self.assertEqual(list(cm.exception.error_code_map.values())[0], 0) + + def test_write_failures_from_coordinator(self): + """ + Test to validate that write failures from the coordinator are surfaced appropriately. + + test_write_failures_from_coordinator Enable write failures on the various nodes using a custom jvm flag, + cassandra.test.fail_writes_ks. This will cause writes to fail on that specific node. Depending on the replication + factor of the keyspace, and the consistency level, we will expect the coordinator to send WriteFailure, or not. + + + @since 2.6.0, 3.7.0 + @jira_ticket PYTHON-238, PYTHON-619 + @expected_result Appropriate write failures from the coordinator + + @test_category queries:basic + """ + + # Setup temporary keyspace. + self._perform_cql_statement( + """ + CREATE KEYSPACE testksfail + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} + """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + + # create table + self._perform_cql_statement( + """ + CREATE TABLE testksfail.test ( + k int PRIMARY KEY, + v int ) + """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + + # Disable one node + failing_nodes = [self.node1] + self.setFailingNodes(failing_nodes, "testksfail") + + # With one node disabled we would expect a write failure with ConsistencyLevel of all + self._perform_cql_statement( + """ + INSERT INTO testksfail.test (k, v) VALUES (1, 0 ) + """, consistency_level=ConsistencyLevel.ALL, expected_exception=WriteFailure) + + # We have two nodes left so a write with consistency level of QUORUM should complete as expected + self._perform_cql_statement( + """ + INSERT INTO testksfail.test (k, v) VALUES (1, 0 ) + """, consistency_level=ConsistencyLevel.QUORUM, expected_exception=None) + + failing_nodes = [] + + # Restart the nodes to fully functional again + self.setFailingNodes(failing_nodes, "testksfail") + + # Drop temporary keyspace + self._perform_cql_statement( + """ + DROP KEYSPACE testksfail + """, consistency_level=ConsistencyLevel.ANY, expected_exception=None) + + def test_tombstone_overflow_read_failure(self): + """ + Test to validate that a ReadFailure is returned from the node when a specified threshold of tombstombs is + reached. + + test_tombstomb_overflow_read_failure First sets the tombstone failure threshold down to a level that allows it + to be more easily encountered. We then create some wide rows and ensure they are deleted appropriately. This + produces the correct amount of tombstombs. Upon making a simple query we expect to get a read failure back + from the coordinator. + + + @since 2.6.0, 3.7.0 + @jira_ticket PYTHON-238, PYTHON-619 + @expected_result Appropriate write failures from the coordinator + + @test_category queries:basic + """ + + # Setup table for "wide row" + self._perform_cql_statement( + """ + CREATE TABLE test3rf.test2 ( + k int, + v0 int, + v1 int, PRIMARY KEY (k,v0)) + """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + + statement = self.session.prepare("INSERT INTO test3rf.test2 (k, v0,v1) VALUES (1,?,1)") + parameters = [(x,) for x in range(3000)] + self.execute_concurrent_args_helper(self.session, statement, parameters) + + statement = self.session.prepare("DELETE v1 FROM test3rf.test2 WHERE k = 1 AND v0 =?") + parameters = [(x,) for x in range(2001)] + self.execute_concurrent_args_helper(self.session, statement, parameters) + + self._perform_cql_statement( + """ + SELECT * FROM test3rf.test2 WHERE k = 1 + """, consistency_level=ConsistencyLevel.ALL, expected_exception=ReadFailure) + + self._perform_cql_statement( + """ + DROP TABLE test3rf.test2; + """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + + def test_user_function_failure(self): + """ + Test to validate that exceptions in user defined function are correctly surfaced by the driver to us. + + test_user_function_failure First creates a table to use for testing. Then creates a function that will throw an + exception when invoked. It then invokes the function and expects a FunctionException. Finally it preforms + cleanup operations. + + @since 2.6.0 + @jira_ticket PYTHON-238 + @expected_result Function failures when UDF throws exception + + @test_category queries:basic + """ + + # create UDF that throws an exception + self._perform_cql_statement( + """ + CREATE FUNCTION test3rf.test_failure(d double) + RETURNS NULL ON NULL INPUT + RETURNS double + LANGUAGE java AS 'throw new RuntimeException("failure");'; + """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + + # Create test table + self._perform_cql_statement( + """ + CREATE TABLE test3rf.d (k int PRIMARY KEY , d double); + """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + + # Insert some values + self._perform_cql_statement( + """ + INSERT INTO test3rf.d (k,d) VALUES (0, 5.12); + """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + + # Run the function expect a function failure exception + self._perform_cql_statement( + """ + SELECT test_failure(d) FROM test3rf.d WHERE k = 0; + """, consistency_level=ConsistencyLevel.ALL, expected_exception=FunctionFailure) + + self._perform_cql_statement( + """ + DROP FUNCTION test3rf.test_failure; + """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + + self._perform_cql_statement( + """ + DROP TABLE test3rf.d; + """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) + + +@requiresmallclockgranularity +class TimeoutTimerTest(unittest.TestCase): + def setUp(self): + """ + Setup sessions and pause node1 + """ + + # self.node1, self.node2, self.node3 = get_cluster().nodes.values() + + node1 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.1" + ) + ) + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, execution_profiles={EXEC_PROFILE_DEFAULT: node1}) + self.session = self.cluster.connect(wait_for_all_pools=True) + + self.control_connection_host_number = 1 + self.node_to_stop = get_node(self.control_connection_host_number) + + ddl = ''' + CREATE TABLE test3rf.timeout ( + k int PRIMARY KEY, + v int )''' + self.session.execute(ddl) + self.node_to_stop.pause() + + def tearDown(self): + """ + Shutdown cluster and resume node1 + """ + self.node_to_stop.resume() + self.session.execute("DROP TABLE test3rf.timeout") + self.cluster.shutdown() + + def test_async_timeouts(self): + """ + Test to validate that timeouts are honored + + + Exercise the underlying timeouts, by attempting a query that will timeout. Ensure the default timeout is still + honored. Make sure that user timeouts are also honored. + + @since 2.7.0 + @jira_ticket PYTHON-108 + @expected_result timeouts should be honored + + @test_category + + """ + + # Because node1 is stopped these statements will all timeout + ss = SimpleStatement('SELECT * FROM test3rf.test', consistency_level=ConsistencyLevel.ALL) + + # Test with default timeout (should be 10) + start_time = time.time() + future = self.session.execute_async(ss) + with self.assertRaises(OperationTimedOut): + future.result() + end_time = time.time() + total_time = end_time-start_time + expected_time = self.session.default_timeout + # check timeout and ensure it's within a reasonable range + self.assertAlmostEqual(expected_time, total_time, delta=.05) + + # Test with user defined timeout (Should be 1) + start_time = time.time() + future = self.session.execute_async(ss, timeout=1) + mock_callback = Mock(return_value=None) + mock_errorback = Mock(return_value=None) + future.add_callback(mock_callback) + future.add_errback(mock_errorback) + + with self.assertRaises(OperationTimedOut): + future.result() + end_time = time.time() + total_time = end_time-start_time + expected_time = 1 + # check timeout and ensure it's within a reasonable range + self.assertAlmostEqual(expected_time, total_time, delta=.05) + self.assertTrue(mock_errorback.called) + self.assertFalse(mock_callback.called) diff --git a/tests/integration/long/test_ipv6.py b/tests/integration/long/test_ipv6.py new file mode 100644 index 0000000..5f2bdbd --- /dev/null +++ b/tests/integration/long/test_ipv6.py @@ -0,0 +1,124 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os, socket, errno +from ccmlib import common + +from cassandra.cluster import Cluster, NoHostAvailable +from cassandra.io.asyncorereactor import AsyncoreConnection + +from tests import is_monkey_patched +from tests.integration import use_cluster, remove_cluster, PROTOCOL_VERSION + +if is_monkey_patched(): + LibevConnection = -1 + AsyncoreConnection = -1 +else: + try: + from cassandra.io.libevreactor import LibevConnection + except ImportError: + LibevConnection = None + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +# If more modules do IPV6 testing, this can be moved down to integration.__init__. +# For now, just keeping the clutter here +IPV6_CLUSTER_NAME = 'ipv6_test_cluster' + + +def setup_module(module): + if os.name != "nt": + validate_host_viable() + # We use a dedicated cluster (instead of common singledc, as in other tests) because + # it's most likely that the test host will only have one local ipv6 address (::1) + # singledc has three + use_cluster(IPV6_CLUSTER_NAME, [1], ipformat='::%d') + + +def teardown_module(): + remove_cluster() + + +def validate_ccm_viable(): + try: + common.normalize_interface(('::1', 0)) + except: + raise unittest.SkipTest('this version of ccm does not support ipv6') + + +def validate_host_viable(): + # this is something ccm does when starting, but preemptively check to avoid + # spinning up the cluster if it's not going to work + try: + common.assert_socket_available(('::1', 9042)) + except: + raise unittest.SkipTest('failed binding ipv6 loopback ::1 on 9042') + + +class IPV6ConnectionTest(object): + + connection_class = None + + def test_connect(self): + cluster = Cluster(connection_class=self.connection_class, contact_points=['::1'], connect_timeout=10, + protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + future = session.execute_async("SELECT * FROM system.local") + future.result() + self.assertEqual(future._current_host.address, '::1') + cluster.shutdown() + + def test_error(self): + cluster = Cluster(connection_class=self.connection_class, contact_points=['::1'], port=9043, + connect_timeout=10, protocol_version=PROTOCOL_VERSION) + self.assertRaisesRegexp(NoHostAvailable, '\(\'Unable to connect.*%s.*::1\', 9043.*Connection refused.*' + % errno.ECONNREFUSED, cluster.connect) + + def test_error_multiple(self): + if len(socket.getaddrinfo('localhost', 9043, socket.AF_UNSPEC, socket.SOCK_STREAM)) < 2: + raise unittest.SkipTest('localhost only resolves one address') + cluster = Cluster(connection_class=self.connection_class, contact_points=['localhost'], port=9043, + connect_timeout=10, protocol_version=PROTOCOL_VERSION) + self.assertRaisesRegexp(NoHostAvailable, '\(\'Unable to connect.*Tried connecting to \[\(.*\(.*\].*Last error', + cluster.connect) + + +class LibevConnectionTests(IPV6ConnectionTest, unittest.TestCase): + + connection_class = LibevConnection + + def setUp(self): + if os.name == "nt": + raise unittest.SkipTest("IPv6 is currently not supported under Windows") + + if LibevConnection == -1: + raise unittest.SkipTest("Can't test libev with monkey patching") + elif LibevConnection is None: + raise unittest.SkipTest("Libev does not appear to be installed properly") + + +class AsyncoreConnectionTests(IPV6ConnectionTest, unittest.TestCase): + + connection_class = AsyncoreConnection + + def setUp(self): + if os.name == "nt": + raise unittest.SkipTest("IPv6 is currently not supported under Windows") + + if AsyncoreConnection == -1: + raise unittest.SkipTest("Can't test asyncore with monkey patching") diff --git a/tests/integration/long/test_large_data.py b/tests/integration/long/test_large_data.py new file mode 100644 index 0000000..76cafa0 --- /dev/null +++ b/tests/integration/long/test_large_data.py @@ -0,0 +1,275 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + from Queue import Queue, Empty +except ImportError: + from queue import Queue, Empty # noqa + +from struct import pack +import logging, sys, traceback, time + +from cassandra import ConsistencyLevel, OperationTimedOut, WriteTimeout +from cassandra.cluster import Cluster +from cassandra.query import dict_factory +from cassandra.query import SimpleStatement +from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration.long.utils import create_schema + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +log = logging.getLogger(__name__) + + +def setup_module(): + use_singledc() + + +# Converts an integer to an string of letters +def create_column_name(i): + letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] + + column_name = '' + while True: + column_name += letters[i % 10] + i = i // 10 + if not i: + break + + if column_name == 'if': + column_name = 'special_case' + return column_name + + +class LargeDataTests(unittest.TestCase): + + def setUp(self): + self.keyspace = 'large_data' + + def make_session_and_keyspace(self): + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + session.default_timeout = 20.0 # increase the default timeout + session.row_factory = dict_factory + + create_schema(cluster, session, self.keyspace) + return session + + def batch_futures(self, session, statement_generator): + concurrency = 10 + futures = Queue(maxsize=concurrency) + number_of_timeouts = 0 + for i, statement in enumerate(statement_generator): + if i > 0 and i % (concurrency - 1) == 0: + # clear the existing queue + while True: + try: + futures.get_nowait().result() + except (OperationTimedOut, WriteTimeout): + ex_type, ex, tb = sys.exc_info() + number_of_timeouts += 1 + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + time.sleep(1) + except Empty: + break + + future = session.execute_async(statement) + futures.put_nowait(future) + + while True: + try: + futures.get_nowait().result() + except (OperationTimedOut, WriteTimeout): + ex_type, ex, tb = sys.exc_info() + number_of_timeouts += 1 + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + time.sleep(1) + except Empty: + break + return number_of_timeouts + + def test_wide_rows(self): + table = 'wide_rows' + session = self.make_session_and_keyspace() + session.execute('CREATE TABLE %s (k INT, i INT, PRIMARY KEY(k, i))' % table) + + prepared = session.prepare('INSERT INTO %s (k, i) VALUES (0, ?)' % (table, )) + + # Write via async futures + self.batch_futures(session, (prepared.bind((i, )) for i in range(100000))) + + # Read + results = session.execute('SELECT i FROM %s WHERE k=0' % (table, )) + + # Verify + for i, row in enumerate(results): + self.assertAlmostEqual(row['i'], i, delta=3) + + session.cluster.shutdown() + + def test_wide_batch_rows(self): + """ + Test for inserting wide rows with batching + + test_wide_batch_rows tests inserting a wide row of data using batching. It will then attempt to query + that data and ensure that all of it has been inserted appropriately. + + @expected_result all items should be inserted, and verified. + + @test_category queries:batch + """ + + # Table Creation + table = 'wide_batch_rows' + session = self.make_session_and_keyspace() + session.execute('CREATE TABLE %s (k INT, i INT, PRIMARY KEY(k, i))' % table) + + # Run batch insert + statement = 'BEGIN BATCH ' + to_insert = 2000 + for i in range(to_insert): + statement += 'INSERT INTO %s (k, i) VALUES (%s, %s) ' % (table, 0, i) + statement += 'APPLY BATCH' + statement = SimpleStatement(statement, consistency_level=ConsistencyLevel.QUORUM) + + # Execute insert with larger timeout, since it's a wide row + try: + session.execute(statement,timeout=30.0) + + except OperationTimedOut: + #If we timeout on insertion that's bad but it could be just slow underlying c* + #Attempt to validate anyway, we will fail if we don't get the right data back. + ex_type, ex, tb = sys.exc_info() + log.warning("Batch wide row insertion timed out, this may require additional investigation") + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + + # Verify + results = session.execute('SELECT i FROM %s WHERE k=%s' % (table, 0)) + lastvalue = 0 + for j, row in enumerate(results): + lastValue=row['i'] + self.assertEqual(lastValue, j) + + #check the last value make sure it's what we expect + index_value = to_insert-1 + self.assertEqual(lastValue,index_value,"Verification failed only found {0} inserted we were expecting {1}".format(j,index_value)) + + session.cluster.shutdown() + + def test_wide_byte_rows(self): + """ + Test for inserting wide row of bytes + + test_wide_batch_rows tests inserting a wide row of data bytes. It will then attempt to query + that data and ensure that all of it has been inserted appropriately. + + @expected_result all items should be inserted, and verified. + + @test_category queries + """ + + # Table creation + table = 'wide_byte_rows' + session = self.make_session_and_keyspace() + session.execute('CREATE TABLE %s (k INT, i INT, v BLOB, PRIMARY KEY(k, i))' % table) + + # Prepare statement and run insertions + to_insert = 100000 + prepared = session.prepare('INSERT INTO %s (k, i, v) VALUES (0, ?, 0xCAFE)' % (table, )) + timeouts = self.batch_futures(session, (prepared.bind((i, )) for i in range(to_insert))) + + # Read + results = session.execute('SELECT i, v FROM %s WHERE k=0' % (table, )) + + # number of expected results + expected_results = to_insert-timeouts-1 + + # Verify + bb = pack('>H', 0xCAFE) + for i, row in enumerate(results): + self.assertEqual(row['v'], bb) + + self.assertGreaterEqual(i, expected_results, "Verification failed only found {0} inserted we were expecting {1}".format(i,expected_results)) + + session.cluster.shutdown() + + def test_large_text(self): + """ + Test for inserting a large text field + + test_large_text tests inserting a large text field into a row. + + @expected_result the large text value should be inserted. When the row is queried it should match the original + value that was inserted + + @test_category queries + """ + table = 'large_text' + session = self.make_session_and_keyspace() + session.execute('CREATE TABLE %s (k int PRIMARY KEY, txt text)' % table) + + # Create ultra-long text + text = 'a' * 1000000 + + # Write + session.execute(SimpleStatement("INSERT INTO %s (k, txt) VALUES (%s, '%s')" + % (table, 0, text), + consistency_level=ConsistencyLevel.QUORUM)) + + # Read + result = session.execute('SELECT * FROM %s WHERE k=%s' % (table, 0)) + + # Verify + found_result = False + for i, row in enumerate(result): + self.assertEqual(row['txt'], text) + found_result = True + self.assertTrue(found_result, "No results were found") + + session.cluster.shutdown() + + def test_wide_table(self): + table = 'wide_table' + table_width = 330 + session = self.make_session_and_keyspace() + table_declaration = 'CREATE TABLE %s (key INT PRIMARY KEY, ' + table_declaration += ' INT, '.join(create_column_name(i) for i in range(table_width)) + table_declaration += ' INT)' + session.execute(table_declaration % table) + + # Write + insert_statement = 'INSERT INTO %s (key, ' + insert_statement += ', '.join(create_column_name(i) for i in range(table_width)) + insert_statement += ') VALUES (%s, ' + insert_statement += ', '.join(str(i) for i in range(table_width)) + insert_statement += ')' + insert_statement = insert_statement % (table, 0) + + session.execute(SimpleStatement(insert_statement, consistency_level=ConsistencyLevel.QUORUM)) + + # Read + result = session.execute('SELECT * FROM %s WHERE key=%s' % (table, 0)) + + # Verify + for row in result: + for i in range(table_width): + self.assertEqual(row[create_column_name(i)], i) + + session.cluster.shutdown() diff --git a/tests/integration/long/test_loadbalancingpolicies.py b/tests/integration/long/test_loadbalancingpolicies.py new file mode 100644 index 0000000..2cc3d8e --- /dev/null +++ b/tests/integration/long/test_loadbalancingpolicies.py @@ -0,0 +1,737 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import struct +import sys +import traceback + +from cassandra import ConsistencyLevel, Unavailable, OperationTimedOut, ReadTimeout, ReadFailure, \ + WriteTimeout, WriteFailure +from cassandra.cluster import Cluster, NoHostAvailable +from cassandra.concurrent import execute_concurrent_with_args +from cassandra.metadata import murmur3 +from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy, + TokenAwarePolicy, WhiteListRoundRobinPolicy, + HostFilterPolicy) +from cassandra.query import SimpleStatement + +from tests.integration import use_singledc, use_multidc, remove_cluster, PROTOCOL_VERSION +from tests.integration.long.utils import (wait_for_up, create_schema, + CoordinatorStats, force_stop, + wait_for_down, decommission, start, + bootstrap, stop, IP_FORMAT) + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +log = logging.getLogger(__name__) + + +class LoadBalancingPolicyTests(unittest.TestCase): + + def setUp(self): + remove_cluster() # clear ahead of test so it doesn't use one left in unknown state + self.coordinator_stats = CoordinatorStats() + self.prepared = None + self.probe_cluster = None + + def tearDown(self): + if self.probe_cluster: + self.probe_cluster.shutdown() + + @classmethod + def teardown_class(cls): + remove_cluster() + + def _connect_probe_cluster(self): + if not self.probe_cluster: + # distinct cluster so we can see the status of nodes ignored by the LBP being tested + self.probe_cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), + schema_metadata_enabled=False, token_metadata_enabled=False) + self.probe_session = self.probe_cluster.connect() + + def _wait_for_nodes_up(self, nodes, cluster=None): + log.debug('entered: _wait_for_nodes_up(nodes={ns}, ' + 'cluster={cs})'.format(ns=nodes, + cs=cluster)) + if not cluster: + log.debug('connecting to cluster') + self._connect_probe_cluster() + cluster = self.probe_cluster + for n in nodes: + wait_for_up(cluster, n) + + def _wait_for_nodes_down(self, nodes, cluster=None): + log.debug('entered: _wait_for_nodes_down(nodes={ns}, ' + 'cluster={cs})'.format(ns=nodes, + cs=cluster)) + if not cluster: + self._connect_probe_cluster() + cluster = self.probe_cluster + for n in nodes: + wait_for_down(cluster, n) + + def _cluster_session_with_lbp(self, lbp): + # create a cluster with no delay on events + cluster = Cluster(load_balancing_policy=lbp, protocol_version=PROTOCOL_VERSION, + topology_event_refresh_window=0, status_event_refresh_window=0) + session = cluster.connect() + return cluster, session + + def _insert(self, session, keyspace, count=12, + consistency_level=ConsistencyLevel.ONE): + log.debug('entered _insert(' + 'session={session}, keyspace={keyspace}, ' + 'count={count}, consistency_level={consistency_level}' + ')'.format(session=session, keyspace=keyspace, count=count, + consistency_level=consistency_level)) + session.execute('USE %s' % keyspace) + ss = SimpleStatement('INSERT INTO cf(k, i) VALUES (0, 0)', consistency_level=consistency_level) + + tries = 0 + while tries < 100: + try: + execute_concurrent_with_args(session, ss, [None] * count) + log.debug('Completed _insert on try #{}'.format(tries + 1)) + return + except (OperationTimedOut, WriteTimeout, WriteFailure): + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(ss)) + + def _query(self, session, keyspace, count=12, + consistency_level=ConsistencyLevel.ONE, use_prepared=False): + log.debug('entered _query(' + 'session={session}, keyspace={keyspace}, ' + 'count={count}, consistency_level={consistency_level}, ' + 'use_prepared={use_prepared}' + ')'.format(session=session, keyspace=keyspace, count=count, + consistency_level=consistency_level, + use_prepared=use_prepared)) + if use_prepared: + query_string = 'SELECT * FROM %s.cf WHERE k = ?' % keyspace + if not self.prepared or self.prepared.query_string != query_string: + self.prepared = session.prepare(query_string) + self.prepared.consistency_level = consistency_level + for i in range(count): + tries = 0 + while True: + if tries > 100: + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(self.prepared)) + try: + self.coordinator_stats.add_coordinator(session.execute_async(self.prepared.bind((0,)))) + break + except (OperationTimedOut, ReadTimeout, ReadFailure): + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + else: + routing_key = struct.pack('>i', 0) + for i in range(count): + ss = SimpleStatement('SELECT * FROM %s.cf WHERE k = 0' % keyspace, + consistency_level=consistency_level, + routing_key=routing_key) + tries = 0 + while True: + if tries > 100: + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(ss)) + try: + self.coordinator_stats.add_coordinator(session.execute_async(ss)) + break + except (OperationTimedOut, ReadTimeout, ReadFailure): + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + def test_token_aware_is_used_by_default(self): + """ + Test for default loadbalacing policy + + test_token_aware_is_used_by_default tests that the default loadbalancing policy is policies.TokenAwarePolicy. + It creates a simple Cluster and verifies that the default loadbalancing policy is TokenAwarePolicy if the + murmur3 C extension is found. Otherwise, the default loadbalancing policy is DCAwareRoundRobinPolicy. + + @since 2.6.0 + @jira_ticket PYTHON-160 + @expected_result TokenAwarePolicy should be the default loadbalancing policy. + + @test_category load_balancing:token_aware + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + + if murmur3 is not None: + self.assertTrue(isinstance(cluster.load_balancing_policy, TokenAwarePolicy)) + else: + self.assertTrue(isinstance(cluster.load_balancing_policy, DCAwareRoundRobinPolicy)) + + cluster.shutdown() + + def test_roundrobin(self): + use_singledc() + keyspace = 'test_roundrobin' + cluster, session = self._cluster_session_with_lbp(RoundRobinPolicy()) + self._wait_for_nodes_up(range(1, 4), cluster) + create_schema(cluster, session, keyspace, replication_factor=3) + self._insert(session, keyspace) + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 4) + self.coordinator_stats.assert_query_count_equals(self, 2, 4) + self.coordinator_stats.assert_query_count_equals(self, 3, 4) + + force_stop(3) + self._wait_for_nodes_down([3], cluster) + + self.coordinator_stats.reset_counts() + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 6) + self.coordinator_stats.assert_query_count_equals(self, 2, 6) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + + decommission(1) + start(3) + self._wait_for_nodes_down([1], cluster) + self._wait_for_nodes_up([3], cluster) + + self.coordinator_stats.reset_counts() + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 6) + self.coordinator_stats.assert_query_count_equals(self, 3, 6) + cluster.shutdown() + + def test_roundrobin_two_dcs(self): + use_multidc([2, 2]) + keyspace = 'test_roundrobin_two_dcs' + cluster, session = self._cluster_session_with_lbp(RoundRobinPolicy()) + self._wait_for_nodes_up(range(1, 5), cluster) + + create_schema(cluster, session, keyspace, replication_strategy=[2, 2]) + self._insert(session, keyspace) + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 3) + self.coordinator_stats.assert_query_count_equals(self, 2, 3) + self.coordinator_stats.assert_query_count_equals(self, 3, 3) + self.coordinator_stats.assert_query_count_equals(self, 4, 3) + + force_stop(1) + bootstrap(5, 'dc3') + + # reset control connection + self._insert(session, keyspace, count=1000) + + self._wait_for_nodes_up([5], cluster) + + self.coordinator_stats.reset_counts() + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 3) + self.coordinator_stats.assert_query_count_equals(self, 3, 3) + self.coordinator_stats.assert_query_count_equals(self, 4, 3) + self.coordinator_stats.assert_query_count_equals(self, 5, 3) + + cluster.shutdown() + + def test_roundrobin_two_dcs_2(self): + use_multidc([2, 2]) + keyspace = 'test_roundrobin_two_dcs_2' + cluster, session = self._cluster_session_with_lbp(RoundRobinPolicy()) + self._wait_for_nodes_up(range(1, 5), cluster) + + create_schema(cluster, session, keyspace, replication_strategy=[2, 2]) + self._insert(session, keyspace) + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 3) + self.coordinator_stats.assert_query_count_equals(self, 2, 3) + self.coordinator_stats.assert_query_count_equals(self, 3, 3) + self.coordinator_stats.assert_query_count_equals(self, 4, 3) + + force_stop(1) + bootstrap(5, 'dc1') + + # reset control connection + self._insert(session, keyspace, count=1000) + + self._wait_for_nodes_up([5], cluster) + + self.coordinator_stats.reset_counts() + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 3) + self.coordinator_stats.assert_query_count_equals(self, 3, 3) + self.coordinator_stats.assert_query_count_equals(self, 4, 3) + self.coordinator_stats.assert_query_count_equals(self, 5, 3) + + cluster.shutdown() + + def test_dc_aware_roundrobin_two_dcs(self): + use_multidc([3, 2]) + keyspace = 'test_dc_aware_roundrobin_two_dcs' + cluster, session = self._cluster_session_with_lbp(DCAwareRoundRobinPolicy('dc1')) + self._wait_for_nodes_up(range(1, 6)) + + create_schema(cluster, session, keyspace, replication_strategy=[2, 2]) + self._insert(session, keyspace) + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 4) + self.coordinator_stats.assert_query_count_equals(self, 2, 4) + self.coordinator_stats.assert_query_count_equals(self, 3, 4) + self.coordinator_stats.assert_query_count_equals(self, 4, 0) + self.coordinator_stats.assert_query_count_equals(self, 5, 0) + + cluster.shutdown() + + def test_dc_aware_roundrobin_two_dcs_2(self): + use_multidc([3, 2]) + keyspace = 'test_dc_aware_roundrobin_two_dcs_2' + cluster, session = self._cluster_session_with_lbp(DCAwareRoundRobinPolicy('dc2')) + self._wait_for_nodes_up(range(1, 6)) + + create_schema(cluster, session, keyspace, replication_strategy=[2, 2]) + self._insert(session, keyspace) + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(self, 4, 6) + self.coordinator_stats.assert_query_count_equals(self, 5, 6) + + cluster.shutdown() + + def test_dc_aware_roundrobin_one_remote_host(self): + use_multidc([2, 2]) + keyspace = 'test_dc_aware_roundrobin_one_remote_host' + cluster, session = self._cluster_session_with_lbp(DCAwareRoundRobinPolicy('dc2', used_hosts_per_remote_dc=1)) + self._wait_for_nodes_up(range(1, 5)) + + create_schema(cluster, session, keyspace, replication_strategy=[2, 2]) + self._insert(session, keyspace) + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(self, 3, 6) + self.coordinator_stats.assert_query_count_equals(self, 4, 6) + + self.coordinator_stats.reset_counts() + bootstrap(5, 'dc1') + self._wait_for_nodes_up([5]) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(self, 3, 6) + self.coordinator_stats.assert_query_count_equals(self, 4, 6) + self.coordinator_stats.assert_query_count_equals(self, 5, 0) + + self.coordinator_stats.reset_counts() + decommission(3) + decommission(4) + self._wait_for_nodes_down([3, 4]) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(self, 4, 0) + responses = set() + for node in [1, 2, 5]: + responses.add(self.coordinator_stats.get_query_count(node)) + self.assertEqual(set([0, 0, 12]), responses) + + self.coordinator_stats.reset_counts() + decommission(5) + self._wait_for_nodes_down([5]) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(self, 4, 0) + self.coordinator_stats.assert_query_count_equals(self, 5, 0) + responses = set() + for node in [1, 2]: + responses.add(self.coordinator_stats.get_query_count(node)) + self.assertEqual(set([0, 12]), responses) + + self.coordinator_stats.reset_counts() + decommission(1) + self._wait_for_nodes_down([1]) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 12) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(self, 4, 0) + self.coordinator_stats.assert_query_count_equals(self, 5, 0) + + self.coordinator_stats.reset_counts() + force_stop(2) + + try: + self._query(session, keyspace) + self.fail() + except NoHostAvailable: + pass + + cluster.shutdown() + + def test_token_aware(self): + keyspace = 'test_token_aware' + self.token_aware(keyspace) + + def test_token_aware_prepared(self): + keyspace = 'test_token_aware_prepared' + self.token_aware(keyspace, True) + + def token_aware(self, keyspace, use_prepared=False): + use_singledc() + cluster, session = self._cluster_session_with_lbp(TokenAwarePolicy(RoundRobinPolicy())) + self._wait_for_nodes_up(range(1, 4), cluster) + + create_schema(cluster, session, keyspace, replication_factor=1) + self._insert(session, keyspace) + self._query(session, keyspace, use_prepared=use_prepared) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 12) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + + self.coordinator_stats.reset_counts() + self._query(session, keyspace, use_prepared=use_prepared) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 12) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + + self.coordinator_stats.reset_counts() + force_stop(2) + self._wait_for_nodes_down([2], cluster) + + try: + self._query(session, keyspace, use_prepared=use_prepared) + self.fail() + except Unavailable as e: + self.assertEqual(e.consistency, 1) + self.assertEqual(e.required_replicas, 1) + self.assertEqual(e.alive_replicas, 0) + + self.coordinator_stats.reset_counts() + start(2) + self._wait_for_nodes_up([2], cluster) + + self._query(session, keyspace, use_prepared=use_prepared) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 12) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + + self.coordinator_stats.reset_counts() + stop(2) + self._wait_for_nodes_down([2], cluster) + + try: + self._query(session, keyspace, use_prepared=use_prepared) + self.fail() + except Unavailable: + pass + + self.coordinator_stats.reset_counts() + start(2) + self._wait_for_nodes_up([2], cluster) + decommission(2) + self._wait_for_nodes_down([2], cluster) + + self._query(session, keyspace, use_prepared=use_prepared) + + results = set([ + self.coordinator_stats.get_query_count(1), + self.coordinator_stats.get_query_count(3) + ]) + self.assertEqual(results, set([0, 12])) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + + cluster.shutdown() + + def test_token_aware_composite_key(self): + use_singledc() + keyspace = 'test_token_aware_composite_key' + table = 'composite' + cluster, session = self._cluster_session_with_lbp(TokenAwarePolicy(RoundRobinPolicy())) + self._wait_for_nodes_up(range(1, 4), cluster) + + create_schema(cluster, session, keyspace, replication_factor=2) + session.execute('CREATE TABLE %s (' + 'k1 int, ' + 'k2 int, ' + 'i int, ' + 'PRIMARY KEY ((k1, k2)))' % table) + + prepared = session.prepare('INSERT INTO %s ' + '(k1, k2, i) ' + 'VALUES ' + '(?, ?, ?)' % table) + bound = prepared.bind((1, 2, 3)) + result = session.execute(bound) + self.assertIn(result.response_future.attempted_hosts[0], + cluster.metadata.get_replicas(keyspace, bound.routing_key)) + + # There could be race condition with querying a node + # which doesn't yet have the data so we query one of + # the replicas + results = session.execute(SimpleStatement('SELECT * FROM %s WHERE k1 = 1 AND k2 = 2' % table, + routing_key=bound.routing_key)) + self.assertIn(results.response_future.attempted_hosts[0], + cluster.metadata.get_replicas(keyspace, bound.routing_key)) + + self.assertTrue(results[0].i) + + cluster.shutdown() + + def test_token_aware_with_rf_2(self, use_prepared=False): + use_singledc() + keyspace = 'test_token_aware_with_rf_2' + cluster, session = self._cluster_session_with_lbp(TokenAwarePolicy(RoundRobinPolicy())) + self._wait_for_nodes_up(range(1, 4), cluster) + + create_schema(cluster, session, keyspace, replication_factor=2) + self._insert(session, keyspace) + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 12) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + + self.coordinator_stats.reset_counts() + stop(2) + self._wait_for_nodes_down([2], cluster) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(self, 3, 12) + + cluster.shutdown() + + def test_token_aware_with_local_table(self): + use_singledc() + cluster, session = self._cluster_session_with_lbp(TokenAwarePolicy(RoundRobinPolicy())) + self._wait_for_nodes_up(range(1, 4), cluster) + + p = session.prepare("SELECT * FROM system.local WHERE key=?") + # this would blow up prior to 61b4fad + r = session.execute(p, ('local',)) + self.assertEqual(r[0].key, 'local') + + cluster.shutdown() + + def test_token_aware_with_shuffle_rf2(self): + """ + Test to validate the hosts are shuffled when the `shuffle_replicas` is truthy + @since 3.8 + @jira_ticket PYTHON-676 + @expected_result the request are spread across the replicas, + when one of them is down, the requests target the available one + + @test_category policy + """ + keyspace = 'test_token_aware_with_rf_2' + cluster, session = self._set_up_shuffle_test(keyspace, replication_factor=2) + + self._check_query_order_changes(session=session, keyspace=keyspace) + + # check TokenAwarePolicy still return the remaining replicas when one goes down + self.coordinator_stats.reset_counts() + stop(2) + self._wait_for_nodes_down([2], cluster) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(self, 3, 12) + + cluster.shutdown() + + def test_token_aware_with_shuffle_rf3(self): + """ + Test to validate the hosts are shuffled when the `shuffle_replicas` is truthy + @since 3.8 + @jira_ticket PYTHON-676 + @expected_result the request are spread across the replicas, + when one of them is down, the requests target the other available ones + + @test_category policy + """ + keyspace = 'test_token_aware_with_rf_3' + cluster, session = self._set_up_shuffle_test(keyspace, replication_factor=3) + + self._check_query_order_changes(session=session, keyspace=keyspace) + + # check TokenAwarePolicy still return the remaining replicas when one goes down + self.coordinator_stats.reset_counts() + stop(1) + self._wait_for_nodes_down([1], cluster) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + query_count_two = self.coordinator_stats.get_query_count(2) + query_count_three = self.coordinator_stats.get_query_count(3) + self.assertEqual(query_count_two + query_count_three, 12) + + self.coordinator_stats.reset_counts() + stop(2) + self._wait_for_nodes_down([2], cluster) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(self, 3, 12) + + cluster.shutdown() + + def _set_up_shuffle_test(self, keyspace, replication_factor): + use_singledc() + cluster, session = self._cluster_session_with_lbp( + TokenAwarePolicy(RoundRobinPolicy(), shuffle_replicas=True) + ) + self._wait_for_nodes_up(range(1, 4), cluster) + + create_schema(cluster, session, keyspace, replication_factor=replication_factor) + return cluster, session + + def _check_query_order_changes(self, session, keyspace): + LIMIT_TRIES, tried, query_counts = 20, 0, set() + + while len(query_counts) <= 1: + tried += 1 + if tried >= LIMIT_TRIES: + raise Exception("After {0} tries shuffle returned the same output".format(LIMIT_TRIES)) + + self._insert(session, keyspace) + self._query(session, keyspace) + + loop_qcs = (self.coordinator_stats.get_query_count(1), + self.coordinator_stats.get_query_count(2), + self.coordinator_stats.get_query_count(3)) + + query_counts.add(loop_qcs) + self.assertEqual(sum(loop_qcs), 12) + + # end the loop if we get more than one query ordering + self.coordinator_stats.reset_counts() + + def test_white_list(self): + use_singledc() + keyspace = 'test_white_list' + + cluster = Cluster(('127.0.0.2',), load_balancing_policy=WhiteListRoundRobinPolicy((IP_FORMAT % 2,)), + protocol_version=PROTOCOL_VERSION, topology_event_refresh_window=0, + status_event_refresh_window=0) + session = cluster.connect() + self._wait_for_nodes_up([1, 2, 3]) + + create_schema(cluster, session, keyspace) + self._insert(session, keyspace) + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 12) + self.coordinator_stats.assert_query_count_equals(self, 3, 0) + + # white list policy should not allow reconnecting to ignored hosts + force_stop(3) + self._wait_for_nodes_down([3]) + self.assertFalse(cluster.metadata.get_host(IP_FORMAT % 3).is_currently_reconnecting()) + + self.coordinator_stats.reset_counts() + force_stop(2) + self._wait_for_nodes_down([2]) + + try: + self._query(session, keyspace) + self.fail() + except NoHostAvailable: + pass + finally: + cluster.shutdown() + + def test_black_list_with_host_filter_policy(self): + """ + Test to validate removing certain hosts from the query plan with + HostFilterPolicy + @since 3.8 + @jira_ticket PYTHON-961 + @expected_result the excluded hosts are ignored + + @test_category policy + """ + use_singledc() + keyspace = 'test_black_list_with_hfp' + ignored_address = (IP_FORMAT % 2) + hfp = HostFilterPolicy( + child_policy=RoundRobinPolicy(), + predicate=lambda host: host.address != ignored_address + ) + cluster = Cluster( + (IP_FORMAT % 1,), + load_balancing_policy=hfp, + protocol_version=PROTOCOL_VERSION, + topology_event_refresh_window=0, + status_event_refresh_window=0 + ) + self.addCleanup(cluster.shutdown) + session = cluster.connect() + self._wait_for_nodes_up([1, 2, 3]) + + self.assertNotIn(ignored_address, [h.address for h in hfp.make_query_plan()]) + + create_schema(cluster, session, keyspace) + self._insert(session, keyspace) + self._query(session, keyspace) + + # RoundRobin doesn't provide a gurantee on the order of the hosts + # so we will have that for 127.0.0.1 and 127.0.0.3 the count for one + # will be 4 and for the other 8 + first_node_count = self.coordinator_stats.get_query_count(1) + third_node_count = self.coordinator_stats.get_query_count(3) + self.assertEqual(first_node_count + third_node_count, 12) + self.assertTrue(first_node_count == 8 or first_node_count == 4) + + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + + # policy should not allow reconnecting to ignored host + force_stop(2) + self._wait_for_nodes_down([2]) + self.assertFalse(cluster.metadata.get_host(ignored_address).is_currently_reconnecting()) diff --git a/tests/integration/long/test_schema.py b/tests/integration/long/test_schema.py new file mode 100644 index 0000000..5163066 --- /dev/null +++ b/tests/integration/long/test_schema.py @@ -0,0 +1,161 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from cassandra import ConsistencyLevel, AlreadyExists +from cassandra.cluster import Cluster +from cassandra.query import SimpleStatement + +from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +log = logging.getLogger(__name__) + + +def setup_module(): + use_singledc() + + +class SchemaTests(unittest.TestCase): + + @classmethod + def setup_class(cls): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.session = cls.cluster.connect(wait_for_all_pools=True) + + @classmethod + def teardown_class(cls): + cls.cluster.shutdown() + + def test_recreates(self): + """ + Basic test for repeated schema creation and use, using many different keyspaces + """ + + session = self.session + + for i in range(2): + for keyspace_number in range(5): + keyspace = "ks_{0}".format(keyspace_number) + + if keyspace in self.cluster.metadata.keyspaces.keys(): + drop = "DROP KEYSPACE {0}".format(keyspace) + log.debug(drop) + execute_until_pass(session, drop) + + create = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 3}}".format(keyspace) + log.debug(create) + execute_until_pass(session, create) + + create = "CREATE TABLE {0}.cf (k int PRIMARY KEY, i int)".format(keyspace) + log.debug(create) + execute_until_pass(session, create) + + use = "USE {0}".format(keyspace) + log.debug(use) + execute_until_pass(session, use) + + insert = "INSERT INTO cf (k, i) VALUES (0, 0)" + log.debug(insert) + ss = SimpleStatement(insert, consistency_level=ConsistencyLevel.QUORUM) + execute_until_pass(session, ss) + + def test_for_schema_disagreements_different_keyspaces(self): + """ + Tests for any schema disagreements using many different keyspaces + """ + + session = self.session + + for i in range(30): + execute_until_pass(session, "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(i)) + execute_until_pass(session, "CREATE TABLE test_{0}.cf (key int PRIMARY KEY, value int)".format(i)) + + for j in range(100): + execute_until_pass(session, "INSERT INTO test_{0}.cf (key, value) VALUES ({1}, {1})".format(i, j)) + + execute_until_pass(session, "DROP KEYSPACE test_{0}".format(i)) + + def test_for_schema_disagreements_same_keyspace(self): + """ + Tests for any schema disagreements using the same keyspace multiple times + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(wait_for_all_pools=True) + + for i in range(30): + try: + execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") + except AlreadyExists: + execute_until_pass(session, "DROP KEYSPACE test") + execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") + + execute_until_pass(session, "CREATE TABLE test.cf (key int PRIMARY KEY, value int)") + + for j in range(100): + execute_until_pass(session, "INSERT INTO test.cf (key, value) VALUES ({0}, {0})".format(j)) + + execute_until_pass(session, "DROP KEYSPACE test") + cluster.shutdown() + + def test_for_schema_disagreement_attribute(self): + """ + Tests to ensure that schema disagreement is properly surfaced on the response future. + + Creates and destroys keyspaces/tables with various schema agreement timeouts set. + First part runs cql create/drop cmds with schema agreement set in such away were it will be impossible for agreement to occur during timeout. + It then validates that the correct value is set on the result. + Second part ensures that when schema agreement occurs, that the result set reflects that appropriately + + @since 3.1.0 + @jira_ticket PYTHON-458 + @expected_result is_schema_agreed is set appropriately on response thefuture + + @test_category schema + """ + # This should yield a schema disagreement + cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0.001) + session = cluster.connect(wait_for_all_pools=True) + + rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}") + self.check_and_wait_for_agreement(session, rs, False) + rs = session.execute(SimpleStatement("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", + consistency_level=ConsistencyLevel.ALL)) + self.check_and_wait_for_agreement(session, rs, False) + rs = session.execute("DROP KEYSPACE test_schema_disagreement") + self.check_and_wait_for_agreement(session, rs, False) + cluster.shutdown() + + # These should have schema agreement + cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=100) + session = cluster.connect() + rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}") + self.check_and_wait_for_agreement(session, rs, True) + rs = session.execute(SimpleStatement("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", + consistency_level=ConsistencyLevel.ALL)) + self.check_and_wait_for_agreement(session, rs, True) + rs = session.execute("DROP KEYSPACE test_schema_disagreement") + self.check_and_wait_for_agreement(session, rs, True) + cluster.shutdown() + + def check_and_wait_for_agreement(self, session, rs, exepected): + self.assertEqual(rs.response_future.is_schema_agreed, exepected) + if not rs.response_future.is_schema_agreed: + session.cluster.control_connection.wait_for_schema_agreement(wait_time=1000) diff --git a/tests/integration/long/test_ssl.py b/tests/integration/long/test_ssl.py new file mode 100644 index 0000000..7f0a870 --- /dev/null +++ b/tests/integration/long/test_ssl.py @@ -0,0 +1,413 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest + +import os, sys, traceback, logging, ssl, time, math, uuid +from cassandra.cluster import Cluster, NoHostAvailable +from cassandra import ConsistencyLevel +from cassandra.query import SimpleStatement +from tests.integration import PROTOCOL_VERSION, get_cluster, remove_cluster, use_single_node, EVENT_LOOP_MANAGER + +log = logging.getLogger(__name__) + +DEFAULT_PASSWORD = "pythondriver" + +# Server keystore trust store locations +SERVER_KEYSTORE_PATH = "tests/integration/long/ssl/.keystore" +SERVER_TRUSTSTORE_PATH = "tests/integration/long/ssl/.truststore" + +# Client specific keys/certs +CLIENT_CA_CERTS = 'tests/integration/long/ssl/cassandra.pem' +DRIVER_KEYFILE = "tests/integration/long/ssl/driver.key" +DRIVER_KEYFILE_ENCRYPTED = "tests/integration/long/ssl/driver_encrypted.key" +DRIVER_CERTFILE = "tests/integration/long/ssl/driver.pem" +DRIVER_CERTFILE_BAD = "tests/integration/long/ssl/python_driver_bad.pem" + +if "twisted" in EVENT_LOOP_MANAGER: + import OpenSSL + ssl_version = OpenSSL.SSL.TLSv1_METHOD + verify_certs = {'cert_reqs': OpenSSL.SSL.VERIFY_PEER, + 'check_hostname': True} + +else: + ssl_version = ssl.PROTOCOL_TLSv1 + verify_certs = {'cert_reqs': ssl.CERT_REQUIRED, + 'check_hostname': True} + + +def setup_cluster_ssl(client_auth=False): + """ + We need some custom setup for this module. This will start the ccm cluster with basic + ssl connectivity, and client authentication if needed. + """ + + use_single_node(start=False) + ccm_cluster = get_cluster() + ccm_cluster.stop() + + # Fetch the absolute path to the keystore for ccm. + abs_path_server_keystore_path = os.path.abspath(SERVER_KEYSTORE_PATH) + + # Configure ccm to use ssl. + + config_options = {'client_encryption_options': {'enabled': True, + 'keystore': abs_path_server_keystore_path, + 'keystore_password': DEFAULT_PASSWORD}} + + if(client_auth): + abs_path_server_truststore_path = os.path.abspath(SERVER_TRUSTSTORE_PATH) + client_encyrption_options = config_options['client_encryption_options'] + client_encyrption_options['require_client_auth'] = True + client_encyrption_options['truststore'] = abs_path_server_truststore_path + client_encyrption_options['truststore_password'] = DEFAULT_PASSWORD + + ccm_cluster.set_configuration_options(config_options) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + + +def validate_ssl_options(**kwargs): + ssl_options = kwargs.get('ssl_options', None) + ssl_context = kwargs.get('ssl_context', None) + + # find absolute path to client CA_CERTS + tries = 0 + while True: + if tries > 5: + raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") + try: + cluster = Cluster(protocol_version=PROTOCOL_VERSION, + ssl_options=ssl_options, ssl_context=ssl_context) + session = cluster.connect(wait_for_all_pools=True) + break + except Exception: + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + # attempt a few simple commands. + insert_keyspace = """CREATE KEYSPACE ssltest + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} + """ + statement = SimpleStatement(insert_keyspace) + statement.consistency_level = 3 + session.execute(statement) + + drop_keyspace = "DROP KEYSPACE ssltest" + statement = SimpleStatement(drop_keyspace) + statement.consistency_level = ConsistencyLevel.ANY + session.execute(statement) + + cluster.shutdown() + + +class SSLConnectionTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + setup_cluster_ssl() + + @classmethod + def tearDownClass(cls): + ccm_cluster = get_cluster() + ccm_cluster.stop() + remove_cluster() + + def test_can_connect_with_ssl_ca(self): + """ + Test to validate that we are able to connect to a cluster using ssl. + + test_can_connect_with_ssl_ca performs a simple sanity check to ensure that we can connect to a cluster with ssl + authentication via simple server-side shared certificate authority. The client is able to validate the identity + of the server, however by using this method the server can't trust the client unless additional authentication + has been provided. + + @since 2.6.0 + @jira_ticket PYTHON-332 + @expected_result The client can connect via SSL and preform some basic operations + + @test_category connection:ssl + """ + + # find absolute path to client CA_CERTS + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + ssl_options = {'ca_certs': abs_path_ca_cert_path,'ssl_version': ssl_version} + validate_ssl_options(ssl_options=ssl_options) + + def test_can_connect_with_ssl_long_running(self): + """ + Test to validate that long running ssl connections continue to function past thier timeout window + + @since 3.6.0 + @jira_ticket PYTHON-600 + @expected_result The client can connect via SSL and preform some basic operations over a period of longer then a minute + + @test_category connection:ssl + """ + + # find absolute path to client CA_CERTS + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + ssl_options = {'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl_version} + tries = 0 + while True: + if tries > 5: + raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") + try: + cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options=ssl_options) + session = cluster.connect(wait_for_all_pools=True) + break + except Exception: + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + # attempt a few simple commands. + + for i in range(8): + rs = session.execute("SELECT * FROM system.local") + time.sleep(10) + + cluster.shutdown() + + def test_can_connect_with_ssl_ca_host_match(self): + """ + Test to validate that we are able to connect to a cluster using ssl, and host matching + + test_can_connect_with_ssl_ca_host_match performs a simple sanity check to ensure that we can connect to a cluster with ssl + authentication via simple server-side shared certificate authority. It also validates that the host ip matches what is expected + + @since 3.3 + @jira_ticket PYTHON-296 + @expected_result The client can connect via SSL and preform some basic operations, with check_hostname specified + + @test_category connection:ssl + """ + + # find absolute path to client CA_CERTS + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + ssl_options = {'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl_version} + ssl_options.update(verify_certs) + + validate_ssl_options(ssl_options=ssl_options) + + +class SSLConnectionAuthTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + setup_cluster_ssl(client_auth=True) + + @classmethod + def tearDownClass(cls): + ccm_cluster = get_cluster() + ccm_cluster.stop() + remove_cluster() + + def test_can_connect_with_ssl_client_auth(self): + """ + Test to validate that we can connect to a C* cluster that has client_auth enabled. + + This test will setup and use a c* cluster that has client authentication enabled. It will then attempt + to connect using valid client keys, and certs (that are in the server's truststore), and attempt to preform some + basic operations + @since 2.7.0 + + @expected_result The client can connect via SSL and preform some basic operations + + @test_category connection:ssl + """ + + # Need to get absolute paths for certs/key + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + abs_driver_keyfile = os.path.abspath(DRIVER_KEYFILE) + abs_driver_certfile = os.path.abspath(DRIVER_CERTFILE) + ssl_options = {'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl_version, + 'keyfile': abs_driver_keyfile, + 'certfile': abs_driver_certfile} + validate_ssl_options(ssl_options=ssl_options) + + def test_can_connect_with_ssl_client_auth_host_name(self): + """ + Test to validate that we can connect to a C* cluster that has client_auth enabled, and hostmatching + + This test will setup and use a c* cluster that has client authentication enabled. It will then attempt + to connect using valid client keys, and certs (that are in the server's truststore), and attempt to preform some + basic operations, with check_hostname specified + @jira_ticket PYTHON-296 + @since 3.3 + + @expected_result The client can connect via SSL and preform some basic operations + + @test_category connection:ssl + """ + + # Need to get absolute paths for certs/key + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + abs_driver_keyfile = os.path.abspath(DRIVER_KEYFILE) + abs_driver_certfile = os.path.abspath(DRIVER_CERTFILE) + + ssl_options = {'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl_version, + 'keyfile': abs_driver_keyfile, + 'certfile': abs_driver_certfile} + ssl_options.update(verify_certs) + + validate_ssl_options(ssl_options=ssl_options) + + def test_cannot_connect_without_client_auth(self): + """ + Test to validate that we cannot connect without client auth. + + This test will omit the keys/certs needed to preform client authentication. It will then attempt to connect + to a server that has client authentication enabled. + + @since 2.7.0 + @expected_result The client will throw an exception on connect + + @test_category connection:ssl + """ + + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options={'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl_version}) + # attempt to connect and expect an exception + + with self.assertRaises(NoHostAvailable) as context: + cluster.connect() + cluster.shutdown() + + def test_cannot_connect_with_bad_client_auth(self): + """ + Test to validate that we cannot connect with invalid client auth. + + This test will use bad keys/certs to preform client authentication. It will then attempt to connect + to a server that has client authentication enabled. + + + @since 2.7.0 + @expected_result The client will throw an exception on connect + + @test_category connection:ssl + """ + + # Setup absolute paths to key/cert files + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + abs_driver_keyfile = os.path.abspath(DRIVER_KEYFILE) + abs_driver_certfile = os.path.abspath(DRIVER_CERTFILE_BAD) + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options={'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl_version, + 'keyfile': abs_driver_keyfile, + 'certfile': abs_driver_certfile}) + with self.assertRaises(NoHostAvailable) as context: + cluster.connect() + cluster.shutdown() + + +class SSLSocketErrorTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + setup_cluster_ssl() + + @classmethod + def tearDownClass(cls): + ccm_cluster = get_cluster() + ccm_cluster.stop() + remove_cluster() + + def test_ssl_want_write_errors_are_retried(self): + """ + Test that when a socket receives a WANT_WRITE error, the message chunk sending is retried. + + @since 3.17.0 + @jira_ticket PYTHON-891 + @expected_result The query is executed successfully + + @test_category connection:ssl + """ + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + ssl_options = {'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl_version} + cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options=ssl_options) + session = cluster.connect(wait_for_all_pools=True) + try: + session.execute('drop keyspace ssl_error_test') + except: + pass + session.execute( + "CREATE KEYSPACE ssl_error_test WITH replication = {'class':'SimpleStrategy','replication_factor':1};") + session.execute("CREATE TABLE ssl_error_test.big_text (id uuid PRIMARY KEY, data text);") + + params = { + '0': uuid.uuid4(), + '1': "0" * int(math.pow(10, 7)) + } + + session.execute('INSERT INTO ssl_error_test.big_text ("id", "data") VALUES (%(0)s, %(1)s)', params) + + +class SSLConnectionWithSSLContextTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + setup_cluster_ssl() + + @classmethod + def tearDownClass(cls): + ccm_cluster = get_cluster() + ccm_cluster.stop() + remove_cluster() + + def test_can_connect_with_sslcontext_certificate(self): + """ + Test to validate that we are able to connect to a cluster using a SSLContext. + + @since 3.17.0 + @jira_ticket PYTHON-995 + @expected_result The client can connect via SSL and preform some basic operations + + @test_category connection:ssl + """ + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + ssl_context = ssl.SSLContext(ssl_version) + ssl_context.load_verify_locations(abs_path_ca_cert_path) + validate_ssl_options(ssl_context=ssl_context) + + def test_can_connect_with_ssl_client_auth_password_private_key(self): + """ + Identical test to SSLConnectionAuthTests.test_can_connect_with_ssl_client_auth, + the only difference is that the DRIVER_KEYFILE is encrypted with a password. + + @since 3.17.0 + @jira_ticket PYTHON-995 + @expected_result The client can connect via SSL and preform some basic operations + + @test_category connection:ssl + """ + abs_driver_keyfile = os.path.abspath(DRIVER_KEYFILE_ENCRYPTED) + abs_driver_certfile = os.path.abspath(DRIVER_CERTFILE) + ssl_context = ssl.SSLContext(ssl_version) + ssl_context.load_cert_chain(certfile=abs_driver_certfile, + keyfile=abs_driver_keyfile, + password='cassandra') + validate_ssl_options(ssl_context=ssl_context) diff --git a/tests/integration/long/utils.py b/tests/integration/long/utils.py new file mode 100644 index 0000000..07652b6 --- /dev/null +++ b/tests/integration/long/utils.py @@ -0,0 +1,183 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import logging +import time + +from collections import defaultdict +from ccmlib.node import Node, ToolError + +from nose.tools import assert_in +from cassandra.query import named_tuple_factory +from cassandra.cluster import ConsistencyLevel + +from tests.integration import get_node, get_cluster, wait_for_node_socket + +IP_FORMAT = '127.0.0.%s' + +log = logging.getLogger(__name__) + + +class CoordinatorStats(): + + def __init__(self): + self.coordinator_counts = defaultdict(int) + + def add_coordinator(self, future): + log.debug('adding coordinator from {}'.format(future)) + future.result() + coordinator = future._current_host.address + self.coordinator_counts[coordinator] += 1 + + if future._errors: + log.error('future._errors: %s', future._errors) + + def reset_counts(self): + self.coordinator_counts = defaultdict(int) + + def get_query_count(self, node): + ip = '127.0.0.%d' % node + return self.coordinator_counts[ip] + + def assert_query_count_equals(self, testcase, node, expected): + ip = '127.0.0.%d' % node + if self.get_query_count(node) != expected: + testcase.fail('Expected %d queries to %s, but got %d. Query counts: %s' % ( + expected, ip, self.coordinator_counts[ip], dict(self.coordinator_counts))) + + +def create_schema(cluster, session, keyspace, simple_strategy=True, + replication_factor=1, replication_strategy=None): + row_factory = session.row_factory + session.row_factory = named_tuple_factory + session.default_consistency_level = ConsistencyLevel.QUORUM + + if keyspace in cluster.metadata.keyspaces.keys(): + session.execute('DROP KEYSPACE %s' % keyspace, timeout=20) + + if simple_strategy: + ddl = "CREATE KEYSPACE %s WITH replication" \ + " = {'class': 'SimpleStrategy', 'replication_factor': '%s'}" + session.execute(ddl % (keyspace, replication_factor), timeout=10) + else: + if not replication_strategy: + raise Exception('replication_strategy is not set') + + ddl = "CREATE KEYSPACE %s" \ + " WITH replication = { 'class' : 'NetworkTopologyStrategy', %s }" + session.execute(ddl % (keyspace, str(replication_strategy)[1:-1]), timeout=10) + + ddl = 'CREATE TABLE %s.cf (k int PRIMARY KEY, i int)' + session.execute(ddl % keyspace, timeout=10) + session.execute('USE %s' % keyspace) + + session.row_factory = row_factory + session.default_consistency_level = ConsistencyLevel.LOCAL_ONE + + +def start(node): + get_node(node).start() + + +def stop(node): + get_node(node).stop() + + +def force_stop(node): + log.debug("Forcing stop of node %s", node) + get_node(node).stop(wait=False, gently=False) + log.debug("Node %s was stopped", node) + + +def decommission(node): + try: + get_node(node).decommission() + except ToolError as e: + expected_errs = (('Not enough live nodes to maintain replication ' + 'factor in keyspace system_distributed'), + 'Perform a forceful decommission to ignore.') + for err in expected_errs: + assert_in(err, e.stdout) + # in this case, we're running against a C* version with CASSANDRA-12510 + # applied and need to decommission with `--force` + get_node(node).decommission(force=True) + get_node(node).stop() + + +def bootstrap(node, data_center=None, token=None): + log.debug('called bootstrap(' + 'node={node}, data_center={data_center}, ' + 'token={token})') + node_instance = Node('node%s' % node, + get_cluster(), + auto_bootstrap=False, + thrift_interface=(IP_FORMAT % node, 9160), + storage_interface=(IP_FORMAT % node, 7000), + binary_interface=(IP_FORMAT % node, 9042), + jmx_port=str(7000 + 100 * node), + remote_debug_port=0, + initial_token=token if token else node * 10) + get_cluster().add(node_instance, is_seed=False, data_center=data_center) + + try: + start(node) + except Exception as e0: + log.debug('failed 1st bootstrap attempt with: \n{}'.format(e0)) + # Try only twice + try: + start(node) + except Exception as e1: + log.debug('failed 2nd bootstrap attempt with: \n{}'.format(e1)) + log.error('Added node failed to start twice.') + raise e1 + + +def ring(node): + get_node(node).nodetool('ring') + + +def wait_for_up(cluster, node): + tries = 0 + addr = IP_FORMAT % node + while tries < 100: + host = cluster.metadata.get_host(addr) + if host and host.is_up: + wait_for_node_socket(get_node(node), 60) + log.debug("Done waiting for node %s to be up", node) + return + else: + log.debug("Host {} is still marked down, waiting".format(addr)) + tries += 1 + time.sleep(1) + + # todo: don't mix string interpolation methods in the same package + raise RuntimeError("Host {0} is not up after {1} attempts".format(addr, tries)) + + +def wait_for_down(cluster, node): + log.debug("Waiting for node %s to be down", node) + tries = 0 + addr = IP_FORMAT % node + while tries < 100: + host = cluster.metadata.get_host(IP_FORMAT % node) + if not host or not host.is_up: + log.debug("Done waiting for node %s to be down", node) + return + else: + log.debug("Host is still marked up, waiting") + tries += 1 + time.sleep(1) + + raise RuntimeError("Host {0} is not down after {1} attempts".format(addr, tries)) diff --git a/tests/integration/simulacron/__init__.py b/tests/integration/simulacron/__init__.py new file mode 100644 index 0000000..665d6b3 --- /dev/null +++ b/tests/integration/simulacron/__init__.py @@ -0,0 +1,58 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from tests.integration.simulacron.utils import stop_simulacron, clear_queries +from tests.integration import PROTOCOL_VERSION, SIMULACRON_JAR, CASSANDRA_VERSION +from tests.integration.simulacron.utils import start_and_prime_singledc + +from cassandra.cluster import Cluster + +from packaging.version import Version + +def teardown_package(): + stop_simulacron() + + +class SimulacronBase(unittest.TestCase): + def tearDown(self): + clear_queries() + stop_simulacron() + + +class SimulacronCluster(SimulacronBase): + + cluster, connect = None, True + + @classmethod + def setUpClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + + start_and_prime_singledc() + if cls.connect: + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + cls.session = cls.cluster.connect(wait_for_all_pools=True) + + @classmethod + def tearDownClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + + if cls.cluster: + cls.cluster.shutdown() + stop_simulacron() diff --git a/tests/integration/simulacron/test_cluster.py b/tests/integration/simulacron/test_cluster.py new file mode 100644 index 0000000..ec20c10 --- /dev/null +++ b/tests/integration/simulacron/test_cluster.py @@ -0,0 +1,80 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from tests.integration.simulacron import SimulacronCluster +from tests.integration import (requiressimulacron, PROTOCOL_VERSION) +from tests.integration.simulacron.utils import prime_query + +from cassandra import (WriteTimeout, WriteType, + ConsistencyLevel, UnresolvableContactPoints) +from cassandra.cluster import Cluster + + +@requiressimulacron +class ClusterTests(SimulacronCluster): + def test_writetimeout(self): + write_type = "UNLOGGED_BATCH" + consistency = "LOCAL_QUORUM" + received_responses = 1 + required_responses = 4 + + query_to_prime_simple = "SELECT * from simulacron_keyspace.simple" + then = { + "result": "write_timeout", + "delay_in_ms": 0, + "consistency_level": consistency, + "received": received_responses, + "block_for": required_responses, + "write_type": write_type, + "ignore_on_prepare": True + } + prime_query(query_to_prime_simple, then=then, rows=None, column_types=None) + + with self.assertRaises(WriteTimeout) as assert_raised_context: + self.session.execute(query_to_prime_simple) + wt = assert_raised_context.exception + self.assertEqual(wt.write_type, WriteType.name_to_value[write_type]) + self.assertEqual(wt.consistency, ConsistencyLevel.name_to_value[consistency]) + self.assertEqual(wt.received_responses, received_responses) + self.assertEqual(wt.required_responses, required_responses) + self.assertIn(write_type, str(wt)) + self.assertIn(consistency, str(wt)) + self.assertIn(str(received_responses), str(wt)) + self.assertIn(str(required_responses), str(wt)) + + +@requiressimulacron +class ClusterDNSResolutionTests(SimulacronCluster): + + connect = False + + def tearDown(self): + if self.cluster: + self.cluster.shutdown() + + def test_connection_with_one_unresolvable_contact_point(self): + # shouldn't raise anything due to name resolution failures + self.cluster = Cluster(['127.0.0.1', 'dns.invalid'], + protocol_version=PROTOCOL_VERSION, + compression=False) + + def test_connection_with_only_unresolvable_contact_points(self): + with self.assertRaises(UnresolvableContactPoints): + self.cluster = Cluster(['dns.invalid'], + protocol_version=PROTOCOL_VERSION, + compression=False) diff --git a/tests/integration/simulacron/test_connection.py b/tests/integration/simulacron/test_connection.py new file mode 100644 index 0000000..c19b616 --- /dev/null +++ b/tests/integration/simulacron/test_connection.py @@ -0,0 +1,468 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import logging +import time + +from mock import Mock, patch + +from cassandra import OperationTimedOut +from cassandra.cluster import (EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, + _Scheduler, NoHostAvailable) +from cassandra.policies import HostStateListener, RoundRobinPolicy +from cassandra.io.asyncorereactor import AsyncoreConnection +from cassandra.connection import DEFAULT_CQL_VERSION +from tests import connection_class, thread_pool_executor_class +from tests.unit.cython.utils import cythontest +from tests.integration import (PROTOCOL_VERSION, requiressimulacron) +from tests.integration.util import assert_quiescent_pool_state, late +from tests.integration.simulacron import SimulacronBase +from tests.integration.simulacron.utils import (NO_THEN, PrimeOptions, + prime_query, prime_request, + start_and_prime_cluster_defaults, + start_and_prime_singledc, + clear_queries, RejectConnections, + RejectType, AcceptConnections) + + +class TrackDownListener(HostStateListener): + def __init__(self): + self.hosts_marked_down = [] + + def on_down(self, host): + self.hosts_marked_down.append(host) + + def on_up(self, host): + pass + + def on_add(self, host): + pass + + def on_remove(self, host): + pass + +class ThreadTracker(thread_pool_executor_class): + called_functions = [] + + def submit(self, fn, *args, **kwargs): + self.called_functions.append(fn.__name__) + return super(ThreadTracker, self).submit(fn, *args, **kwargs) + + +class OrderedRoundRobinPolicy(RoundRobinPolicy): + + def make_query_plan(self, working_keyspace=None, query=None): + self._position += 1 + + hosts = [] + for _ in range(10): + hosts.extend(sorted(self._live_hosts, key=lambda x : x.address)) + + return hosts + + +def _send_options_message(self): + """ + Mock that doesn't the OptionMessage. It is required for the heart_beat_timeout + test to avoid a condition where the CC tries to reconnect in the executor but can't + since we prime that message.""" + self._compressor = None + self.cql_version = DEFAULT_CQL_VERSION + self._send_startup_message(no_compact=self.no_compact) + + +@requiressimulacron +class ConnectionTests(SimulacronBase): + + @patch('cassandra.connection.Connection._send_options_message', _send_options_message) + def test_heart_beat_timeout(self): + """ + Test to ensure the hosts are marked as down after a OTO is received. + Also to ensure this happens within the expected timeout + @since 3.10 + @jira_ticket PYTHON-762 + @expected_result all the hosts have been marked as down at some point + + @test_category metadata + """ + number_of_dcs = 3 + nodes_per_dc = 20 + + query_to_prime = "INSERT INTO test3rf.test (k, v) VALUES (0, 1);" + + idle_heartbeat_timeout = 5 + idle_heartbeat_interval = 1 + + start_and_prime_cluster_defaults(number_of_dcs, nodes_per_dc) + + listener = TrackDownListener() + executor = ThreadTracker(max_workers=8) + + # We need to disable compression since it's not supported in simulacron + cluster = Cluster(compression=False, + idle_heartbeat_interval=idle_heartbeat_interval, + idle_heartbeat_timeout=idle_heartbeat_timeout, + executor_threads=8, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=RoundRobinPolicy())}) + self.addCleanup(cluster.shutdown) + + cluster.scheduler.shutdown() + cluster.executor = executor + cluster.scheduler = _Scheduler(executor) + + session = cluster.connect(wait_for_all_pools=True) + cluster.register_listener(listener) + + log = logging.getLogger() + log.setLevel('CRITICAL') + self.addCleanup(log.setLevel, "DEBUG") + + prime_query(query_to_prime, then=NO_THEN) + + futures = [] + for _ in range(number_of_dcs * nodes_per_dc): + future = session.execute_async(query_to_prime) + futures.append(future) + + for f in futures: + f._event.wait() + self.assertIsInstance(f._final_exception, OperationTimedOut) + + prime_request(PrimeOptions(then=NO_THEN)) + + # We allow from some extra time for all the hosts to be to on_down + # The callbacks should start happening after idle_heartbeat_timeout + idle_heartbeat_interval + time.sleep((idle_heartbeat_timeout + idle_heartbeat_interval) * 2.5) + + for host in cluster.metadata.all_hosts(): + self.assertIn(host, listener.hosts_marked_down) + + # In this case HostConnection._replace shouldn't be called + self.assertNotIn("_replace", executor.called_functions) + + def test_callbacks_and_pool_when_oto(self): + """ + Test to ensure the callbacks are correcltly called and the connection + is returned when there is an OTO + @since 3.12 + @jira_ticket PYTHON-630 + @expected_result the connection is correctly returned to the pool + after an OTO, also the only the errback is called and not the callback + when the message finally arrives. + + @test_category metadata + """ + start_and_prime_singledc() + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + + query_to_prime = "SELECT * from testkesypace.testtable" + + server_delay = 2 # seconds + prime_query(query_to_prime, then={"delay_in_ms": server_delay * 1000}) + + future = session.execute_async(query_to_prime, timeout=1) + callback, errback = Mock(name='callback'), Mock(name='errback') + future.add_callbacks(callback, errback) + self.assertRaises(OperationTimedOut, future.result) + + assert_quiescent_pool_state(self, cluster) + + time.sleep(server_delay + 1) + # PYTHON-630 -- only the errback should be called + errback.assert_called_once() + callback.assert_not_called() + + @cythontest + def test_heartbeat_defunct_deadlock(self): + """ + Ensure that there is no deadlock when request is in-flight and heartbeat defuncts connection + @since 3.16 + @jira_ticket PYTHON-1044 + @expected_result an OperationTimeout is raised and no deadlock occurs + + @test_category connection + """ + start_and_prime_singledc() + + # This is all about timing. We will need the QUERY response future to time out and the heartbeat to defunct + # at the same moment. The latter will schedule a QUERY retry to another node in case the pool is not + # already shut down. If and only if the response future timeout falls in between the retry scheduling and + # its execution the deadlock occurs. The odds are low, so we need to help fate a bit: + # 1) Make one heartbeat messages be sent to every node + # 2) Our QUERY goes always to the same host + # 3) This host needs to defunct first + # 4) Open a small time window for the response future timeout, i.e. block executor threads for retry + # execution and last connection to defunct + query_to_prime = "SELECT * from testkesypace.testtable" + query_host = "127.0.0.2" + heartbeat_interval = 1 + heartbeat_timeout = 1 + lag = 0.05 + never = 9999 + + class PatchedRoundRobinPolicy(RoundRobinPolicy): + # Send always to same host + def make_query_plan(self, working_keyspace=None, query=None): + if query and query.query_string == query_to_prime: + return filter(lambda h: h == query_host, self._live_hosts) + else: + return super(PatchedRoundRobinPolicy, self).make_query_plan() + + class PatchedCluster(Cluster): + # Make sure that QUERY connection will timeout first + def get_connection_holders(self): + holders = super(PatchedCluster, self).get_connection_holders() + return sorted(holders, reverse=True, key=lambda v: int(v._connection.host == query_host)) + + # Block executor thread like closing a dead socket could do + def connection_factory(self, *args, **kwargs): + conn = super(PatchedCluster, self).connection_factory(*args, **kwargs) + conn.defunct = late(seconds=2*lag)(conn.defunct) + return conn + + cluster = PatchedCluster( + protocol_version=PROTOCOL_VERSION, + compression=False, + idle_heartbeat_interval=heartbeat_interval, + idle_heartbeat_timeout=heartbeat_timeout, + load_balancing_policy=PatchedRoundRobinPolicy() + ) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + + prime_query(query_to_prime, then={"delay_in_ms": never}) + + # Make heartbeat due + time.sleep(heartbeat_interval) + + future = session.execute_async(query_to_prime, timeout=heartbeat_interval+heartbeat_timeout+3*lag) + # Delay thread execution like kernel could do + future._retry_task = late(seconds=4*lag)(future._retry_task) + + prime_request(PrimeOptions(then={"result": "no_result", "delay_in_ms": never})) + prime_request(RejectConnections("unbind")) + + self.assertRaisesRegexp(OperationTimedOut, "Connection defunct by heartbeat", future.result) + + def test_close_when_query(self): + """ + Test to ensure the driver behaves correctly if the connection is closed + just when querying + @since 3.12 + @expected_result NoHostAvailable is risen + + @test_category connection + """ + start_and_prime_singledc() + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + + query_to_prime = "SELECT * from testkesypace.testtable" + + for close_type in ("disconnect", "shutdown_read", "shutdown_write"): + then = { + "result": "close_connection", + "delay_in_ms": 0, + "close_type": close_type, + "scope": "connection" + } + + prime_query(query_to_prime, then=then, rows=None, column_types=None) + self.assertRaises(NoHostAvailable, session.execute, query_to_prime) + + def test_retry_after_defunct(self): + """ + We test cluster._retry is called if an the connection is defunct + in the middle of a query + + Finally we verify the driver recovers correctly in the event + of a network partition + + @since 3.12 + @expected_result the driver is able to query even if a host is marked + as down in the middle of the query, it will go to the next one if the timeout + hasn't expired + + @test_category connection + """ + number_of_dcs = 3 + nodes_per_dc = 2 + + query_to_prime = "INSERT INTO test3rf.test (k, v) VALUES (0, 1);" + + idle_heartbeat_timeout = 1 + idle_heartbeat_interval = 5 + + simulacron_cluster = start_and_prime_cluster_defaults(number_of_dcs, nodes_per_dc) + + dc_ids = sorted(simulacron_cluster.data_center_ids) + last_host = dc_ids.pop() + prime_query(query_to_prime, + cluster_name="{}/{}".format(simulacron_cluster.cluster_name, last_host)) + + roundrobin_lbp = OrderedRoundRobinPolicy() + cluster = Cluster(compression=False, + idle_heartbeat_interval=idle_heartbeat_interval, + idle_heartbeat_timeout=idle_heartbeat_timeout, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=roundrobin_lbp)}) + + session = cluster.connect(wait_for_all_pools=True) + self.addCleanup(cluster.shutdown) + + # This simulates we only have access to one DC + for dc_id in dc_ids: + datacenter_path = "{}/{}".format(simulacron_cluster.cluster_name, dc_id) + prime_query(query_to_prime, then=NO_THEN, cluster_name=datacenter_path) + prime_request(PrimeOptions(then=NO_THEN, cluster_name=datacenter_path)) + + # Only the last datacenter will respond, therefore the first host won't + # We want to make sure the returned hosts are 127.0.0.1, 127.0.0.2, ... 127.0.0.8 + roundrobin_lbp._position = 0 + + # After 3 + 1 seconds the connection should be marked and down and another host retried + response_future = session.execute_async(query_to_prime, timeout=4 * idle_heartbeat_interval + + idle_heartbeat_timeout) + response_future.result() + self.assertGreater(len(response_future.attempted_hosts), 1) + + # No error should be raised here since the hosts have been marked + # as down and there's still 1 DC available + for _ in range(10): + session.execute(query_to_prime) + + # Might take some time to close the previous connections and reconnect + time.sleep(10) + assert_quiescent_pool_state(self, cluster) + clear_queries() + + time.sleep(10) + assert_quiescent_pool_state(self, cluster) + + def test_idle_connection_is_not_closed(self): + """ + Test to ensure that the connections aren't closed if they are idle + @since 3.12 + @jira_ticket PYTHON-573 + @expected_result the connections aren't closed nor the hosts are + set to down if the connection is idle + + @test_category connection + """ + start_and_prime_singledc() + + idle_heartbeat_timeout = 1 + idle_heartbeat_interval = 1 + + listener = TrackDownListener() + cluster = Cluster(compression=False, + idle_heartbeat_interval=idle_heartbeat_interval, + idle_heartbeat_timeout=idle_heartbeat_timeout) + session = cluster.connect(wait_for_all_pools=True) + cluster.register_listener(listener) + + self.addCleanup(cluster.shutdown) + + time.sleep(20) + + self.assertEqual(listener.hosts_marked_down, []) + + def test_host_is_not_set_to_down_after_query_oto(self): + """ + Test to ensure that the connections aren't closed if there's an + OperationTimedOut in a normal query. This should only happen from the + heart beat thread (in the case of a OperationTimedOut) with the default + configuration + @since 3.12 + @expected_result the connections aren't closed nor the hosts are + set to down + + @test_category connection + """ + start_and_prime_singledc() + + query_to_prime = "SELECT * FROM madeup_keyspace.madeup_table" + + prime_query(query_to_prime, then=NO_THEN) + + listener = TrackDownListener() + cluster = Cluster(compression=False) + session = cluster.connect(wait_for_all_pools=True) + cluster.register_listener(listener) + + futures = [] + for _ in range(10): + future = session.execute_async(query_to_prime) + futures.append(future) + + for f in futures: + f._event.wait() + self.assertIsInstance(f._final_exception, OperationTimedOut) + + self.assertEqual(listener.hosts_marked_down, []) + assert_quiescent_pool_state(self, cluster) + + def test_can_shutdown_connection_subclass(self): + start_and_prime_singledc() + class ExtendedConnection(connection_class): + pass + + cluster = Cluster(contact_points=["127.0.0.2"], + connection_class=ExtendedConnection) + cluster.connect() + cluster.shutdown() + + def test_driver_recovers_nework_isolation(self): + start_and_prime_singledc() + + idle_heartbeat_timeout = 3 + idle_heartbeat_interval = 1 + + listener = TrackDownListener() + + cluster = Cluster(['127.0.0.1'], + load_balancing_policy=RoundRobinPolicy(), + idle_heartbeat_timeout=idle_heartbeat_timeout, + idle_heartbeat_interval=idle_heartbeat_interval, + executor_threads=16) + session = cluster.connect(wait_for_all_pools=True) + + cluster.register_listener(listener) + + prime_request(PrimeOptions(then=NO_THEN)) + prime_request(RejectConnections(RejectType.REJECT_STARTUP)) + + time.sleep((idle_heartbeat_timeout + idle_heartbeat_interval) * 2) + + for host in cluster.metadata.all_hosts(): + self.assertIn(host, listener.hosts_marked_down) + + self.assertRaises(NoHostAvailable, session.execute, "SELECT * from system.local") + + clear_queries() + prime_request(AcceptConnections()) + + time.sleep(idle_heartbeat_timeout + idle_heartbeat_interval + 2) + + self.assertIsNotNone(session.execute("SELECT * from system.local")) diff --git a/tests/integration/simulacron/test_policies.py b/tests/integration/simulacron/test_policies.py new file mode 100644 index 0000000..d7a6775 --- /dev/null +++ b/tests/integration/simulacron/test_policies.py @@ -0,0 +1,451 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra import OperationTimedOut, WriteTimeout +from cassandra.cluster import Cluster, ExecutionProfile, ResponseFuture +from cassandra.query import SimpleStatement +from cassandra.policies import ConstantSpeculativeExecutionPolicy, RoundRobinPolicy, RetryPolicy, WriteType + +from tests.integration import PROTOCOL_VERSION, greaterthancass21, requiressimulacron, SIMULACRON_JAR, \ + CASSANDRA_VERSION +from tests.integration.simulacron.utils import start_and_prime_singledc, prime_query, \ + stop_simulacron, NO_THEN, clear_queries + +from itertools import count +from packaging.version import Version + + +class BadRoundRobinPolicy(RoundRobinPolicy): + def make_query_plan(self, working_keyspace=None, query=None): + pos = self._position + self._position += 1 + + hosts = [] + for _ in range(10): + hosts.extend(self._live_hosts) + + return hosts + + +# This doesn't work well with Windows clock granularity +@requiressimulacron +class SpecExecTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + + start_and_prime_singledc() + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + cls.session = cls.cluster.connect(wait_for_all_pools=True) + + spec_ep_brr = ExecutionProfile(load_balancing_policy=BadRoundRobinPolicy(), + speculative_execution_policy=ConstantSpeculativeExecutionPolicy(1, 6), + request_timeout=12) + spec_ep_rr = ExecutionProfile(speculative_execution_policy=ConstantSpeculativeExecutionPolicy(.5, 10), + request_timeout=12) + spec_ep_rr_lim = ExecutionProfile(load_balancing_policy=BadRoundRobinPolicy(), + speculative_execution_policy=ConstantSpeculativeExecutionPolicy(0.5, 1), + request_timeout=12) + spec_ep_brr_lim = ExecutionProfile(load_balancing_policy=BadRoundRobinPolicy(), + speculative_execution_policy=ConstantSpeculativeExecutionPolicy(4, 10)) + + cls.cluster.add_execution_profile("spec_ep_brr", spec_ep_brr) + cls.cluster.add_execution_profile("spec_ep_rr", spec_ep_rr) + cls.cluster.add_execution_profile("spec_ep_rr_lim", spec_ep_rr_lim) + cls.cluster.add_execution_profile("spec_ep_brr_lim", spec_ep_brr_lim) + + @classmethod + def tearDownClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + + cls.cluster.shutdown() + stop_simulacron() + + def tearDown(self): + clear_queries() + + @greaterthancass21 + def test_speculative_execution(self): + """ + Test to ensure that speculative execution honors LBP, and that they retry appropriately. + + This test will use various LBP, and ConstantSpeculativeExecutionPolicy settings and ensure the proper number of hosts are queried + @since 3.7.0 + @jira_ticket PYTHON-218 + @expected_result speculative retries should honor max retries, idempotent state of queries, and underlying lbp. + + @test_category metadata + """ + query_to_prime = "INSERT INTO test3rf.test (k, v) VALUES (0, 1);" + prime_query(query_to_prime, then={"delay_in_ms": 10000}) + + statement = SimpleStatement(query_to_prime, is_idempotent=True) + statement_non_idem = SimpleStatement(query_to_prime, is_idempotent=False) + + # This LBP should repeat hosts up to around 30 + result = self.session.execute(statement, execution_profile='spec_ep_brr') + self.assertEqual(7, len(result.response_future.attempted_hosts)) + + # This LBP should keep host list to 3 + result = self.session.execute(statement, execution_profile='spec_ep_rr') + self.assertEqual(3, len(result.response_future.attempted_hosts)) + # Spec_execution policy should limit retries to 1 + result = self.session.execute(statement, execution_profile='spec_ep_rr_lim') + + self.assertEqual(2, len(result.response_future.attempted_hosts)) + + # Spec_execution policy should not be used if the query is not idempotent + result = self.session.execute(statement_non_idem, execution_profile='spec_ep_brr') + self.assertEqual(1, len(result.response_future.attempted_hosts)) + + # Default policy with non_idem query + result = self.session.execute(statement_non_idem, timeout=12) + self.assertEqual(1, len(result.response_future.attempted_hosts)) + + # Should be able to run an idempotent query against default execution policy with no speculative_execution_policy + result = self.session.execute(statement, timeout=12) + self.assertEqual(1, len(result.response_future.attempted_hosts)) + + # Test timeout with spec_ex + with self.assertRaises(OperationTimedOut): + self.session.execute(statement, execution_profile='spec_ep_rr', timeout=.5) + + prepared_query_to_prime = "SELECT * FROM test3rf.test where k = ?" + when = {"params": {"k": "0"}, "param_types": {"k": "ascii"}} + prime_query(prepared_query_to_prime, when=when, then={"delay_in_ms": 4000}) + + # PYTHON-736 Test speculation policy works with a prepared statement + prepared_statement = self.session.prepare(prepared_query_to_prime) + # non-idempotent + result = self.session.execute(prepared_statement, ("0",), execution_profile='spec_ep_brr') + self.assertEqual(1, len(result.response_future.attempted_hosts)) + # idempotent + prepared_statement.is_idempotent = True + result = self.session.execute(prepared_statement, ("0",), execution_profile='spec_ep_brr') + self.assertLess(1, len(result.response_future.attempted_hosts)) + + def test_speculative_and_timeout(self): + """ + Test to ensure the timeout is honored when using speculative execution + @since 3.10 + @jira_ticket PYTHON-750 + @expected_result speculative retries be schedule every fixed period, during the maximum + period of the timeout. + + @test_category metadata + """ + query_to_prime = "INSERT INTO testkeyspace.testtable (k, v) VALUES (0, 1);" + prime_query(query_to_prime, then=NO_THEN) + + statement = SimpleStatement(query_to_prime, is_idempotent=True) + + # An OperationTimedOut is placed here in response_future, + # that's why we can't call session.execute,which would raise it, but + # we have to directly wait for the event + response_future = self.session.execute_async(statement, execution_profile='spec_ep_brr_lim', + timeout=14) + response_future._event.wait(16) + self.assertIsInstance(response_future._final_exception, OperationTimedOut) + + # This is because 14 / 4 + 1 = 4 + self.assertEqual(len(response_future.attempted_hosts), 4) + + def test_delay_can_be_0(self): + """ + Test to validate that the delay can be zero for the ConstantSpeculativeExecutionPolicy + @since 3.13 + @jira_ticket PYTHON-836 + @expected_result all the queries are executed immediately + @test_category policy + """ + query_to_prime = "INSERT INTO madeup_keyspace.madeup_table(k, v) VALUES (1, 2)" + prime_query(query_to_prime, then={"delay_in_ms": 5000}) + number_of_requests = 4 + spec = ExecutionProfile(load_balancing_policy=RoundRobinPolicy(), + speculative_execution_policy=ConstantSpeculativeExecutionPolicy(0, number_of_requests)) + + cluster = Cluster() + cluster.add_execution_profile("spec", spec) + session = cluster.connect(wait_for_all_pools=True) + self.addCleanup(cluster.shutdown) + + counter = count() + + def patch_and_count(f): + def patched(*args, **kwargs): + next(counter) + print("patched") + f(*args, **kwargs) + return patched + + self.addCleanup(setattr, ResponseFuture, "send_request", ResponseFuture.send_request) + ResponseFuture.send_request = patch_and_count(ResponseFuture.send_request) + stmt = SimpleStatement(query_to_prime) + stmt.is_idempotent = True + results = session.execute(stmt, execution_profile="spec") + self.assertEqual(len(results.response_future.attempted_hosts), 3) + + # send_request is called number_of_requests times for the speculative request + # plus one for the call from the main thread. + self.assertEqual(next(counter), number_of_requests + 1) + + +class CustomRetryPolicy(RetryPolicy): + def on_write_timeout(self, query, consistency, write_type, + required_responses, received_responses, retry_num): + if retry_num != 0: + return self.RETHROW, None + elif write_type == WriteType.SIMPLE: + return self.RETHROW, None + elif write_type == WriteType.CDC: + return self.IGNORE, None + + +class CounterRetryPolicy(RetryPolicy): + def __init__(self): + self.write_timeout = count() + self.read_timeout = count() + self.unavailable = count() + self.request_error = count() + + def on_read_timeout(self, query, consistency, required_responses, + received_responses, data_retrieved, retry_num): + next(self.read_timeout) + return self.IGNORE, None + + def on_write_timeout(self, query, consistency, write_type, + required_responses, received_responses, retry_num): + next(self.write_timeout) + return self.IGNORE, None + + def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): + next(self.unavailable) + return self.IGNORE, None + + def on_request_error(self, query, consistency, error, retry_num): + next(self.request_error) + return self.IGNORE, None + + def reset_counters(self): + self.write_timeout = count() + self.read_timeout = count() + self.unavailable = count() + self.request_error = count() + + +@requiressimulacron +class RetryPolicyTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + start_and_prime_singledc() + + @classmethod + def tearDownClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + stop_simulacron() + + def tearDown(self): + clear_queries() + + def set_cluster(self, retry_policy): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False, + default_retry_policy=retry_policy) + self.session = self.cluster.connect(wait_for_all_pools=True) + self.addCleanup(self.cluster.shutdown) + + def test_retry_policy_ignores_and_rethrows(self): + """ + Test to verify :class:`~cassandra.protocol.WriteTimeoutErrorMessage` is decoded correctly and that + :attr:`.~cassandra.policies.RetryPolicy.RETHROW` and + :attr:`.~cassandra.policies.RetryPolicy.IGNORE` are respected + to localhost + + @since 3.12 + @jira_ticket PYTHON-812 + @expected_result the retry policy functions as expected + + @test_category connection + """ + self.set_cluster(CustomRetryPolicy()) + query_to_prime_simple = "SELECT * from simulacron_keyspace.simple" + query_to_prime_cdc = "SELECT * from simulacron_keyspace.cdc" + then = { + "result": "write_timeout", + "delay_in_ms": 0, + "consistency_level": "LOCAL_QUORUM", + "received": 1, + "block_for": 2, + "write_type": "SIMPLE", + "ignore_on_prepare": True + } + prime_query(query_to_prime_simple, then=then, rows=None, column_types=None) + then["write_type"] = "CDC" + prime_query(query_to_prime_cdc, then=then, rows=None, column_types=None) + + with self.assertRaises(WriteTimeout): + self.session.execute(query_to_prime_simple) + + #CDC should be ignored + self.session.execute(query_to_prime_cdc) + + def test_retry_policy_with_prepared(self): + """ + Test to verify that the retry policy is called as expected + for bound and prepared statements when set at the cluster level + + @since 3.13 + @jira_ticket PYTHON-861 + @expected_result the appropriate retry policy is called + + @test_category connection + """ + counter_policy = CounterRetryPolicy() + self.set_cluster(counter_policy) + query_to_prime = "SELECT * from simulacron_keyspace.simulacron_table" + then = { + "result": "write_timeout", + "delay_in_ms": 0, + "consistency_level": "LOCAL_QUORUM", + "received": 1, + "block_for": 2, + "write_type": "SIMPLE", + "ignore_on_prepare": True + } + prime_query(query_to_prime, then=then, rows=None, column_types=None) + self.session.execute(query_to_prime) + self.assertEqual(next(counter_policy.write_timeout), 1) + counter_policy.reset_counters() + + query_to_prime_prepared = "SELECT * from simulacron_keyspace.simulacron_table WHERE key = :key" + when = {"params": {"key": "0"}, "param_types": {"key": "ascii"}} + + prime_query(query_to_prime_prepared, when=when, then=then, rows=None, column_types=None) + + prepared_stmt = self.session.prepare(query_to_prime_prepared) + + bound_stm = prepared_stmt.bind({"key": "0"}) + self.session.execute(bound_stm) + self.assertEqual(next(counter_policy.write_timeout), 1) + + counter_policy.reset_counters() + self.session.execute(prepared_stmt, ("0",)) + self.assertEqual(next(counter_policy.write_timeout), 1) + + def test_setting_retry_policy_to_statement(self): + """ + Test to verify that the retry policy is called as expected + for bound and prepared statements when set to the prepared statement + + @since 3.13 + @jira_ticket PYTHON-861 + @expected_result the appropriate retry policy is called + + @test_category connection + """ + retry_policy = RetryPolicy() + self.set_cluster(retry_policy) + then = { + "result": "write_timeout", + "delay_in_ms": 0, + "consistency_level": "LOCAL_QUORUM", + "received": 1, + "block_for": 2, + "write_type": "SIMPLE", + "ignore_on_prepare": True + } + query_to_prime_prepared = "SELECT * from simulacron_keyspace.simulacron_table WHERE key = :key" + when = {"params": {"key": "0"}, "param_types": {"key": "ascii"}} + prime_query(query_to_prime_prepared, when=when, then=then, rows=None, column_types=None) + + counter_policy = CounterRetryPolicy() + prepared_stmt = self.session.prepare(query_to_prime_prepared) + prepared_stmt.retry_policy = counter_policy + self.session.execute(prepared_stmt, ("0",)) + self.assertEqual(next(counter_policy.write_timeout), 1) + + counter_policy.reset_counters() + bound_stmt = prepared_stmt.bind({"key": "0"}) + bound_stmt.retry_policy = counter_policy + self.session.execute(bound_stmt) + self.assertEqual(next(counter_policy.write_timeout), 1) + + def test_retry_policy_on_request_error(self): + """ + Test to verify that on_request_error is called properly. + + @since 3.18 + @jira_ticket PYTHON-1064 + @expected_result the appropriate retry policy is called + + @test_category connection + """ + overloaded_error = { + "result": "overloaded", + "message": "overloaded" + } + + bootstrapping_error = { + "result": "is_bootstrapping", + "message": "isbootstrapping" + } + + truncate_error = { + "result": "truncate_error", + "message": "truncate_error" + } + + server_error = { + "result": "server_error", + "message": "server_error" + } + + # Test the on_request_error call + retry_policy = CounterRetryPolicy() + self.set_cluster(retry_policy) + + for e in [overloaded_error, bootstrapping_error, truncate_error, server_error]: + query_to_prime = "SELECT * from simulacron_keyspace.simulacron_table;" + prime_query(query_to_prime, then=e, rows=None, column_types=None) + rf = self.session.execute_async(query_to_prime) + try: + rf.result() + except: + pass + self.assertEqual(len(rf.attempted_hosts), 1) # no retry + + self.assertEqual(next(retry_policy.request_error), 4) + + # Test that by default, retry on next host + retry_policy = RetryPolicy() + self.set_cluster(retry_policy) + + for e in [overloaded_error, bootstrapping_error, truncate_error, server_error]: + query_to_prime = "SELECT * from simulacron_keyspace.simulacron_table;" + prime_query(query_to_prime, then=e, rows=None, column_types=None) + rf = self.session.execute_async(query_to_prime) + try: + rf.result() + except: + pass + self.assertEqual(len(rf.attempted_hosts), 3) # all 3 nodes failed diff --git a/tests/integration/simulacron/utils.py b/tests/integration/simulacron/utils.py new file mode 100644 index 0000000..5ec5383 --- /dev/null +++ b/tests/integration/simulacron/utils.py @@ -0,0 +1,380 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from six.moves.urllib.request import build_opener, Request, HTTPHandler +import json +from tests.integration import CASSANDRA_VERSION, SIMULACRON_JAR +import subprocess +import time +import nose + +DEFAULT_CLUSTER = "python_simulacron_cluster" + + +class SimulacronCluster(object): + """ + Represents a Cluster object as returned by simulacron + """ + def __init__(self, json_text): + self.json = json_text + self.o = json.loads(json_text) + + @property + def cluster_id(self): + return self.o["id"] + + @property + def cluster_name(self): + return self.o["name"] + + @property + def data_center_ids(self): + return [dc["id"] for dc in self.o["data_centers"]] + + @property + def data_centers_names(self): + return [dc["name"] for dc in self.o["data_centers"]] + + def get_node_ids(self, datacenter_id): + datacenter = list(filter(lambda x: x["id"] == datacenter_id, self.o["data_centers"])).pop() + return [node["id"] for node in datacenter["nodes"]] + + +class SimulacronServer(object): + """ + Class for starting and stopping the server from within the tests + """ + def __init__(self, jar_path): + self.jar_path = jar_path + self.running = False + self.proc = None + + def start(self): + self.proc = subprocess.Popen(['java', '-jar', self.jar_path, "--loglevel", "ERROR"], shell=False) + self.running = True + + def stop(self): + if self.proc: + self.proc.terminate() + self.running = False + + def is_running(self): + # We could check self.proc.poll here instead + return self.running + + +SERVER_SIMULACRON = SimulacronServer(SIMULACRON_JAR) + + +def start_simulacron(): + """ + Starts and waits for simulacron to run + """ + if SERVER_SIMULACRON.is_running(): + SERVER_SIMULACRON.stop() + + SERVER_SIMULACRON.start() + + #TODO improve this sleep, maybe check the logs like ccm + time.sleep(5) + +def stop_simulacron(): + SERVER_SIMULACRON.stop() + + +class SimulacronClient(object): + def __init__(self, admin_addr="127.0.0.1:8187"): + self.admin_addr = admin_addr + + def submit_request(self, query): + opener = build_opener(HTTPHandler) + data = json.dumps(query.fetch_json()).encode('utf8') + + request = Request("http://{}/{}{}".format( + self.admin_addr, query.path, query.fetch_url_params()), data=data) + request.get_method = lambda: query.method + request.add_header("Content-Type", 'application/json') + request.add_header("Content-Length", len(data)) + + connection = opener.open(request) + return connection.read().decode('utf-8') + + def prime_server_versions(self): + """ + This information has to be primed for the test harness to run + """ + system_local_row = {} + system_local_row["cql_version"] = CASSANDRA_VERSION.base_version + system_local_row["release_version"] = CASSANDRA_VERSION.base_version + "-SNAPSHOT" + column_types = {"cql_version": "ascii", "release_version": "ascii"} + system_local = PrimeQuery("SELECT cql_version, release_version FROM system.local", + rows=[system_local_row], + column_types=column_types) + + self.submit_request(system_local) + + def clear_all_queries(self, cluster_name=DEFAULT_CLUSTER): + """ + Clear all the primed queries from a particular cluster + :param cluster_name: cluster to clear queries from + """ + opener = build_opener(HTTPHandler) + request = Request("http://{0}/{1}/{2}".format( + self.admin_addr, "prime", cluster_name)) + request.get_method = lambda: 'DELETE' + connection = opener.open(request) + return connection.read() + + +NO_THEN = object() + + +class SimulacronRequest(object): + def fetch_json(self): + return {} + + def fetch_url_params(self): + return "" + + @property + def method(self): + raise NotImplementedError() + + +class PrimeOptions(SimulacronRequest): + """ + Class used for specifying how should simulacron respond to an OptionsMessage + """ + def __init__(self, then=None, cluster_name=DEFAULT_CLUSTER): + self.path = "prime/{}".format(cluster_name) + self.then = then + + def fetch_json(self): + json_dict = {} + then = {} + when = {} + + when['request'] = "options" + + if self.then is not None and self.then is not NO_THEN: + then.update(self.then) + + json_dict['when'] = when + if self.then is not NO_THEN: + json_dict['then'] = then + + return json_dict + + def fetch_url_params(self): + return "" + + @property + def method(self): + return "POST" + + +class RejectType(): + UNBIND = "UNBIND" + STOP = "STOP" + REJECT_STARTUP = "REJECT_STARTUP" + + +class RejectConnections(SimulacronRequest): + """ + Class used for making simulacron reject new connections + """ + def __init__(self, reject_type, cluster_name=DEFAULT_CLUSTER): + self.path = "listener/{}".format(cluster_name) + self.reject_type = reject_type + + def fetch_url_params(self): + return "?type={0}".format(self.reject_type) + + @property + def method(self): + return "DELETE" + + +class AcceptConnections(SimulacronRequest): + """ + Class used for making simulacron reject new connections + """ + def __init__(self, cluster_name=DEFAULT_CLUSTER): + self.path = "listener/{}".format(cluster_name) + + @property + def method(self): + return "PUT" + + +class PrimeQuery(SimulacronRequest): + """ + Class used for specifying how should simulacron respond to particular query + """ + def __init__(self, expected_query, result="success", rows=None, + column_types=None, when=None, then=None, cluster_name=DEFAULT_CLUSTER): + self.expected_query = expected_query + self.rows = rows + self.result = result + self.column_types = column_types + self.path = "prime/{}".format(cluster_name) + self.then = then + self.when = when + + def fetch_json(self): + json_dict = {} + then = {} + when = {} + + when['query'] = self.expected_query + then['result'] = self.result + if self.rows is not None: + then['rows'] = self.rows + + if self.column_types is not None: + then['column_types'] = self.column_types + + if self.then is not None and self.then is not NO_THEN: + then.update(self.then) + + if self.then is not NO_THEN: + json_dict['then'] = then + + if self.when is not None: + when.update(self.when) + + json_dict['when'] = when + + return json_dict + + def set_node(self, cluster_id, datacenter_id, node_id): + self.cluster_id = cluster_id + self.datacenter_id = datacenter_id + self.node_id = node_id + + self.path += '/'.join([component for component in + (self.cluster_id, self.datacenter_id, self.node_id) + if component is not None]) + + def fetch_url_params(self): + return "" + + @property + def method(self): + return "POST" + +class ClusterQuery(SimulacronRequest): + """ + Class used for creating a cluster + """ + def __init__(self, cluster_name, cassandra_version, data_centers="3", json_dict=None): + self.cluster_name = cluster_name + self.cassandra_version = cassandra_version + self.data_centers = data_centers + if json_dict is None: + self.json_dict = {} + else: + self.json_dict = json_dict + + self.path = "cluster" + + def fetch_json(self): + return self.json_dict + + def fetch_url_params(self): + return "?cassandra_version={0}&data_centers={1}&name={2}".\ + format(self.cassandra_version, self.data_centers, self.cluster_name) + + @property + def method(self): + return "POST" + +def prime_driver_defaults(): + """ + Function to prime the necessary queries so the test harness can run + """ + client_simulacron = SimulacronClient() + client_simulacron.prime_server_versions() + + +def prime_cluster(data_centers="3", version=CASSANDRA_VERSION, cluster_name=DEFAULT_CLUSTER): + """ + Creates a new cluster in the simulacron server + :param cluster_name: name of the cluster + :param data_centers: string describing the datacenter, e.g. 2/3 would be two + datacenters of 2 nodes and three nodes + :param version: C* version + """ + version = version or CASSANDRA_VERSION.base_version + cluster_query = ClusterQuery(cluster_name, version, data_centers) + client_simulacron = SimulacronClient() + response = client_simulacron.submit_request(cluster_query) + return SimulacronCluster(response) + + +def start_and_prime_singledc(cluster_name=DEFAULT_CLUSTER): + """ + Starts simulacron and creates a cluster with a single datacenter + :param cluster_name: name of the cluster to start and prime + :return: + """ + return start_and_prime_cluster_defaults(number_of_dc=1, nodes_per_dc=3, cluster_name=cluster_name) + + +def start_and_prime_cluster_defaults(number_of_dc=1, nodes_per_dc=3, version=CASSANDRA_VERSION, cluster_name=DEFAULT_CLUSTER): + """ + :param number_of_dc: number of datacentes + :param nodes_per_dc: number of nodes per datacenter + :param version: C* version + """ + start_simulacron() + data_centers = ",".join([str(nodes_per_dc)] * number_of_dc) + simulacron_cluster = prime_cluster(data_centers=data_centers, version=version, cluster_name=cluster_name) + prime_driver_defaults() + + return simulacron_cluster + + +default_column_types = { + "key": "bigint", + "value": "ascii" +} + +default_row = {"key": 2, "value": "value"} +default_rows = [default_row] + + +def prime_request(request): + """ + :param request: It could be PrimeQuery class or an PrimeOptions class + """ + return SimulacronClient().submit_request(request) + + +def prime_query(query, rows=default_rows, column_types=default_column_types, when=None, then=None, cluster_name=DEFAULT_CLUSTER): + """ + Shortcut function for priming a query + :return: + """ + # If then is set, then rows and column_types should not + query = PrimeQuery(query, rows=rows, column_types=column_types, when=when, then=then, cluster_name=cluster_name) + response = prime_request(query) + return response + + +def clear_queries(): + """ + Clears all the queries that have been primed to simulacron + """ + SimulacronClient().clear_all_queries() diff --git a/tests/integration/standard/__init__.py b/tests/integration/standard/__init__.py new file mode 100644 index 0000000..e54b6fd --- /dev/null +++ b/tests/integration/standard/__init__.py @@ -0,0 +1,23 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +try: + from ccmlib import common +except ImportError as e: + raise unittest.SkipTest('ccm is a dependency for integration tests:', e) diff --git a/tests/integration/standard/test_authentication.py b/tests/integration/standard/test_authentication.py new file mode 100644 index 0000000..f8ce423 --- /dev/null +++ b/tests/integration/standard/test_authentication.py @@ -0,0 +1,192 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from cassandra.cluster import Cluster, NoHostAvailable +from cassandra.auth import PlainTextAuthProvider, SASLClient, SaslAuthProvider + +from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION, CASSANDRA_IP, \ + set_default_cass_ip, USE_CASS_EXTERNAL +from tests.integration.util import assert_quiescent_pool_state + +try: + import unittest2 as unittest +except ImportError: + import unittest + +log = logging.getLogger(__name__) + + +#This can be tested for remote hosts, but the cluster has to be configured accordingly +#@local + + +def setup_module(): + if CASSANDRA_IP.startswith("127.0.0.") and not USE_CASS_EXTERNAL: + use_singledc(start=False) + ccm_cluster = get_cluster() + ccm_cluster.stop() + config_options = {'authenticator': 'PasswordAuthenticator', + 'authorizer': 'CassandraAuthorizer'} + ccm_cluster.set_configuration_options(config_options) + log.debug("Starting ccm test cluster with %s", config_options) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + # there seems to be some race, with some versions of C* taking longer to + # get the auth (and default user) setup. Sleep here to give it a chance + time.sleep(10) + else: + set_default_cass_ip() + + +def teardown_module(): + remove_cluster() # this test messes with config + + +class AuthenticationTests(unittest.TestCase): + """ + Tests to cover basic authentication functionality + """ + def get_authentication_provider(self, username, password): + """ + Return correct authentication provider based on protocol version. + There is a difference in the semantics of authentication provider argument with protocol versions 1 and 2 + For protocol version 2 and higher it should be a PlainTextAuthProvider object. + For protocol version 1 it should be a function taking hostname as an argument and returning a dictionary + containing username and password. + :param username: authentication username + :param password: authentication password + :return: authentication object suitable for Cluster.connect() + """ + if PROTOCOL_VERSION < 2: + return lambda hostname: dict(username=username, password=password) + else: + return PlainTextAuthProvider(username=username, password=password) + + def cluster_as(self, usr, pwd): + # test we can connect at least once with creds + # to ensure the role manager is setup + for _ in range(5): + try: + cluster = Cluster( + protocol_version=PROTOCOL_VERSION, + idle_heartbeat_interval=0, + auth_provider=self.get_authentication_provider(username='cassandra', password='cassandra')) + cluster.connect(wait_for_all_pools=True) + + return Cluster( + protocol_version=PROTOCOL_VERSION, + idle_heartbeat_interval=0, + auth_provider=self.get_authentication_provider(username=usr, password=pwd)) + except Exception as e: + time.sleep(5) + + raise Exception('Unable to connect with creds: {}/{}'.format(usr, pwd)) + + def test_auth_connect(self): + user = 'u' + passwd = 'password' + + root_session = self.cluster_as('cassandra', 'cassandra').connect() + root_session.execute('CREATE USER %s WITH PASSWORD %s', (user, passwd)) + + try: + cluster = self.cluster_as(user, passwd) + session = cluster.connect() + try: + self.assertTrue(session.execute('SELECT release_version FROM system.local')) + assert_quiescent_pool_state(self, cluster, wait=1) + for pool in session.get_pools(): + connection, _ = pool.borrow_connection(timeout=0) + self.assertEqual(connection.authenticator.server_authenticator_class, 'org.apache.cassandra.auth.PasswordAuthenticator') + pool.return_connection(connection) + finally: + cluster.shutdown() + finally: + root_session.execute('DROP USER %s', user) + assert_quiescent_pool_state(self, root_session.cluster, wait=1) + root_session.cluster.shutdown() + + def test_connect_wrong_pwd(self): + cluster = self.cluster_as('cassandra', 'wrong_pass') + try: + self.assertRaisesRegexp(NoHostAvailable, + '.*AuthenticationFailed.', + cluster.connect) + assert_quiescent_pool_state(self, cluster) + finally: + cluster.shutdown() + + def test_connect_wrong_username(self): + cluster = self.cluster_as('wrong_user', 'cassandra') + try: + self.assertRaisesRegexp(NoHostAvailable, + '.*AuthenticationFailed.*', + cluster.connect) + assert_quiescent_pool_state(self, cluster) + finally: + cluster.shutdown() + + def test_connect_empty_pwd(self): + cluster = self.cluster_as('Cassandra', '') + try: + self.assertRaisesRegexp(NoHostAvailable, + '.*AuthenticationFailed.*', + cluster.connect) + assert_quiescent_pool_state(self, cluster) + finally: + cluster.shutdown() + + def test_connect_no_auth_provider(self): + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + try: + self.assertRaisesRegexp(NoHostAvailable, + '.*AuthenticationFailed.*', + cluster.connect) + assert_quiescent_pool_state(self, cluster) + finally: + cluster.shutdown() + + +class SaslAuthenticatorTests(AuthenticationTests): + """ + Test SaslAuthProvider as PlainText + """ + def setUp(self): + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest('Sasl authentication not available for protocol v1') + if SASLClient is None: + raise unittest.SkipTest('pure-sasl is not installed') + + def get_authentication_provider(self, username, password): + sasl_kwargs = {'service': 'cassandra', + 'mechanism': 'PLAIN', + 'qops': ['auth'], + 'username': username, + 'password': password} + return SaslAuthProvider(**sasl_kwargs) + + # these could equally be unit tests + def test_host_passthrough(self): + sasl_kwargs = {'service': 'cassandra', + 'mechanism': 'PLAIN'} + provider = SaslAuthProvider(**sasl_kwargs) + host = 'thehostname' + authenticator = provider.new_authenticator(host) + self.assertEqual(authenticator.sasl.host, host) + + def test_host_rejected(self): + sasl_kwargs = {'host': 'something'} + self.assertRaises(ValueError, SaslAuthProvider, **sasl_kwargs) diff --git a/tests/integration/standard/test_client_warnings.py b/tests/integration/standard/test_client_warnings.py new file mode 100644 index 0000000..1092af7 --- /dev/null +++ b/tests/integration/standard/test_client_warnings.py @@ -0,0 +1,133 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +try: + import unittest2 as unittest +except ImportError: + import unittest + +from cassandra.query import BatchStatement +from cassandra.cluster import Cluster + +from tests.integration import use_singledc, PROTOCOL_VERSION, local + + +def setup_module(): + use_singledc() + + +class ClientWarningTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if PROTOCOL_VERSION < 4: + return + + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.session = cls.cluster.connect() + + cls.session.execute("CREATE TABLE IF NOT EXISTS test1rf.client_warning (k int, v0 int, v1 int, PRIMARY KEY (k, v0))") + cls.prepared = cls.session.prepare("INSERT INTO test1rf.client_warning (k, v0, v1) VALUES (?, ?, ?)") + + cls.warn_batch = BatchStatement() + # 213 = 5 * 1024 / (4+4 + 4+4 + 4+4) + # thresh_kb/ (min param size) + for x in range(214): + cls.warn_batch.add(cls.prepared, (x, x, 1)) + + @classmethod + def tearDownClass(cls): + if PROTOCOL_VERSION < 4: + return + + cls.cluster.shutdown() + + def setUp(self): + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest( + "Native protocol 4,0+ is required for client warnings, currently using %r" + % (PROTOCOL_VERSION,)) + + def test_warning_basic(self): + """ + Test to validate that client warnings can be surfaced + + @since 2.6.0 + @jira_ticket PYTHON-315 + @expected_result valid warnings returned + @test_assumptions + - batch_size_warn_threshold_in_kb: 5 + @test_category queries:client_warning + """ + future = self.session.execute_async(self.warn_batch) + future.result() + self.assertEqual(len(future.warnings), 1) + self.assertRegexpMatches(future.warnings[0], 'Batch.*exceeding.*') + + def test_warning_with_trace(self): + """ + Test to validate client warning with tracing + + @since 2.6.0 + @jira_ticket PYTHON-315 + @expected_result valid warnings returned + @test_assumptions + - batch_size_warn_threshold_in_kb: 5 + @test_category queries:client_warning + """ + future = self.session.execute_async(self.warn_batch, trace=True) + future.result() + self.assertEqual(len(future.warnings), 1) + self.assertRegexpMatches(future.warnings[0], 'Batch.*exceeding.*') + self.assertIsNotNone(future.get_query_trace()) + + @local + def test_warning_with_custom_payload(self): + """ + Test to validate client warning with custom payload + + @since 2.6.0 + @jira_ticket PYTHON-315 + @expected_result valid warnings returned + @test_assumptions + - batch_size_warn_threshold_in_kb: 5 + @test_category queries:client_warning + """ + payload = {'key': b'value'} + future = self.session.execute_async(self.warn_batch, custom_payload=payload) + future.result() + self.assertEqual(len(future.warnings), 1) + self.assertRegexpMatches(future.warnings[0], 'Batch.*exceeding.*') + self.assertDictEqual(future.custom_payload, payload) + + @local + def test_warning_with_trace_and_custom_payload(self): + """ + Test to validate client warning with tracing and client warning + + @since 2.6.0 + @jira_ticket PYTHON-315 + @expected_result valid warnings returned + @test_assumptions + - batch_size_warn_threshold_in_kb: 5 + @test_category queries:client_warning + """ + payload = {'key': b'value'} + future = self.session.execute_async(self.warn_batch, trace=True, custom_payload=payload) + future.result() + self.assertEqual(len(future.warnings), 1) + self.assertRegexpMatches(future.warnings[0], 'Batch.*exceeding.*') + self.assertIsNotNone(future.get_query_trace()) + self.assertDictEqual(future.custom_payload, payload) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py new file mode 100644 index 0000000..768fa77 --- /dev/null +++ b/tests/integration/standard/test_cluster.py @@ -0,0 +1,1564 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from collections import deque +from copy import copy +from mock import Mock, call, patch +import time +from uuid import uuid4 +import logging +import warnings +from packaging.version import Version + +import cassandra +from cassandra.cluster import Cluster, NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT +from cassandra.concurrent import execute_concurrent +from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy, + RetryPolicy, SimpleConvictionPolicy, HostDistance, + AddressTranslator, TokenAwarePolicy, HostFilterPolicy) +from cassandra import ConsistencyLevel + +from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory +from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider +from cassandra import connection +from cassandra.connection import DefaultEndPoint + +from tests import notwindows +from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, CASSANDRA_VERSION, \ + execute_until_pass, execute_with_long_wait_retry, get_node, MockLoggingHandler, get_unsupported_lower_protocol, \ + get_unsupported_upper_protocol, protocolv5, local, CASSANDRA_IP, greaterthanorequalcass30, lessthanorequalcass40 +from tests.integration.util import assert_quiescent_pool_state +import sys + + +def setup_module(): + use_singledc() + warnings.simplefilter("always") + + +class IgnoredHostPolicy(RoundRobinPolicy): + + def __init__(self, ignored_hosts): + self.ignored_hosts = ignored_hosts + RoundRobinPolicy.__init__(self) + + def distance(self, host): + if(host.address in self.ignored_hosts): + return HostDistance.IGNORED + else: + return HostDistance.LOCAL + + +class ClusterTests(unittest.TestCase): + @local + def test_ignored_host_up(self): + """ + Test to ensure that is_up is not set by default on ignored hosts + + @since 3.6 + @jira_ticket PYTHON-551 + @expected_result ignored hosts should have None set for is_up + + @test_category connection + """ + ingored_host_policy = IgnoredHostPolicy(["127.0.0.2", "127.0.0.3"]) + cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=ingored_host_policy) + session = cluster.connect() + for host in cluster.metadata.all_hosts(): + if str(host) == "127.0.0.1:9042": + self.assertTrue(host.is_up) + else: + self.assertIsNone(host.is_up) + cluster.shutdown() + + @local + def test_host_resolution(self): + """ + Test to insure A records are resolved appropriately. + + @since 3.3 + @jira_ticket PYTHON-415 + @expected_result hostname will be transformed into IP + + @test_category connection + """ + cluster = Cluster(contact_points=["localhost"], protocol_version=PROTOCOL_VERSION, connect_timeout=1) + self.assertTrue(DefaultEndPoint('127.0.0.1') in cluster.endpoints_resolved) + + @local + def test_host_duplication(self): + """ + Ensure that duplicate hosts in the contact points are surfaced in the cluster metadata + + @since 3.3 + @jira_ticket PYTHON-103 + @expected_result duplicate hosts aren't surfaced in cluster.metadata + + @test_category connection + """ + cluster = Cluster(contact_points=["localhost", "127.0.0.1", "localhost", "localhost", "localhost"], protocol_version=PROTOCOL_VERSION, connect_timeout=1) + cluster.connect(wait_for_all_pools=True) + self.assertEqual(len(cluster.metadata.all_hosts()), 3) + cluster.shutdown() + cluster = Cluster(contact_points=["127.0.0.1", "localhost"], protocol_version=PROTOCOL_VERSION, connect_timeout=1) + cluster.connect(wait_for_all_pools=True) + self.assertEqual(len(cluster.metadata.all_hosts()), 3) + cluster.shutdown() + + @local + def test_raise_error_on_control_connection_timeout(self): + """ + Test for initial control connection timeout + + test_raise_error_on_control_connection_timeout tests that the driver times out after the set initial connection + timeout. It first pauses node1, essentially making it unreachable. It then attempts to create a Cluster object + via connecting to node1 with a timeout of 1 second, and ensures that a NoHostAvailable is raised, along with + an OperationTimedOut for 1 second. + + @expected_errors NoHostAvailable When node1 is paused, and a connection attempt is made. + @since 2.6.0 + @jira_ticket PYTHON-206 + @expected_result NoHostAvailable exception should be raised after 1 second. + + @test_category connection + """ + + get_node(1).pause() + cluster = Cluster(contact_points=['127.0.0.1'], protocol_version=PROTOCOL_VERSION, connect_timeout=1) + + with self.assertRaisesRegexp(NoHostAvailable, "OperationTimedOut\('errors=Timed out creating connection \(1 seconds\)"): + cluster.connect() + cluster.shutdown() + + get_node(1).resume() + + def test_basic(self): + """ + Test basic connection and usage + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + result = execute_until_pass(session, + """ + CREATE KEYSPACE clustertests + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + """) + self.assertFalse(result) + + result = execute_with_long_wait_retry(session, + """ + CREATE TABLE clustertests.cf0 ( + a text, + b text, + c text, + PRIMARY KEY (a, b) + ) + """) + self.assertFalse(result) + + result = session.execute( + """ + INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c') + """) + self.assertFalse(result) + + result = session.execute("SELECT * FROM clustertests.cf0") + self.assertEqual([('a', 'b', 'c')], result) + + execute_with_long_wait_retry(session, "DROP KEYSPACE clustertests") + + cluster.shutdown() + + def test_session_host_parameter(self): + """ + Test for protocol negotiation + + Very that NoHostAvailable is risen in Session.__init__ when there are no valid connections and that + no error is arisen otherwise, despite maybe being some invalid hosts + + @since 3.9 + @jira_ticket PYTHON-665 + @expected_result NoHostAvailable when the driver is unable to connect to a valid host, + no exception otherwise + + @test_category connection + """ + def cleanup(): + """ + When this test fails, the inline .shutdown() calls don't get + called, so we register this as a cleanup. + """ + self.cluster_to_shutdown.shutdown() + self.addCleanup(cleanup) + + # Test with empty list + self.cluster_to_shutdown = Cluster([], protocol_version=PROTOCOL_VERSION) + with self.assertRaises(NoHostAvailable): + self.cluster_to_shutdown.connect() + self.cluster_to_shutdown.shutdown() + + # Test with only invalid + self.cluster_to_shutdown = Cluster(('1.2.3.4',), protocol_version=PROTOCOL_VERSION) + with self.assertRaises(NoHostAvailable): + self.cluster_to_shutdown.connect() + self.cluster_to_shutdown.shutdown() + + # Test with valid and invalid hosts + self.cluster_to_shutdown = Cluster(("127.0.0.1", "127.0.0.2", "1.2.3.4"), + protocol_version=PROTOCOL_VERSION) + self.cluster_to_shutdown.connect() + self.cluster_to_shutdown.shutdown() + + def test_protocol_negotiation(self): + """ + Test for protocol negotiation + + test_protocol_negotiation tests that the driver will select the correct protocol version to match + the correct cassandra version. Please note that 2.1.5 has a + bug https://issues.apache.org/jira/browse/CASSANDRA-9451 that will cause this test to fail + that will cause this to not pass. It was rectified in 2.1.6 + + @since 2.6.0 + @jira_ticket PYTHON-240 + @expected_result the correct protocol version should be selected + + @test_category connection + """ + + cluster = Cluster() + self.assertLessEqual(cluster.protocol_version, cassandra.ProtocolVersion.MAX_SUPPORTED) + session = cluster.connect() + updated_protocol_version = session._protocol_version + updated_cluster_version = cluster.protocol_version + # Make sure the correct protocol was selected by default + if CASSANDRA_VERSION >= Version('2.2'): + self.assertEqual(updated_protocol_version, 4) + self.assertEqual(updated_cluster_version, 4) + elif CASSANDRA_VERSION >= Version('2.1'): + self.assertEqual(updated_protocol_version, 3) + self.assertEqual(updated_cluster_version, 3) + elif CASSANDRA_VERSION >= Version('2.0'): + self.assertEqual(updated_protocol_version, 2) + self.assertEqual(updated_cluster_version, 2) + else: + self.assertEqual(updated_protocol_version, 1) + self.assertEqual(updated_cluster_version, 1) + + cluster.shutdown() + + def test_invalid_protocol_negotation(self): + """ + Test for protocol negotiation when explicit versions are set + + If an explicit protocol version that is not compatible with the server version is set + an exception should be thrown. It should not attempt to negotiate + + for reference supported protocol version to server versions is as follows/ + + 1.2 -> 1 + 2.0 -> 2, 1 + 2.1 -> 3, 2, 1 + 2.2 -> 4, 3, 2, 1 + 3.X -> 4, 3 + + @since 3.6.0 + @jira_ticket PYTHON-537 + @expected_result downgrading should not be allowed when explicit protocol versions are set. + + @test_category connection + """ + + upper_bound = get_unsupported_upper_protocol() + if upper_bound is not None: + cluster = Cluster(protocol_version=upper_bound) + with self.assertRaises(NoHostAvailable): + cluster.connect() + cluster.shutdown() + + lower_bound = get_unsupported_lower_protocol() + if lower_bound is not None: + cluster = Cluster(protocol_version=lower_bound) + with self.assertRaises(NoHostAvailable): + cluster.connect() + cluster.shutdown() + + def test_connect_on_keyspace(self): + """ + Ensure clusters that connect on a keyspace, do + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + result = session.execute( + """ + INSERT INTO test1rf.test (k, v) VALUES (8889, 8889) + """) + self.assertFalse(result) + + result = session.execute("SELECT * FROM test1rf.test") + self.assertEqual([(8889, 8889)], result, "Rows in ResultSet are {0}".format(result.current_rows)) + + # test_connect_on_keyspace + session2 = cluster.connect('test1rf') + result2 = session2.execute("SELECT * FROM test") + self.assertEqual(result, result2) + cluster.shutdown() + + def test_set_keyspace_twice(self): + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + session.execute("USE system") + session.execute("USE system") + cluster.shutdown() + + def test_default_connections(self): + """ + Ensure errors are not thrown when using non-default policies + """ + + Cluster( + load_balancing_policy=RoundRobinPolicy(), + reconnection_policy=ExponentialReconnectionPolicy(1.0, 600.0), + default_retry_policy=RetryPolicy(), + conviction_policy_factory=SimpleConvictionPolicy, + protocol_version=PROTOCOL_VERSION + ) + + def test_connect_to_already_shutdown_cluster(self): + """ + Ensure you cannot connect to a cluster that's been shutdown + """ + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster.shutdown() + self.assertRaises(Exception, cluster.connect) + + def test_auth_provider_is_callable(self): + """ + Ensure that auth_providers are always callable + """ + self.assertRaises(TypeError, Cluster, auth_provider=1, protocol_version=1) + c = Cluster(protocol_version=1) + self.assertRaises(TypeError, setattr, c, 'auth_provider', 1) + + def test_v2_auth_provider(self): + """ + Check for v2 auth_provider compliance + """ + bad_auth_provider = lambda x: {'username': 'foo', 'password': 'bar'} + self.assertRaises(TypeError, Cluster, auth_provider=bad_auth_provider, protocol_version=2) + c = Cluster(protocol_version=2) + self.assertRaises(TypeError, setattr, c, 'auth_provider', bad_auth_provider) + + def test_conviction_policy_factory_is_callable(self): + """ + Ensure that conviction_policy_factory are always callable + """ + + self.assertRaises(ValueError, Cluster, conviction_policy_factory=1) + + def test_connect_to_bad_hosts(self): + """ + Ensure that a NoHostAvailable Exception is thrown + when a cluster cannot connect to given hosts + """ + + cluster = Cluster(['127.1.2.9', '127.1.2.10'], + protocol_version=PROTOCOL_VERSION) + self.assertRaises(NoHostAvailable, cluster.connect) + + def test_cluster_settings(self): + """ + Test connection setting getters and setters + """ + if PROTOCOL_VERSION >= 3: + raise unittest.SkipTest("min/max requests and core/max conns aren't used with v3 protocol") + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + + min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL) + self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection) + cluster.set_min_requests_per_connection(HostDistance.LOCAL, min_requests_per_connection + 1) + self.assertEqual(cluster.get_min_requests_per_connection(HostDistance.LOCAL), min_requests_per_connection + 1) + + max_requests_per_connection = cluster.get_max_requests_per_connection(HostDistance.LOCAL) + self.assertEqual(cassandra.cluster.DEFAULT_MAX_REQUESTS, max_requests_per_connection) + cluster.set_max_requests_per_connection(HostDistance.LOCAL, max_requests_per_connection + 1) + self.assertEqual(cluster.get_max_requests_per_connection(HostDistance.LOCAL), max_requests_per_connection + 1) + + core_connections_per_host = cluster.get_core_connections_per_host(HostDistance.LOCAL) + self.assertEqual(cassandra.cluster.DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, core_connections_per_host) + cluster.set_core_connections_per_host(HostDistance.LOCAL, core_connections_per_host + 1) + self.assertEqual(cluster.get_core_connections_per_host(HostDistance.LOCAL), core_connections_per_host + 1) + + max_connections_per_host = cluster.get_max_connections_per_host(HostDistance.LOCAL) + self.assertEqual(cassandra.cluster.DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, max_connections_per_host) + cluster.set_max_connections_per_host(HostDistance.LOCAL, max_connections_per_host + 1) + self.assertEqual(cluster.get_max_connections_per_host(HostDistance.LOCAL), max_connections_per_host + 1) + + def test_refresh_schema(self): + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + original_meta = cluster.metadata.keyspaces + # full schema refresh, with wait + cluster.refresh_schema_metadata() + self.assertIsNot(original_meta, cluster.metadata.keyspaces) + self.assertEqual(original_meta, cluster.metadata.keyspaces) + + cluster.shutdown() + + def test_refresh_schema_keyspace(self): + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + original_meta = cluster.metadata.keyspaces + original_system_meta = original_meta['system'] + + # only refresh one keyspace + cluster.refresh_keyspace_metadata('system') + current_meta = cluster.metadata.keyspaces + self.assertIs(original_meta, current_meta) + current_system_meta = current_meta['system'] + self.assertIsNot(original_system_meta, current_system_meta) + self.assertEqual(original_system_meta.as_cql_query(), current_system_meta.as_cql_query()) + cluster.shutdown() + + def test_refresh_schema_table(self): + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + original_meta = cluster.metadata.keyspaces + original_system_meta = original_meta['system'] + original_system_schema_meta = original_system_meta.tables['local'] + + # only refresh one table + cluster.refresh_table_metadata('system', 'local') + current_meta = cluster.metadata.keyspaces + current_system_meta = current_meta['system'] + current_system_schema_meta = current_system_meta.tables['local'] + self.assertIs(original_meta, current_meta) + self.assertIs(original_system_meta, current_system_meta) + self.assertIsNot(original_system_schema_meta, current_system_schema_meta) + self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query()) + cluster.shutdown() + + def test_refresh_schema_type(self): + if get_server_versions()[0] < (2, 1, 0): + raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1') + + if PROTOCOL_VERSION < 3: + raise unittest.SkipTest('UDTs are not specified in change events for protocol v2') + # We may want to refresh types on keyspace change events in that case(?) + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + keyspace_name = 'test1rf' + type_name = self._testMethodName + + execute_until_pass(session, 'CREATE TYPE IF NOT EXISTS %s.%s (one int, two text)' % (keyspace_name, type_name)) + original_meta = cluster.metadata.keyspaces + original_test1rf_meta = original_meta[keyspace_name] + original_type_meta = original_test1rf_meta.user_types[type_name] + + # only refresh one type + cluster.refresh_user_type_metadata('test1rf', type_name) + current_meta = cluster.metadata.keyspaces + current_test1rf_meta = current_meta[keyspace_name] + current_type_meta = current_test1rf_meta.user_types[type_name] + self.assertIs(original_meta, current_meta) + self.assertEqual(original_test1rf_meta.export_as_string(), current_test1rf_meta.export_as_string()) + self.assertIsNot(original_type_meta, current_type_meta) + self.assertEqual(original_type_meta.as_cql_query(), current_type_meta.as_cql_query()) + cluster.shutdown() + + @local + @notwindows + def test_refresh_schema_no_wait(self): + contact_points = [CASSANDRA_IP] + cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10, + contact_points=contact_points, + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == CASSANDRA_IP + )) + session = cluster.connect() + + schema_ver = session.execute("SELECT schema_version FROM system.local WHERE key='local'")[0][0] + new_schema_ver = uuid4() + session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (new_schema_ver,)) + + + try: + agreement_timeout = 1 + + # cluster agreement wait exceeded + c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=agreement_timeout) + c.connect() + self.assertTrue(c.metadata.keyspaces) + + # cluster agreement wait used for refresh + original_meta = c.metadata.keyspaces + start_time = time.time() + self.assertRaisesRegexp(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata) + end_time = time.time() + self.assertGreaterEqual(end_time - start_time, agreement_timeout) + self.assertIs(original_meta, c.metadata.keyspaces) + + # refresh wait overrides cluster value + original_meta = c.metadata.keyspaces + start_time = time.time() + c.refresh_schema_metadata(max_schema_agreement_wait=0) + end_time = time.time() + self.assertLess(end_time - start_time, agreement_timeout) + self.assertIsNot(original_meta, c.metadata.keyspaces) + self.assertEqual(original_meta, c.metadata.keyspaces) + + c.shutdown() + + refresh_threshold = 0.5 + # cluster agreement bypass + c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0) + start_time = time.time() + s = c.connect() + end_time = time.time() + self.assertLess(end_time - start_time, refresh_threshold) + self.assertTrue(c.metadata.keyspaces) + + # cluster agreement wait used for refresh + original_meta = c.metadata.keyspaces + start_time = time.time() + c.refresh_schema_metadata() + end_time = time.time() + self.assertLess(end_time - start_time, refresh_threshold) + self.assertIsNot(original_meta, c.metadata.keyspaces) + self.assertEqual(original_meta, c.metadata.keyspaces) + + # refresh wait overrides cluster value + original_meta = c.metadata.keyspaces + start_time = time.time() + self.assertRaisesRegexp(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata, + max_schema_agreement_wait=agreement_timeout) + end_time = time.time() + self.assertGreaterEqual(end_time - start_time, agreement_timeout) + self.assertIs(original_meta, c.metadata.keyspaces) + c.shutdown() + finally: + # TODO once fixed this connect call + session = cluster.connect() + session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (schema_ver,)) + + cluster.shutdown() + + def test_trace(self): + """ + Ensure trace can be requested for async and non-async queries + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + result = session.execute( "SELECT * FROM system.local", trace=True) + self._check_trace(result.get_query_trace()) + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + result = session.execute(statement, trace=True) + self._check_trace(result.get_query_trace()) + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + result = session.execute(statement) + self.assertIsNone(result.get_query_trace()) + + statement2 = SimpleStatement(query) + future = session.execute_async(statement2, trace=True) + future.result() + self._check_trace(future.get_query_trace()) + + statement2 = SimpleStatement(query) + future = session.execute_async(statement2) + future.result() + self.assertIsNone(future.get_query_trace()) + + prepared = session.prepare("SELECT * FROM system.local") + future = session.execute_async(prepared, parameters=(), trace=True) + future.result() + self._check_trace(future.get_query_trace()) + cluster.shutdown() + + def test_trace_unavailable(self): + """ + First checks that TraceUnavailable is arisen if the + max_wait parameter is negative + + Then checks that TraceUnavailable is arisen if the + result hasn't been set yet + + @since 3.10 + @jira_ticket PYTHON-196 + @expected_result TraceUnavailable is arisen in both cases + + @test_category query + """ + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.addCleanup(cluster.shutdown) + session = cluster.connect() + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + + max_retry_count = 10 + for i in range(max_retry_count): + future = session.execute_async(statement, trace=True) + future.result() + try: + result = future.get_query_trace(-1.0) + # In case the result has time to come back before this timeout due to a race condition + self._check_trace(result) + except TraceUnavailable: + break + else: + raise Exception("get_query_trace didn't raise TraceUnavailable after {} tries".format(max_retry_count)) + + + for i in range(max_retry_count): + future = session.execute_async(statement, trace=True) + try: + result = future.get_query_trace(max_wait=120) + # In case the result has been set check the trace + self._check_trace(result) + except TraceUnavailable: + break + else: + raise Exception("get_query_trace didn't raise TraceUnavailable after {} tries".format(max_retry_count)) + + def test_one_returns_none(self): + """ + Test ResulSet.one returns None if no rows where found + + @since 3.14 + @jira_ticket PYTHON-947 + @expected_result ResulSet.one is None + + @test_category query + """ + with Cluster() as cluster: + session = cluster.connect() + self.assertIsNone(session.execute("SELECT * from system.local WHERE key='madeup_key'").one()) + + def test_string_coverage(self): + """ + Ensure str(future) returns without error + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + future = session.execute_async(statement) + + self.assertIn(query, str(future)) + future.result() + + self.assertIn(query, str(future)) + self.assertIn('result', str(future)) + cluster.shutdown() + + def test_can_connect_with_plainauth(self): + """ + Verify that we can connect setting PlainTextAuthProvider against a + C* server without authentication set. We also verify a warning is + issued per connection. This test is here instead of in test_authentication.py + because the C* server running in that module has auth set. + + @since 3.14 + @jira_ticket PYTHON-940 + @expected_result we can connect, query C* and warning are issued + + @test_category auth + """ + auth_provider = PlainTextAuthProvider( + username="made_up_username", + password="made_up_password" + ) + self._warning_are_issued_when_auth(auth_provider) + + def test_can_connect_with_sslauth(self): + """ + Verify that we can connect setting SaslAuthProvider against a + C* server without authentication set. We also verify a warning is + issued per connection. This test is here instead of in test_authentication.py + because the C* server running in that module has auth set. + + @since 3.14 + @jira_ticket PYTHON-940 + @expected_result we can connect, query C* and warning are issued + + @test_category auth + """ + sasl_kwargs = {'service': 'cassandra', + 'mechanism': 'PLAIN', + 'qops': ['auth'], + 'username': "made_up_username", + 'password': "made_up_password"} + + auth_provider = SaslAuthProvider(**sasl_kwargs) + self._warning_are_issued_when_auth(auth_provider) + + def _warning_are_issued_when_auth(self, auth_provider): + with MockLoggingHandler().set_module_name(connection.__name__) as mock_handler: + with Cluster(auth_provider=auth_provider) as cluster: + session = cluster.connect() + self.assertIsNotNone(session.execute("SELECT * from system.local")) + + # Three conenctions to nodes plus the control connection + self.assertEqual(4, mock_handler.get_message_count('warning', + "An authentication challenge was not sent")) + + def test_idle_heartbeat(self): + interval = 2 + cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=interval) + if PROTOCOL_VERSION < 3: + cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) + session = cluster.connect(wait_for_all_pools=True) + + # This test relies on impl details of connection req id management to see if heartbeats + # are being sent. May need update if impl is changed + connection_request_ids = {} + for h in cluster.get_connection_holders(): + for c in h.get_connections(): + # make sure none are idle (should have startup messages + self.assertFalse(c.is_idle) + with c.lock: + connection_request_ids[id(c)] = deque(c.request_ids) # copy of request ids + + # let two heatbeat intervals pass (first one had startup messages in it) + time.sleep(2 * interval + interval/2) + + connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()] + + # make sure requests were sent on all connections + for c in connections: + expected_ids = connection_request_ids[id(c)] + expected_ids.rotate(-1) + with c.lock: + self.assertListEqual(list(c.request_ids), list(expected_ids)) + + # assert idle status + self.assertTrue(all(c.is_idle for c in connections)) + + # send messages on all connections + statements_and_params = [("SELECT release_version FROM system.local", ())] * len(cluster.metadata.all_hosts()) + results = execute_concurrent(session, statements_and_params) + for success, result in results: + self.assertTrue(success) + + # assert not idle status + self.assertFalse(any(c.is_idle if not c.is_control_connection else False for c in connections)) + + # holders include session pools and cc + holders = cluster.get_connection_holders() + self.assertIn(cluster.control_connection, holders) + self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1) # hosts pools, 1 for cc + + # include additional sessions + session2 = cluster.connect(wait_for_all_pools=True) + + holders = cluster.get_connection_holders() + self.assertIn(cluster.control_connection, holders) + self.assertEqual(len(holders), 2 * len(cluster.metadata.all_hosts()) + 1) # 2 sessions' hosts pools, 1 for cc + + cluster._idle_heartbeat.stop() + cluster._idle_heartbeat.join() + assert_quiescent_pool_state(self, cluster) + + cluster.shutdown() + + @patch('cassandra.cluster.Cluster.idle_heartbeat_interval', new=0.1) + def test_idle_heartbeat_disabled(self): + self.assertTrue(Cluster.idle_heartbeat_interval) + + # heartbeat disabled with '0' + cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0) + self.assertEqual(cluster.idle_heartbeat_interval, 0) + session = cluster.connect() + + # let two heatbeat intervals pass (first one had startup messages in it) + time.sleep(2 * Cluster.idle_heartbeat_interval) + + connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()] + + # assert not idle status (should never get reset because there is not heartbeat) + self.assertFalse(any(c.is_idle for c in connections)) + + cluster.shutdown() + + def test_pool_management(self): + # Ensure that in_flight and request_ids quiesce after cluster operations + cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0) # no idle heartbeat here, pool management is tested in test_idle_heartbeat + session = cluster.connect() + session2 = cluster.connect() + + # prepare + p = session.prepare("SELECT * FROM system.local WHERE key=?") + self.assertTrue(session.execute(p, ('local',))) + + # simple + self.assertTrue(session.execute("SELECT * FROM system.local WHERE key='local'")) + + # set keyspace + session.set_keyspace('system') + session.set_keyspace('system_traces') + + # use keyspace + session.execute('USE system') + session.execute('USE system_traces') + + # refresh schema + cluster.refresh_schema_metadata() + cluster.refresh_schema_metadata(max_schema_agreement_wait=0) + + assert_quiescent_pool_state(self, cluster) + + cluster.shutdown() + + @local + def test_profile_load_balancing(self): + """ + Tests that profile load balancing policies are honored. + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result Execution Policy should be used when applicable. + + @test_category config_profiles + """ + query = "select release_version from system.local" + node1 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == CASSANDRA_IP + ) + ) + with Cluster(execution_profiles={'node1': node1}) as cluster: + session = cluster.connect(wait_for_all_pools=True) + + # default is DCA RR for all hosts + expected_hosts = set(cluster.metadata.all_hosts()) + queried_hosts = set() + for _ in expected_hosts: + rs = session.execute(query) + queried_hosts.add(rs.response_future._current_host) + self.assertEqual(queried_hosts, expected_hosts) + + # by name we should only hit the one + expected_hosts = set(h for h in cluster.metadata.all_hosts() if h.address == CASSANDRA_IP) + queried_hosts = set() + for _ in cluster.metadata.all_hosts(): + rs = session.execute(query, execution_profile='node1') + queried_hosts.add(rs.response_future._current_host) + self.assertEqual(queried_hosts, expected_hosts) + + # use a copied instance and override the row factory + # assert last returned value can be accessed as a namedtuple so we can prove something different + named_tuple_row = rs[0] + self.assertIsInstance(named_tuple_row, tuple) + self.assertTrue(named_tuple_row.release_version) + + tmp_profile = copy(node1) + tmp_profile.row_factory = tuple_factory + queried_hosts = set() + for _ in cluster.metadata.all_hosts(): + rs = session.execute(query, execution_profile=tmp_profile) + queried_hosts.add(rs.response_future._current_host) + self.assertEqual(queried_hosts, expected_hosts) + tuple_row = rs[0] + self.assertIsInstance(tuple_row, tuple) + with self.assertRaises(AttributeError): + tuple_row.release_version + + # make sure original profile is not impacted + self.assertTrue(session.execute(query, execution_profile='node1')[0].release_version) + + def test_setting_lbp_legacy(self): + cluster = Cluster() + self.addCleanup(cluster.shutdown) + cluster.load_balancing_policy = RoundRobinPolicy() + self.assertEqual( + list(cluster.load_balancing_policy.make_query_plan()), [] + ) + cluster.connect() + self.assertNotEqual( + list(cluster.load_balancing_policy.make_query_plan()), [] + ) + + def test_profile_lb_swap(self): + """ + Tests that profile load balancing policies are not shared + + Creates two LBP, runs a few queries, and validates that each LBP is execised + seperately between EP's + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result LBP should not be shared. + + @test_category config_profiles + """ + query = "select release_version from system.local" + rr1 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + rr2 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + exec_profiles = {'rr1': rr1, 'rr2': rr2} + with Cluster(execution_profiles=exec_profiles) as cluster: + session = cluster.connect(wait_for_all_pools=True) + + # default is DCA RR for all hosts + expected_hosts = set(cluster.metadata.all_hosts()) + rr1_queried_hosts = set() + rr2_queried_hosts = set() + + rs = session.execute(query, execution_profile='rr1') + rr1_queried_hosts.add(rs.response_future._current_host) + rs = session.execute(query, execution_profile='rr2') + rr2_queried_hosts.add(rs.response_future._current_host) + + self.assertEqual(rr2_queried_hosts, rr1_queried_hosts) + + def test_ta_lbp(self): + """ + Test that execution profiles containing token aware LBP can be added + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result Queries can run + + @test_category config_profiles + """ + query = "select release_version from system.local" + ta1 = ExecutionProfile() + with Cluster() as cluster: + session = cluster.connect() + cluster.add_execution_profile("ta1", ta1) + rs = session.execute(query, execution_profile='ta1') + + def test_clone_shared_lbp(self): + """ + Tests that profile load balancing policies are shared on clone + + Creates one LBP clones it, and ensures that the LBP is shared between + the two EP's + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result LBP is shared + + @test_category config_profiles + """ + query = "select release_version from system.local" + rr1 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + exec_profiles = {'rr1': rr1} + with Cluster(execution_profiles=exec_profiles) as cluster: + session = cluster.connect(wait_for_all_pools=True) + self.assertGreater(len(cluster.metadata.all_hosts()), 1, "We only have one host connected at this point") + + rr1_clone = session.execution_profile_clone_update('rr1', row_factory=tuple_factory) + cluster.add_execution_profile("rr1_clone", rr1_clone) + rr1_queried_hosts = set() + rr1_clone_queried_hosts = set() + rs = session.execute(query, execution_profile='rr1') + rr1_queried_hosts.add(rs.response_future._current_host) + rs = session.execute(query, execution_profile='rr1_clone') + rr1_clone_queried_hosts.add(rs.response_future._current_host) + self.assertNotEqual(rr1_clone_queried_hosts, rr1_queried_hosts) + + def test_missing_exec_prof(self): + """ + Tests to verify that using an unknown profile raises a ValueError + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result ValueError + + @test_category config_profiles + """ + query = "select release_version from system.local" + rr1 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + rr2 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + exec_profiles = {'rr1': rr1, 'rr2': rr2} + with Cluster(execution_profiles=exec_profiles) as cluster: + session = cluster.connect() + with self.assertRaises(ValueError): + session.execute(query, execution_profile='rr3') + + @local + def test_profile_pool_management(self): + """ + Tests that changes to execution profiles correctly impact our cluster's pooling + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result pools should be correctly updated as EP's are added and removed + + @test_category config_profiles + """ + + node1 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.1" + ) + ) + node2 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.2" + ) + ) + with Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1, 'node2': node2}) as cluster: + session = cluster.connect(wait_for_all_pools=True) + pools = session.get_pool_state() + # there are more hosts, but we connected to the ones in the lbp aggregate + self.assertGreater(len(cluster.metadata.all_hosts()), 2) + self.assertEqual(set(h.address for h in pools), set(('127.0.0.1', '127.0.0.2'))) + + # dynamically update pools on add + node3 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.3" + ) + ) + cluster.add_execution_profile('node3', node3) + pools = session.get_pool_state() + self.assertEqual(set(h.address for h in pools), set(('127.0.0.1', '127.0.0.2', '127.0.0.3'))) + + @local + def test_add_profile_timeout(self): + """ + Tests that EP Timeouts are honored. + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result EP timeouts should override defaults + + @test_category config_profiles + """ + max_retry_count = 10 + for i in range(max_retry_count): + node1 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.1" + ) + ) + with Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1}) as cluster: + session = cluster.connect(wait_for_all_pools=True) + pools = session.get_pool_state() + self.assertGreater(len(cluster.metadata.all_hosts()), 2) + self.assertEqual(set(h.address for h in pools), set(('127.0.0.1',))) + + node2 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address in ["127.0.0.2", "127.0.0.3"] + ) + ) + + start = time.time() + try: + self.assertRaises(cassandra.OperationTimedOut, cluster.add_execution_profile, + 'profile_{0}'.format(i), + node2, pool_wait_timeout=sys.float_info.min) + break + except AssertionError: + end = time.time() + self.assertAlmostEqual(start, end, 1) + else: + raise Exception("add_execution_profile didn't timeout after {0} retries".format(max_retry_count)) + + def test_replicas_are_queried(self): + """ + Test that replicas are queried first for TokenAwarePolicy. A table with RF 1 + is created. All the queries should go to that replica when TokenAwarePolicy + is used. + Then using HostFilterPolicy the replica is excluded from the considered hosts. + By checking the trace we verify that there are no more replicas. + + @since 3.5 + @jira_ticket PYTHON-653 + @expected_result the replicas are queried for HostFilterPolicy + + @test_category metadata + """ + queried_hosts = set() + with Cluster(protocol_version=PROTOCOL_VERSION, + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy())) as cluster: + session = cluster.connect(wait_for_all_pools=True) + session.execute(''' + CREATE TABLE test1rf.table_with_big_key ( + k1 int, + k2 int, + k3 int, + k4 int, + PRIMARY KEY((k1, k2, k3), k4))''') + prepared = session.prepare("""SELECT * from test1rf.table_with_big_key + WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") + for i in range(10): + result = session.execute(prepared, (i, i, i, i), trace=True) + trace = result.response_future.get_query_trace(query_cl=ConsistencyLevel.ALL) + queried_hosts = self._assert_replica_queried(trace, only_replicas=True) + last_i = i + + only_replica = queried_hosts.pop() + log = logging.getLogger(__name__) + log.info("The only replica found was: {}".format(only_replica)) + available_hosts = [host for host in ["127.0.0.1", "127.0.0.2", "127.0.0.3"] if host != only_replica] + with Cluster(contact_points=available_hosts, + protocol_version=PROTOCOL_VERSION, + load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), + predicate=lambda host: host.address != only_replica)) as cluster: + + session = cluster.connect(wait_for_all_pools=True) + prepared = session.prepare("""SELECT * from test1rf.table_with_big_key + WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") + for _ in range(10): + result = session.execute(prepared, (last_i, last_i, last_i, last_i), trace=True) + trace = result.response_future.get_query_trace(query_cl=ConsistencyLevel.ALL) + self._assert_replica_queried(trace, only_replicas=False) + + session.execute('''DROP TABLE test1rf.table_with_big_key''') + + @unittest.skip + @greaterthanorequalcass30 + @lessthanorequalcass40 + def test_compact_option(self): + """ + Test the driver can connect with the no_compact option and the results + are as expected. This test is very similar to the corresponding dtest + + @since 3.12 + @jira_ticket PYTHON-366 + @expected_result only one hosts' metadata will be populated + + @test_category connection + """ + nc_cluster = Cluster(protocol_version=PROTOCOL_VERSION, no_compact=True) + nc_session = nc_cluster.connect() + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, no_compact=False) + session = cluster.connect() + + self.addCleanup(cluster.shutdown) + self.addCleanup(nc_cluster.shutdown) + + nc_session.set_keyspace("test3rf") + session.set_keyspace("test3rf") + + nc_session.execute( + "CREATE TABLE IF NOT EXISTS compact_table (k int PRIMARY KEY, v1 int, v2 int) WITH COMPACT STORAGE;") + + for i in range(1, 5): + nc_session.execute( + "INSERT INTO compact_table (k, column1, v1, v2, value) VALUES " + "({i}, 'a{i}', {i}, {i}, textAsBlob('b{i}'))".format(i=i)) + nc_session.execute( + "INSERT INTO compact_table (k, column1, v1, v2, value) VALUES " + "({i}, 'a{i}{i}', {i}{i}, {i}{i}, textAsBlob('b{i}{i}'))".format(i=i)) + + nc_results = nc_session.execute("SELECT * FROM compact_table") + self.assertEqual( + set(nc_results.current_rows), + {(1, u'a1', 11, 11, 'b1'), + (1, u'a11', 11, 11, 'b11'), + (2, u'a2', 22, 22, 'b2'), + (2, u'a22', 22, 22, 'b22'), + (3, u'a3', 33, 33, 'b3'), + (3, u'a33', 33, 33, 'b33'), + (4, u'a4', 44, 44, 'b4'), + (4, u'a44', 44, 44, 'b44')}) + + results = session.execute("SELECT * FROM compact_table") + self.assertEqual( + set(results.current_rows), + {(1, 11, 11), + (2, 22, 22), + (3, 33, 33), + (4, 44, 44)}) + + def _assert_replica_queried(self, trace, only_replicas=True): + queried_hosts = set() + for row in trace.events: + queried_hosts.add(row.source) + if only_replicas: + self.assertEqual(len(queried_hosts), 1, "The hosts queried where {}".format(queried_hosts)) + else: + self.assertGreater(len(queried_hosts), 1, "The host queried was {}".format(queried_hosts)) + return queried_hosts + + def _check_trace(self, trace): + self.assertIsNotNone(trace.request_type) + self.assertIsNotNone(trace.duration) + self.assertIsNotNone(trace.started_at) + self.assertIsNotNone(trace.coordinator) + self.assertIsNotNone(trace.events) + + +class LocalHostAdressTranslator(AddressTranslator): + + def __init__(self, addr_map=None): + self.addr_map = addr_map + + def translate(self, addr): + new_addr = self.addr_map.get(addr) + return new_addr + +@local +class TestAddressTranslation(unittest.TestCase): + + def test_address_translator_basic(self): + """ + Test host address translation + + Uses a custom Address Translator to map all ip back to one. + Validates AddressTranslator invocation by ensuring that only meta data associated with single + host is populated + + @since 3.3 + @jira_ticket PYTHON-69 + @expected_result only one hosts' metadata will be populated + + @test_category metadata + """ + lh_ad = LocalHostAdressTranslator({'127.0.0.1': '127.0.0.1', '127.0.0.2': '127.0.0.1', '127.0.0.3': '127.0.0.1'}) + c = Cluster(address_translator=lh_ad) + c.connect() + self.assertEqual(len(c.metadata.all_hosts()), 1) + c.shutdown() + + def test_address_translator_with_mixed_nodes(self): + """ + Test host address translation + + Uses a custom Address Translator to map ip's of non control_connection nodes to each other + Validates AddressTranslator invocation by ensuring that metadata for mapped hosts is also mapped + + @since 3.3 + @jira_ticket PYTHON-69 + @expected_result metadata for crossed hosts will also be crossed + + @test_category metadata + """ + adder_map = {'127.0.0.1': '127.0.0.1', '127.0.0.2': '127.0.0.3', '127.0.0.3': '127.0.0.2'} + lh_ad = LocalHostAdressTranslator(adder_map) + c = Cluster(address_translator=lh_ad) + c.connect() + for host in c.metadata.all_hosts(): + self.assertEqual(adder_map.get(host.address), host.broadcast_address) + c.shutdown() + +@local +class ContextManagementTest(unittest.TestCase): + load_balancing_policy = HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == CASSANDRA_IP + ) + cluster_kwargs = {'execution_profiles': {EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy= + load_balancing_policy)}, + 'schema_metadata_enabled': False, + 'token_metadata_enabled': False} + + def test_no_connect(self): + """ + Test cluster context without connecting. + + @since 3.4 + @jira_ticket PYTHON-521 + @expected_result context should still be valid + + @test_category configuration + """ + with Cluster() as cluster: + self.assertFalse(cluster.is_shutdown) + self.assertTrue(cluster.is_shutdown) + + def test_simple_nested(self): + """ + Test cluster and session contexts nested in one another. + + @since 3.4 + @jira_ticket PYTHON-521 + @expected_result cluster/session should be crated and shutdown appropriately. + + @test_category configuration + """ + with Cluster(**self.cluster_kwargs) as cluster: + with cluster.connect() as session: + self.assertFalse(cluster.is_shutdown) + self.assertFalse(session.is_shutdown) + self.assertTrue(session.execute('select release_version from system.local')[0]) + self.assertTrue(session.is_shutdown) + self.assertTrue(cluster.is_shutdown) + + def test_cluster_no_session(self): + """ + Test cluster context without session context. + + @since 3.4 + @jira_ticket PYTHON-521 + @expected_result Session should be created correctly. Cluster should shutdown outside of context + + @test_category configuration + """ + with Cluster(**self.cluster_kwargs) as cluster: + session = cluster.connect() + self.assertFalse(cluster.is_shutdown) + self.assertFalse(session.is_shutdown) + self.assertTrue(session.execute('select release_version from system.local')[0]) + self.assertTrue(session.is_shutdown) + self.assertTrue(cluster.is_shutdown) + + def test_session_no_cluster(self): + """ + Test session context without cluster context. + + @since 3.4 + @jira_ticket PYTHON-521 + @expected_result session should be created correctly. Session should shutdown correctly outside of context + + @test_category configuration + """ + cluster = Cluster(**self.cluster_kwargs) + unmanaged_session = cluster.connect() + with cluster.connect() as session: + self.assertFalse(cluster.is_shutdown) + self.assertFalse(session.is_shutdown) + self.assertFalse(unmanaged_session.is_shutdown) + self.assertTrue(session.execute('select release_version from system.local')[0]) + self.assertTrue(session.is_shutdown) + self.assertFalse(cluster.is_shutdown) + self.assertFalse(unmanaged_session.is_shutdown) + unmanaged_session.shutdown() + self.assertTrue(unmanaged_session.is_shutdown) + self.assertFalse(cluster.is_shutdown) + cluster.shutdown() + self.assertTrue(cluster.is_shutdown) + + +class HostStateTest(unittest.TestCase): + + def test_down_event_with_active_connection(self): + """ + Test to ensure that on down calls to clusters with connections still active don't result in + a host being marked down. The second part of the test kills the connection then invokes + on_down, and ensures the state changes for host's metadata. + + @since 3.7 + @jira_ticket PYTHON-498 + @expected_result host should never be toggled down while a connection is active. + + @test_category connection + """ + with Cluster(protocol_version=PROTOCOL_VERSION) as cluster: + session = cluster.connect(wait_for_all_pools=True) + random_host = cluster.metadata.all_hosts()[0] + cluster.on_down(random_host, False) + for _ in range(10): + new_host = cluster.metadata.all_hosts()[0] + self.assertTrue(new_host.is_up, "Host was not up on iteration {0}".format(_)) + time.sleep(.01) + + pool = session._pools.get(random_host) + pool.shutdown() + cluster.on_down(random_host, False) + was_marked_down = False + for _ in range(20): + new_host = cluster.metadata.all_hosts()[0] + if not new_host.is_up: + was_marked_down = True + break + time.sleep(.01) + self.assertTrue(was_marked_down) + + +@local +class DontPrepareOnIgnoredHostsTest(unittest.TestCase): + ignored_addresses = ['127.0.0.3'] + ignore_node_3_policy = IgnoredHostPolicy(ignored_addresses) + + def test_prepare_on_ignored_hosts(self): + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, + load_balancing_policy=self.ignore_node_3_policy) + session = cluster.connect() + cluster.reprepare_on_up, cluster.prepare_on_all_hosts = True, False + + hosts = cluster.metadata.all_hosts() + session.execute("CREATE KEYSPACE clustertests " + "WITH replication = " + "{'class': 'SimpleStrategy', 'replication_factor': '1'}") + session.execute("CREATE TABLE clustertests.tab (a text, PRIMARY KEY (a))") + # assign to an unused variable so cluster._prepared_statements retains + # reference + _ = session.prepare("INSERT INTO clustertests.tab (a) VALUES ('a')") # noqa + + cluster.connection_factory = Mock(wraps=cluster.connection_factory) + + unignored_address = '127.0.0.1' + unignored_host = next(h for h in hosts if h.address == unignored_address) + ignored_host = next(h for h in hosts if h.address in self.ignored_addresses) + unignored_host.is_up = ignored_host.is_up = False + + cluster.on_up(unignored_host) + cluster.on_up(ignored_host) + + # the length of mock_calls will vary, but all should use the unignored + # address + for c in cluster.connection_factory.mock_calls: + self.assertEqual(call(DefaultEndPoint(unignored_address)), c) + cluster.shutdown() + + +@local +class DuplicateRpcTest(unittest.TestCase): + + load_balancing_policy = HostFilterPolicy(RoundRobinPolicy(), + lambda host: host.address == "127.0.0.1") + + def setUp(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=self.load_balancing_policy) + self.session = self.cluster.connect() + self.session.execute("UPDATE system.peers SET rpc_address = '127.0.0.1' WHERE peer='127.0.0.2'") + + def tearDown(self): + self.session.execute("UPDATE system.peers SET rpc_address = '127.0.0.2' WHERE peer='127.0.0.2'") + self.cluster.shutdown() + + def test_duplicate(self): + """ + Test duplicate RPC addresses. + + Modifies the system.peers table to make hosts have the same rpc address. Ensures such hosts are filtered out and a message is logged + + @since 3.4 + @jira_ticket PYTHON-366 + @expected_result only one hosts' metadata will be populated + + @test_category metadata + """ + mock_handler = MockLoggingHandler() + logger = logging.getLogger(cassandra.cluster.__name__) + logger.addHandler(mock_handler) + test_cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=self.load_balancing_policy) + test_cluster.connect() + warnings = mock_handler.messages.get("warning") + self.assertEqual(len(warnings), 1) + self.assertTrue('multiple' in warnings[0]) + logger.removeHandler(mock_handler) + test_cluster.shutdown() + + +@protocolv5 +class BetaProtocolTest(unittest.TestCase): + + @protocolv5 + def test_invalid_protocol_version_beta_option(self): + """ + Test cluster connection with protocol v5 and beta flag not set + + @since 3.7.0 + @jira_ticket PYTHON-614 + @expected_result client shouldn't connect with V5 and no beta flag set + + @test_category connection + """ + + cluster = Cluster(protocol_version=cassandra.ProtocolVersion.MAX_SUPPORTED, allow_beta_protocol_version=False) + try: + with self.assertRaises(NoHostAvailable): + cluster.connect() + except Exception as e: + self.fail("Unexpected error encountered {0}".format(e.message)) + + @protocolv5 + def test_valid_protocol_version_beta_options_connect(self): + """ + Test cluster connection with protocol version 5 and beta flag set + + @since 3.7.0 + @jira_ticket PYTHON-614 + @expected_result client should connect with protocol v5 and beta flag set. + + @test_category connection + """ + cluster = Cluster(protocol_version=cassandra.ProtocolVersion.MAX_SUPPORTED, allow_beta_protocol_version=True) + session = cluster.connect() + self.assertEqual(cluster.protocol_version, cassandra.ProtocolVersion.MAX_SUPPORTED) + self.assertTrue(session.execute("select release_version from system.local")[0]) + cluster.shutdown() + + +class DeprecationWarningTest(unittest.TestCase): + def test_deprecation_warnings_legacy_parameters(self): + """ + Tests the deprecation warning has been added when using + legacy parameters + + @since 3.13 + @jira_ticket PYTHON-877 + @expected_result the deprecation warning is emitted + + @test_category logs + """ + with warnings.catch_warnings(record=True) as w: + Cluster(load_balancing_policy=RoundRobinPolicy()) + self.assertEqual(len(w), 1) + self.assertIn("Legacy execution parameters will be removed in 4.0. Consider using execution profiles.", + str(w[0].message)) + + def test_deprecation_warnings_meta_refreshed(self): + """ + Tests the deprecation warning has been added when enabling + metadata refreshment + + @since 3.13 + @jira_ticket PYTHON-890 + @expected_result the deprecation warning is emitted + + @test_category logs + """ + with warnings.catch_warnings(record=True) as w: + cluster = Cluster() + cluster.set_meta_refresh_enabled(True) + self.assertEqual(len(w), 1) + self.assertIn("Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0.", + str(w[0].message)) + + def test_deprecation_warning_default_consistency_level(self): + """ + Tests the deprecation warning has been added when enabling + session the default consistency level to session + + @since 3.14 + @jira_ticket PYTHON-935 + @expected_result the deprecation warning is emitted + + @test_category logs + """ + with warnings.catch_warnings(record=True) as w: + cluster = Cluster() + session = cluster.connect() + session.default_consistency_level = ConsistencyLevel.ONE + self.assertEqual(len(w), 1) + self.assertIn("Setting the consistency level at the session level will be removed in 4.0", + str(w[0].message)) diff --git a/tests/integration/standard/test_concurrent.py b/tests/integration/standard/test_concurrent.py new file mode 100644 index 0000000..c85bb64 --- /dev/null +++ b/tests/integration/standard/test_concurrent.py @@ -0,0 +1,309 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import cycle +from six import next +import sys, logging, traceback + +from cassandra import InvalidRequest, ConsistencyLevel, ReadTimeout, WriteTimeout, OperationTimedOut, \ + ReadFailure, WriteFailure +from cassandra.cluster import Cluster +from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args, ExecutionResult +from cassandra.policies import HostDistance +from cassandra.query import tuple_factory, SimpleStatement + +from tests.integration import use_singledc, PROTOCOL_VERSION + +from six import next + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +log = logging.getLogger(__name__) + + +def setup_module(): + use_singledc() + + +class ClusterTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + if PROTOCOL_VERSION < 3: + cls.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) + cls.session = cls.cluster.connect() + cls.session.row_factory = tuple_factory + + @classmethod + def tearDownClass(cls): + cls.cluster.shutdown() + + def execute_concurrent_helper(self, session, query, results_generator=False): + count = 0 + while count < 100: + try: + return execute_concurrent(session, query, results_generator=False) + except (ReadTimeout, WriteTimeout, OperationTimedOut, ReadFailure, WriteFailure): + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + count += 1 + + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + + def execute_concurrent_args_helper(self, session, query, params, results_generator=False): + count = 0 + while count < 100: + try: + return execute_concurrent_with_args(session, query, params, results_generator=results_generator) + except (ReadTimeout, WriteTimeout, OperationTimedOut, ReadFailure, WriteFailure): + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + + def test_execute_concurrent(self): + for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201): + # write + statement = SimpleStatement( + "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", + consistency_level=ConsistencyLevel.QUORUM) + statements = cycle((statement, )) + parameters = [(i, i) for i in range(num_statements)] + + results = self.execute_concurrent_helper(self.session, list(zip(statements, parameters))) + self.assertEqual(num_statements, len(results)) + for success, result in results: + self.assertTrue(success) + self.assertFalse(result) + + # read + statement = SimpleStatement( + "SELECT v FROM test3rf.test WHERE k=%s", + consistency_level=ConsistencyLevel.QUORUM) + statements = cycle((statement, )) + parameters = [(i, ) for i in range(num_statements)] + + results = self.execute_concurrent_helper(self.session, list(zip(statements, parameters))) + self.assertEqual(num_statements, len(results)) + self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results) + + def test_execute_concurrent_with_args(self): + for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201): + statement = SimpleStatement( + "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", + consistency_level=ConsistencyLevel.QUORUM) + parameters = [(i, i) for i in range(num_statements)] + + results = self.execute_concurrent_args_helper(self.session, statement, parameters) + self.assertEqual(num_statements, len(results)) + for success, result in results: + self.assertTrue(success) + self.assertFalse(result) + + # read + statement = SimpleStatement( + "SELECT v FROM test3rf.test WHERE k=%s", + consistency_level=ConsistencyLevel.QUORUM) + parameters = [(i, ) for i in range(num_statements)] + + results = self.execute_concurrent_args_helper(self.session, statement, parameters) + self.assertEqual(num_statements, len(results)) + self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results) + + def test_execute_concurrent_with_args_generator(self): + """ + Test to validate that generator based results are surfaced correctly + + Repeatedly inserts data into a a table and attempts to query it. It then validates that the + results are returned in the order expected + + @since 2.7.0 + @jira_ticket PYTHON-123 + @expected_result all data should be returned in order. + + @test_category queries:async + """ + for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201): + statement = SimpleStatement( + "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", + consistency_level=ConsistencyLevel.QUORUM) + parameters = [(i, i) for i in range(num_statements)] + + results = self.execute_concurrent_args_helper(self.session, statement, parameters, results_generator=True) + for success, result in results: + self.assertTrue(success) + self.assertFalse(result) + + results = self.execute_concurrent_args_helper(self.session, statement, parameters, results_generator=True) + for result in results: + self.assertTrue(isinstance(result, ExecutionResult)) + self.assertTrue(result.success) + self.assertFalse(result.result_or_exc) + + # read + statement = SimpleStatement( + "SELECT v FROM test3rf.test WHERE k=%s", + consistency_level=ConsistencyLevel.QUORUM) + parameters = [(i, ) for i in range(num_statements)] + + results = self.execute_concurrent_args_helper(self.session, statement, parameters, results_generator=True) + + for i in range(num_statements): + result = next(results) + self.assertEqual((True, [(i,)]), result) + self.assertRaises(StopIteration, next, results) + + def test_execute_concurrent_paged_result(self): + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest( + "Protocol 2+ is required for Paging, currently testing against %r" + % (PROTOCOL_VERSION,)) + + num_statements = 201 + statement = SimpleStatement( + "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", + consistency_level=ConsistencyLevel.QUORUM) + parameters = [(i, i) for i in range(num_statements)] + + results = self.execute_concurrent_args_helper(self.session, statement, parameters) + self.assertEqual(num_statements, len(results)) + for success, result in results: + self.assertTrue(success) + self.assertFalse(result) + + # read + statement = SimpleStatement( + "SELECT * FROM test3rf.test LIMIT %s", + consistency_level=ConsistencyLevel.QUORUM, + fetch_size=int(num_statements / 2)) + + results = self.execute_concurrent_args_helper(self.session, statement, [(num_statements,)]) + self.assertEqual(1, len(results)) + self.assertTrue(results[0][0]) + result = results[0][1] + self.assertTrue(result.has_more_pages) + self.assertEqual(num_statements, sum(1 for _ in result)) + + def test_execute_concurrent_paged_result_generator(self): + """ + Test to validate that generator based results are surfaced correctly when paging is used + + Inserts data into a a table and attempts to query it. It then validates that the + results are returned as expected (no order specified) + + @since 2.7.0 + @jira_ticket PYTHON-123 + @expected_result all data should be returned in order. + + @test_category paging + """ + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest( + "Protocol 2+ is required for Paging, currently testing against %r" + % (PROTOCOL_VERSION,)) + + num_statements = 201 + statement = SimpleStatement( + "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", + consistency_level=ConsistencyLevel.QUORUM) + parameters = [(i, i) for i in range(num_statements)] + + results = self.execute_concurrent_args_helper(self.session, statement, parameters, results_generator=True) + self.assertEqual(num_statements, sum(1 for _ in results)) + + # read + statement = SimpleStatement( + "SELECT * FROM test3rf.test LIMIT %s", + consistency_level=ConsistencyLevel.QUORUM, + fetch_size=int(num_statements / 2)) + + paged_results_gen = self.execute_concurrent_args_helper(self.session, statement, [(num_statements,)], results_generator=True) + + # iterate over all the result and make sure we find the correct number. + found_results = 0 + for result_tuple in paged_results_gen: + paged_result = result_tuple[1] + for _ in paged_result: + found_results += 1 + + self.assertEqual(found_results, num_statements) + + def test_first_failure(self): + statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", )) + parameters = [(i, i) for i in range(100)] + + # we'll get an error back from the server + parameters[57] = ('efefef', 'awefawefawef') + + self.assertRaises( + InvalidRequest, + execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True) + + def test_first_failure_client_side(self): + statement = SimpleStatement( + "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", + consistency_level=ConsistencyLevel.QUORUM) + statements = cycle((statement, )) + parameters = [(i, i) for i in range(100)] + + # the driver will raise an error when binding the params + parameters[57] = 1 + + self.assertRaises( + TypeError, + execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True) + + def test_no_raise_on_first_failure(self): + statement = SimpleStatement( + "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", + consistency_level=ConsistencyLevel.QUORUM) + statements = cycle((statement, )) + parameters = [(i, i) for i in range(100)] + + # we'll get an error back from the server + parameters[57] = ('efefef', 'awefawefawef') + + results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False) + for i, (success, result) in enumerate(results): + if i == 57: + self.assertFalse(success) + self.assertIsInstance(result, InvalidRequest) + else: + self.assertTrue(success) + self.assertFalse(result) + + def test_no_raise_on_first_failure_client_side(self): + statement = SimpleStatement( + "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", + consistency_level=ConsistencyLevel.QUORUM) + statements = cycle((statement, )) + parameters = [(i, i) for i in range(100)] + + # the driver will raise an error when binding the params + parameters[57] = 1 + + results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False) + for i, (success, result) in enumerate(results): + if i == 57: + self.assertFalse(success) + self.assertIsInstance(result, TypeError) + else: + self.assertTrue(success) + self.assertFalse(result) diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py new file mode 100644 index 0000000..595fc12 --- /dev/null +++ b/tests/integration/standard/test_connection.py @@ -0,0 +1,462 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from functools import partial +import logging +from six.moves import range +import sys +import threading +from threading import Thread, Event +import time +from unittest import SkipTest + +from cassandra import ConsistencyLevel, OperationTimedOut +from cassandra.cluster import NoHostAvailable, ConnectionShutdown, Cluster +import cassandra.io.asyncorereactor +from cassandra.io.asyncorereactor import AsyncoreConnection +from cassandra.protocol import QueryMessage +from cassandra.connection import Connection +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy, HostStateListener +from cassandra.pool import HostConnectionPool + +from tests import is_monkey_patched +from tests.integration import use_singledc, PROTOCOL_VERSION, get_node, CASSANDRA_IP, local, \ + requiresmallclockgranularity, greaterthancass20 +try: + from cassandra.io.libevreactor import LibevConnection + import cassandra.io.libevreactor +except ImportError: + LibevConnection = None + + +log = logging.getLogger(__name__) + + +def setup_module(): + use_singledc() + + +class ConnectionTimeoutTest(unittest.TestCase): + + def setUp(self): + self.defaultInFlight = Connection.max_in_flight + Connection.max_in_flight = 2 + self.cluster = Cluster( + protocol_version=PROTOCOL_VERSION, + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), predicate=lambda host: host.address == CASSANDRA_IP + ) + ) + self.session = self.cluster.connect() + + def tearDown(self): + Connection.max_in_flight = self.defaultInFlight + self.cluster.shutdown() + + def test_in_flight_timeout(self): + """ + Test to ensure that connection id fetching will block when max_id is reached/ + + In previous versions of the driver this test will cause a + NoHostAvailable exception to be thrown, when the max_id is restricted + + @since 3.3 + @jira_ticket PYTHON-514 + @expected_result When many requests are run on a single node connection acquisition should block + until connection is available or the request times out. + + @test_category connection timeout + """ + futures = [] + query = '''SELECT * FROM system.local''' + for i in range(100): + futures.append(self.session.execute_async(query)) + + for future in futures: + future.result() + + +class TestHostListener(HostStateListener): + host_down = None + + def on_down(self, host): + self.host_down = True + + def on_up(self, host): + self.host_down = False + + +class HeartbeatTest(unittest.TestCase): + """ + Test to validate failing a heartbeat check doesn't mark a host as down + + @since 3.3 + @jira_ticket PYTHON-286 + @expected_result host should be marked down when heartbeat fails. This + happens after PYTHON-734 + + @test_category connection heartbeat + """ + + def setUp(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=1) + self.session = self.cluster.connect(wait_for_all_pools=True) + + def tearDown(self): + self.cluster.shutdown() + + @local + @greaterthancass20 + def test_heart_beat_timeout(self): + # Setup a host listener to ensure the nodes don't go down + test_listener = TestHostListener() + host = "127.0.0.1:9042" + node = get_node(1) + initial_connections = self.fetch_connections(host, self.cluster) + self.assertNotEqual(len(initial_connections), 0) + self.cluster.register_listener(test_listener) + # Pause the node + try: + node.pause() + # Wait for connections associated with this host go away + self.wait_for_no_connections(host, self.cluster) + + # Wait to seconds for the driver to be notified + time.sleep(2) + self.assertTrue(test_listener.host_down) + # Resume paused node + finally: + node.resume() + # Run a query to ensure connections are re-established + current_host = "" + count = 0 + while current_host != host and count < 100: + rs = self.session.execute_async("SELECT * FROM system.local", trace=False) + rs.result() + current_host = str(rs._current_host) + count += 1 + time.sleep(.1) + self.assertLess(count, 100, "Never connected to the first node") + new_connections = self.wait_for_connections(host, self.cluster) + self.assertFalse(test_listener.host_down) + # Make sure underlying new connections don't match previous ones + for connection in initial_connections: + self.assertFalse(connection in new_connections) + + def fetch_connections(self, host, cluster): + # Given a cluster object and host grab all connection associated with that host + connections = [] + holders = cluster.get_connection_holders() + for conn in holders: + if host == str(getattr(conn, 'host', '')): + if isinstance(conn, HostConnectionPool): + if conn._connections is not None and len(conn._connections) > 0: + connections.append(conn._connections) + else: + if conn._connection is not None: + connections.append(conn._connection) + return connections + + def wait_for_connections(self, host, cluster): + retry = 0 + while(retry < 300): + retry += 1 + connections = self.fetch_connections(host, cluster) + if len(connections) is not 0: + return connections + time.sleep(.1) + self.fail("No new connections found") + + def wait_for_no_connections(self, host, cluster): + retry = 0 + while(retry < 100): + retry += 1 + connections = self.fetch_connections(host, cluster) + if len(connections) is 0: + return + time.sleep(.5) + self.fail("Connections never cleared") + + +class ConnectionTests(object): + + klass = None + + def setUp(self): + self.klass.initialize_reactor() + + def get_connection(self, timeout=5): + """ + Helper method to solve automated testing issues within Jenkins. + Officially patched under the 2.0 branch through + 17998ef72a2fe2e67d27dd602b6ced33a58ad8ef, but left as is for the + 1.0 branch due to possible regressions for fixing an + automated testing edge-case. + """ + conn = None + e = None + for i in range(5): + try: + contact_point = CASSANDRA_IP + conn = self.klass.factory(endpoint=contact_point, timeout=timeout, protocol_version=PROTOCOL_VERSION) + break + except (OperationTimedOut, NoHostAvailable, ConnectionShutdown) as e: + continue + + if conn: + return conn + else: + raise e + + def test_single_connection(self): + """ + Test a single connection with sequential requests. + """ + conn = self.get_connection() + query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1" + event = Event() + + def cb(count, *args, **kwargs): + count += 1 + if count >= 10: + conn.close() + event.set() + else: + conn.send_msg( + QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE), + request_id=0, + cb=partial(cb, count)) + + conn.send_msg( + QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE), + request_id=0, + cb=partial(cb, 0)) + event.wait() + + def test_single_connection_pipelined_requests(self): + """ + Test a single connection with pipelined requests. + """ + conn = self.get_connection() + query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1" + responses = [False] * 100 + event = Event() + + def cb(response_list, request_num, *args, **kwargs): + response_list[request_num] = True + if all(response_list): + conn.close() + event.set() + + for i in range(100): + conn.send_msg( + QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE), + request_id=i, + cb=partial(cb, responses, i)) + + event.wait() + + def test_multiple_connections(self): + """ + Test multiple connections with pipelined requests. + """ + conns = [self.get_connection() for i in range(5)] + events = [Event() for i in range(5)] + query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1" + + def cb(event, conn, count, *args, **kwargs): + count += 1 + if count >= 10: + conn.close() + event.set() + else: + conn.send_msg( + QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE), + request_id=count, + cb=partial(cb, event, conn, count)) + + for event, conn in zip(events, conns): + conn.send_msg( + QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE), + request_id=0, + cb=partial(cb, event, conn, 0)) + + for event in events: + event.wait() + + def test_multiple_threads_shared_connection(self): + """ + Test sharing a single connections across multiple threads, + which will result in pipelined requests. + """ + num_requests_per_conn = 25 + num_threads = 5 + event = Event() + + conn = self.get_connection() + query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1" + + def cb(all_responses, thread_responses, request_num, *args, **kwargs): + thread_responses[request_num] = True + if all(map(all, all_responses)): + conn.close() + event.set() + + def send_msgs(all_responses, thread_responses): + for i in range(num_requests_per_conn): + qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + with conn.lock: + request_id = conn.get_request_id() + conn.send_msg(qmsg, request_id, cb=partial(cb, all_responses, thread_responses, i)) + + all_responses = [] + threads = [] + for i in range(num_threads): + thread_responses = [False] * num_requests_per_conn + all_responses.append(thread_responses) + t = Thread(target=send_msgs, args=(all_responses, thread_responses)) + threads.append(t) + + for t in threads: + t.start() + + for t in threads: + t.join() + + event.wait() + + def test_multiple_threads_multiple_connections(self): + """ + Test several threads, each with their own Connection and pipelined + requests. + """ + num_requests_per_conn = 25 + num_conns = 5 + events = [Event() for i in range(5)] + + query = "SELECT keyspace_name FROM system.schema_keyspaces LIMIT 1" + + def cb(conn, event, thread_responses, request_num, *args, **kwargs): + thread_responses[request_num] = True + if all(thread_responses): + conn.close() + event.set() + + def send_msgs(conn, event): + thread_responses = [False] * num_requests_per_conn + for i in range(num_requests_per_conn): + qmsg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + with conn.lock: + request_id = conn.get_request_id() + conn.send_msg(qmsg, request_id, cb=partial(cb, conn, event, thread_responses, i)) + + event.wait() + + threads = [] + for i in range(num_conns): + conn = self.get_connection() + t = Thread(target=send_msgs, args=(conn, events[i])) + threads.append(t) + + for t in threads: + t.start() + + for t in threads: + t.join() + + @requiresmallclockgranularity + def test_connect_timeout(self): + # Underlying socket implementations don't always throw a socket timeout even with min float + # This can be timing sensitive, added retry to ensure failure occurs if it can + max_retry_count = 10 + exception_thrown = False + for i in range(max_retry_count): + start = time.time() + try: + conn = self.get_connection(timeout=sys.float_info.min) + conn.close() + except Exception as e: + end = time.time() + self.assertAlmostEqual(start, end, 1) + exception_thrown = True + break + self.assertTrue(exception_thrown) + + def test_subclasses_share_loop(self): + + if self.klass not in (AsyncoreConnection, LibevConnection): + raise SkipTest + + class C1(self.klass): + pass + + class C2(self.klass): + pass + + clusterC1 = Cluster(connection_class=C1) + clusterC1.connect(wait_for_all_pools=True) + + clusterC2 = Cluster(connection_class=C2) + clusterC2.connect(wait_for_all_pools=True) + self.addCleanup(clusterC1.shutdown) + self.addCleanup(clusterC2.shutdown) + + self.assertEqual(len(get_eventloop_threads(self.event_loop_name)), 1) + + +def get_eventloop_threads(name): + all_threads = list(threading.enumerate()) + log.debug('all threads: {}'.format(all_threads)) + log.debug('all names: {}'.format([thread.name for thread in all_threads])) + event_loops_threads = [thread for thread in all_threads if name == thread.name] + + return event_loops_threads + + +class AsyncoreConnectionTests(ConnectionTests, unittest.TestCase): + + klass = AsyncoreConnection + event_loop_name = "asyncore_cassandra_driver_event_loop" + + def setUp(self): + if is_monkey_patched(): + raise unittest.SkipTest("Can't test asyncore with monkey patching") + ConnectionTests.setUp(self) + + def clean_global_loop(self): + cassandra.io.asyncorereactor._global_loop._cleanup() + cassandra.io.asyncorereactor._global_loop = None + + +class LibevConnectionTests(ConnectionTests, unittest.TestCase): + + klass = LibevConnection + event_loop_name = "event_loop" + + def setUp(self): + if is_monkey_patched(): + raise unittest.SkipTest("Can't test libev with monkey patching") + if LibevConnection is None: + raise unittest.SkipTest( + 'libev does not appear to be installed properly') + ConnectionTests.setUp(self) + + def clean_global_loop(self): + cassandra.io.libevreactor._global_loop._cleanup() + cassandra.io.libevreactor._global_loop = None diff --git a/tests/integration/standard/test_control_connection.py b/tests/integration/standard/test_control_connection.py new file mode 100644 index 0000000..b928cd2 --- /dev/null +++ b/tests/integration/standard/test_control_connection.py @@ -0,0 +1,105 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +# +# +# + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +from cassandra.cluster import Cluster +from cassandra.protocol import ConfigurationException +from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration.datatype_utils import update_datatypes + + +def setup_module(): + use_singledc() + update_datatypes() + + +class ControlConnectionTests(unittest.TestCase): + def setUp(self): + if PROTOCOL_VERSION < 3: + raise unittest.SkipTest( + "Native protocol 3,0+ is required for UDTs using %r" + % (PROTOCOL_VERSION,)) + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + + def tearDown(self): + try: + self.session.execute("DROP KEYSPACE keyspacetodrop ") + except (ConfigurationException): + # we already removed the keyspace. + pass + self.cluster.shutdown() + + def test_drop_keyspace(self): + """ + Test to validate that dropping a keyspace with user defined types doesn't kill the control connection. + + + Creates a keyspace, and populates with a user defined type. It then records the control_connection's id. It + will then drop the keyspace and get the id of the control_connection again. They should be the same. If they are + not dropping the keyspace likely caused the control connection to be rebuilt. + + @since 2.7.0 + @jira_ticket PYTHON-358 + @expected_result the control connection is not killed + + @test_category connection + """ + + self.session = self.cluster.connect() + self.session.execute(""" + CREATE KEYSPACE keyspacetodrop + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + self.session.set_keyspace("keyspacetodrop") + self.session.execute("CREATE TYPE user (age int, name text)") + self.session.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + cc_id_pre_drop = id(self.cluster.control_connection._connection) + self.session.execute("DROP KEYSPACE keyspacetodrop") + cc_id_post_drop = id(self.cluster.control_connection._connection) + self.assertEqual(cc_id_post_drop, cc_id_pre_drop) + + def test_get_control_connection_host(self): + """ + Test to validate Cluster.get_control_connection_host() metadata + + @since 3.5.0 + @jira_ticket PYTHON-583 + @expected_result the control connection metadata should accurately reflect cluster state. + + @test_category metadata + """ + + host = self.cluster.get_control_connection_host() + self.assertEqual(host, None) + + self.session = self.cluster.connect() + cc_host = self.cluster.control_connection._connection.host + + host = self.cluster.get_control_connection_host() + self.assertEqual(host.address, cc_host) + self.assertEqual(host.is_up, True) + + # reconnect and make sure that the new host is reflected correctly + self.cluster.control_connection._reconnect() + new_host = self.cluster.get_control_connection_host() + self.assertNotEqual(host, new_host) + diff --git a/tests/integration/standard/test_custom_payload.py b/tests/integration/standard/test_custom_payload.py new file mode 100644 index 0000000..c68e9ef --- /dev/null +++ b/tests/integration/standard/test_custom_payload.py @@ -0,0 +1,170 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +try: + import unittest2 as unittest +except ImportError: + import unittest + +import six + +from cassandra.query import (SimpleStatement, BatchStatement, BatchType) +from cassandra.cluster import Cluster + +from tests.integration import use_singledc, PROTOCOL_VERSION, local + +def setup_module(): + use_singledc() + +#These test rely on the custom payload being returned but by default C* +#ignores all the payloads. +@local +class CustomPayloadTests(unittest.TestCase): + + def setUp(self): + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest( + "Native protocol 4,0+ is required for custom payloads, currently using %r" + % (PROTOCOL_VERSION,)) + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.session = self.cluster.connect() + + def tearDown(self): + + self.cluster.shutdown() + + def test_custom_query_basic(self): + """ + Test to validate that custom payloads work with simple queries + + creates a simple query and ensures that custom payloads are passed to C*. A custom + query provider is used with C* so we can validate that same custom payloads are sent back + with the results + + + @since 2.6 + @jira_ticket PYTHON-280 + @expected_result valid custom payloads should be sent and received + + @test_category queries:custom_payload + """ + + # Create a simple query statement a + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + # Validate that various types of custom payloads are sent and received okay + self.validate_various_custom_payloads(statement=statement) + + def test_custom_query_batching(self): + """ + Test to validate that custom payloads work with batch queries + + creates a batch query and ensures that custom payloads are passed to C*. A custom + query provider is used with C* so we can validate that same custom payloads are sent back + with the results + + + @since 2.6 + @jira_ticket PYTHON-280 + @expected_result valid custom payloads should be sent and received + + @test_category queries:custom_payload + """ + + # Construct Batch Statement + batch = BatchStatement(BatchType.LOGGED) + for i in range(10): + batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i)) + + # Validate that various types of custom payloads are sent and received okay + self.validate_various_custom_payloads(statement=batch) + + def test_custom_query_prepared(self): + """ + Test to validate that custom payloads work with prepared queries + + creates a batch query and ensures that custom payloads are passed to C*. A custom + query provider is used with C* so we can validate that same custom payloads are sent back + with the results + + + @since 2.6 + @jira_ticket PYTHON-280 + @expected_result valid custom payloads should be sent and received + + @test_category queries:custom_payload + """ + + # Construct prepared statement + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + bound = prepared.bind((1, None)) + + # Validate that various custom payloads are validated correctly + self.validate_various_custom_payloads(statement=bound) + + def validate_various_custom_payloads(self, statement): + """ + This is a utility method that given a statement will attempt + to submit the statement with various custom payloads. It will + validate that the custom payloads are sent and received correctly. + + @param statement The statement to validate the custom queries in conjunction with + """ + + # Simple key value + custom_payload = {'test': b'test_return'} + self.execute_async_validate_custom_payload(statement=statement, custom_payload=custom_payload) + + # no key value + custom_payload = {'': b''} + self.execute_async_validate_custom_payload(statement=statement, custom_payload=custom_payload) + + # Space value + custom_payload = {' ': b' '} + self.execute_async_validate_custom_payload(statement=statement, custom_payload=custom_payload) + + # Long key value pair + key_value = "x" * 10 + custom_payload = {key_value: six.b(key_value)} + self.execute_async_validate_custom_payload(statement=statement, custom_payload=custom_payload) + + # Max supported value key pairs according C* binary protocol v4 should be 65534 (unsigned short max value) + for i in range(65534): + custom_payload[str(i)] = six.b('x') + self.execute_async_validate_custom_payload(statement=statement, custom_payload=custom_payload) + + # Add one custom payload to this is too many key value pairs and should fail + custom_payload[str(65535)] = six.b('x') + with self.assertRaises(ValueError): + self.execute_async_validate_custom_payload(statement=statement, custom_payload=custom_payload) + + def execute_async_validate_custom_payload(self, statement, custom_payload): + """ + This is just a simple method that submits a statement with a payload, and validates + that the custom payload we submitted matches the one that we got back + @param statement The statement to execute + @param custom_payload The custom payload to submit with + """ + + # Submit the statement with our custom payload. Validate the one + # we receive from the server matches + response_future = self.session.execute_async(statement, custom_payload=custom_payload) + response_future.result() + returned_custom_payload = response_future.custom_payload + self.assertEqual(custom_payload, returned_custom_payload) diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py new file mode 100644 index 0000000..e76972b --- /dev/null +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -0,0 +1,234 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.protocol import ProtocolHandler, ResultMessage, QueryMessage, UUIDType, read_int +from cassandra.query import tuple_factory, SimpleStatement +from cassandra.cluster import Cluster, ResponseFuture +from cassandra import ProtocolVersion, ConsistencyLevel + +from tests.integration import use_singledc, PROTOCOL_VERSION, drop_keyspace_shutdown_cluster, \ + greaterthanorequalcass30, execute_with_long_wait_retry +from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES +from tests.integration.standard.utils import create_table_with_all_types, get_all_primitive_params +from six import binary_type + +import uuid +import mock + + +def setup_module(): + use_singledc() + update_datatypes() + + +class CustomProtocolHandlerTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.session = cls.cluster.connect() + cls.session.execute("CREATE KEYSPACE custserdes WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}") + cls.session.set_keyspace("custserdes") + + @classmethod + def tearDownClass(cls): + drop_keyspace_shutdown_cluster("custserdes", cls.session, cls.cluster) + + def test_custom_raw_uuid_row_results(self): + """ + Test to validate that custom protocol handlers work with raw row results + + Connect and validate that the normal protocol handler is used. + Re-Connect and validate that the custom protocol handler is used. + Re-Connect and validate that the normal protocol handler is used. + + @since 2.7 + @jira_ticket PYTHON-313 + @expected_result custom protocol handler is invoked appropriately. + + @test_category data_types:serialization + """ + + # Ensure that we get normal uuid back first + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="custserdes") + session.row_factory = tuple_factory + + result = session.execute("SELECT schema_version FROM system.local") + uuid_type = result[0][0] + self.assertEqual(type(uuid_type), uuid.UUID) + + # use our custom protocol handlder + + session.client_protocol_handler = CustomTestRawRowType + session.row_factory = tuple_factory + result_set = session.execute("SELECT schema_version FROM system.local") + raw_value = result_set[0][0] + self.assertTrue(isinstance(raw_value, binary_type)) + self.assertEqual(len(raw_value), 16) + + # Ensure that we get normal uuid back when we re-connect + session.client_protocol_handler = ProtocolHandler + result_set = session.execute("SELECT schema_version FROM system.local") + uuid_type = result_set[0][0] + self.assertEqual(type(uuid_type), uuid.UUID) + cluster.shutdown() + + def test_custom_raw_row_results_all_types(self): + """ + Test to validate that custom protocol handlers work with varying types of + results + + Connect, create a table with all sorts of data. Query the data, make the sure the custom results handler is + used correctly. + + @since 2.7 + @jira_ticket PYTHON-313 + @expected_result custom protocol handler is invoked with various result types + + @test_category data_types:serialization + """ + # Connect using a custom protocol handler that tracks the various types the result message is used with. + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="custserdes") + session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked + session.row_factory = tuple_factory + + colnames = create_table_with_all_types("alltypes", session, 1) + columns_string = ", ".join(colnames) + + # verify data + params = get_all_primitive_params(0) + results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string))[0] + for expected, actual in zip(params, results): + self.assertEqual(actual, expected) + # Ensure we have covered the various primitive types + self.assertEqual(len(CustomResultMessageTracked.checked_rev_row_set), len(PRIMITIVE_DATATYPES)-1) + cluster.shutdown() + + @greaterthanorequalcass30 + def test_protocol_divergence_v4_fail_by_flag_uses_int(self): + """ + Test to validate that the _PAGE_SIZE_FLAG is not treated correctly in V4 if the flags are + written using write_uint instead of write_int + + @since 3.9 + @jira_ticket PYTHON-713 + @expected_result the fetch_size=1 parameter will be ignored + + @test_category connection + """ + self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.V4, uses_int_query_flag=False, + int_flag=True) + + + def _send_query_message(self, session, timeout, **kwargs): + query = "SELECT * FROM test3rf.test" + message = QueryMessage(query=query, **kwargs) + future = ResponseFuture(session, message, query=None, timeout=timeout) + future.send_request() + return future + + def _protocol_divergence_fail_by_flag_uses_int(self, version, uses_int_query_flag, int_flag = True, beta=False): + cluster = Cluster(protocol_version=version, allow_beta_protocol_version=beta) + session = cluster.connect() + + query_one = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (1, 1)") + query_two = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)") + + execute_with_long_wait_retry(session, query_one) + execute_with_long_wait_retry(session, query_two) + + with mock.patch('cassandra.protocol.ProtocolVersion.uses_int_query_flags', new=mock.Mock(return_value=int_flag)): + future = self._send_query_message(session, 10, + consistency_level=ConsistencyLevel.ONE, fetch_size=1) + + response = future.result() + + # This means the flag are not handled as they are meant by the server if uses_int=False + self.assertEqual(response.has_more_pages, uses_int_query_flag) + + execute_with_long_wait_retry(session, SimpleStatement("TRUNCATE test3rf.test")) + cluster.shutdown() + + +class CustomResultMessageRaw(ResultMessage): + """ + This is a custom Result Message that is used to return raw results, rather then + results which contain objects. + """ + my_type_codes = ResultMessage.type_codes.copy() + my_type_codes[0xc] = UUIDType + type_codes = my_type_codes + + @classmethod + def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): + paging_state, column_metadata, result_metadata_id = cls.recv_results_metadata(f, user_type_map) + rowcount = read_int(f) + rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] + colnames = [c[2] for c in column_metadata] + coltypes = [c[3] for c in column_metadata] + return paging_state, coltypes, (colnames, rows), result_metadata_id + + +class CustomTestRawRowType(ProtocolHandler): + """ + This is the a custom protocol handler that will substitute the the + customResultMesageRowRaw Result message for our own implementation + """ + my_opcodes = ProtocolHandler.message_types_by_opcode.copy() + my_opcodes[CustomResultMessageRaw.opcode] = CustomResultMessageRaw + message_types_by_opcode = my_opcodes + + +class CustomResultMessageTracked(ResultMessage): + """ + This is a custom Result Message that is use to track what primitive types + have been processed when it receives results + """ + my_type_codes = ResultMessage.type_codes.copy() + my_type_codes[0xc] = UUIDType + type_codes = my_type_codes + checked_rev_row_set = set() + + @classmethod + def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): + paging_state, column_metadata, result_metadata_id = cls.recv_results_metadata(f, user_type_map) + rowcount = read_int(f) + rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] + colnames = [c[2] for c in column_metadata] + coltypes = [c[3] for c in column_metadata] + cls.checked_rev_row_set.update(coltypes) + parsed_rows = [ + tuple(ctype.from_binary(val, protocol_version) + for ctype, val in zip(coltypes, row)) + for row in rows] + return paging_state, coltypes, (colnames, parsed_rows), result_metadata_id + + +class CustomProtocolHandlerResultMessageTracked(ProtocolHandler): + """ + This is the a custom protocol handler that will substitute the the + CustomTestRawRowTypeTracked Result message for our own implementation + """ + my_opcodes = ProtocolHandler.message_types_by_opcode.copy() + my_opcodes[CustomResultMessageTracked.opcode] = CustomResultMessageTracked + message_types_by_opcode = my_opcodes + + diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py new file mode 100644 index 0000000..593dcba --- /dev/null +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -0,0 +1,258 @@ +"""Test the various Cython-based message deserializers""" + +# Based on test_custom_protocol_handler.py + +try: + import unittest2 as unittest +except ImportError: + import unittest + +from itertools import count + +from cassandra.query import tuple_factory +from cassandra.cluster import Cluster, NoHostAvailable +from cassandra.concurrent import execute_concurrent_with_args +from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler +from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY +from tests import VERIFY_CYTHON +from tests.integration import use_singledc, PROTOCOL_VERSION, notprotocolv1, drop_keyspace_shutdown_cluster, BasicSharedKeyspaceUnitTestCase, execute_with_retry_tolerant, greaterthancass21 +from tests.integration.datatype_utils import update_datatypes +from tests.integration.standard.utils import ( + create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes) + +from tests.unit.cython.utils import cythontest, numpytest + + +def setup_module(): + use_singledc() + update_datatypes() + + +class CythonProtocolHandlerTest(unittest.TestCase): + + N_ITEMS = 10 + + @classmethod + def setUpClass(cls): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.session = cls.cluster.connect() + cls.session.execute("CREATE KEYSPACE testspace WITH replication = " + "{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}") + cls.session.set_keyspace("testspace") + cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS) + + @classmethod + def tearDownClass(cls): + drop_keyspace_shutdown_cluster("testspace", cls.session, cls.cluster) + + @cythontest + def test_cython_parser(self): + """ + Test Cython-based parser that returns a list of tuples + """ + verify_iterator_data(self.assertEqual, get_data(ProtocolHandler)) + + @cythontest + def test_cython_lazy_parser(self): + """ + Test Cython-based parser that returns an iterator of tuples + """ + verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler)) + + @notprotocolv1 + @numpytest + def test_cython_lazy_results_paged(self): + """ + Test Cython-based parser that returns an iterator, over multiple pages + """ + # arrays = { 'a': arr1, 'b': arr2, ... } + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="testspace") + session.row_factory = tuple_factory + session.client_protocol_handler = LazyProtocolHandler + session.default_fetch_size = 2 + + self.assertLess(session.default_fetch_size, self.N_ITEMS) + + results = session.execute("SELECT * FROM test_table") + + self.assertTrue(results.has_more_pages) + self.assertEqual(verify_iterator_data(self.assertEqual, results), self.N_ITEMS) # make sure we see all rows + + cluster.shutdown() + + @notprotocolv1 + @numpytest + def test_numpy_parser(self): + """ + Test Numpy-based parser that returns a NumPy array + """ + # arrays = { 'a': arr1, 'b': arr2, ... } + result = get_data(NumpyProtocolHandler) + self.assertFalse(result.has_more_pages) + self._verify_numpy_page(result[0]) + + @notprotocolv1 + @numpytest + def test_numpy_results_paged(self): + """ + Test Numpy-based parser that returns a NumPy array + """ + # arrays = { 'a': arr1, 'b': arr2, ... } + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="testspace") + session.row_factory = tuple_factory + session.client_protocol_handler = NumpyProtocolHandler + session.default_fetch_size = 2 + + expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size + + self.assertLess(session.default_fetch_size, self.N_ITEMS) + + results = session.execute("SELECT * FROM test_table") + + self.assertTrue(results.has_more_pages) + for count, page in enumerate(results, 1): + self.assertIsInstance(page, dict) + for colname, arr in page.items(): + if count <= expected_pages: + self.assertGreater(len(arr), 0, "page count: %d" % (count,)) + self.assertLessEqual(len(arr), session.default_fetch_size) + else: + # we get one extra item out of this iteration because of the way NumpyParser returns results + # The last page is returned as a dict with zero-length arrays + self.assertEqual(len(arr), 0) + self.assertEqual(self._verify_numpy_page(page), len(arr)) + self.assertEqual(count, expected_pages + 1) # see note about extra 'page' above + + cluster.shutdown() + + @numpytest + def test_cython_numpy_are_installed_valid(self): + """ + Test to validate that cython and numpy are installed correctly + @since 3.3.0 + @jira_ticket PYTHON-543 + @expected_result Cython and Numpy should be present + + @test_category configuration + """ + if VERIFY_CYTHON: + self.assertTrue(HAVE_CYTHON) + self.assertTrue(HAVE_NUMPY) + + def _verify_numpy_page(self, page): + colnames = self.colnames + datatypes = get_primitive_datatypes() + for colname, datatype in zip(colnames, datatypes): + arr = page[colname] + self.match_dtype(datatype, arr.dtype) + + return verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(page, colnames)) + + def match_dtype(self, datatype, dtype): + """Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype""" + if datatype == 'smallint': + self.match_dtype_props(dtype, 'i', 2) + elif datatype == 'int': + self.match_dtype_props(dtype, 'i', 4) + elif datatype in ('bigint', 'counter'): + self.match_dtype_props(dtype, 'i', 8) + elif datatype == 'float': + self.match_dtype_props(dtype, 'f', 4) + elif datatype == 'double': + self.match_dtype_props(dtype, 'f', 8) + else: + self.assertEqual(dtype.kind, 'O', msg=(dtype, datatype)) + + def match_dtype_props(self, dtype, kind, size, signed=None): + self.assertEqual(dtype.kind, kind, msg=dtype) + self.assertEqual(dtype.itemsize, size, msg=dtype) + + +def arrays_to_list_of_tuples(arrays, colnames): + """Convert a dict of arrays (as given by the numpy protocol handler) to a list of tuples""" + first_array = arrays[colnames[0]] + return [tuple(arrays[colname][i] for colname in colnames) + for i in range(len(first_array))] + + +def get_data(protocol_handler): + """ + Get data from the test table. + """ + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="testspace") + + # use our custom protocol handler + session.client_protocol_handler = protocol_handler + session.row_factory = tuple_factory + + results = session.execute("SELECT * FROM test_table") + cluster.shutdown() + return results + + +def verify_iterator_data(assertEqual, results): + """ + Check the result of get_data() when this is a list or + iterator of tuples + """ + count = 0 + for count, result in enumerate(results, 1): + params = get_all_primitive_params(result[0]) + assertEqual(len(params), len(result), + msg="Not the right number of columns?") + for expected, actual in zip(params, result): + assertEqual(actual, expected) + return count + + +class NumpyNullTest(BasicSharedKeyspaceUnitTestCase): + + @numpytest + @greaterthancass21 + def test_null_types(self): + """ + Test to validate that the numpy protocol handler can deal with null values. + @since 3.3.0 + - updated 3.6.0: now numeric types used masked array + @jira_ticket PYTHON-550 + @expected_result Numpy can handle non mapped types' null values. + + @test_category data_types:serialization + """ + s = self.session + s.row_factory = tuple_factory + s.client_protocol_handler = NumpyProtocolHandler + + table = "%s.%s" % (self.keyspace_name, self.function_table_name) + create_table_with_all_types(table, s, 10) + + begin_unset = max(s.execute('select primkey from %s' % (table,))[0]['primkey']) + 1 + keys_null = range(begin_unset, begin_unset + 10) + + # scatter some emptry rows in here + insert = "insert into %s (primkey) values (%%s)" % (table,) + execute_concurrent_with_args(s, insert, ((k,) for k in keys_null)) + + result = s.execute("select * from %s" % (table,))[0] + + from numpy.ma import masked, MaskedArray + result_keys = result.pop('primkey') + mapped_index = [v[1] for v in sorted(zip(result_keys, count()))] + + had_masked = had_none = False + for col_array in result.values(): + # these have to be different branches (as opposed to comparing against an 'unset value') + # because None and `masked` have different identity and equals semantics + if isinstance(col_array, MaskedArray): + had_masked = True + [self.assertIsNot(col_array[i], masked) for i in mapped_index[:begin_unset]] + [self.assertIs(col_array[i], masked) for i in mapped_index[begin_unset:]] + else: + had_none = True + [self.assertIsNotNone(col_array[i]) for i in mapped_index[:begin_unset]] + [self.assertIsNone(col_array[i]) for i in mapped_index[begin_unset:]] + self.assertTrue(had_masked) + self.assertTrue(had_none) diff --git a/tests/integration/standard/test_dse.py b/tests/integration/standard/test_dse.py new file mode 100644 index 0000000..a8a3d64 --- /dev/null +++ b/tests/integration/standard/test_dse.py @@ -0,0 +1,98 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from packaging.version import Version + +from cassandra.cluster import Cluster +from tests import notwindows +from tests.unit.cython.utils import notcython +from tests.integration import (execute_until_pass, + execute_with_long_wait_retry, use_cluster) + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +CCM_IS_DSE = (os.environ.get('CCM_IS_DSE', None) == 'true') + + +@unittest.skipIf(os.environ.get('CCM_ARGS', None), 'environment has custom CCM_ARGS; skipping') +@notwindows +@notcython # no need to double up on this test; also __default__ setting doesn't work +class DseCCMClusterTest(unittest.TestCase): + """ + This class can be executed setting the DSE_VERSION variable, for example: + DSE_VERSION=5.1.4 python2.7 -m nose tests/integration/standard/test_dse.py + If CASSANDRA_VERSION is set instead, it will be converted to the corresponding DSE_VERSION + """ + + def test_dse_5x(self): + self._test_basic(Version('5.1.10')) + + def test_dse_60(self): + self._test_basic(Version('6.0.2')) + + @unittest.skipUnless(CCM_IS_DSE, 'DSE version unavailable') + def test_dse_67(self): + self._test_basic(Version('6.7.0')) + + def _test_basic(self, dse_version): + """ + Test basic connection and usage + """ + cluster_name = '{}-{}'.format( + self.__class__.__name__, dse_version.base_version.replace('.', '_') + ) + use_cluster(cluster_name=cluster_name, nodes=[3], + dse_cluster=True, dse_options={}, dse_version=dse_version) + + cluster = Cluster( + allow_beta_protocol_version=(dse_version >= Version('6.7.0'))) + session = cluster.connect() + result = execute_until_pass( + session, + """ + CREATE KEYSPACE clustertests + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + """) + self.assertFalse(result) + + result = execute_with_long_wait_retry( + session, + """ + CREATE TABLE clustertests.cf0 ( + a text, + b text, + c text, + PRIMARY KEY (a, b) + ) + """) + self.assertFalse(result) + + result = session.execute( + """ + INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c') + """) + self.assertFalse(result) + + result = session.execute("SELECT * FROM clustertests.cf0") + self.assertEqual([('a', 'b', 'c')], result) + + execute_with_long_wait_retry(session, "DROP KEYSPACE clustertests") + + cluster.shutdown() diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py new file mode 100644 index 0000000..4fb7ebf --- /dev/null +++ b/tests/integration/standard/test_metadata.py @@ -0,0 +1,2481 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from collections import defaultdict +import difflib +import logging +import six +import sys +import time +import os +from packaging.version import Version +from mock import Mock, patch + +from cassandra import AlreadyExists, SignatureDescriptor, UserFunctionDescriptor, UserAggregateDescriptor + +from cassandra.cluster import Cluster +from cassandra.encoder import Encoder +from cassandra.metadata import (IndexMetadata, Token, murmur3, Function, Aggregate, protect_name, protect_names, + RegisteredTableExtension, _RegisteredExtensionType, get_schema_parser, + group_keys_by_replica, NO_VALID_REPLICA) + +from tests.integration import (get_cluster, use_singledc, PROTOCOL_VERSION, execute_until_pass, + BasicSegregatedKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase, + BasicExistingKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster, CASSANDRA_VERSION, + get_supported_protocol_versions, greaterthanorequalcass30, lessthancass30, local, + greaterthancass20, greaterthanorequalcass40) + +from tests.integration import greaterthancass21 + + +log = logging.getLogger(__name__) + + +def setup_module(): + use_singledc() + + +class HostMetatDataTests(BasicExistingKeyspaceUnitTestCase): + @local + def test_broadcast_listen_address(self): + """ + Check to ensure that the broadcast, rpc_address, listen adresss and host are is populated correctly + + @since 3.3 + @jira_ticket PYTHON-332 + @expected_result They are populated for C*> 2.1.6, 2.2.0 + + @test_category metadata + """ + # All nodes should have the broadcast_address, rpc_address and host_id set + for host in self.cluster.metadata.all_hosts(): + self.assertIsNotNone(host.broadcast_address) + self.assertIsNotNone(host.broadcast_rpc_address) + self.assertIsNotNone(host.host_id) + con = self.cluster.control_connection.get_connections()[0] + local_host = con.host + + # The control connection node should have the listen address set. + listen_addrs = [host.listen_address for host in self.cluster.metadata.all_hosts()] + self.assertTrue(local_host in listen_addrs) + + # The control connection node should have the broadcast_rpc_address set. + rpc_addrs = [host.broadcast_rpc_address for host in self.cluster.metadata.all_hosts()] + self.assertTrue(local_host in rpc_addrs) + + @unittest.skipUnless( + os.getenv('MAPPED_CASSANDRA_VERSION', None) is None, + "Don't check the host version for test-dse") + def test_host_release_version(self): + """ + Checks the hosts release version and validates that it is equal to the + Cassandra version we are using in our test harness. + + @since 3.3 + @jira_ticket PYTHON-301 + @expected_result host.release version should match our specified Cassandra version. + + @test_category metadata + """ + for host in self.cluster.metadata.all_hosts(): + self.assertTrue(host.release_version.startswith(CASSANDRA_VERSION.base_version)) + + +@local +class MetaDataRemovalTest(unittest.TestCase): + + def setUp(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, contact_points=['127.0.0.1', '127.0.0.2', '127.0.0.3', '126.0.0.186']) + self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + def test_bad_contact_point(self): + """ + Checks to ensure that hosts that are not resolvable are excluded from the contact point list. + + @since 3.6 + @jira_ticket PYTHON-549 + @expected_result Invalid hosts on the contact list should be excluded + + @test_category metadata + """ + self.assertEqual(len(self.cluster.metadata.all_hosts()), 3) + + +class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase): + + def test_schema_metadata_disable(self): + """ + Checks to ensure that schema metadata_enabled, and token_metadata_enabled + flags work correctly. + + @since 3.3 + @jira_ticket PYTHON-327 + @expected_result schema metadata will not be populated when schema_metadata_enabled is fause + token_metadata will be missing when token_metadata is set to false + + @test_category metadata + """ + # Validate metadata is missing where appropriate + no_schema = Cluster(schema_metadata_enabled=False) + no_schema_session = no_schema.connect() + self.assertEqual(len(no_schema.metadata.keyspaces), 0) + self.assertEqual(no_schema.metadata.export_schema_as_string(), '') + no_token = Cluster(token_metadata_enabled=False) + no_token_session = no_token.connect() + self.assertEqual(len(no_token.metadata.token_map.token_to_host_owner), 0) + + # Do a simple query to ensure queries are working + query = "SELECT * FROM system.local" + no_schema_rs = no_schema_session.execute(query) + no_token_rs = no_token_session.execute(query) + self.assertIsNotNone(no_schema_rs[0]) + self.assertIsNotNone(no_token_rs[0]) + no_schema.shutdown() + no_token.shutdown() + + def make_create_statement(self, partition_cols, clustering_cols=None, other_cols=None): + clustering_cols = clustering_cols or [] + other_cols = other_cols or [] + + statement = "CREATE TABLE %s.%s (" % (self.keyspace_name, self.function_table_name) + if len(partition_cols) == 1 and not clustering_cols: + statement += "%s text PRIMARY KEY, " % protect_name(partition_cols[0]) + else: + statement += ", ".join("%s text" % protect_name(col) for col in partition_cols) + statement += ", " + + statement += ", ".join("%s text" % protect_name(col) for col in clustering_cols + other_cols) + + if len(partition_cols) != 1 or clustering_cols: + statement += ", PRIMARY KEY (" + + if len(partition_cols) > 1: + statement += "(" + ", ".join(protect_names(partition_cols)) + ")" + else: + statement += protect_name(partition_cols[0]) + + if clustering_cols: + statement += ", " + statement += ", ".join(protect_names(clustering_cols)) + + statement += ")" + + statement += ")" + + return statement + + def check_create_statement(self, tablemeta, original): + recreate = tablemeta.as_cql_query(formatted=False) + self.assertEqual(original, recreate[:len(original)]) + execute_until_pass(self.session, "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name)) + execute_until_pass(self.session, recreate) + + # create the table again, but with formatting enabled + execute_until_pass(self.session, "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name)) + recreate = tablemeta.as_cql_query(formatted=True) + execute_until_pass(self.session, recreate) + + def get_table_metadata(self): + self.cluster.refresh_table_metadata(self.keyspace_name, self.function_table_name) + return self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name] + + def test_basic_table_meta_properties(self): + create_statement = self.make_create_statement(["a"], [], ["b", "c"]) + self.session.execute(create_statement) + + self.cluster.refresh_schema_metadata() + + meta = self.cluster.metadata + self.assertNotEqual(meta.cluster_name, None) + self.assertTrue(self.keyspace_name in meta.keyspaces) + ksmeta = meta.keyspaces[self.keyspace_name] + + self.assertEqual(ksmeta.name, self.keyspace_name) + self.assertTrue(ksmeta.durable_writes) + self.assertEqual(ksmeta.replication_strategy.name, 'SimpleStrategy') + self.assertEqual(ksmeta.replication_strategy.replication_factor, 1) + + self.assertTrue(self.function_table_name in ksmeta.tables) + tablemeta = ksmeta.tables[self.function_table_name] + self.assertEqual(tablemeta.keyspace_name, ksmeta.name) + self.assertEqual(tablemeta.name, self.function_table_name) + self.assertEqual(tablemeta.name, self.function_table_name) + + self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([], tablemeta.clustering_key) + self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) + + cc = self.cluster.control_connection._connection + parser = get_schema_parser(cc, CASSANDRA_VERSION.base_version, 1) + + for option in tablemeta.options: + self.assertIn(option, parser.recognized_table_options) + + self.check_create_statement(tablemeta, create_statement) + + def test_compound_primary_keys(self): + create_statement = self.make_create_statement(["a"], ["b"], ["c"]) + create_statement += " WITH CLUSTERING ORDER BY (b ASC)" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([u'b'], [c.name for c in tablemeta.clustering_key]) + self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) + + self.check_create_statement(tablemeta, create_statement) + + def test_compound_primary_keys_protected(self): + create_statement = self.make_create_statement(["Aa"], ["Bb"], ["Cc"]) + create_statement += ' WITH CLUSTERING ORDER BY ("Bb" ASC)' + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'Aa'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([u'Bb'], [c.name for c in tablemeta.clustering_key]) + self.assertEqual([u'Aa', u'Bb', u'Cc'], sorted(tablemeta.columns.keys())) + + self.check_create_statement(tablemeta, create_statement) + + def test_compound_primary_keys_more_columns(self): + create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) + create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([u'b', u'c'], [c.name for c in tablemeta.clustering_key]) + self.assertEqual( + [u'a', u'b', u'c', u'd', u'e', u'f'], + sorted(tablemeta.columns.keys())) + + self.check_create_statement(tablemeta, create_statement) + + def test_composite_primary_key(self): + create_statement = self.make_create_statement(["a", "b"], [], ["c"]) + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'a', u'b'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([], tablemeta.clustering_key) + self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) + + self.check_create_statement(tablemeta, create_statement) + + def test_composite_in_compound_primary_key(self): + create_statement = self.make_create_statement(["a", "b"], ["c"], ["d", "e"]) + create_statement += " WITH CLUSTERING ORDER BY (c ASC)" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'a', u'b'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([u'c'], [c.name for c in tablemeta.clustering_key]) + self.assertEqual([u'a', u'b', u'c', u'd', u'e'], sorted(tablemeta.columns.keys())) + + self.check_create_statement(tablemeta, create_statement) + + def test_compound_primary_keys_compact(self): + create_statement = self.make_create_statement(["a"], ["b"], ["c"]) + create_statement += " WITH CLUSTERING ORDER BY (b ASC)" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([u'b'], [c.name for c in tablemeta.clustering_key]) + self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) + + self.check_create_statement(tablemeta, create_statement) + + def test_cluster_column_ordering_reversed_metadata(self): + """ + Simple test to ensure that the metatdata associated with cluster ordering is surfaced is surfaced correctly. + + Creates a table with a few clustering keys. Then checks the clustering order associated with clustering columns + and ensure it's set correctly. + @since 3.0.0 + @jira_ticket PYTHON-402 + @expected_result is_reversed is set on DESC order, and is False on ASC + + @test_category metadata + """ + + create_statement = self.make_create_statement(["a"], ["b", "c"], ["d"]) + create_statement += " WITH CLUSTERING ORDER BY (b ASC, c DESC)" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + b_column = tablemeta.columns['b'] + self.assertFalse(b_column.is_reversed) + c_column = tablemeta.columns['c'] + self.assertTrue(c_column.is_reversed) + + def test_compound_primary_keys_more_columns_compact(self): + create_statement = self.make_create_statement(["a"], ["b", "c"], ["d"]) + create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([u'b', u'c'], [c.name for c in tablemeta.clustering_key]) + self.assertEqual([u'a', u'b', u'c', u'd'], sorted(tablemeta.columns.keys())) + + self.check_create_statement(tablemeta, create_statement) + + def test_composite_primary_key_compact(self): + create_statement = self.make_create_statement(["a", "b"], [], ["c"]) + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'a', u'b'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([], tablemeta.clustering_key) + self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) + + self.check_create_statement(tablemeta, create_statement) + + def test_composite_in_compound_primary_key_compact(self): + create_statement = self.make_create_statement(["a", "b"], ["c"], ["d"]) + create_statement += " WITH CLUSTERING ORDER BY (c ASC)" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'a', u'b'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([u'c'], [c.name for c in tablemeta.clustering_key]) + self.assertEqual([u'a', u'b', u'c', u'd'], sorted(tablemeta.columns.keys())) + + self.check_create_statement(tablemeta, create_statement) + + @lessthancass30 + def test_cql_compatibility(self): + + # having more than one non-PK column is okay if there aren't any + # clustering columns + create_statement = self.make_create_statement(["a"], [], ["b", "c", "d"]) + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([], tablemeta.clustering_key) + self.assertEqual([u'a', u'b', u'c', u'd'], sorted(tablemeta.columns.keys())) + + self.assertTrue(tablemeta.is_cql_compatible) + + # It will be cql compatible after CASSANDRA-10857 + # since compact storage is being dropped + tablemeta.clustering_key = ["foo", "bar"] + tablemeta.columns["foo"] = None + tablemeta.columns["bar"] = None + self.assertTrue(tablemeta.is_cql_compatible) + + def test_compound_primary_keys_ordering(self): + create_statement = self.make_create_statement(["a"], ["b"], ["c"]) + create_statement += " WITH CLUSTERING ORDER BY (b DESC)" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + def test_compound_primary_keys_more_columns_ordering(self): + create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) + create_statement += " WITH CLUSTERING ORDER BY (b DESC, c ASC)" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + def test_composite_in_compound_primary_key_ordering(self): + create_statement = self.make_create_statement(["a", "b"], ["c"], ["d", "e"]) + create_statement += " WITH CLUSTERING ORDER BY (c DESC)" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + def test_compact_storage(self): + create_statement = self.make_create_statement(["a"], [], ["b"]) + create_statement += " WITH COMPACT STORAGE" + + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + def test_dense_compact_storage(self): + create_statement = self.make_create_statement(["a"], ["b"], ["c"]) + create_statement += " WITH COMPACT STORAGE" + + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + def test_counter(self): + create_statement = ( + "CREATE TABLE {keyspace}.{table} (" + "key text PRIMARY KEY, a1 counter)" + ).format(keyspace=self.keyspace_name, table=self.function_table_name) + + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + def test_counter_with_compact_storage(self): + """ PYTHON-1100 """ + create_statement = ( + "CREATE TABLE {keyspace}.{table} (" + "key text PRIMARY KEY, a1 counter) WITH COMPACT STORAGE" + ).format(keyspace=self.keyspace_name, table=self.function_table_name) + + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + def test_counter_with_dense_compact_storage(self): + create_statement = ( + "CREATE TABLE {keyspace}.{table} (" + "key text, c1 text, a1 counter, PRIMARY KEY (key, c1)) WITH COMPACT STORAGE" + ).format(keyspace=self.keyspace_name, table=self.function_table_name) + + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + def test_indexes(self): + create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) + create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)" + execute_until_pass(self.session, create_statement) + + d_index = "CREATE INDEX d_index ON %s.%s (d)" % (self.keyspace_name, self.function_table_name) + e_index = "CREATE INDEX e_index ON %s.%s (e)" % (self.keyspace_name, self.function_table_name) + execute_until_pass(self.session, d_index) + execute_until_pass(self.session, e_index) + + tablemeta = self.get_table_metadata() + statements = tablemeta.export_as_string().strip() + statements = [s.strip() for s in statements.split(';')] + statements = list(filter(bool, statements)) + self.assertEqual(3, len(statements)) + self.assertIn(d_index, statements) + self.assertIn(e_index, statements) + + # make sure indexes are included in KeyspaceMetadata.export_as_string() + ksmeta = self.cluster.metadata.keyspaces[self.keyspace_name] + statement = ksmeta.export_as_string() + self.assertIn('CREATE INDEX d_index', statement) + self.assertIn('CREATE INDEX e_index', statement) + + @greaterthancass21 + def test_collection_indexes(self): + + self.session.execute("CREATE TABLE %s.%s (a int PRIMARY KEY, b map)" + % (self.keyspace_name, self.function_table_name)) + self.session.execute("CREATE INDEX index1 ON %s.%s (keys(b))" + % (self.keyspace_name, self.function_table_name)) + + tablemeta = self.get_table_metadata() + self.assertIn('(keys(b))', tablemeta.export_as_string()) + + self.session.execute("DROP INDEX %s.index1" % (self.keyspace_name,)) + self.session.execute("CREATE INDEX index2 ON %s.%s (b)" + % (self.keyspace_name, self.function_table_name)) + + tablemeta = self.get_table_metadata() + target = ' (b)' if CASSANDRA_VERSION < Version("3.0") else 'values(b))' # explicit values in C* 3+ + self.assertIn(target, tablemeta.export_as_string()) + + # test full indexes on frozen collections, if available + if CASSANDRA_VERSION >= Version("2.1.3"): + self.session.execute("DROP TABLE %s.%s" % (self.keyspace_name, self.function_table_name)) + self.session.execute("CREATE TABLE %s.%s (a int PRIMARY KEY, b frozen>)" + % (self.keyspace_name, self.function_table_name)) + self.session.execute("CREATE INDEX index3 ON %s.%s (full(b))" + % (self.keyspace_name, self.function_table_name)) + + tablemeta = self.get_table_metadata() + self.assertIn('(full(b))', tablemeta.export_as_string()) + + def test_compression_disabled(self): + create_statement = self.make_create_statement(["a"], ["b"], ["c"]) + create_statement += " WITH compression = {}" + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + expected = "compression = {}" if CASSANDRA_VERSION < Version("3.0") else "compression = {'enabled': 'false'}" + self.assertIn(expected, tablemeta.export_as_string()) + + def test_non_size_tiered_compaction(self): + """ + test options for non-size-tiered compaction strategy + + Creates a table with LeveledCompactionStrategy, specifying one non-default option. Verifies that the option is + present in generated CQL, and that other legacy table parameters (min_threshold, max_threshold) are not included. + + @since 2.6.0 + @jira_ticket PYTHON-352 + @expected_result the options map for LeveledCompactionStrategy does not contain min_threshold, max_threshold + + @test_category metadata + """ + create_statement = self.make_create_statement(["a"], [], ["b", "c"]) + create_statement += "WITH COMPACTION = {'class': 'LeveledCompactionStrategy', 'tombstone_threshold': '0.3'}" + self.session.execute(create_statement) + + table_meta = self.get_table_metadata() + cql = table_meta.export_as_string() + self.assertIn("'tombstone_threshold': '0.3'", cql) + self.assertIn("LeveledCompactionStrategy", cql) + # formerly legacy options; reintroduced in 4.0 + if CASSANDRA_VERSION < Version('4.0'): + self.assertNotIn("min_threshold", cql) + self.assertNotIn("max_threshold", cql) + + def test_refresh_schema_metadata(self): + """ + test for synchronously refreshing all cluster metadata + + test_refresh_schema_metadata tests all cluster metadata is refreshed when calling refresh_schema_metadata(). + It creates a second cluster object with schema_event_refresh_window=-1 such that schema refreshes are disabled + for schema change push events. It then alters the cluster, creating a new keyspace, using the first cluster + object, and verifies that the cluster metadata has not changed in the second cluster object. It then calls + refresh_schema_metadata() and verifies that the cluster metadata is updated in the second cluster object. + Similarly, it then proceeds to altering keyspace, table, UDT, UDF, and UDA metadata and subsequently verfies + that these metadata is updated when refresh_schema_metadata() is called. + + @since 2.6.0 + @jira_ticket PYTHON-291 + @expected_result Cluster, keyspace, table, UDT, UDF, and UDA metadata should be refreshed when refresh_schema_metadata() is called. + + @test_category metadata + """ + + cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2.connect() + + self.assertNotIn("new_keyspace", cluster2.metadata.keyspaces) + + # Cluster metadata modification + self.session.execute("CREATE KEYSPACE new_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}") + self.assertNotIn("new_keyspace", cluster2.metadata.keyspaces) + + cluster2.refresh_schema_metadata() + self.assertIn("new_keyspace", cluster2.metadata.keyspaces) + + # Keyspace metadata modification + self.session.execute("ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name)) + self.assertTrue(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) + cluster2.refresh_schema_metadata() + self.assertFalse(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) + + # Table metadata modification + table_name = "test" + self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, table_name)) + cluster2.refresh_schema_metadata() + + self.session.execute("ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name)) + self.assertNotIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) + cluster2.refresh_schema_metadata() + self.assertIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) + + if PROTOCOL_VERSION >= 3: + # UDT metadata modification + self.session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].user_types, {}) + cluster2.refresh_schema_metadata() + self.assertIn("user", cluster2.metadata.keyspaces[self.keyspace_name].user_types) + + if PROTOCOL_VERSION >= 4: + # UDF metadata modification + self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) + RETURNS NULL ON NULL INPUT + RETURNS int + LANGUAGE javascript AS 'key + val';""".format(self.keyspace_name)) + + self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].functions, {}) + cluster2.refresh_schema_metadata() + self.assertIn("sum_int(int,int)", cluster2.metadata.keyspaces[self.keyspace_name].functions) + + # UDA metadata modification + self.session.execute("""CREATE AGGREGATE {0}.sum_agg(int) + SFUNC sum_int + STYPE int + INITCOND 0""" + .format(self.keyspace_name)) + + self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].aggregates, {}) + cluster2.refresh_schema_metadata() + self.assertIn("sum_agg(int)", cluster2.metadata.keyspaces[self.keyspace_name].aggregates) + + # Cluster metadata modification + self.session.execute("DROP KEYSPACE new_keyspace") + self.assertIn("new_keyspace", cluster2.metadata.keyspaces) + + cluster2.refresh_schema_metadata() + self.assertNotIn("new_keyspace", cluster2.metadata.keyspaces) + + cluster2.shutdown() + + def test_refresh_keyspace_metadata(self): + """ + test for synchronously refreshing keyspace metadata + + test_refresh_keyspace_metadata tests that keyspace metadata is refreshed when calling refresh_keyspace_metadata(). + It creates a second cluster object with schema_event_refresh_window=-1 such that schema refreshes are disabled + for schema change push events. It then alters the keyspace, disabling durable_writes, using the first cluster + object, and verifies that the keyspace metadata has not changed in the second cluster object. Finally, it calls + refresh_keyspace_metadata() and verifies that the keyspace metadata is updated in the second cluster object. + + @since 2.6.0 + @jira_ticket PYTHON-291 + @expected_result Keyspace metadata should be refreshed when refresh_keyspace_metadata() is called. + + @test_category metadata + """ + + cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2.connect() + + self.assertTrue(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) + self.session.execute("ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name)) + self.assertTrue(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) + cluster2.refresh_keyspace_metadata(self.keyspace_name) + self.assertFalse(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) + + cluster2.shutdown() + + def test_refresh_table_metadata(self): + """ + test for synchronously refreshing table metadata + + test_refresh_table_metatadata tests that table metadata is refreshed when calling test_refresh_table_metatadata(). + It creates a second cluster object with schema_event_refresh_window=-1 such that schema refreshes are disabled + for schema change push events. It then alters the table, adding a new column, using the first cluster + object, and verifies that the table metadata has not changed in the second cluster object. Finally, it calls + test_refresh_table_metatadata() and verifies that the table metadata is updated in the second cluster object. + + @since 2.6.0 + @jira_ticket PYTHON-291 + @expected_result Table metadata should be refreshed when refresh_table_metadata() is called. + + @test_category metadata + """ + + table_name = "test" + self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, table_name)) + + cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2.connect() + + self.assertNotIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) + self.session.execute("ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name)) + self.assertNotIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) + + cluster2.refresh_table_metadata(self.keyspace_name, table_name) + self.assertIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) + + cluster2.shutdown() + + @greaterthanorequalcass30 + def test_refresh_metadata_for_mv(self): + """ + test for synchronously refreshing materialized view metadata + + test_refresh_table_metadata_for_materialized_views tests that materialized view metadata is refreshed when calling + test_refresh_table_metatadata() with the materialized view name as the table. It creates a second cluster object + with schema_event_refresh_window=-1 such that schema refreshes are disabled for schema change push events. + It then creates a new materialized view , using the first cluster object, and verifies that the materialized view + metadata has not changed in the second cluster object. Finally, it calls test_refresh_table_metatadata() with the + materialized view name as the table name, and verifies that the materialized view metadata is updated in the + second cluster object. + + @since 3.0.0 + @jira_ticket PYTHON-371 + @expected_result Materialized view metadata should be refreshed when refresh_table_metadata() is called. + + @test_category metadata + """ + + self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, self.function_table_name)) + + cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2.connect() + + try: + self.assertNotIn("mv1", cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + self.session.execute("CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT b FROM {0}.{1} WHERE b IS NOT NULL PRIMARY KEY (a, b)" + .format(self.keyspace_name, self.function_table_name)) + self.assertNotIn("mv1", cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + + cluster2.refresh_table_metadata(self.keyspace_name, "mv1") + self.assertIn("mv1", cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + finally: + cluster2.shutdown() + + original_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] + self.assertIs(original_meta, self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1']) + self.cluster.refresh_materialized_view_metadata(self.keyspace_name, 'mv1') + + current_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] + self.assertIsNot(current_meta, original_meta) + self.assertIsNot(original_meta, self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1']) + self.assertEqual(original_meta.as_cql_query(), current_meta.as_cql_query()) + + cluster3 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster3.connect() + try: + self.assertNotIn("mv2", cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + self.session.execute("CREATE MATERIALIZED VIEW {0}.mv2 AS SELECT b FROM {0}.{1} WHERE b IS NOT NULL PRIMARY KEY (a, b)" + .format(self.keyspace_name, self.function_table_name)) + self.assertNotIn("mv2", cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + cluster3.refresh_materialized_view_metadata(self.keyspace_name, 'mv2') + self.assertIn("mv2", cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + finally: + cluster3.shutdown() + + def test_refresh_user_type_metadata(self): + """ + test for synchronously refreshing UDT metadata in keyspace + + test_refresh_user_type_metadata tests that UDT metadata in a keyspace is refreshed when calling refresh_user_type_metadata(). + It creates a second cluster object with schema_event_refresh_window=-1 such that schema refreshes are disabled + for schema change push events. It then alters the keyspace, creating a new UDT, using the first cluster + object, and verifies that the UDT metadata has not changed in the second cluster object. Finally, it calls + refresh_user_type_metadata() and verifies that the UDT metadata in the keyspace is updated in the second cluster object. + + @since 2.6.0 + @jira_ticket PYTHON-291 + @expected_result UDT metadata in the keyspace should be refreshed when refresh_user_type_metadata() is called. + + @test_category metadata + """ + + if PROTOCOL_VERSION < 3: + raise unittest.SkipTest("Protocol 3+ is required for UDTs, currently testing against {0}".format(PROTOCOL_VERSION)) + + cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2.connect() + + self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].user_types, {}) + self.session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].user_types, {}) + + cluster2.refresh_user_type_metadata(self.keyspace_name, "user") + self.assertIn("user", cluster2.metadata.keyspaces[self.keyspace_name].user_types) + + cluster2.shutdown() + + @greaterthancass20 + def test_refresh_user_type_metadata_proto_2(self): + """ + Test to insure that protocol v1/v2 surface UDT metadata changes + + @since 3.7.0 + @jira_ticket PYTHON-106 + @expected_result UDT metadata in the keyspace should be updated regardless of protocol version + + @test_category metadata + """ + supported_versions = get_supported_protocol_versions() + if 2 not in supported_versions: # 1 and 2 were dropped in the same version + raise unittest.SkipTest("Protocol versions 1 and 2 are not supported in Cassandra version ".format(CASSANDRA_VERSION)) + + for protocol_version in (1, 2): + cluster = Cluster(protocol_version=protocol_version) + session = cluster.connect() + self.assertEqual(cluster.metadata.keyspaces[self.keyspace_name].user_types, {}) + + session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + self.assertIn("user", cluster.metadata.keyspaces[self.keyspace_name].user_types) + self.assertIn("age", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + self.assertIn("name", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + + session.execute("ALTER TYPE {0}.user ADD flag boolean".format(self.keyspace_name)) + self.assertIn("flag", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + + session.execute("ALTER TYPE {0}.user RENAME flag TO something".format(self.keyspace_name)) + self.assertIn("something", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + + session.execute("DROP TYPE {0}.user".format(self.keyspace_name)) + self.assertEqual(cluster.metadata.keyspaces[self.keyspace_name].user_types, {}) + cluster.shutdown() + + def test_refresh_user_function_metadata(self): + """ + test for synchronously refreshing UDF metadata in keyspace + + test_refresh_user_function_metadata tests that UDF metadata in a keyspace is refreshed when calling + refresh_user_function_metadata(). It creates a second cluster object with schema_event_refresh_window=-1 such + that schema refreshes are disabled for schema change push events. It then alters the keyspace, creating a new + UDF, using the first cluster object, and verifies that the UDF metadata has not changed in the second cluster + object. Finally, it calls refresh_user_function_metadata() and verifies that the UDF metadata in the keyspace + is updated in the second cluster object. + + @since 2.6.0 + @jira_ticket PYTHON-291 + @expected_result UDF metadata in the keyspace should be refreshed when refresh_user_function_metadata() is called. + + @test_category metadata + """ + + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Protocol 4+ is required for UDFs, currently testing against {0}".format(PROTOCOL_VERSION)) + + cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2.connect() + + self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].functions, {}) + self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) + RETURNS NULL ON NULL INPUT + RETURNS int + LANGUAGE javascript AS 'key + val';""".format(self.keyspace_name)) + + self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].functions, {}) + cluster2.refresh_user_function_metadata(self.keyspace_name, UserFunctionDescriptor("sum_int", ["int", "int"])) + self.assertIn("sum_int(int,int)", cluster2.metadata.keyspaces[self.keyspace_name].functions) + + cluster2.shutdown() + + def test_refresh_user_aggregate_metadata(self): + """ + test for synchronously refreshing UDA metadata in keyspace + + test_refresh_user_aggregate_metadata tests that UDA metadata in a keyspace is refreshed when calling + refresh_user_aggregate_metadata(). It creates a second cluster object with schema_event_refresh_window=-1 such + that schema refreshes are disabled for schema change push events. It then alters the keyspace, creating a new + UDA, using the first cluster object, and verifies that the UDA metadata has not changed in the second cluster + object. Finally, it calls refresh_user_aggregate_metadata() and verifies that the UDF metadata in the keyspace + is updated in the second cluster object. + + @since 2.6.0 + @jira_ticket PYTHON-291 + @expected_result UDA metadata in the keyspace should be refreshed when refresh_user_aggregate_metadata() is called. + + @test_category metadata + """ + + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Protocol 4+ is required for UDAs, currently testing against {0}".format(PROTOCOL_VERSION)) + + cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2.connect() + + self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].aggregates, {}) + self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) + RETURNS NULL ON NULL INPUT + RETURNS int + LANGUAGE javascript AS 'key + val';""".format(self.keyspace_name)) + + self.session.execute("""CREATE AGGREGATE {0}.sum_agg(int) + SFUNC sum_int + STYPE int + INITCOND 0""" + .format(self.keyspace_name)) + + self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].aggregates, {}) + cluster2.refresh_user_aggregate_metadata(self.keyspace_name, UserAggregateDescriptor("sum_agg", ["int"])) + self.assertIn("sum_agg(int)", cluster2.metadata.keyspaces[self.keyspace_name].aggregates) + + cluster2.shutdown() + + @greaterthanorequalcass30 + def test_multiple_indices(self): + """ + test multiple indices on the same column. + + Creates a table and two indices. Ensures that both indices metatdata is surface appropriately. + + @since 3.0.0 + @jira_ticket PYTHON-276 + @expected_result IndexMetadata is appropriately surfaced + + @test_category metadata + """ + + self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b map)".format(self.keyspace_name, self.function_table_name)) + self.session.execute("CREATE INDEX index_1 ON {0}.{1}(b)".format(self.keyspace_name, self.function_table_name)) + self.session.execute("CREATE INDEX index_2 ON {0}.{1}(keys(b))".format(self.keyspace_name, self.function_table_name)) + + indices = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].indexes + self.assertEqual(len(indices), 2) + index_1 = indices["index_1"] + index_2 = indices['index_2'] + self.assertEqual(index_1.table_name, "test_multiple_indices") + self.assertEqual(index_1.name, "index_1") + self.assertEqual(index_1.kind, "COMPOSITES") + self.assertEqual(index_1.index_options["target"], "values(b)") + self.assertEqual(index_1.keyspace_name, "schemametadatatests") + self.assertEqual(index_2.table_name, "test_multiple_indices") + self.assertEqual(index_2.name, "index_2") + self.assertEqual(index_2.kind, "COMPOSITES") + self.assertEqual(index_2.index_options["target"], "keys(b)") + self.assertEqual(index_2.keyspace_name, "schemametadatatests") + + @greaterthanorequalcass30 + def test_table_extensions(self): + s = self.session + ks = self.keyspace_name + ks_meta = s.cluster.metadata.keyspaces[ks] + t = self.function_table_name + v = t + 'view' + + s.execute("CREATE TABLE %s.%s (k text PRIMARY KEY, v int)" % (ks, t)) + s.execute("CREATE MATERIALIZED VIEW %s.%s AS SELECT * FROM %s.%s WHERE v IS NOT NULL PRIMARY KEY (v, k)" % (ks, v, ks, t)) + + table_meta = ks_meta.tables[t] + view_meta = table_meta.views[v] + + self.assertFalse(table_meta.extensions) + self.assertFalse(view_meta.extensions) + + original_table_cql = table_meta.export_as_string() + original_view_cql = view_meta.export_as_string() + + # extensions registered, not present + # -------------------------------------- + class Ext0(RegisteredTableExtension): + name = t + + @classmethod + def after_table_cql(cls, table_meta, ext_key, ext_blob): + return "%s %s %s %s" % (cls.name, table_meta.name, ext_key, ext_blob) + + class Ext1(Ext0): + name = t + '##' + + self.assertFalse(table_meta.extensions) + self.assertFalse(view_meta.extensions) + self.assertIn(Ext0.name, _RegisteredExtensionType._extension_registry) + self.assertIn(Ext1.name, _RegisteredExtensionType._extension_registry) + self.assertEqual(len(_RegisteredExtensionType._extension_registry), 2) + + self.cluster.refresh_table_metadata(ks, t) + table_meta = ks_meta.tables[t] + view_meta = table_meta.views[v] + + self.assertEqual(table_meta.export_as_string(), original_table_cql) + self.assertEqual(view_meta.export_as_string(), original_view_cql) + + update_t = s.prepare('UPDATE system_schema.tables SET extensions=? WHERE keyspace_name=? AND table_name=?') # for blob type coercing + update_v = s.prepare('UPDATE system_schema.views SET extensions=? WHERE keyspace_name=? AND view_name=?') + # extensions registered, one present + # -------------------------------------- + ext_map = {Ext0.name: six.b("THA VALUE")} + [(s.execute(update_t, (ext_map, ks, t)), s.execute(update_v, (ext_map, ks, v))) + for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts + self.cluster.refresh_table_metadata(ks, t) + self.cluster.refresh_materialized_view_metadata(ks, v) + table_meta = ks_meta.tables[t] + view_meta = table_meta.views[v] + + self.assertIn(Ext0.name, table_meta.extensions) + new_cql = table_meta.export_as_string() + self.assertNotEqual(new_cql, original_table_cql) + self.assertIn(Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]), new_cql) + self.assertNotIn(Ext1.name, new_cql) + + self.assertIn(Ext0.name, view_meta.extensions) + new_cql = view_meta.export_as_string() + self.assertNotEqual(new_cql, original_view_cql) + self.assertIn(Ext0.after_table_cql(view_meta, Ext0.name, ext_map[Ext0.name]), new_cql) + self.assertNotIn(Ext1.name, new_cql) + + # extensions registered, one present + # -------------------------------------- + ext_map = {Ext0.name: six.b("THA VALUE"), + Ext1.name: six.b("OTHA VALUE")} + [(s.execute(update_t, (ext_map, ks, t)), s.execute(update_v, (ext_map, ks, v))) + for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts + self.cluster.refresh_table_metadata(ks, t) + self.cluster.refresh_materialized_view_metadata(ks, v) + table_meta = ks_meta.tables[t] + view_meta = table_meta.views[v] + + self.assertIn(Ext0.name, table_meta.extensions) + self.assertIn(Ext1.name, table_meta.extensions) + new_cql = table_meta.export_as_string() + self.assertNotEqual(new_cql, original_table_cql) + self.assertIn(Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]), new_cql) + self.assertIn(Ext1.after_table_cql(table_meta, Ext1.name, ext_map[Ext1.name]), new_cql) + + self.assertIn(Ext0.name, view_meta.extensions) + self.assertIn(Ext1.name, view_meta.extensions) + new_cql = view_meta.export_as_string() + self.assertNotEqual(new_cql, original_view_cql) + self.assertIn(Ext0.after_table_cql(view_meta, Ext0.name, ext_map[Ext0.name]), new_cql) + self.assertIn(Ext1.after_table_cql(view_meta, Ext1.name, ext_map[Ext1.name]), new_cql) + + +class TestCodeCoverage(unittest.TestCase): + + def test_export_schema(self): + """ + Test export schema functionality + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster.connect() + + self.assertIsInstance(cluster.metadata.export_schema_as_string(), six.string_types) + cluster.shutdown() + + def test_export_keyspace_schema(self): + """ + Test export keyspace schema functionality + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster.connect() + + for keyspace in cluster.metadata.keyspaces: + keyspace_metadata = cluster.metadata.keyspaces[keyspace] + self.assertIsInstance(keyspace_metadata.export_as_string(), six.string_types) + self.assertIsInstance(keyspace_metadata.as_cql_query(), six.string_types) + cluster.shutdown() + + def assert_equal_diff(self, received, expected): + if received != expected: + diff_string = '\n'.join(difflib.unified_diff(expected.split('\n'), + received.split('\n'), + 'EXPECTED', 'RECEIVED', + lineterm='')) + self.fail(diff_string) + + def assert_startswith_diff(self, received, prefix): + if not received.startswith(prefix): + prefix_lines = prefix.split('\n') + diff_string = '\n'.join(difflib.unified_diff(prefix_lines, + received.split('\n')[:len(prefix_lines)], + 'EXPECTED', 'RECEIVED', + lineterm='')) + self.fail(diff_string) + + @greaterthancass20 + def test_export_keyspace_schema_udts(self): + """ + Test udt exports + """ + + if PROTOCOL_VERSION < 3: + raise unittest.SkipTest( + "Protocol 3.0+ is required for UDT change events, currently testing against %r" + % (PROTOCOL_VERSION,)) + + if sys.version_info[0:2] != (2, 7): + raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.') + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + session.execute(""" + CREATE KEYSPACE export_udts + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + AND durable_writes = true; + """) + session.execute(""" + CREATE TYPE export_udts.street ( + street_number int, + street_name text) + """) + session.execute(""" + CREATE TYPE export_udts.zip ( + zipcode int, + zip_plus_4 int) + """) + session.execute(""" + CREATE TYPE export_udts.address ( + street_address frozen, + zip_code frozen) + """) + session.execute(""" + CREATE TABLE export_udts.users ( + user text PRIMARY KEY, + addresses map>) + """) + + expected_prefix = """CREATE KEYSPACE export_udts WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND durable_writes = true; + +CREATE TYPE export_udts.street ( + street_number int, + street_name text +); + +CREATE TYPE export_udts.zip ( + zipcode int, + zip_plus_4 int +); + +CREATE TYPE export_udts.address ( + street_address frozen, + zip_code frozen +); + +CREATE TABLE export_udts.users ( + user text PRIMARY KEY, + addresses map>""" + + self.assert_startswith_diff(cluster.metadata.keyspaces['export_udts'].export_as_string(), expected_prefix) + + table_meta = cluster.metadata.keyspaces['export_udts'].tables['users'] + + expected_prefix = """CREATE TABLE export_udts.users ( + user text PRIMARY KEY, + addresses map>""" + + self.assert_startswith_diff(table_meta.export_as_string(), expected_prefix) + + cluster.shutdown() + + @greaterthancass21 + def test_case_sensitivity(self): + """ + Test that names that need to be escaped in CREATE statements are + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + ksname = 'AnInterestingKeyspace' + cfname = 'AnInterestingTable' + + session.execute("DROP KEYSPACE IF EXISTS {0}".format(ksname)) + session.execute(""" + CREATE KEYSPACE "%s" + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + """ % (ksname,)) + session.execute(""" + CREATE TABLE "%s"."%s" ( + k int, + "A" int, + "B" int, + "MyColumn" int, + PRIMARY KEY (k, "A")) + WITH CLUSTERING ORDER BY ("A" DESC) + """ % (ksname, cfname)) + session.execute(""" + CREATE INDEX myindex ON "%s"."%s" ("MyColumn") + """ % (ksname, cfname)) + session.execute(""" + CREATE INDEX "AnotherIndex" ON "%s"."%s" ("B") + """ % (ksname, cfname)) + + ksmeta = cluster.metadata.keyspaces[ksname] + schema = ksmeta.export_as_string() + self.assertIn('CREATE KEYSPACE "AnInterestingKeyspace"', schema) + self.assertIn('CREATE TABLE "AnInterestingKeyspace"."AnInterestingTable"', schema) + self.assertIn('"A" int', schema) + self.assertIn('"B" int', schema) + self.assertIn('"MyColumn" int', schema) + self.assertIn('PRIMARY KEY (k, "A")', schema) + self.assertIn('WITH CLUSTERING ORDER BY ("A" DESC)', schema) + self.assertIn('CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")', schema) + self.assertIn('CREATE INDEX "AnotherIndex" ON "AnInterestingKeyspace"."AnInterestingTable" ("B")', schema) + cluster.shutdown() + + def test_already_exists_exceptions(self): + """ + Ensure AlreadyExists exception is thrown when hit + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + ksname = 'test3rf' + cfname = 'test' + + ddl = ''' + CREATE KEYSPACE %s + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}''' + self.assertRaises(AlreadyExists, session.execute, ddl % ksname) + + ddl = ''' + CREATE TABLE %s.%s ( + k int PRIMARY KEY, + v int )''' + self.assertRaises(AlreadyExists, session.execute, ddl % (ksname, cfname)) + cluster.shutdown() + + @local + def test_replicas(self): + """ + Ensure cluster.metadata.get_replicas return correctly when not attached to keyspace + """ + if murmur3 is None: + raise unittest.SkipTest('the murmur3 extension is not available') + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.assertEqual(cluster.metadata.get_replicas('test3rf', 'key'), []) + + cluster.connect('test3rf') + + self.assertNotEqual(list(cluster.metadata.get_replicas('test3rf', six.b('key'))), []) + host = list(cluster.metadata.get_replicas('test3rf', six.b('key')))[0] + self.assertEqual(host.datacenter, 'dc1') + self.assertEqual(host.rack, 'r1') + cluster.shutdown() + + def test_token_map(self): + """ + Test token mappings + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster.connect('test3rf') + ring = cluster.metadata.token_map.ring + owners = list(cluster.metadata.token_map.token_to_host_owner[token] for token in ring) + get_replicas = cluster.metadata.token_map.get_replicas + + for ksname in ('test1rf', 'test2rf', 'test3rf'): + self.assertNotEqual(list(get_replicas(ksname, ring[0])), []) + + for i, token in enumerate(ring): + self.assertEqual(set(get_replicas('test3rf', token)), set(owners)) + self.assertEqual(set(get_replicas('test2rf', token)), set([owners[i], owners[(i + 1) % 3]])) + self.assertEqual(set(get_replicas('test1rf', token)), set([owners[i]])) + cluster.shutdown() + + +class TokenMetadataTest(unittest.TestCase): + """ + Test of TokenMap creation and other behavior. + """ + @local + def test_token(self): + expected_node_count = len(get_cluster().nodes) + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster.connect() + tmap = cluster.metadata.token_map + self.assertTrue(issubclass(tmap.token_class, Token)) + self.assertEqual(expected_node_count, len(tmap.ring)) + cluster.shutdown() + + +class KeyspaceAlterMetadata(unittest.TestCase): + """ + Test verifies that table metadata is preserved on keyspace alter + """ + def setUp(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.session = self.cluster.connect() + name = self._testMethodName.lower() + crt_ks = ''' + CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} AND durable_writes = true''' % name + self.session.execute(crt_ks) + + def tearDown(self): + name = self._testMethodName.lower() + self.session.execute('DROP KEYSPACE %s' % name) + self.cluster.shutdown() + + def test_keyspace_alter(self): + """ + Table info is preserved upon keyspace alter: + Create table + Verify schema + Alter ks + Verify that table metadata is still present + + PYTHON-173 + """ + name = self._testMethodName.lower() + + self.session.execute('CREATE TABLE %s.d (d INT PRIMARY KEY)' % name) + original_keyspace_meta = self.cluster.metadata.keyspaces[name] + self.assertEqual(original_keyspace_meta.durable_writes, True) + self.assertEqual(len(original_keyspace_meta.tables), 1) + + self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % name) + new_keyspace_meta = self.cluster.metadata.keyspaces[name] + self.assertNotEqual(original_keyspace_meta, new_keyspace_meta) + self.assertEqual(new_keyspace_meta.durable_writes, False) + + +class IndexMapTests(unittest.TestCase): + + keyspace_name = 'index_map_tests' + + @property + def table_name(self): + return self._testMethodName.lower() + + @classmethod + def setup_class(cls): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.session = cls.cluster.connect() + try: + if cls.keyspace_name in cls.cluster.metadata.keyspaces: + cls.session.execute("DROP KEYSPACE %s" % cls.keyspace_name) + + cls.session.execute( + """ + CREATE KEYSPACE %s + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}; + """ % cls.keyspace_name) + cls.session.set_keyspace(cls.keyspace_name) + except Exception: + cls.cluster.shutdown() + raise + + @classmethod + def teardown_class(cls): + try: + cls.session.execute("DROP KEYSPACE %s" % cls.keyspace_name) + finally: + cls.cluster.shutdown() + + def create_basic_table(self): + self.session.execute("CREATE TABLE %s (k int PRIMARY KEY, a int)" % self.table_name) + + def drop_basic_table(self): + self.session.execute("DROP TABLE %s" % self.table_name) + + def test_index_updates(self): + self.create_basic_table() + + ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] + table_meta = ks_meta.tables[self.table_name] + self.assertNotIn('a_idx', ks_meta.indexes) + self.assertNotIn('b_idx', ks_meta.indexes) + self.assertNotIn('a_idx', table_meta.indexes) + self.assertNotIn('b_idx', table_meta.indexes) + + self.session.execute("CREATE INDEX a_idx ON %s (a)" % self.table_name) + self.session.execute("ALTER TABLE %s ADD b int" % self.table_name) + self.session.execute("CREATE INDEX b_idx ON %s (b)" % self.table_name) + + ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] + table_meta = ks_meta.tables[self.table_name] + self.assertIsInstance(ks_meta.indexes['a_idx'], IndexMetadata) + self.assertIsInstance(ks_meta.indexes['b_idx'], IndexMetadata) + self.assertIsInstance(table_meta.indexes['a_idx'], IndexMetadata) + self.assertIsInstance(table_meta.indexes['b_idx'], IndexMetadata) + + # both indexes updated when index dropped + self.session.execute("DROP INDEX a_idx") + + # temporarily synchronously refresh the schema metadata, until CASSANDRA-9391 is merged in + self.cluster.refresh_table_metadata(self.keyspace_name, self.table_name) + + ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] + table_meta = ks_meta.tables[self.table_name] + self.assertNotIn('a_idx', ks_meta.indexes) + self.assertIsInstance(ks_meta.indexes['b_idx'], IndexMetadata) + self.assertNotIn('a_idx', table_meta.indexes) + self.assertIsInstance(table_meta.indexes['b_idx'], IndexMetadata) + + # keyspace index updated when table dropped + self.drop_basic_table() + ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] + self.assertNotIn(self.table_name, ks_meta.tables) + self.assertNotIn('a_idx', ks_meta.indexes) + self.assertNotIn('b_idx', ks_meta.indexes) + + def test_index_follows_alter(self): + self.create_basic_table() + + idx = self.table_name + '_idx' + self.session.execute("CREATE INDEX %s ON %s (a)" % (idx, self.table_name)) + ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] + table_meta = ks_meta.tables[self.table_name] + self.assertIsInstance(ks_meta.indexes[idx], IndexMetadata) + self.assertIsInstance(table_meta.indexes[idx], IndexMetadata) + self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) + old_meta = ks_meta + ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] + self.assertIsNot(ks_meta, old_meta) + table_meta = ks_meta.tables[self.table_name] + self.assertIsInstance(ks_meta.indexes[idx], IndexMetadata) + self.assertIsInstance(table_meta.indexes[idx], IndexMetadata) + self.drop_basic_table() + + +class FunctionTest(unittest.TestCase): + """ + Base functionality for Function and Aggregate metadata test classes + """ + + def setUp(self): + """ + Tests are skipped if run with native protocol version < 4 + """ + + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Function metadata requires native protocol version 4+") + + @property + def function_name(self): + return self._testMethodName.lower() + + @classmethod + def setup_class(cls): + if PROTOCOL_VERSION >= 4: + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.keyspace_name = cls.__name__.lower() + cls.session = cls.cluster.connect() + cls.session.execute("CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) + cls.session.set_keyspace(cls.keyspace_name) + cls.keyspace_function_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].functions + cls.keyspace_aggregate_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].aggregates + + @classmethod + def teardown_class(cls): + if PROTOCOL_VERSION >= 4: + cls.session.execute("DROP KEYSPACE IF EXISTS %s" % cls.keyspace_name) + cls.cluster.shutdown() + + class Verified(object): + + def __init__(self, test_case, meta_class, element_meta, **function_kwargs): + self.test_case = test_case + self.function_kwargs = dict(function_kwargs) + self.meta_class = meta_class + self.element_meta = element_meta + + def __enter__(self): + tc = self.test_case + expected_meta = self.meta_class(**self.function_kwargs) + tc.assertNotIn(expected_meta.signature, self.element_meta) + tc.session.execute(expected_meta.as_cql_query()) + tc.assertIn(expected_meta.signature, self.element_meta) + + generated_meta = self.element_meta[expected_meta.signature] + self.test_case.assertEqual(generated_meta.as_cql_query(), expected_meta.as_cql_query()) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + tc = self.test_case + tc.session.execute("DROP %s %s.%s" % (self.meta_class.__name__, tc.keyspace_name, self.signature)) + tc.assertNotIn(self.signature, self.element_meta) + + @property + def signature(self): + return SignatureDescriptor.format_signature(self.function_kwargs['name'], + self.function_kwargs['argument_types']) + + class VerifiedFunction(Verified): + def __init__(self, test_case, **kwargs): + super(FunctionTest.VerifiedFunction, self).__init__(test_case, Function, test_case.keyspace_function_meta, **kwargs) + + class VerifiedAggregate(Verified): + def __init__(self, test_case, **kwargs): + super(FunctionTest.VerifiedAggregate, self).__init__(test_case, Aggregate, test_case.keyspace_aggregate_meta, **kwargs) + + +class FunctionMetadata(FunctionTest): + + def make_function_kwargs(self, called_on_null=True): + return {'keyspace': self.keyspace_name, + 'name': self.function_name, + 'argument_types': ['double', 'int'], + 'argument_names': ['d', 'i'], + 'return_type': 'double', + 'language': 'java', + 'body': 'return new Double(0.0);', + 'called_on_null_input': called_on_null} + + def test_functions_after_udt(self): + """ + Test to to ensure functions come after UDTs in in keyspace dump + + test_functions_after_udt creates a basic function. Then queries that function and make sure that in the results + that UDT's are listed before any corresponding functions, when we dump the keyspace + + Ideally we would make a function that takes a udt type, but this presently fails because C* c059a56 requires + udt to be frozen to create, but does not store meta indicating frozen + SEE https://issues.apache.org/jira/browse/CASSANDRA-9186 + Maybe update this after release + kwargs = self.make_function_kwargs() + kwargs['argument_types'][0] = "frozen<%s>" % udt_name + expected_meta = Function(**kwargs) + with self.VerifiedFunction(self, **kwargs): + + @since 2.6.0 + @jira_ticket PYTHON-211 + @expected_result UDT's should come before any functions + @test_category function + """ + + self.assertNotIn(self.function_name, self.keyspace_function_meta) + + udt_name = 'udtx' + self.session.execute("CREATE TYPE %s (x int)" % udt_name) + + with self.VerifiedFunction(self, **self.make_function_kwargs()): + # udts must come before functions in keyspace dump + keyspace_cql = self.cluster.metadata.keyspaces[self.keyspace_name].export_as_string() + type_idx = keyspace_cql.rfind("CREATE TYPE") + func_idx = keyspace_cql.find("CREATE FUNCTION") + self.assertNotIn(-1, (type_idx, func_idx), "TYPE or FUNCTION not found in keyspace_cql: " + keyspace_cql) + self.assertGreater(func_idx, type_idx) + + def test_function_same_name_diff_types(self): + """ + Test to verify to that functions with different signatures are differentiated in metadata + + test_function_same_name_diff_types Creates two functions. One with the same name but a slightly different + signature. Then ensures that both are surfaced separately in our metadata. + + @since 2.6.0 + @jira_ticket PYTHON-211 + @expected_result function with the same name but different signatures should be surfaced separately + @test_category function + """ + + # Create a function + kwargs = self.make_function_kwargs() + with self.VerifiedFunction(self, **kwargs): + + # another function: same name, different type sig. + self.assertGreater(len(kwargs['argument_types']), 1) + self.assertGreater(len(kwargs['argument_names']), 1) + kwargs['argument_types'] = kwargs['argument_types'][:1] + kwargs['argument_names'] = kwargs['argument_names'][:1] + + # Ensure they are surfaced separately + with self.VerifiedFunction(self, **kwargs): + functions = [f for f in self.keyspace_function_meta.values() if f.name == self.function_name] + self.assertEqual(len(functions), 2) + self.assertNotEqual(functions[0].argument_types, functions[1].argument_types) + + def test_function_no_parameters(self): + """ + Test to verify CQL output for functions with zero parameters + + Creates a function with no input parameters, verify that CQL output is correct. + + @since 2.7.1 + @jira_ticket PYTHON-392 + @expected_result function with no parameters should generate proper CQL + @test_category function + """ + kwargs = self.make_function_kwargs() + kwargs['argument_types'] = [] + kwargs['argument_names'] = [] + kwargs['return_type'] = 'bigint' + kwargs['body'] = 'return System.currentTimeMillis() / 1000L;' + + with self.VerifiedFunction(self, **kwargs) as vf: + fn_meta = self.keyspace_function_meta[vf.signature] + self.assertRegexpMatches(fn_meta.as_cql_query(), "CREATE FUNCTION.*%s\(\) .*" % kwargs['name']) + + def test_functions_follow_keyspace_alter(self): + """ + Test to verify to that functions maintain equality after a keyspace is altered + + test_functions_follow_keyspace_alter creates a function then alters a the keyspace associated with that function. + After the alter we validate that the function maintains the same metadata + + + @since 2.6.0 + @jira_ticket PYTHON-211 + @expected_result functions are the same after parent keyspace is altered + @test_category function + """ + + # Create function + with self.VerifiedFunction(self, **self.make_function_kwargs()): + original_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] + self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) + + # After keyspace alter ensure that we maintain function equality. + try: + new_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] + self.assertNotEqual(original_keyspace_meta, new_keyspace_meta) + self.assertIs(original_keyspace_meta.functions, new_keyspace_meta.functions) + finally: + self.session.execute('ALTER KEYSPACE %s WITH durable_writes = true' % self.keyspace_name) + + def test_function_cql_called_on_null(self): + """ + Test to verify to that that called on null argument is honored on function creation. + + test_functions_follow_keyspace_alter create two functions. One with the called_on_null_input set to true, + the other with it set to false. We then verify that the metadata constructed from those function is correctly + reflected + + + @since 2.6.0 + @jira_ticket PYTHON-211 + @expected_result functions metadata correctly reflects called_on_null_input flag. + @test_category function + """ + + kwargs = self.make_function_kwargs() + kwargs['called_on_null_input'] = True + with self.VerifiedFunction(self, **kwargs) as vf: + fn_meta = self.keyspace_function_meta[vf.signature] + self.assertRegexpMatches(fn_meta.as_cql_query(), "CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*") + + kwargs['called_on_null_input'] = False + with self.VerifiedFunction(self, **kwargs) as vf: + fn_meta = self.keyspace_function_meta[vf.signature] + self.assertRegexpMatches(fn_meta.as_cql_query(), "CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*") + + +class AggregateMetadata(FunctionTest): + + @classmethod + def setup_class(cls): + if PROTOCOL_VERSION >= 4: + super(AggregateMetadata, cls).setup_class() + + cls.session.execute("""CREATE OR REPLACE FUNCTION sum_int(s int, i int) + RETURNS NULL ON NULL INPUT + RETURNS int + LANGUAGE javascript AS 's + i';""") + cls.session.execute("""CREATE OR REPLACE FUNCTION sum_int_two(s int, i int, j int) + RETURNS NULL ON NULL INPUT + RETURNS int + LANGUAGE javascript AS 's + i + j';""") + cls.session.execute("""CREATE OR REPLACE FUNCTION "List_As_String"(l list) + RETURNS NULL ON NULL INPUT + RETURNS int + LANGUAGE javascript AS ''''' + l';""") + cls.session.execute("""CREATE OR REPLACE FUNCTION extend_list(s list, i int) + CALLED ON NULL INPUT + RETURNS list + LANGUAGE java AS 'if (i != null) s.add(i.toString()); return s;';""") + cls.session.execute("""CREATE OR REPLACE FUNCTION update_map(s map, i int) + RETURNS NULL ON NULL INPUT + RETURNS map + LANGUAGE java AS 's.put(new Integer(i), new Integer(i)); return s;';""") + cls.session.execute("""CREATE TABLE IF NOT EXISTS t + (k int PRIMARY KEY, v int)""") + for x in range(4): + cls.session.execute("INSERT INTO t (k,v) VALUES (%s, %s)", (x, x)) + cls.session.execute("INSERT INTO t (k) VALUES (%s)", (4,)) + + def make_aggregate_kwargs(self, state_func, state_type, final_func=None, init_cond=None): + return {'keyspace': self.keyspace_name, + 'name': self.function_name + '_aggregate', + 'argument_types': ['int'], + 'state_func': state_func, + 'state_type': state_type, + 'final_func': final_func, + 'initial_condition': init_cond, + 'return_type': "does not matter for creation"} + + def test_return_type_meta(self): + """ + Test to verify to that the return type of a an aggregate is honored in the metadata + + test_return_type_meta creates an aggregate then ensures the return type of the created + aggregate is correctly surfaced in the metadata + + + @since 2.6.0 + @jira_ticket PYTHON-211 + @expected_result aggregate has the correct return typ in the metadata + @test_category aggregate + """ + + with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond='1')) as va: + self.assertEqual(self.keyspace_aggregate_meta[va.signature].return_type, 'int') + + def test_init_cond(self): + """ + Test to verify that various initial conditions are correctly surfaced in various aggregate functions + + test_init_cond creates several different types of aggregates, and given various initial conditions it verifies that + they correctly impact the aggregate's execution + + + @since 2.6.0 + @jira_ticket PYTHON-211 + @expected_result initial conditions are correctly evaluated as part of the aggregates + @test_category aggregate + """ + + # This is required until the java driver bundled with C* is updated to support v4 + c = Cluster(protocol_version=3) + s = c.connect(self.keyspace_name) + + encoder = Encoder() + + expected_values = range(4) + + # int32 + for init_cond in (-1, 0, 1): + cql_init = encoder.cql_encode_all_types(init_cond) + with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond=cql_init)) as va: + sum_res = s.execute("SELECT %s(v) AS sum FROM t" % va.function_kwargs['name'])[0].sum + self.assertEqual(sum_res, int(init_cond) + sum(expected_values)) + + # list + for init_cond in ([], ['1', '2']): + cql_init = encoder.cql_encode_all_types(init_cond) + with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('extend_list', 'list', init_cond=cql_init)) as va: + list_res = s.execute("SELECT %s(v) AS list_res FROM t" % va.function_kwargs['name'])[0].list_res + self.assertListEqual(list_res[:len(init_cond)], init_cond) + self.assertEqual(set(i for i in list_res[len(init_cond):]), + set(str(i) for i in expected_values)) + + # map + expected_map_values = dict((i, i) for i in expected_values) + expected_key_set = set(expected_values) + for init_cond in ({}, {1: 2, 3: 4}, {5: 5}): + cql_init = encoder.cql_encode_all_types(init_cond) + with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('update_map', 'map', init_cond=cql_init)) as va: + map_res = s.execute("SELECT %s(v) AS map_res FROM t" % va.function_kwargs['name'])[0].map_res + self.assertDictContainsSubset(expected_map_values, map_res) + init_not_updated = dict((k, init_cond[k]) for k in set(init_cond) - expected_key_set) + self.assertDictContainsSubset(init_not_updated, map_res) + c.shutdown() + + def test_aggregates_after_functions(self): + """ + Test to verify that aggregates are listed after function in metadata + + test_aggregates_after_functions creates an aggregate, and then verifies that they are listed + after any function creations when the keypspace dump is preformed + + + @since 2.6.0 + @jira_ticket PYTHON-211 + @expected_result aggregates are declared after any functions + @test_category aggregate + """ + + # functions must come before functions in keyspace dump + with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('extend_list', 'list')): + keyspace_cql = self.cluster.metadata.keyspaces[self.keyspace_name].export_as_string() + func_idx = keyspace_cql.find("CREATE FUNCTION") + aggregate_idx = keyspace_cql.rfind("CREATE AGGREGATE") + self.assertNotIn(-1, (aggregate_idx, func_idx), "AGGREGATE or FUNCTION not found in keyspace_cql: " + keyspace_cql) + self.assertGreater(aggregate_idx, func_idx) + + def test_same_name_diff_types(self): + """ + Test to verify to that aggregates with different signatures are differentiated in metadata + + test_same_name_diff_types Creates two Aggregates. One with the same name but a slightly different + signature. Then ensures that both are surfaced separately in our metadata. + + @since 2.6.0 + @jira_ticket PYTHON-211 + @expected_result aggregates with the same name but different signatures should be surfaced separately + @test_category function + """ + + kwargs = self.make_aggregate_kwargs('sum_int', 'int', init_cond='0') + with self.VerifiedAggregate(self, **kwargs): + kwargs['state_func'] = 'sum_int_two' + kwargs['argument_types'] = ['int', 'int'] + with self.VerifiedAggregate(self, **kwargs): + aggregates = [a for a in self.keyspace_aggregate_meta.values() if a.name == kwargs['name']] + self.assertEqual(len(aggregates), 2) + self.assertNotEqual(aggregates[0].argument_types, aggregates[1].argument_types) + + def test_aggregates_follow_keyspace_alter(self): + """ + Test to verify to that aggregates maintain equality after a keyspace is altered + + test_aggregates_follow_keyspace_alter creates a function then alters a the keyspace associated with that + function. After the alter we validate that the function maintains the same metadata + + + @since 2.6.0 + @jira_ticket PYTHON-211 + @expected_result aggregates are the same after parent keyspace is altered + @test_category function + """ + + with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond='0')): + original_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] + self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) + try: + new_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] + self.assertNotEqual(original_keyspace_meta, new_keyspace_meta) + self.assertIs(original_keyspace_meta.aggregates, new_keyspace_meta.aggregates) + finally: + self.session.execute('ALTER KEYSPACE %s WITH durable_writes = true' % self.keyspace_name) + + def test_cql_optional_params(self): + """ + Test to verify that the initial_cond and final_func parameters are correctly honored + + test_cql_optional_params creates various aggregates with different combinations of initial_condition, + and final_func parameters set. It then ensures they are correctly honored. + + + @since 2.6.0 + @jira_ticket PYTHON-211 + @expected_result initial_condition and final_func parameters are honored correctly + @test_category function + """ + + kwargs = self.make_aggregate_kwargs('extend_list', 'list') + encoder = Encoder() + + # no initial condition, final func + self.assertIsNone(kwargs['initial_condition']) + self.assertIsNone(kwargs['final_func']) + with self.VerifiedAggregate(self, **kwargs) as va: + meta = self.keyspace_aggregate_meta[va.signature] + self.assertIsNone(meta.initial_condition) + self.assertIsNone(meta.final_func) + cql = meta.as_cql_query() + self.assertEqual(cql.find('INITCOND'), -1) + self.assertEqual(cql.find('FINALFUNC'), -1) + + # initial condition, no final func + kwargs['initial_condition'] = encoder.cql_encode_all_types(['init', 'cond']) + with self.VerifiedAggregate(self, **kwargs) as va: + meta = self.keyspace_aggregate_meta[va.signature] + self.assertEqual(meta.initial_condition, kwargs['initial_condition']) + self.assertIsNone(meta.final_func) + cql = meta.as_cql_query() + search_string = "INITCOND %s" % kwargs['initial_condition'] + self.assertGreater(cql.find(search_string), 0, '"%s" search string not found in cql:\n%s' % (search_string, cql)) + self.assertEqual(cql.find('FINALFUNC'), -1) + + # no initial condition, final func + kwargs['initial_condition'] = None + kwargs['final_func'] = 'List_As_String' + with self.VerifiedAggregate(self, **kwargs) as va: + meta = self.keyspace_aggregate_meta[va.signature] + self.assertIsNone(meta.initial_condition) + self.assertEqual(meta.final_func, kwargs['final_func']) + cql = meta.as_cql_query() + self.assertEqual(cql.find('INITCOND'), -1) + search_string = 'FINALFUNC "%s"' % kwargs['final_func'] + self.assertGreater(cql.find(search_string), 0, '"%s" search string not found in cql:\n%s' % (search_string, cql)) + + # both + kwargs['initial_condition'] = encoder.cql_encode_all_types(['init', 'cond']) + kwargs['final_func'] = 'List_As_String' + with self.VerifiedAggregate(self, **kwargs) as va: + meta = self.keyspace_aggregate_meta[va.signature] + self.assertEqual(meta.initial_condition, kwargs['initial_condition']) + self.assertEqual(meta.final_func, kwargs['final_func']) + cql = meta.as_cql_query() + init_cond_idx = cql.find("INITCOND %s" % kwargs['initial_condition']) + final_func_idx = cql.find('FINALFUNC "%s"' % kwargs['final_func']) + self.assertNotIn(-1, (init_cond_idx, final_func_idx)) + self.assertGreater(init_cond_idx, final_func_idx) + + +class BadMetaTest(unittest.TestCase): + """ + Test behavior when metadata has unexpected form + Verify that new cluster/session can still connect, and the CQL output indicates the exception with a warning. + PYTHON-370 + """ + + class BadMetaException(Exception): + pass + + @property + def function_name(self): + return self._testMethodName.lower() + + @classmethod + def setup_class(cls): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.keyspace_name = cls.__name__.lower() + cls.session = cls.cluster.connect() + cls.session.execute("CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) + cls.session.set_keyspace(cls.keyspace_name) + connection = cls.cluster.control_connection._connection + cls.parser_class = get_schema_parser(connection, CASSANDRA_VERSION.base_version, timeout=20).__class__ + cls.cluster.control_connection.reconnect = Mock() + + @classmethod + def teardown_class(cls): + drop_keyspace_shutdown_cluster(cls.keyspace_name, cls.session, cls.cluster) + + def test_bad_keyspace(self): + with patch.object(self.parser_class, '_build_keyspace_metadata_internal', side_effect=self.BadMetaException): + self.cluster.refresh_keyspace_metadata(self.keyspace_name) + m = self.cluster.metadata.keyspaces[self.keyspace_name] + self.assertIs(m._exc_info[0], self.BadMetaException) + self.assertIn("/*\nWarning:", m.export_as_string()) + + def test_bad_table(self): + self.session.execute('CREATE TABLE %s (k int PRIMARY KEY, v int)' % self.function_name) + with patch.object(self.parser_class, '_build_column_metadata', side_effect=self.BadMetaException): + self.cluster.refresh_table_metadata(self.keyspace_name, self.function_name) + m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_name] + self.assertIs(m._exc_info[0], self.BadMetaException) + self.assertIn("/*\nWarning:", m.export_as_string()) + + def test_bad_index(self): + self.session.execute('CREATE TABLE %s (k int PRIMARY KEY, v int)' % self.function_name) + self.session.execute('CREATE INDEX ON %s(v)' % self.function_name) + with patch.object(self.parser_class, '_build_index_metadata', side_effect=self.BadMetaException): + self.cluster.refresh_table_metadata(self.keyspace_name, self.function_name) + m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_name] + self.assertIs(m._exc_info[0], self.BadMetaException) + self.assertIn("/*\nWarning:", m.export_as_string()) + + @greaterthancass20 + def test_bad_user_type(self): + self.session.execute('CREATE TYPE %s (i int, d double)' % self.function_name) + with patch.object(self.parser_class, '_build_user_type', side_effect=self.BadMetaException): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + m = self.cluster.metadata.keyspaces[self.keyspace_name] + self.assertIs(m._exc_info[0], self.BadMetaException) + self.assertIn("/*\nWarning:", m.export_as_string()) + + @greaterthancass21 + def test_bad_user_function(self): + self.session.execute("""CREATE FUNCTION IF NOT EXISTS %s (key int, val int) + RETURNS NULL ON NULL INPUT + RETURNS int + LANGUAGE javascript AS 'key + val';""" % self.function_name) + with patch.object(self.parser_class, '_build_function', side_effect=self.BadMetaException): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + m = self.cluster.metadata.keyspaces[self.keyspace_name] + self.assertIs(m._exc_info[0], self.BadMetaException) + self.assertIn("/*\nWarning:", m.export_as_string()) + + @greaterthancass21 + def test_bad_user_aggregate(self): + self.session.execute("""CREATE FUNCTION IF NOT EXISTS sum_int (key int, val int) + RETURNS NULL ON NULL INPUT + RETURNS int + LANGUAGE javascript AS 'key + val';""") + self.session.execute("""CREATE AGGREGATE %s(int) + SFUNC sum_int + STYPE int + INITCOND 0""" % self.function_name) + with patch.object(self.parser_class, '_build_aggregate', side_effect=self.BadMetaException): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + m = self.cluster.metadata.keyspaces[self.keyspace_name] + self.assertIs(m._exc_info[0], self.BadMetaException) + self.assertIn("/*\nWarning:", m.export_as_string()) + + +class DynamicCompositeTypeTest(BasicSharedKeyspaceUnitTestCase): + + def test_dct_alias(self): + """ + Tests to make sure DCT's have correct string formatting + + Constructs a DCT and check the format as generated. To insure it matches what is expected + + @since 3.6.0 + @jira_ticket PYTHON-579 + @expected_result DCT subtypes should always have fully qualified names + + @test_category metadata + """ + self.session.execute("CREATE TABLE {0}.{1} (" + "k int PRIMARY KEY," + "c1 'DynamicCompositeType(s => UTF8Type, i => Int32Type)'," + "c2 Text)".format(self.ks_name, self.function_table_name)) + dct_table = self.cluster.metadata.keyspaces.get(self.ks_name).tables.get(self.function_table_name) + + # Format can very slightly between versions, strip out whitespace for consistency sake + self.assertTrue("c1'org.apache.cassandra.db.marshal.DynamicCompositeType(" + "s=>org.apache.cassandra.db.marshal.UTF8Type," + "i=>org.apache.cassandra.db.marshal.Int32Type)'" + in dct_table.as_cql_query().replace(" ", "")) + + +@greaterthanorequalcass30 +class Materia3lizedViewMetadataTestSimple(BasicSharedKeyspaceUnitTestCase): + + def setUp(self): + self.session.execute("CREATE TABLE {0}.{1} (pk int PRIMARY KEY, c int)".format(self.keyspace_name, self.function_table_name)) + self.session.execute("CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT c FROM {0}.{1} WHERE c IS NOT NULL PRIMARY KEY (pk, c)".format(self.keyspace_name, self.function_table_name)) + + def tearDown(self): + self.session.execute("DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name)) + self.session.execute("DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name)) + + def test_materialized_view_metadata_creation(self): + """ + test for materialized view metadata creation + + test_materialized_view_metadata_creation tests that materialized view metadata properly created implicitly in + both keyspace and table metadata under "views". It creates a simple base table and then creates a view based + on that table. It then checks that the materialized view metadata is contained in the keyspace and table + metadata. Finally, it checks that the keyspace_name and the base_table_name in the view metadata is properly set. + + @since 3.0.0 + @jira_ticket PYTHON-371 + @expected_result Materialized view metadata in both the ks and table should be created with a new view is created. + + @test_category metadata + """ + + self.assertIn("mv1", self.cluster.metadata.keyspaces[self.keyspace_name].views) + self.assertIn("mv1", self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + + self.assertEqual(self.keyspace_name, self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].keyspace_name) + self.assertEqual(self.function_table_name, self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].base_table_name) + + def test_materialized_view_metadata_alter(self): + """ + test for materialized view metadata alteration + + test_materialized_view_metadata_alter tests that materialized view metadata is properly updated implicitly in the + table metadata once that view is updated. It creates a simple base table and then creates a view based + on that table. It then alters that materalized view and checks that the materialized view metadata is altered in + the table metadata. + + @since 3.0.0 + @jira_ticket PYTHON-371 + @expected_result Materialized view metadata should be updated with the view is altered. + + @test_category metadata + """ + self.assertIn("SizeTieredCompactionStrategy", self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"]) + + self.session.execute("ALTER MATERIALIZED VIEW {0}.mv1 WITH compaction = {{ 'class' : 'LeveledCompactionStrategy' }}".format(self.keyspace_name)) + self.assertIn("LeveledCompactionStrategy", self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"]) + + def test_materialized_view_metadata_drop(self): + """ + test for materialized view metadata dropping + + test_materialized_view_metadata_drop tests that materialized view metadata is properly removed implicitly in + both keyspace and table metadata once that view is dropped. It creates a simple base table and then creates a view + based on that table. It then drops that materalized view and checks that the materialized view metadata is removed + from the keyspace and table metadata. + + @since 3.0.0 + @jira_ticket PYTHON-371 + @expected_result Materialized view metadata in both the ks and table should be removed with the view is dropped. + + @test_category metadata + """ + + self.session.execute("DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name)) + + self.assertNotIn("mv1", self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + self.assertNotIn("mv1", self.cluster.metadata.keyspaces[self.keyspace_name].views) + self.assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + self.assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].views) + + self.session.execute("CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT c FROM {0}.{1} WHERE c IS NOT NULL PRIMARY KEY (pk, c)".format(self.keyspace_name, self.function_table_name)) + + +@greaterthanorequalcass30 +class MaterializedViewMetadataTestComplex(BasicSegregatedKeyspaceUnitTestCase): + def test_create_view_metadata(self): + """ + test to ensure that materialized view metadata is properly constructed + + test_create_view_metadata tests that materialized views metadata is properly constructed. It runs a simple + query to construct a materialized view, then proceeds to inspect the metadata associated with that MV. + Columns are inspected to insure that all are of the proper type, and in the proper type. + + @since 3.0.0 + @jira_ticket PYTHON-371 + @expected_result Materialized view metadata should be constructed appropriately. + + @test_category metadata + """ + create_table = """CREATE TABLE {0}.scores( + user TEXT, + game TEXT, + year INT, + month INT, + day INT, + score INT, + PRIMARY KEY (user, game, year, month, day) + )""".format(self.keyspace_name) + + self.session.execute(create_table) + + create_mv = """CREATE MATERIALIZED VIEW {0}.monthlyhigh AS + SELECT game, year, month, score, user, day FROM {0}.scores + WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL + PRIMARY KEY ((game, year, month), score, user, day) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) + + self.session.execute(create_mv) + score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] + mv = self.cluster.metadata.keyspaces[self.keyspace_name].views['monthlyhigh'] + + self.assertIsNotNone(score_table.views["monthlyhigh"]) + self.assertIsNotNone(len(score_table.views), 1) + + # Make sure user is a partition key, and not null + self.assertEqual(len(score_table.partition_key), 1) + self.assertIsNotNone(score_table.columns['user']) + self.assertTrue(score_table.columns['user'], score_table.partition_key[0]) + + # Validate clustering keys + self.assertEqual(len(score_table.clustering_key), 4) + + self.assertIsNotNone(score_table.columns['game']) + self.assertTrue(score_table.columns['game'], score_table.clustering_key[0]) + + self.assertIsNotNone(score_table.columns['year']) + self.assertTrue(score_table.columns['year'], score_table.clustering_key[1]) + + self.assertIsNotNone(score_table.columns['month']) + self.assertTrue(score_table.columns['month'], score_table.clustering_key[2]) + + self.assertIsNotNone(score_table.columns['day']) + self.assertTrue(score_table.columns['day'], score_table.clustering_key[3]) + + self.assertIsNotNone(score_table.columns['score']) + + # Validate basic mv information + self.assertEqual(mv.keyspace_name, self.keyspace_name) + self.assertEqual(mv.name, "monthlyhigh") + self.assertEqual(mv.base_table_name, "scores") + self.assertFalse(mv.include_all_columns) + + # Validate that all columns are preset and correct + mv_columns = list(mv.columns.values()) + self.assertEqual(len(mv_columns), 6) + + game_column = mv_columns[0] + self.assertIsNotNone(game_column) + self.assertEqual(game_column.name, 'game') + self.assertEqual(game_column, mv.partition_key[0]) + + year_column = mv_columns[1] + self.assertIsNotNone(year_column) + self.assertEqual(year_column.name, 'year') + self.assertEqual(year_column, mv.partition_key[1]) + + month_column = mv_columns[2] + self.assertIsNotNone(month_column) + self.assertEqual(month_column.name, 'month') + self.assertEqual(month_column, mv.partition_key[2]) + + def compare_columns(a, b, name): + self.assertEqual(a.name, name) + self.assertEqual(a.name, b.name) + self.assertEqual(a.table, b.table) + self.assertEqual(a.cql_type, b.cql_type) + self.assertEqual(a.is_static, b.is_static) + self.assertEqual(a.is_reversed, b.is_reversed) + + score_column = mv_columns[3] + compare_columns(score_column, mv.clustering_key[0], 'score') + + user_column = mv_columns[4] + compare_columns(user_column, mv.clustering_key[1], 'user') + + day_column = mv_columns[5] + compare_columns(day_column, mv.clustering_key[2], 'day') + + def test_base_table_column_addition_mv(self): + """ + test to ensure that materialized view metadata is properly updated with base columns are added + + test_create_view_metadata tests that materialized views metadata is properly updated when columns are added to + the base table. + + @since 3.0.0 + @jira_ticket PYTHON-419 + @expected_result Materialized view metadata should be updated correctly + + @test_category metadata + """ + create_table = """CREATE TABLE {0}.scores( + user TEXT, + game TEXT, + year INT, + month INT, + day INT, + score TEXT, + PRIMARY KEY (user, game, year, month, day) + )""".format(self.keyspace_name) + + self.session.execute(create_table) + + create_mv = """CREATE MATERIALIZED VIEW {0}.monthlyhigh AS + SELECT game, year, month, score, user, day FROM {0}.scores + WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL + PRIMARY KEY ((game, year, month), score, user, day) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) + + create_mv_alltime = """CREATE MATERIALIZED VIEW {0}.alltimehigh AS + SELECT * FROM {0}.scores + WHERE game IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL + PRIMARY KEY (game, score, user, year, month, day) + WITH CLUSTERING ORDER BY (score DESC)""".format(self.keyspace_name) + + self.session.execute(create_mv) + + self.session.execute(create_mv_alltime) + + score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] + + self.assertIsNotNone(score_table.views["monthlyhigh"]) + self.assertIsNotNone(score_table.views["alltimehigh"]) + self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 2) + + insert_fouls = """ALTER TABLE {0}.scores ADD fouls INT""".format((self.keyspace_name)) + + self.session.execute(insert_fouls) + self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 2) + + score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] + self.assertIn("fouls", score_table.columns) + + # This is a workaround for mv notifications being separate from base table schema responses. + # This maybe fixed with future protocol changes + for i in range(10): + mv_alltime = self.cluster.metadata.keyspaces[self.keyspace_name].views["alltimehigh"] + if("fouls" in mv_alltime.columns): + break + time.sleep(.2) + + self.assertIn("fouls", mv_alltime.columns) + + mv_alltime_fouls_comumn = self.cluster.metadata.keyspaces[self.keyspace_name].views["alltimehigh"].columns['fouls'] + self.assertEqual(mv_alltime_fouls_comumn.cql_type, 'int') + + @lessthancass30 + def test_base_table_type_alter_mv(self): + """ + test to ensure that materialized view metadata is properly updated when a type in the base table + is updated. + + test_create_view_metadata tests that materialized views metadata is properly updated when the type of base table + column is changed. + + @since 3.0.0 + @jira_ticket CASSANDRA-10424 + @expected_result Materialized view metadata should be updated correctly + + @test_category metadata + """ + create_table = """CREATE TABLE {0}.scores( + user TEXT, + game TEXT, + year INT, + month INT, + day INT, + score TEXT, + PRIMARY KEY (user, game, year, month, day) + )""".format(self.keyspace_name) + + self.session.execute(create_table) + + create_mv = """CREATE MATERIALIZED VIEW {0}.monthlyhigh AS + SELECT game, year, month, score, user, day FROM {0}.scores + WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL + PRIMARY KEY ((game, year, month), score, user, day) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) + + self.session.execute(create_mv) + self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 1) + alter_scores = """ALTER TABLE {0}.scores ALTER score TYPE blob""".format((self.keyspace_name)) + self.session.execute(alter_scores) + self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 1) + + score_column = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'].columns['score'] + self.assertEqual(score_column.cql_type, 'blob') + + # until CASSANDRA-9920+CASSANDRA-10500 MV updates are only available later with an async event + for i in range(10): + score_mv_column = self.cluster.metadata.keyspaces[self.keyspace_name].views["monthlyhigh"].columns['score'] + if "blob" == score_mv_column.cql_type: + break + time.sleep(.2) + + self.assertEqual(score_mv_column.cql_type, 'blob') + + def test_metadata_with_quoted_identifiers(self): + """ + test to ensure that materialized view metadata is properly constructed when quoted identifiers are used + + test_metadata_with_quoted_identifiers tests that materialized views metadata is properly constructed. + It runs a simple query to construct a materialized view, then proceeds to inspect the metadata associated with + that MV. The caveat here is that the tables and the materialized view both have quoted identifiers + Columns are inspected to insure that all are of the proper type, and in the proper type. + + @since 3.0.0 + @jira_ticket PYTHON-371 + @expected_result Materialized view metadata should be constructed appropriately even with quoted identifiers. + + @test_category metadata + """ + create_table = """CREATE TABLE {0}.t1 ( + "theKey" int, + "the;Clustering" int, + "the Value" int, + PRIMARY KEY ("theKey", "the;Clustering"))""".format(self.keyspace_name) + + self.session.execute(create_table) + + create_mv = """CREATE MATERIALIZED VIEW {0}.mv1 AS + SELECT "theKey", "the;Clustering", "the Value" + FROM {0}.t1 + WHERE "theKey" IS NOT NULL AND "the;Clustering" IS NOT NULL AND "the Value" IS NOT NULL + PRIMARY KEY ("theKey", "the;Clustering")""".format(self.keyspace_name) + + self.session.execute(create_mv) + + t1_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['t1'] + mv = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] + + self.assertIsNotNone(t1_table.views["mv1"]) + self.assertIsNotNone(len(t1_table.views), 1) + + # Validate partition key, and not null + self.assertEqual(len(t1_table.partition_key), 1) + self.assertIsNotNone(t1_table.columns['theKey']) + self.assertTrue(t1_table.columns['theKey'], t1_table.partition_key[0]) + + # Validate clustering key column + self.assertEqual(len(t1_table.clustering_key), 1) + self.assertIsNotNone(t1_table.columns['the;Clustering']) + self.assertTrue(t1_table.columns['the;Clustering'], t1_table.clustering_key[0]) + + # Validate regular column + self.assertIsNotNone(t1_table.columns['the Value']) + + # Validate basic mv information + self.assertEqual(mv.keyspace_name, self.keyspace_name) + self.assertEqual(mv.name, "mv1") + self.assertEqual(mv.base_table_name, "t1") + self.assertFalse(mv.include_all_columns) + + # Validate that all columns are preset and correct + mv_columns = list(mv.columns.values()) + self.assertEqual(len(mv_columns), 3) + + theKey_column = mv_columns[0] + self.assertIsNotNone(theKey_column) + self.assertEqual(theKey_column.name, 'theKey') + self.assertEqual(theKey_column, mv.partition_key[0]) + + cluster_column = mv_columns[1] + self.assertIsNotNone(cluster_column) + self.assertEqual(cluster_column.name, 'the;Clustering') + self.assertEqual(cluster_column.name, mv.clustering_key[0].name) + self.assertEqual(cluster_column.table, mv.clustering_key[0].table) + self.assertEqual(cluster_column.is_static, mv.clustering_key[0].is_static) + self.assertEqual(cluster_column.is_reversed, mv.clustering_key[0].is_reversed) + + value_column = mv_columns[2] + self.assertIsNotNone(value_column) + self.assertEqual(value_column.name, 'the Value') + + +class GroupPerHost(BasicSharedKeyspaceUnitTestCase): + @classmethod + def setUpClass(cls): + cls.common_setup(rf=1, create_class_table=True) + cls.table_two_pk = "table_with_two_pk" + cls.session.execute( + ''' + CREATE TABLE {0}.{1} ( + k_one int, + k_two int, + v int, + PRIMARY KEY ((k_one, k_two)) + )'''.format(cls.ks_name, cls.table_two_pk) + ) + + def test_group_keys_by_host(self): + """ + Test to ensure group_keys_by_host functions as expected. It is tried + with a table with a single field for the partition key and a table + with two fields for the partition key + @since 3.13 + @jira_ticket PYTHON-647 + @expected_result group_keys_by_host return the expected value + + @test_category metadata + """ + stmt = """SELECT * FROM {}.{} + WHERE k_one = ? AND k_two = ? """.format(self.ks_name, self.table_two_pk) + keys = ((1, 2), (2, 2), (2, 3), (3, 4)) + self._assert_group_keys_by_host(keys, self.table_two_pk, stmt) + + stmt = """SELECT * FROM {}.{} + WHERE k = ? """.format(self.ks_name, self.ks_name) + keys = ((1,), (2,), (2,), (3,)) + self._assert_group_keys_by_host(keys, self.ks_name, stmt) + + def _assert_group_keys_by_host(self, keys, table_name, stmt): + keys_per_host = group_keys_by_replica(self.session, self.ks_name, table_name, keys) + self.assertNotIn(NO_VALID_REPLICA, keys_per_host) + + prepared_stmt = self.session.prepare(stmt) + for key in keys: + routing_key = prepared_stmt.bind(key).routing_key + hosts = self.cluster.metadata.get_replicas(self.ks_name, routing_key) + self.assertEqual(1, len(hosts)) # RF is 1 for this keyspace + self.assertIn(key, keys_per_host[hosts[0]]) + + +class VirtualKeypaceTest(BasicSharedKeyspaceUnitTestCase): + virtual_ks_names = ('system_virtual_schema', 'system_views') + + virtual_ks_structure = { + 'system_views': { + # map from table names to sets of column names for unordered + # comparison + 'caches': {'capacity_bytes', 'entry_count', 'hit_count', + 'hit_ratio', 'name', 'recent_hit_rate_per_second', + 'recent_request_rate_per_second', 'request_count', + 'size_bytes'}, + 'clients': {'address', 'connection_stage', 'driver_name', + 'driver_version', 'hostname', 'port', + 'protocol_version', 'request_count', + 'ssl_cipher_suite', 'ssl_enabled', 'ssl_protocol', + 'username'}, + 'sstable_tasks': {'keyspace_name', 'kind', 'progress', + 'table_name', 'task_id', 'total', 'unit'}, + 'thread_pools': {'active_tasks', 'active_tasks_limit', + 'blocked_tasks', 'blocked_tasks_all_time', + 'completed_tasks', 'name', 'pending_tasks'} + }, + 'system_virtual_schema': { + 'columns': {'clustering_order', 'column_name', + 'column_name_bytes', 'keyspace_name', 'kind', + 'position', 'table_name', 'type'}, + 'keyspaces': {'keyspace_name'}, + 'tables': {'comment', 'keyspace_name', 'table_name'} + } + } + + def test_existing_keyspaces_have_correct_virtual_tags(self): + for name, ks in self.cluster.metadata.keyspaces.items(): + if name in self.virtual_ks_names: + self.assertTrue( + ks.virtual, + 'incorrect .virtual value for {}'.format(name) + ) + else: + self.assertFalse( + ks.virtual, + 'incorrect .virtual value for {}'.format(name) + ) + + @greaterthanorequalcass40 + def test_expected_keyspaces_exist_and_are_virtual(self): + for name in self.virtual_ks_names: + self.assertTrue( + self.cluster.metadata.keyspaces[name].virtual, + 'incorrect .virtual value for {}'.format(name) + ) + + @greaterthanorequalcass40 + def test_virtual_keyspaces_have_expected_schema_structure(self): + self.maxDiff = None + + ingested_virtual_ks_structure = defaultdict(dict) + for ks_name, ks in self.cluster.metadata.keyspaces.items(): + if not ks.virtual: + continue + for tab_name, tab in ks.tables.items(): + ingested_virtual_ks_structure[ks_name][tab_name] = set( + tab.columns.keys() + ) + + self.assertDictEqual(ingested_virtual_ks_structure, + self.virtual_ks_structure) diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py new file mode 100644 index 0000000..d40a66f --- /dev/null +++ b/tests/integration/standard/test_metrics.py @@ -0,0 +1,386 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from cassandra.connection import ConnectionShutdown +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy, FallthroughRetryPolicy + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.query import SimpleStatement +from cassandra import ConsistencyLevel, WriteTimeout, Unavailable, ReadTimeout +from cassandra.protocol import SyntaxException + +from cassandra.cluster import Cluster, NoHostAvailable +from tests.integration import get_cluster, get_node, use_singledc, PROTOCOL_VERSION, execute_until_pass +from greplin import scales +from tests.integration import BasicSharedKeyspaceUnitTestCaseRF3WM, BasicExistingKeyspaceUnitTestCase, local + +def setup_module(): + use_singledc() + +@local +class MetricsTests(unittest.TestCase): + + def setUp(self): + contact_point = ['127.0.0.2'] + self.cluster = Cluster(contact_points=contact_point, metrics_enabled=True, protocol_version=PROTOCOL_VERSION, + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address in contact_point + ), + default_retry_policy=FallthroughRetryPolicy()) + self.session = self.cluster.connect("test3rf", wait_for_all_pools=True) + + def tearDown(self): + self.cluster.shutdown() + + def test_connection_error(self): + """ + Trigger and ensure connection_errors are counted + Stop all node with the driver knowing about the "DOWN" states. + """ + # Test writes + for i in range(0, 100): + self.session.execute_async("INSERT INTO test (k, v) VALUES ({0}, {1})".format(i, i)) + + # Stop the cluster + get_cluster().stop(wait=True, gently=False) + + try: + # Ensure the nodes are actually down + query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) + # both exceptions can happen depending on when the connection has been detected as defunct + with self.assertRaises((NoHostAvailable, ConnectionShutdown)): + self.session.execute(query) + finally: + get_cluster().start(wait_for_binary_proto=True, wait_other_notice=True) + # Give some time for the cluster to come back up, for the next test + time.sleep(5) + + self.assertGreater(self.cluster.metrics.stats.connection_errors, 0) + + def test_write_timeout(self): + """ + Trigger and ensure write_timeouts are counted + Write a key, value pair. Pause a node without the coordinator node knowing about the "DOWN" state. + Attempt a write at cl.ALL and receive a WriteTimeout. + """ + + # Test write + self.session.execute("INSERT INTO test (k, v) VALUES (1, 1)") + + # Assert read + query = SimpleStatement("SELECT * FROM test WHERE k=1", consistency_level=ConsistencyLevel.ALL) + results = execute_until_pass(self.session, query) + self.assertTrue(results) + + # Pause node so it shows as unreachable to coordinator + get_node(1).pause() + + try: + # Test write + query = SimpleStatement("INSERT INTO test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) + with self.assertRaises(WriteTimeout): + self.session.execute(query, timeout=None) + self.assertEqual(1, self.cluster.metrics.stats.write_timeouts) + + finally: + get_node(1).resume() + + def test_read_timeout(self): + """ + Trigger and ensure read_timeouts are counted + Write a key, value pair. Pause a node without the coordinator node knowing about the "DOWN" state. + Attempt a read at cl.ALL and receive a ReadTimeout. + """ + + + # Test write + self.session.execute("INSERT INTO test (k, v) VALUES (1, 1)") + + # Assert read + query = SimpleStatement("SELECT * FROM test WHERE k=1", consistency_level=ConsistencyLevel.ALL) + results = execute_until_pass(self.session, query) + self.assertTrue(results) + + # Pause node so it shows as unreachable to coordinator + get_node(1).pause() + + try: + # Test read + query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) + with self.assertRaises(ReadTimeout): + self.session.execute(query, timeout=None) + self.assertEqual(1, self.cluster.metrics.stats.read_timeouts) + + finally: + get_node(1).resume() + + def test_unavailable(self): + """ + Trigger and ensure unavailables are counted + Write a key, value pair. Stop a node with the coordinator node knowing about the "DOWN" state. + Attempt an insert/read at cl.ALL and receive a Unavailable Exception. + """ + + # Test write + self.session.execute("INSERT INTO test (k, v) VALUES (1, 1)") + + # Assert read + query = SimpleStatement("SELECT * FROM test WHERE k=1", consistency_level=ConsistencyLevel.ALL) + results = execute_until_pass(self.session, query) + self.assertTrue(results) + + # Stop node gracefully + # Sometimes this commands continues with the other nodes having not noticed + # 1 is down, and a Timeout error is returned instead of an Unavailable + get_node(1).stop(wait=True, wait_other_notice=True) + time.sleep(5) + try: + # Test write + query = SimpleStatement("INSERT INTO test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) + with self.assertRaises(Unavailable): + self.session.execute(query) + self.assertEqual(self.cluster.metrics.stats.unavailables, 1) + + # Test write + query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) + with self.assertRaises(Unavailable): + self.session.execute(query, timeout=None) + self.assertEqual(self.cluster.metrics.stats.unavailables, 2) + finally: + get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) + # Give some time for the cluster to come back up, for the next test + time.sleep(5) + + self.cluster.shutdown() + + # def test_other_error(self): + # # TODO: Bootstrapping or Overloaded cases + # pass + # + # def test_ignore(self): + # # TODO: Look for ways to generate ignores + # pass + # + # def test_retry(self): + # # TODO: Look for ways to generate retries + # pass + + +class MetricsNamespaceTest(BasicSharedKeyspaceUnitTestCaseRF3WM): + @local + def test_metrics_per_cluster(self): + """ + Test to validate that metrics can be scopped to invdividual clusters + @since 3.6.0 + @jira_ticket PYTHON-561 + @expected_result metrics should be scopped to a cluster level + + @test_category metrics + """ + + cluster2 = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION, + default_retry_policy=FallthroughRetryPolicy()) + cluster2.connect(self.ks_name, wait_for_all_pools=True) + + self.assertEqual(len(cluster2.metadata.all_hosts()), 3) + + query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + self.session.execute(query) + + # Pause node so it shows as unreachable to coordinator + get_node(1).pause() + + try: + # Test write + query = SimpleStatement("INSERT INTO {0}.{0} (k, v) VALUES (2, 2)".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + with self.assertRaises(WriteTimeout): + self.session.execute(query, timeout=None) + finally: + get_node(1).resume() + + # Change the scales stats_name of the cluster2 + cluster2.metrics.set_stats_name('cluster2-metrics') + + stats_cluster1 = self.cluster.metrics.get_stats() + stats_cluster2 = cluster2.metrics.get_stats() + + # Test direct access to stats + self.assertEqual(1, self.cluster.metrics.stats.write_timeouts) + self.assertEqual(0, cluster2.metrics.stats.write_timeouts) + + # Test direct access to a child stats + self.assertNotEqual(0.0, self.cluster.metrics.request_timer['mean']) + self.assertEqual(0.0, cluster2.metrics.request_timer['mean']) + + # Test access via metrics.get_stats() + self.assertNotEqual(0.0, stats_cluster1['request_timer']['mean']) + self.assertEqual(0.0, stats_cluster2['request_timer']['mean']) + + # Test access by stats_name + self.assertEqual(0.0, scales.getStats()['cluster2-metrics']['request_timer']['mean']) + + cluster2.shutdown() + + def test_duplicate_metrics_per_cluster(self): + """ + Test to validate that cluster metrics names can't overlap. + @since 3.6.0 + @jira_ticket PYTHON-561 + @expected_result metric names should not be allowed to be same. + + @test_category metrics + """ + cluster2 = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION, + default_retry_policy=FallthroughRetryPolicy()) + + cluster3 = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION, + default_retry_policy=FallthroughRetryPolicy()) + + # Ensure duplicate metric names are not allowed + cluster2.metrics.set_stats_name("appcluster") + cluster2.metrics.set_stats_name("appcluster") + with self.assertRaises(ValueError): + cluster3.metrics.set_stats_name("appcluster") + cluster3.metrics.set_stats_name("devops") + + session2 = cluster2.connect(self.ks_name, wait_for_all_pools=True) + session3 = cluster3.connect(self.ks_name, wait_for_all_pools=True) + + # Basic validation that naming metrics doesn't impact their segration or accuracy + for i in range(10): + query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + session2.execute(query) + + for i in range(5): + query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + session3.execute(query) + + self.assertEqual(cluster2.metrics.get_stats()['request_timer']['count'], 10) + self.assertEqual(cluster3.metrics.get_stats()['request_timer']['count'], 5) + + # Check scales to ensure they are appropriately named + self.assertTrue("appcluster" in scales._Stats.stats.keys()) + self.assertTrue("devops" in scales._Stats.stats.keys()) + + cluster2.shutdown() + cluster3.shutdown() + + +class RequestAnalyzer(object): + """ + Class used to track request and error counts for a Session. + Also computes statistics on encoded request size. + """ + + requests = scales.PmfStat('request size') + errors = scales.IntStat('errors') + successful = scales.IntStat("success") + # Throw exceptions when invoked. + throw_on_success = False + throw_on_fail = False + + def __init__(self, session, throw_on_success=False, throw_on_fail=False): + scales.init(self, '/request') + # each instance will be registered with a session, and receive a callback for each request generated + session.add_request_init_listener(self.on_request) + self.throw_on_fail = throw_on_fail + self.throw_on_success = throw_on_success + + def on_request(self, rf): + # This callback is invoked each time a request is created, on the thread creating the request. + # We can use this to count events, or add callbacks + rf.add_callbacks(self.on_success, self.on_error, callback_args=(rf,), errback_args=(rf,)) + + def on_success(self, _, response_future): + # future callback on a successful request; just record the size + self.requests.addValue(response_future.request_encoded_size) + self.successful += 1 + if self.throw_on_success: + raise AttributeError + + def on_error(self, _, response_future): + # future callback for failed; record size and increment errors + self.requests.addValue(response_future.request_encoded_size) + self.errors += 1 + if self.throw_on_fail: + raise AttributeError + + def remove_ra(self, session): + session.remove_request_init_listener(self.on_request) + + def __str__(self): + # just extracting request count from the size stats (which are recorded on all requests) + request_sizes = dict(self.requests) + count = request_sizes.pop('count') + return "%d requests (%d errors)\nRequest size statistics:\n%s" % (count, self.errors, pp.pformat(request_sizes)) + + +class MetricsRequestSize(BasicExistingKeyspaceUnitTestCase): + + def wait_for_count(self, ra, expected_count, error=False): + for _ in range(10): + if not error: + if ra.successful is expected_count: + return True + else: + if ra.errors is expected_count: + return True + time.sleep(.01) + return False + + def test_metrics_per_cluster(self): + """ + Test to validate that requests listeners. + + This test creates a simple metrics based request listener to track request size, it then + check to ensure that on_success and on_error methods are invoked appropriately. + @since 3.7.0 + @jira_ticket PYTHON-284 + @expected_result in_error, and on_success should be invoked apropriately + + @test_category metrics + """ + + ra = RequestAnalyzer(self.session) + for _ in range(10): + self.session.execute("SELECT release_version FROM system.local") + + for _ in range(3): + try: + self.session.execute("nonesense") + except SyntaxException: + continue + + self.assertTrue(self.wait_for_count(ra, 10)) + self.assertTrue(self.wait_for_count(ra, 3, error=True)) + + ra.remove_ra(self.session) + + # Make sure a poorly coded RA doesn't cause issues + ra = RequestAnalyzer(self.session, throw_on_success=False, throw_on_fail=True) + self.session.execute("SELECT release_version FROM system.local") + + ra.remove_ra(self.session) + + RequestAnalyzer(self.session, throw_on_success=True) + try: + self.session.execute("nonesense") + except SyntaxException: + pass diff --git a/tests/integration/standard/test_policies.py b/tests/integration/standard/test_policies.py new file mode 100644 index 0000000..5a985e5 --- /dev/null +++ b/tests/integration/standard/test_policies.py @@ -0,0 +1,92 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.cluster import Cluster, ExecutionProfile, ResponseFuture +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy, SimpleConvictionPolicy, \ + WhiteListRoundRobinPolicy +from cassandra.pool import Host +from cassandra.connection import DefaultEndPoint + +from tests.integration import PROTOCOL_VERSION, local, use_singledc + +from concurrent.futures import wait as wait_futures + +def setup_module(): + use_singledc() + + +class HostFilterPolicyTests(unittest.TestCase): + + def test_predicate_changes(self): + """ + Test to validate host filter reacts correctly when the predicate return + a different subset of the hosts + HostFilterPolicy + @since 3.8 + @jira_ticket PYTHON-961 + @expected_result the excluded hosts are ignored + + @test_category policy + """ + external_event = True + contact_point = DefaultEndPoint("127.0.0.1") + + single_host = {Host(contact_point, SimpleConvictionPolicy)} + all_hosts = {Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in (1, 2, 3)} + + predicate = lambda host: host.endpoint == contact_point if external_event else True + cluster = Cluster((contact_point,), load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), + predicate=predicate), + protocol_version=PROTOCOL_VERSION, topology_event_refresh_window=0, + status_event_refresh_window=0) + session = cluster.connect(wait_for_all_pools=True) + + queried_hosts = set() + for _ in range(10): + response = session.execute("SELECT * from system.local") + queried_hosts.update(response.response_future.attempted_hosts) + + self.assertEqual(queried_hosts, single_host) + + external_event = False + futures = session.update_created_pools() + wait_futures(futures, timeout=cluster.connect_timeout) + + queried_hosts = set() + for _ in range(10): + response = session.execute("SELECT * from system.local") + queried_hosts.update(response.response_future.attempted_hosts) + self.assertEqual(queried_hosts, all_hosts) + + +class WhiteListRoundRobinPolicyTests(unittest.TestCase): + + @local + def test_only_connects_to_subset(self): + only_connect_hosts = {"127.0.0.1", "127.0.0.2"} + white_list = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(only_connect_hosts)) + cluster = Cluster(execution_profiles={"white_list": white_list}) + #cluster = Cluster(load_balancing_policy=WhiteListRoundRobinPolicy(only_connect_hosts)) + session = cluster.connect(wait_for_all_pools=True) + queried_hosts = set() + for _ in range(10): + response = session.execute('SELECT * from system.local', execution_profile="white_list") + queried_hosts.update(response.response_future.attempted_hosts) + queried_hosts = set(host.address for host in queried_hosts) + self.assertEqual(queried_hosts, only_connect_hosts) diff --git a/tests/integration/standard/test_prepared_statements.py b/tests/integration/standard/test_prepared_statements.py new file mode 100644 index 0000000..76073d7 --- /dev/null +++ b/tests/integration/standard/test_prepared_statements.py @@ -0,0 +1,587 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests.integration import use_singledc, PROTOCOL_VERSION + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa +from cassandra import InvalidRequest + +from cassandra import ConsistencyLevel, ProtocolVersion +from cassandra.cluster import Cluster +from cassandra.query import PreparedStatement, UNSET_VALUE, tuple_factory +from tests.integration import (get_server_versions, greaterthanorequalcass40, + set_default_beta_flag_true, + BasicSharedKeyspaceUnitTestCase) + +import logging + + +LOG = logging.getLogger(__name__) + +def setup_module(): + use_singledc() + + +class PreparedStatementTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.cass_version = get_server_versions() + + def setUp(self): + self.cluster = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION, + allow_beta_protocol_version=True) + self.session = self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + def test_basic(self): + """ + Test basic PreparedStatement usage + """ + self.session.execute( + """ + DROP KEYSPACE IF EXISTS preparedtests + """ + ) + self.session.execute( + """ + CREATE KEYSPACE preparedtests + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + """) + + self.session.set_keyspace("preparedtests") + self.session.execute( + """ + CREATE TABLE cf0 ( + a text, + b text, + c text, + PRIMARY KEY (a, b) + ) + """) + + prepared = self.session.prepare( + """ + INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind(('a', 'b', 'c')) + + self.session.execute(bound) + + prepared = self.session.prepare( + """ + SELECT * FROM cf0 WHERE a=? + """) + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind(('a')) + results = self.session.execute(bound) + self.assertEqual(results, [('a', 'b', 'c')]) + + # test with new dict binding + prepared = self.session.prepare( + """ + INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind({ + 'a': 'x', + 'b': 'y', + 'c': 'z' + }) + + self.session.execute(bound) + + prepared = self.session.prepare( + """ + SELECT * FROM cf0 WHERE a=? + """) + + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind({'a': 'x'}) + results = self.session.execute(bound) + self.assertEqual(results, [('x', 'y', 'z')]) + + def test_missing_primary_key(self): + """ + Ensure an InvalidRequest is thrown + when prepared statements are missing the primary key + """ + + self._run_missing_primary_key(self.session) + + def _run_missing_primary_key(self, session): + statement_to_prepare = """INSERT INTO test3rf.test (v) VALUES (?)""" + # logic needed work with changes in CASSANDRA-6237 + if self.cass_version[0] >= (3, 0, 0): + self.assertRaises(InvalidRequest, session.prepare, statement_to_prepare) + else: + prepared = session.prepare(statement_to_prepare) + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind((1,)) + self.assertRaises(InvalidRequest, session.execute, bound) + + def test_missing_primary_key_dicts(self): + """ + Ensure an InvalidRequest is thrown + when prepared statements are missing the primary key + with dict bindings + """ + self._run_missing_primary_key_dicts(self.session) + + def _run_missing_primary_key_dicts(self, session): + statement_to_prepare = """ INSERT INTO test3rf.test (v) VALUES (?)""" + # logic needed work with changes in CASSANDRA-6237 + if self.cass_version[0] >= (3, 0, 0): + self.assertRaises(InvalidRequest, session.prepare, statement_to_prepare) + else: + prepared = session.prepare(statement_to_prepare) + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind({'v': 1}) + self.assertRaises(InvalidRequest, session.execute, bound) + + def test_too_many_bind_values(self): + """ + Ensure a ValueError is thrown when attempting to bind too many variables + """ + self._run_too_many_bind_values(self.session) + + def _run_too_many_bind_values(self, session): + statement_to_prepare = """ INSERT INTO test3rf.test (v) VALUES (?)""" + # logic needed work with changes in CASSANDRA-6237 + if self.cass_version[0] >= (3, 0, 0): + self.assertRaises(InvalidRequest, session.prepare, statement_to_prepare) + else: + prepared = session.prepare(statement_to_prepare) + self.assertIsInstance(prepared, PreparedStatement) + self.assertRaises(ValueError, prepared.bind, (1, 2)) + + def test_imprecise_bind_values_dicts(self): + """ + Ensure an error is thrown when attempting to bind the wrong values + with dict bindings + """ + + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + + # too many values is ok - others are ignored + prepared.bind({'k': 1, 'v': 2, 'v2': 3}) + + # right number, but one does not belong + if PROTOCOL_VERSION < 4: + # pre v4, the driver bails with key error when 'v' is found missing + self.assertRaises(KeyError, prepared.bind, {'k': 1, 'v2': 3}) + else: + # post v4, the driver uses UNSET_VALUE for 'v' and 'v2' is ignored + prepared.bind({'k': 1, 'v2': 3}) + + # also catch too few variables with dicts + self.assertIsInstance(prepared, PreparedStatement) + if PROTOCOL_VERSION < 4: + self.assertRaises(KeyError, prepared.bind, {}) + else: + # post v4, the driver attempts to use UNSET_VALUE for unspecified keys + self.assertRaises(ValueError, prepared.bind, {}) + + def test_none_values(self): + """ + Ensure binding None is handled correctly + """ + + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind((1, None)) + self.session.execute(bound) + + prepared = self.session.prepare( + """ + SELECT * FROM test3rf.test WHERE k=? + """) + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind((1,)) + results = self.session.execute(bound) + self.assertEqual(results[0].v, None) + + def test_unset_values(self): + """ + Test to validate that UNSET_VALUEs are bound, and have the expected effect + + Prepare a statement and insert all values. Then follow with execute excluding + parameters. Verify that the original values are unaffected. + + @since 2.6.0 + + @jira_ticket PYTHON-317 + @expected_result UNSET_VALUE is implicitly added to bind parameters, and properly encoded, leving unset values unaffected. + + @test_category prepared_statements:binding + """ + if PROTOCOL_VERSION < 4: + raise unittest.SkipTest("Binding UNSET values is not supported in protocol version < 4") + + # table with at least two values so one can be used as a marker + self.session.execute("CREATE TABLE IF NOT EXISTS test1rf.test_unset_values (k int PRIMARY KEY, v0 int, v1 int)") + insert = self.session.prepare("INSERT INTO test1rf.test_unset_values (k, v0, v1) VALUES (?, ?, ?)") + select = self.session.prepare("SELECT * FROM test1rf.test_unset_values WHERE k=?") + + bind_expected = [ + # initial condition + ((0, 0, 0), (0, 0, 0)), + # unset implicit + ((0, 1,), (0, 1, 0)), + ({'k': 0, 'v0': 2}, (0, 2, 0)), + ({'k': 0, 'v1': 1}, (0, 2, 1)), + # unset explicit + ((0, 3, UNSET_VALUE), (0, 3, 1)), + ((0, UNSET_VALUE, 2), (0, 3, 2)), + ({'k': 0, 'v0': 4, 'v1': UNSET_VALUE}, (0, 4, 2)), + ({'k': 0, 'v0': UNSET_VALUE, 'v1': 3}, (0, 4, 3)), + # nulls still work + ((0, None, None), (0, None, None)), + ] + + for params, expected in bind_expected: + self.session.execute(insert, params) + results = self.session.execute(select, (0,)) + self.assertEqual(results[0], expected) + + self.assertRaises(ValueError, self.session.execute, select, (UNSET_VALUE, 0, 0)) + + def test_no_meta(self): + + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (0, 0) + """) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind(None) + bound.consistency_level = ConsistencyLevel.ALL + self.session.execute(bound) + + prepared = self.session.prepare( + """ + SELECT * FROM test3rf.test WHERE k=0 + """) + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind(None) + bound.consistency_level = ConsistencyLevel.ALL + results = self.session.execute(bound) + self.assertEqual(results[0].v, 0) + + def test_none_values_dicts(self): + """ + Ensure binding None is handled correctly with dict bindings + """ + + # test with new dict binding + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind({'k': 1, 'v': None}) + self.session.execute(bound) + + prepared = self.session.prepare( + """ + SELECT * FROM test3rf.test WHERE k=? + """) + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind({'k': 1}) + results = self.session.execute(bound) + self.assertEqual(results[0].v, None) + + def test_async_binding(self): + """ + Ensure None binding over async queries + """ + + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + future = self.session.execute_async(prepared, (873, None)) + future.result() + + prepared = self.session.prepare( + """ + SELECT * FROM test3rf.test WHERE k=? + """) + self.assertIsInstance(prepared, PreparedStatement) + + future = self.session.execute_async(prepared, (873,)) + results = future.result() + self.assertEqual(results[0].v, None) + + def test_async_binding_dicts(self): + """ + Ensure None binding over async queries with dict bindings + """ + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + future = self.session.execute_async(prepared, {'k': 873, 'v': None}) + future.result() + + prepared = self.session.prepare( + """ + SELECT * FROM test3rf.test WHERE k=? + """) + self.assertIsInstance(prepared, PreparedStatement) + + future = self.session.execute_async(prepared, {'k': 873}) + results = future.result() + self.assertEqual(results[0].v, None) + + def test_raise_error_on_prepared_statement_execution_dropped_table(self): + """ + test for error in executing prepared statement on a dropped table + + test_raise_error_on_execute_prepared_statement_dropped_table tests that an InvalidRequest is raised when a + prepared statement is executed after its corresponding table is dropped. This happens because if a prepared + statement is invalid, the driver attempts to automatically re-prepare it on a non-existing table. + + @expected_errors InvalidRequest If a prepared statement is executed on a dropped table + + @since 2.6.0 + @jira_ticket PYTHON-207 + @expected_result InvalidRequest error should be raised upon prepared statement execution. + + @test_category prepared_statements + """ + + self.session.execute("CREATE TABLE test3rf.error_test (k int PRIMARY KEY, v int)") + prepared = self.session.prepare("SELECT * FROM test3rf.error_test WHERE k=?") + self.session.execute("DROP TABLE test3rf.error_test") + + with self.assertRaises(InvalidRequest): + self.session.execute(prepared, [0]) + + +@greaterthanorequalcass40 +class PreparedStatementInvalidationTest(BasicSharedKeyspaceUnitTestCase): + + def setUp(self): + self.table_name = "{}.prepared_statement_invalidation_test".format(self.keyspace_name) + self.session.execute("CREATE TABLE {} (a int PRIMARY KEY, b int, d int);".format(self.table_name)) + self.session.execute("INSERT INTO {} (a, b, d) VALUES (1, 1, 1);".format(self.table_name)) + self.session.execute("INSERT INTO {} (a, b, d) VALUES (2, 2, 2);".format(self.table_name)) + self.session.execute("INSERT INTO {} (a, b, d) VALUES (3, 3, 3);".format(self.table_name)) + self.session.execute("INSERT INTO {} (a, b, d) VALUES (4, 4, 4);".format(self.table_name)) + + def tearDown(self): + self.session.execute("DROP TABLE {}".format(self.table_name)) + + def test_invalidated_result_metadata(self): + """ + Tests to make sure cached metadata is updated when an invalidated prepared statement is reprepared. + + @since 2.7.0 + @jira_ticket PYTHON-621 + + Prior to this fix, the request would blow up with a protocol error when the result was decoded expecting a different + number of columns. + """ + wildcard_prepared = self.session.prepare("SELECT * FROM {}".format(self.table_name)) + original_result_metadata = wildcard_prepared.result_metadata + self.assertEqual(len(original_result_metadata), 3) + + r = self.session.execute(wildcard_prepared) + self.assertEqual(r[0], (1, 1, 1)) + + self.session.execute("ALTER TABLE {} DROP d".format(self.table_name)) + + # Get a bunch of requests in the pipeline with varying states of result_meta, reprepare, resolved + futures = set(self.session.execute_async(wildcard_prepared.bind(None)) for _ in range(200)) + for f in futures: + self.assertEqual(f.result()[0], (1, 1)) + + self.assertIsNot(wildcard_prepared.result_metadata, original_result_metadata) + + def test_prepared_id_is_update(self): + """ + Tests that checks the query id from the prepared statement + is updated properly if the table that the prepared statement is querying + is altered. + + @since 3.12 + @jira_ticket PYTHON-808 + + The query id from the prepared statment must have changed + """ + prepared_statement = self.session.prepare("SELECT * from {} WHERE a = ?".format(self.table_name)) + id_before = prepared_statement.result_metadata_id + self.assertEqual(len(prepared_statement.result_metadata), 3) + + self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) + bound_statement = prepared_statement.bind((1, )) + self.session.execute(bound_statement, timeout=1) + + id_after = prepared_statement.result_metadata_id + + self.assertNotEqual(id_before, id_after) + self.assertEqual(len(prepared_statement.result_metadata), 4) + + def test_prepared_id_is_updated_across_pages(self): + """ + Test that checks that the query id from the prepared statement + is updated if the table hat the prepared statement is querying + is altered while fetching pages in a single query. + Then it checks that the updated rows have the expected result. + + @since 3.12 + @jira_ticket PYTHON-808 + """ + prepared_statement = self.session.prepare("SELECT * from {}".format(self.table_name)) + id_before = prepared_statement.result_metadata_id + self.assertEqual(len(prepared_statement.result_metadata), 3) + + prepared_statement.fetch_size = 2 + result = self.session.execute(prepared_statement.bind((None))) + + self.assertTrue(result.has_more_pages) + + self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) + + result_set = set(x for x in ((1, 1, 1), (2, 2, 2), (3, 3, None, 3), (4, 4, None, 4))) + expected_result_set = set(row for row in result) + + id_after = prepared_statement.result_metadata_id + + self.assertEqual(result_set, expected_result_set) + self.assertNotEqual(id_before, id_after) + self.assertEqual(len(prepared_statement.result_metadata), 4) + + def test_prepare_id_is_updated_across_session(self): + """ + Test that checks that the query id from the prepared statement + is updated if the table hat the prepared statement is querying + is altered by a different session + + @since 3.12 + @jira_ticket PYTHON-808 + """ + one_cluster = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION) + one_session = one_cluster.connect() + self.addCleanup(one_cluster.shutdown) + + stm = "SELECT * from {} WHERE a = ?".format(self.table_name) + one_prepared_stm = one_session.prepare(stm) + self.assertEqual(len(one_prepared_stm.result_metadata), 3) + + one_id_before = one_prepared_stm.result_metadata_id + + self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) + one_session.execute(one_prepared_stm, (1, )) + + one_id_after = one_prepared_stm.result_metadata_id + self.assertNotEqual(one_id_before, one_id_after) + self.assertEqual(len(one_prepared_stm.result_metadata), 4) + + def test_not_reprepare_invalid_statements(self): + """ + Test that checks that an InvalidRequest is arisen if a column + expected by the prepared statement is dropped. + + @since 3.12 + @jira_ticket PYTHON-808 + """ + prepared_statement = self.session.prepare( + "SELECT a, b, d FROM {} WHERE a = ?".format(self.table_name)) + self.session.execute("ALTER TABLE {} DROP d".format(self.table_name)) + with self.assertRaises(InvalidRequest): + self.session.execute(prepared_statement.bind((1, ))) + + def test_id_is_not_updated_conditional_v4(self): + """ + Test that verifies that the result_metadata and the + result_metadata_id are udpated correctly in conditional statements + in protocol V4 + + @since 3.13 + @jira_ticket PYTHON-847 + """ + cluster = Cluster(protocol_version=ProtocolVersion.V4) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + self._test_updated_conditional(session, 9) + + def test_id_is_not_updated_conditional_v5(self): + """ + Test that verifies that the result_metadata and the + result_metadata_id are udpated correctly in conditional statements + in protocol V5 + + @since 3.13 + @jira_ticket PYTHON-847 + """ + cluster = Cluster(protocol_version=ProtocolVersion.V5) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + self._test_updated_conditional(session, 10) + + def _test_updated_conditional(self, session, value): + prepared_statement = session.prepare( + "INSERT INTO {}(a, b, d) VALUES " + "(?, ? , ?) IF NOT EXISTS".format(self.table_name)) + first_id = prepared_statement.result_metadata_id + LOG.debug('initial result_metadata_id: {}'.format(first_id)) + + def check_result_and_metadata(expected): + self.assertEqual( + session.execute(prepared_statement, (value, value, value))[0], + expected + ) + self.assertEqual(prepared_statement.result_metadata_id, first_id) + self.assertEqual(prepared_statement.result_metadata, []) + + # Successful conditional update + check_result_and_metadata((True,)) + + # Failed conditional update + check_result_and_metadata((False, value, value, value)) + + session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) + + # Failed conditional update + check_result_and_metadata((False, value, value, None, value)) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py new file mode 100644 index 0000000..97a1e68 --- /dev/null +++ b/tests/integration/standard/test_query.py @@ -0,0 +1,1653 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from cassandra.concurrent import execute_concurrent +from cassandra import DriverException + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa +import logging +from cassandra import ProtocolVersion +from cassandra import ConsistencyLevel, Unavailable, InvalidRequest, cluster +from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement, + BatchStatement, BatchType, dict_factory, TraceUnavailable) +from cassandra.cluster import Cluster, NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT +from cassandra.policies import HostDistance, RoundRobinPolicy, WhiteListRoundRobinPolicy +from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, \ + greaterthanprotocolv3, MockLoggingHandler, get_supported_protocol_versions, local, get_cluster, setup_keyspace, \ + USE_CASS_EXTERNAL, greaterthanorequalcass40 +from tests import notwindows +from tests.integration import greaterthanorequalcass30, get_node + +import time +import random +import re + +import mock +import six + + +log = logging.getLogger(__name__) + + +def setup_module(): + if not USE_CASS_EXTERNAL: + use_singledc(start=False) + ccm_cluster = get_cluster() + ccm_cluster.stop() + # This is necessary because test_too_many_statements may + # timeout otherwise + config_options = {'write_request_timeout_in_ms': '20000'} + ccm_cluster.set_configuration_options(config_options) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + + setup_keyspace() + + +class QueryTests(BasicSharedKeyspaceUnitTestCase): + + def test_query(self): + + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """.format(self.keyspace_name)) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind((1, None)) + self.assertIsInstance(bound, BoundStatement) + self.assertEqual(2, len(bound.values)) + self.session.execute(bound) + self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01') + + def test_trace_prints_okay(self): + """ + Code coverage to ensure trace prints to string without error + """ + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + rs = self.session.execute(statement, trace=True) + + # Ensure this does not throw an exception + trace = rs.get_query_trace() + self.assertTrue(trace.events) + str(trace) + for event in trace.events: + str(event) + + def test_row_error_message(self): + """ + Test to validate, new column deserialization message + @since 3.7.0 + @jira_ticket PYTHON-361 + @expected_result Special failed decoding message should be present + + @test_category tracing + """ + self.session.execute("CREATE TABLE {0}.{1} (k int PRIMARY KEY, v timestamp)".format(self.keyspace_name,self.function_table_name)) + ss = SimpleStatement("INSERT INTO {0}.{1} (k, v) VALUES (1, 1000000000000000)".format(self.keyspace_name, self.function_table_name)) + self.session.execute(ss) + with self.assertRaises(DriverException) as context: + self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name)) + self.assertIn("Failed decoding result column", str(context.exception)) + + def test_trace_id_to_resultset(self): + + future = self.session.execute_async("SELECT * FROM system.local", trace=True) + + # future should have the current trace + rs = future.result() + future_trace = future.get_query_trace() + self.assertIsNotNone(future_trace) + + rs_trace = rs.get_query_trace() + self.assertEqual(rs_trace, future_trace) + self.assertTrue(rs_trace.events) + self.assertEqual(len(rs_trace.events), len(future_trace.events)) + + self.assertListEqual([rs_trace], rs.get_all_query_traces()) + + def test_trace_ignores_row_factory(self): + self.session.row_factory = dict_factory + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + rs = self.session.execute(statement, trace=True) + + # Ensure this does not throw an exception + trace = rs.get_query_trace() + self.assertTrue(trace.events) + str(trace) + for event in trace.events: + str(event) + + @local + @greaterthanprotocolv3 + def test_client_ip_in_trace(self): + """ + Test to validate that client trace contains client ip information. + + creates a simple query and ensures that the client trace information is present. This will + only be the case if the c* version is 2.2 or greater + + @since 2.6.0 + @jira_ticket PYTHON-435 + @expected_result client address should be present in C* >= 2.2, otherwise should be none. + + @test_category tracing + #The current version on the trunk doesn't have the version set to 2.2 yet. + #For now we will use the protocol version. Once they update the version on C* trunk + #we can use the C*. See below + #self._cass_version, self._cql_version = get_server_versions() + #if self._cass_version < (2, 2): + # raise unittest.SkipTest("Client IP was not present in trace until C* 2.2") + """ + + # Make simple query with trace enabled + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + response_future = self.session.execute_async(statement, trace=True) + response_future.result() + + # Fetch the client_ip from the trace. + trace = response_future.get_query_trace(max_wait=10.0) + client_ip = trace.client + + # Ip address should be in the local_host range + pat = re.compile("127.0.0.\d{1,3}") + + # Ensure that ip is set + self.assertIsNotNone(client_ip, "Client IP was not set in trace with C* >= 2.2") + self.assertTrue(pat.match(client_ip), "Client IP from trace did not match the expected value") + + def test_trace_cl(self): + """ + Test to ensure that CL is set correctly honored when executing trace queries. + + @since 3.3 + @jira_ticket PYTHON-435 + @expected_result Consistency Levels set on get_query_trace should be honored + """ + # Execute a query + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + response_future = self.session.execute_async(statement, trace=True) + response_future.result() + with self.assertRaises(Unavailable): + response_future.get_query_trace(query_cl=ConsistencyLevel.THREE) + # Try again with a smattering of other CL's + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.TWO).trace_id) + response_future = self.session.execute_async(statement, trace=True) + response_future.result() + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ONE).trace_id) + response_future = self.session.execute_async(statement, trace=True) + response_future.result() + with self.assertRaises(InvalidRequest): + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ANY).trace_id) + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.QUORUM).trace_id) + + @notwindows + def test_incomplete_query_trace(self): + """ + Tests to ensure that partial tracing works. + + Creates a table and runs an insert. Then attempt a query with tracing enabled. After the query is run we delete the + duration information associated with the trace, and attempt to populate the tracing information. + Should fail with wait_for_complete=True, succeed for False. + + @since 3.0.0 + @jira_ticket PYTHON-438 + @expected_result tracing comes back sans duration + + @test_category tracing + """ + + # Create table and run insert, then select + self.session.execute("CREATE TABLE {0} (k INT, i INT, PRIMARY KEY(k, i))".format(self.keyspace_table_name)) + self.session.execute("INSERT INTO {0} (k, i) VALUES (0, 1)".format(self.keyspace_table_name)) + + response_future = self.session.execute_async("SELECT i FROM {0} WHERE k=0".format(self.keyspace_table_name), trace=True) + response_future.result() + + self.assertEqual(len(response_future._query_traces), 1) + trace = response_future._query_traces[0] + self.assertTrue(self._wait_for_trace_to_populate(trace.trace_id)) + + # Delete trace duration from the session (this is what the driver polls for "complete") + delete_statement = SimpleStatement("DELETE duration FROM system_traces.sessions WHERE session_id = {0}".format(trace.trace_id), consistency_level=ConsistencyLevel.ALL) + self.session.execute(delete_statement) + self.assertTrue(self._wait_for_trace_to_delete(trace.trace_id)) + + # should raise because duration is not set + self.assertRaises(TraceUnavailable, trace.populate, max_wait=0.2, wait_for_complete=True) + self.assertFalse(trace.events) + + # should get the events with wait False + trace.populate(wait_for_complete=False) + self.assertIsNone(trace.duration) + self.assertIsNotNone(trace.trace_id) + self.assertIsNotNone(trace.request_type) + self.assertIsNotNone(trace.parameters) + self.assertTrue(trace.events) # non-zero list len + self.assertIsNotNone(trace.started_at) + + def _wait_for_trace_to_populate(self, trace_id): + count = 0 + retry_max = 10 + while(not self._is_trace_present(trace_id) and count < retry_max): + time.sleep(.2) + count += 1 + return count != retry_max + + def _wait_for_trace_to_delete(self, trace_id): + count = 0 + retry_max = 10 + while(self._is_trace_present(trace_id) and count < retry_max): + time.sleep(.2) + count += 1 + return count != retry_max + + def _is_trace_present(self, trace_id): + select_statement = SimpleStatement("SElECT duration FROM system_traces.sessions WHERE session_id = {0}".format(trace_id), consistency_level=ConsistencyLevel.ALL) + ssrs = self.session.execute(select_statement) + if not len(ssrs.current_rows) or ssrs[0].duration is None: + return False + return True + + def test_query_by_id(self): + """ + Test to ensure column_types are set as part of the result set + + @since 3.8 + @jira_ticket PYTHON-648 + @expected_result column_names should be preset. + + @test_category queries basic + """ + create_table = "CREATE TABLE {0}.{1} (id int primary key, m map)".format(self.keyspace_name, self.function_table_name) + self.session.execute(create_table) + + self.session.execute("insert into "+self.keyspace_name+"."+self.function_table_name+" (id, m) VALUES ( 1, {1: 'one', 2: 'two', 3:'three'})") + results1 = self.session.execute("select id, m from {0}.{1}".format(self.keyspace_name, self.function_table_name)) + + self.assertIsNotNone(results1.column_types) + self.assertEqual(results1.column_types[0].typename, 'int') + self.assertEqual(results1.column_types[1].typename, 'map') + self.assertEqual(results1.column_types[0].cassname, 'Int32Type') + self.assertEqual(results1.column_types[1].cassname, 'MapType') + self.assertEqual(len(results1.column_types[0].subtypes), 0) + self.assertEqual(len(results1.column_types[1].subtypes), 2) + self.assertEqual(results1.column_types[1].subtypes[0].typename, "int") + self.assertEqual(results1.column_types[1].subtypes[1].typename, "varchar") + self.assertEqual(results1.column_types[1].subtypes[0].cassname, "Int32Type") + self.assertEqual(results1.column_types[1].subtypes[1].cassname, "VarcharType") + + def test_column_names(self): + """ + Test to validate the columns are present on the result set. + Preforms a simple query against a table then checks to ensure column names are correct and present and correct. + + @since 3.0.0 + @jira_ticket PYTHON-439 + @expected_result column_names should be preset. + + @test_category queries basic + """ + create_table = """CREATE TABLE {0}.{1}( + user TEXT, + game TEXT, + year INT, + month INT, + day INT, + score INT, + PRIMARY KEY (user, game, year, month, day) + )""".format(self.keyspace_name, self.function_table_name) + + + self.session.execute(create_table) + result_set = self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name)) + self.assertIsNotNone(result_set.column_types) + + self.assertEqual(result_set.column_names, [u'user', u'game', u'year', u'month', u'day', u'score']) + + @greaterthanorequalcass30 + def test_basic_json_query(self): + insert_query = SimpleStatement("INSERT INTO test3rf.test(k, v) values (1, 1)", consistency_level = ConsistencyLevel.QUORUM) + json_query = SimpleStatement("SELECT JSON * FROM test3rf.test where k=1", consistency_level = ConsistencyLevel.QUORUM) + + self.session.execute(insert_query) + results = self.session.execute(json_query) + self.assertEqual(results.column_names, ["[json]"]) + self.assertEqual(results[0][0], '{"k": 1, "v": 1}') + + def test_host_targeting_query(self): + """ + Test to validate the the single host targeting works. + + @since 3.17.0 + @jira_ticket PYTHON-933 + @expected_result the coordinator host is always the one set + """ + + default_ep = self.cluster.profile_manager.default + # copy of default EP with checkable LBP + checkable_ep = self.session.execution_profile_clone_update( + ep=default_ep, + load_balancing_policy=mock.Mock(wraps=default_ep.load_balancing_policy) + ) + query = SimpleStatement("INSERT INTO test3rf.test(k, v) values (1, 1)") + + for i in range(10): + host = random.choice(self.cluster.metadata.all_hosts()) + log.debug('targeting {}'.format(host)) + future = self.session.execute_async(query, host=host, execution_profile=checkable_ep) + future.result() + # check we're using the selected host + self.assertEqual(host, future.coordinator_host) + # check that this bypasses the LBP + self.assertFalse(checkable_ep.load_balancing_policy.make_query_plan.called) + + +class PreparedStatementTests(unittest.TestCase): + + def setUp(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.session = self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + def test_routing_key(self): + """ + Simple code coverage to ensure routing_keys can be accessed + """ + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind((1, None)) + self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01') + + def test_empty_routing_key_indexes(self): + """ + Ensure when routing_key_indexes are blank, + the routing key should be None + """ + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + prepared.routing_key_indexes = None + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind((1, None)) + self.assertEqual(bound.routing_key, None) + + def test_predefined_routing_key(self): + """ + Basic test that ensures _set_routing_key() + overrides the current routing key + """ + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind((1, None)) + bound._set_routing_key('fake_key') + self.assertEqual(bound.routing_key, 'fake_key') + + def test_multiple_routing_key_indexes(self): + """ + Basic test that uses a fake routing_key_index + """ + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + self.assertIsInstance(prepared, PreparedStatement) + + prepared.routing_key_indexes = [0, 1] + bound = prepared.bind((1, 2)) + self.assertEqual(bound.routing_key, b'\x00\x04\x00\x00\x00\x01\x00\x00\x04\x00\x00\x00\x02\x00') + + prepared.routing_key_indexes = [1, 0] + bound = prepared.bind((1, 2)) + self.assertEqual(bound.routing_key, b'\x00\x04\x00\x00\x00\x02\x00\x00\x04\x00\x00\x00\x01\x00') + + def test_bound_keyspace(self): + """ + Ensure that bound.keyspace works as expected + """ + prepared = self.session.prepare( + """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """) + + self.assertIsInstance(prepared, PreparedStatement) + bound = prepared.bind((1, 2)) + self.assertEqual(bound.keyspace, 'test3rf') + + +class ForcedHostIndexPolicy(RoundRobinPolicy): + def __init__(self, host_index_to_use=0): + super(ForcedHostIndexPolicy, self).__init__() + self.host_index_to_use = host_index_to_use + + def set_host(self, host_index): + """ 0-based index of which host to use """ + self.host_index_to_use = host_index + + def make_query_plan(self, working_keyspace=None, query=None): + live_hosts = sorted(list(self._live_hosts)) + host = [] + try: + host = [live_hosts[self.host_index_to_use]] + except IndexError as e: + six.raise_from(IndexError( + 'You specified an index larger than the number of hosts. Total hosts: {}. Index specified: {}'.format( + len(live_hosts), self.host_index_to_use + )), e) + return host + + +class PreparedStatementMetdataTest(unittest.TestCase): + + def test_prepared_metadata_generation(self): + """ + Test to validate that result metadata is appropriately populated across protocol version + + In protocol version 1 result metadata is retrieved everytime the statement is issued. In all + other protocol versions it's set once upon the prepare, then re-used. This test ensures that it manifests + it's self the same across multiple protocol versions. + + @since 3.6.0 + @jira_ticket PYTHON-71 + @expected_result result metadata is consistent. + """ + + base_line = None + for proto_version in get_supported_protocol_versions(): + beta_flag = True if proto_version in ProtocolVersion.BETA_VERSIONS else False + cluster = Cluster(protocol_version=proto_version, allow_beta_protocol_version=beta_flag) + + session = cluster.connect() + select_statement = session.prepare("SELECT * FROM system.local") + if proto_version == 1: + self.assertEqual(select_statement.result_metadata, None) + else: + self.assertNotEqual(select_statement.result_metadata, None) + future = session.execute_async(select_statement) + results = future.result() + if base_line is None: + base_line = results[0]._asdict().keys() + else: + self.assertEqual(base_line, results[0]._asdict().keys()) + cluster.shutdown() + + +class PreparedStatementArgTest(unittest.TestCase): + + def setUp(self): + self.mock_handler = MockLoggingHandler() + logger = logging.getLogger(cluster.__name__) + logger.addHandler(self.mock_handler) + + def test_prepare_on_all_hosts(self): + """ + Test to validate prepare_on_all_hosts flag is honored. + + Force the host of each query to ensure prepared queries are cycled over nodes that should not + have them prepared. Check the logs to insure they are being re-prepared on those nodes + + @since 3.4.0 + @jira_ticket PYTHON-556 + @expected_result queries will have to re-prepared on hosts that aren't the control connection + """ + clus = Cluster(protocol_version=PROTOCOL_VERSION, prepare_on_all_hosts=False, reprepare_on_up=False) + self.addCleanup(clus.shutdown) + + session = clus.connect(wait_for_all_pools=True) + select_statement = session.prepare("SELECT k FROM test3rf.test WHERE k = ?") + for host in clus.metadata.all_hosts(): + session.execute(select_statement, (1, ), host=host) + self.assertEqual(2, self.mock_handler.get_message_count('debug', "Re-preparing")) + + def test_prepare_batch_statement(self): + """ + Test to validate a prepared statement used inside a batch statement is correctly handled + by the driver + + @since 3.10 + @jira_ticket PYTHON-706 + @expected_result queries will have to re-prepared on hosts that aren't the control connection + and the batch statement will be sent. + """ + policy = ForcedHostIndexPolicy() + clus = Cluster( + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=policy), + }, + protocol_version=PROTOCOL_VERSION, + prepare_on_all_hosts=False, + reprepare_on_up=False, + ) + self.addCleanup(clus.shutdown) + + table = "test3rf.%s" % self._testMethodName.lower() + + session = clus.connect(wait_for_all_pools=True) + + session.execute("DROP TABLE IF EXISTS %s" % table) + session.execute("CREATE TABLE %s (k int PRIMARY KEY, v int )" % table) + + insert_statement = session.prepare("INSERT INTO %s (k, v) VALUES (?, ?)" % table) + + # This is going to query a host where the query + # is not prepared + policy.set_host(1) + batch_statement = BatchStatement(consistency_level=ConsistencyLevel.ONE) + batch_statement.add(insert_statement, (1, 2)) + session.execute(batch_statement) + + # To verify our test assumption that queries are getting re-prepared properly + self.assertEqual(1, self.mock_handler.get_message_count('debug', "Re-preparing")) + + select_results = session.execute(SimpleStatement("SELECT * FROM %s WHERE k = 1" % table, + consistency_level=ConsistencyLevel.ALL)) + first_row = select_results[0][:2] + self.assertEqual((1, 2), first_row) + + def test_prepare_batch_statement_after_alter(self): + """ + Test to validate a prepared statement used inside a batch statement is correctly handled + by the driver. The metadata might be updated when a table is altered. This tests combines + queries not being prepared and an update of the prepared statement metadata + + @since 3.10 + @jira_ticket PYTHON-706 + @expected_result queries will have to re-prepared on hosts that aren't the control connection + and the batch statement will be sent. + """ + clus = Cluster(protocol_version=PROTOCOL_VERSION, prepare_on_all_hosts=False, reprepare_on_up=False) + self.addCleanup(clus.shutdown) + + table = "test3rf.%s" % self._testMethodName.lower() + + session = clus.connect(wait_for_all_pools=True) + + session.execute("DROP TABLE IF EXISTS %s" % table) + session.execute("CREATE TABLE %s (k int PRIMARY KEY, a int, b int, d int)" % table) + insert_statement = session.prepare("INSERT INTO %s (k, b, d) VALUES (?, ?, ?)" % table) + + # Altering the table might trigger an update in the insert metadata + session.execute("ALTER TABLE %s ADD c int" % table) + + values_to_insert = [(1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] + + # We query the three hosts in order (due to the ForcedHostIndexPolicy) + # the first three queries will have to be repreapred and the rest should + # work as normal batch prepared statements + hosts = clus.metadata.all_hosts() + for i in range(10): + value_to_insert = values_to_insert[i % len(values_to_insert)] + batch_statement = BatchStatement(consistency_level=ConsistencyLevel.ONE) + batch_statement.add(insert_statement, value_to_insert) + session.execute(batch_statement, host=hosts[i % len(hosts)]) + + select_results = session.execute("SELECT * FROM %s" % table) + expected_results = [ + (1, None, 2, None, 3), + (2, None, 3, None, 4), + (3, None, 4, None, 5), + (4, None, 5, None, 6) + ] + + self.assertEqual(set(expected_results), set(select_results._current_rows)) + + # To verify our test assumption that queries are getting re-prepared properly + self.assertEqual(3, self.mock_handler.get_message_count('debug', "Re-preparing")) + + +class PrintStatementTests(unittest.TestCase): + """ + Test that shows the format used when printing Statements + """ + + def test_simple_statement(self): + """ + Highlight the format of printing SimpleStatements + """ + + ss = SimpleStatement('SELECT * FROM test3rf.test', consistency_level=ConsistencyLevel.ONE) + self.assertEqual(str(ss), + '') + + def test_prepared_statement(self): + """ + Highlight the difference between Prepared and Bound statements + """ + + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect() + + prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)') + prepared.consistency_level = ConsistencyLevel.ONE + + self.assertEqual(str(prepared), + '') + + bound = prepared.bind((1, 2)) + self.assertEqual(str(bound), + '') + + cluster.shutdown() + + +class BatchStatementTests(BasicSharedKeyspaceUnitTestCase): + + def setUp(self): + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest( + "Protocol 2.0+ is required for BATCH operations, currently testing against %r" + % (PROTOCOL_VERSION,)) + + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + if PROTOCOL_VERSION < 3: + self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) + self.session = self.cluster.connect(wait_for_all_pools=True) + + def tearDown(self): + self.cluster.shutdown() + + def confirm_results(self): + keys = set() + values = set() + # Assuming the test data is inserted at default CL.ONE, we need ALL here to guarantee we see + # everything inserted + results = self.session.execute(SimpleStatement("SELECT * FROM test3rf.test", + consistency_level=ConsistencyLevel.ALL)) + for result in results: + keys.add(result.k) + values.add(result.v) + + self.assertEqual(set(range(10)), keys, msg=results) + self.assertEqual(set(range(10)), values, msg=results) + + def test_string_statements(self): + batch = BatchStatement(BatchType.LOGGED) + for i in range(10): + batch.add("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", (i, i)) + + self.session.execute(batch) + self.session.execute_async(batch).result() + self.confirm_results() + + def test_simple_statements(self): + batch = BatchStatement(BatchType.LOGGED) + for i in range(10): + batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i)) + + self.session.execute(batch) + self.session.execute_async(batch).result() + self.confirm_results() + + def test_prepared_statements(self): + prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)") + + batch = BatchStatement(BatchType.LOGGED) + for i in range(10): + batch.add(prepared, (i, i)) + + self.session.execute(batch) + self.session.execute_async(batch).result() + self.confirm_results() + + def test_bound_statements(self): + prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)") + + batch = BatchStatement(BatchType.LOGGED) + for i in range(10): + batch.add(prepared.bind((i, i))) + + self.session.execute(batch) + self.session.execute_async(batch).result() + self.confirm_results() + + def test_no_parameters(self): + batch = BatchStatement(BatchType.LOGGED) + batch.add("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") + batch.add("INSERT INTO test3rf.test (k, v) VALUES (1, 1)", ()) + batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)")) + batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (3, 3)"), ()) + + prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (4, 4)") + batch.add(prepared) + batch.add(prepared, ()) + batch.add(prepared.bind([])) + batch.add(prepared.bind([]), ()) + + batch.add("INSERT INTO test3rf.test (k, v) VALUES (5, 5)", ()) + batch.add("INSERT INTO test3rf.test (k, v) VALUES (6, 6)", ()) + batch.add("INSERT INTO test3rf.test (k, v) VALUES (7, 7)", ()) + batch.add("INSERT INTO test3rf.test (k, v) VALUES (8, 8)", ()) + batch.add("INSERT INTO test3rf.test (k, v) VALUES (9, 9)", ()) + + self.assertRaises(ValueError, batch.add, prepared.bind([]), (1)) + self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2)) + self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2, 3)) + + self.session.execute(batch) + self.confirm_results() + + def test_unicode(self): + ddl = ''' + CREATE TABLE test3rf.testtext ( + k int PRIMARY KEY, + v text )''' + self.session.execute(ddl) + unicode_text = u'Fran\u00E7ois' + query = u'INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)' + try: + batch = BatchStatement(BatchType.LOGGED) + batch.add(u"INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)", (0, unicode_text)) + self.session.execute(batch) + finally: + self.session.execute("DROP TABLE test3rf.testtext") + + def test_too_many_statements(self): + max_statements = 0xFFFF + ss = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") + b = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE) + + # max works + b.add_all([ss] * max_statements, [None] * max_statements) + self.session.execute(b) + + # max + 1 raises + self.assertRaises(ValueError, b.add, ss) + + # also would have bombed trying to encode + b._statements_and_parameters.append((False, ss.query_string, ())) + self.assertRaises(NoHostAvailable, self.session.execute, b) + + +class SerialConsistencyTests(unittest.TestCase): + def setUp(self): + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest( + "Protocol 2.0+ is required for Serial Consistency, currently testing against %r" + % (PROTOCOL_VERSION,)) + + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + if PROTOCOL_VERSION < 3: + self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) + self.session = self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + def test_conditional_update(self): + self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") + statement = SimpleStatement( + "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1", + serial_consistency_level=ConsistencyLevel.SERIAL) + # crazy test, but PYTHON-299 + # TODO: expand to check more parameters get passed to statement, and on to messages + self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.SERIAL) + future = self.session.execute_async(statement) + result = future.result() + self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL) + self.assertTrue(result) + self.assertFalse(result[0].applied) + + statement = SimpleStatement( + "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0", + serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL) + self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) + future = self.session.execute_async(statement) + result = future.result() + self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) + self.assertTrue(result) + self.assertTrue(result[0].applied) + + def test_conditional_update_with_prepared_statements(self): + self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") + statement = self.session.prepare( + "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=2") + + statement.serial_consistency_level = ConsistencyLevel.SERIAL + future = self.session.execute_async(statement) + result = future.result() + self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL) + self.assertTrue(result) + self.assertFalse(result[0].applied) + + statement = self.session.prepare( + "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0") + bound = statement.bind(()) + bound.serial_consistency_level = ConsistencyLevel.LOCAL_SERIAL + future = self.session.execute_async(bound) + result = future.result() + self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) + self.assertTrue(result) + self.assertTrue(result[0].applied) + + def test_conditional_update_with_batch_statements(self): + self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") + statement = BatchStatement(serial_consistency_level=ConsistencyLevel.SERIAL) + statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1") + self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.SERIAL) + future = self.session.execute_async(statement) + result = future.result() + self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL) + self.assertTrue(result) + self.assertFalse(result[0].applied) + + statement = BatchStatement(serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL) + statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0") + self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) + future = self.session.execute_async(statement) + result = future.result() + self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) + self.assertTrue(result) + self.assertTrue(result[0].applied) + + def test_bad_consistency_level(self): + statement = SimpleStatement("foo") + self.assertRaises(ValueError, setattr, statement, 'serial_consistency_level', ConsistencyLevel.ONE) + self.assertRaises(ValueError, SimpleStatement, 'foo', serial_consistency_level=ConsistencyLevel.ONE) + + +class LightweightTransactionTests(unittest.TestCase): + def setUp(self): + """ + Test is skipped if run with cql version < 2 + + """ + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest( + "Protocol 2.0+ is required for Lightweight transactions, currently testing against %r" + % (PROTOCOL_VERSION,)) + + serial_profile = ExecutionProfile(consistency_level=ConsistencyLevel.SERIAL) + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, execution_profiles={'serial': serial_profile}) + self.session = self.cluster.connect() + + ddl = ''' + CREATE TABLE test3rf.lwt ( + k int PRIMARY KEY, + v int )''' + self.session.execute(ddl) + + ddl = ''' + CREATE TABLE test3rf.lwt_clustering ( + k int, + c int, + v int, + PRIMARY KEY (k, c))''' + self.session.execute(ddl) + + def tearDown(self): + """ + Shutdown cluster + """ + self.session.execute("DROP TABLE test3rf.lwt") + self.session.execute("DROP TABLE test3rf.lwt_clustering") + self.cluster.shutdown() + + def test_no_connection_refused_on_timeout(self): + """ + Test for PYTHON-91 "Connection closed after LWT timeout" + Verifies that connection to the cluster is not shut down when timeout occurs. + Number of iterations can be specified with LWT_ITERATIONS environment variable. + Default value is 1000 + """ + insert_statement = self.session.prepare("INSERT INTO test3rf.lwt (k, v) VALUES (0, 0) IF NOT EXISTS") + delete_statement = self.session.prepare("DELETE FROM test3rf.lwt WHERE k = 0 IF EXISTS") + + iterations = int(os.getenv("LWT_ITERATIONS", 1000)) + + # Prepare series of parallel statements + statements_and_params = [] + for i in range(iterations): + statements_and_params.append((insert_statement, ())) + statements_and_params.append((delete_statement, ())) + + received_timeout = False + results = execute_concurrent(self.session, statements_and_params, raise_on_first_error=False) + for (success, result) in results: + if success: + continue + else: + # In this case result is an exception + if type(result).__name__ == "NoHostAvailable": + self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message) + if type(result).__name__ == "WriteTimeout": + received_timeout = True + continue + if type(result).__name__ == "WriteFailure": + received_timeout = True + continue + if type(result).__name__ == "ReadTimeout": + continue + if type(result).__name__ == "ReadFailure": + continue + + self.fail("Unexpected exception %s: %s" % (type(result).__name__, result.message)) + + # Make sure test passed + self.assertTrue(received_timeout) + + def test_was_applied_batch_stmt(self): + """ + Test to ensure `:attr:cassandra.cluster.ResultSet.was_applied` works as expected + with Batchstatements. + + For both type of batches verify was_applied has the correct result + under different scenarios: + - If on LWT fails the rest of the statements fail including normal UPSERTS + - If on LWT fails the rest of the statements fail + - All the queries succeed + + @since 3.14 + @jira_ticket PYTHON-848 + @expected_result `:attr:cassandra.cluster.ResultSet.was_applied` is updated as + expected + + @test_category query + """ + for batch_type in (BatchType.UNLOGGED, BatchType.LOGGED): + batch_statement = BatchStatement(batch_type) + batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10);"], [None] * 3) + result = self.session.execute(batch_statement) + #self.assertTrue(result.was_applied) + + # Should fail since (0, 0, 10) have already been written + # The non conditional insert shouldn't be written as well + batch_statement = BatchStatement(batch_type) + batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 4) + result = self.session.execute(batch_statement) + self.assertFalse(result.was_applied) + + all_rows = self.session.execute("SELECT * from test3rf.lwt_clustering", execution_profile='serial') + # Verify the non conditional insert hasn't been inserted + self.assertEqual(len(all_rows.current_rows), 3) + + # Should fail since (0, 0, 10) have already been written + batch_statement = BatchStatement(batch_type) + batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 3) + result = self.session.execute(batch_statement) + self.assertFalse(result.was_applied) + + # Should fail since (0, 0, 10) have already been written + batch_statement.add("INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;") + result = self.session.execute(batch_statement) + self.assertFalse(result.was_applied) + + # Should succeed + batch_statement = BatchStatement(batch_type) + batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 3) + + result = self.session.execute(batch_statement) + self.assertTrue(result.was_applied) + + all_rows = self.session.execute("SELECT * from test3rf.lwt_clustering", execution_profile='serial') + for i, row in enumerate(all_rows): + self.assertEqual((0, i, 10), (row[0], row[1], row[2])) + + self.session.execute("TRUNCATE TABLE test3rf.lwt_clustering") + + def test_empty_batch_statement(self): + """ + Test to ensure `:attr:cassandra.cluster.ResultSet.was_applied` works as expected + with empty Batchstatements. + + @since 3.14 + @jira_ticket PYTHON-848 + @expected_result an Exception is raised + expected + + @test_category query + """ + batch_statement = BatchStatement() + results = self.session.execute(batch_statement) + with self.assertRaises(RuntimeError): + results.was_applied + + @unittest.skip("Skipping until PYTHON-943 is resolved") + def test_was_applied_batch_string(self): + batch_statement = BatchStatement(BatchType.LOGGED) + batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10);"], [None] * 3) + self.session.execute(batch_statement) + + batch_str = """ + BEGIN unlogged batch + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS; + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10) IF NOT EXISTS; + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10) IF NOT EXISTS; + APPLY batch; + """ + result = self.session.execute(batch_str) + self.assertFalse(result.was_applied) + + batch_str = """ + BEGIN unlogged batch + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS; + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10) IF NOT EXISTS; + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS; + APPLY batch; + """ + result = self.session.execute(batch_str) + self.assertTrue(result.was_applied) + + +class BatchStatementDefaultRoutingKeyTests(unittest.TestCase): + # Test for PYTHON-126: BatchStatement.add() should set the routing key of the first added prepared statement + + def setUp(self): + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest( + "Protocol 2.0+ is required for BATCH operations, currently testing against %r" + % (PROTOCOL_VERSION,)) + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.session = self.cluster.connect() + query = """ + INSERT INTO test3rf.test (k, v) VALUES (?, ?) + """ + self.simple_statement = SimpleStatement(query, routing_key='ss_rk', keyspace='keyspace_name') + self.prepared = self.session.prepare(query) + + def tearDown(self): + self.cluster.shutdown() + + def test_rk_from_bound(self): + """ + batch routing key is inherited from BoundStatement + """ + bound = self.prepared.bind((1, None)) + batch = BatchStatement() + batch.add(bound) + self.assertIsNotNone(batch.routing_key) + self.assertEqual(batch.routing_key, bound.routing_key) + + def test_rk_from_simple(self): + """ + batch routing key is inherited from SimpleStatement + """ + batch = BatchStatement() + batch.add(self.simple_statement) + self.assertIsNotNone(batch.routing_key) + self.assertEqual(batch.routing_key, self.simple_statement.routing_key) + + def test_inherit_first_rk_bound(self): + """ + compound batch inherits the first routing key of the first added statement (bound statement is first) + """ + bound = self.prepared.bind((100000000, None)) + batch = BatchStatement() + batch.add("ss with no rk") + batch.add(bound) + batch.add(self.simple_statement) + + for i in range(3): + batch.add(self.prepared, (i, i)) + + self.assertIsNotNone(batch.routing_key) + self.assertEqual(batch.routing_key, bound.routing_key) + + def test_inherit_first_rk_simple_statement(self): + """ + compound batch inherits the first routing key of the first added statement (Simplestatement is first) + """ + bound = self.prepared.bind((1, None)) + batch = BatchStatement() + batch.add("ss with no rk") + batch.add(self.simple_statement) + batch.add(bound) + + for i in range(10): + batch.add(self.prepared, (i, i)) + + self.assertIsNotNone(batch.routing_key) + self.assertEqual(batch.routing_key, self.simple_statement.routing_key) + + def test_inherit_first_rk_prepared_param(self): + """ + compound batch inherits the first routing key of the first added statement (prepared statement is first) + """ + bound = self.prepared.bind((2, None)) + batch = BatchStatement() + batch.add("ss with no rk") + batch.add(self.prepared, (1, 0)) + batch.add(bound) + batch.add(self.simple_statement) + + self.assertIsNotNone(batch.routing_key) + self.assertEqual(batch.routing_key, self.prepared.bind((1, 0)).routing_key) + + +@greaterthanorequalcass30 +class MaterializedViewQueryTest(BasicSharedKeyspaceUnitTestCase): + + def test_mv_filtering(self): + """ + Test to ensure that cql filtering where clauses are properly supported in the python driver. + + test_mv_filtering Tests that various complex MV where clauses produce the correct results. It also validates that + these results and the grammar is supported appropriately. + + @since 3.0.0 + @jira_ticket PYTHON-399 + @expected_result Materialized view where clauses should produce the appropriate results. + + @test_category materialized_view + """ + create_table = """CREATE TABLE {0}.scores( + user TEXT, + game TEXT, + year INT, + month INT, + day INT, + score INT, + PRIMARY KEY (user, game, year, month, day) + )""".format(self.keyspace_name) + + self.session.execute(create_table) + + create_mv_alltime = """CREATE MATERIALIZED VIEW {0}.alltimehigh AS + SELECT * FROM {0}.scores + WHERE game IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL + PRIMARY KEY (game, score, user, year, month, day) + WITH CLUSTERING ORDER BY (score DESC)""".format(self.keyspace_name) + + create_mv_dailyhigh = """CREATE MATERIALIZED VIEW {0}.dailyhigh AS + SELECT * FROM {0}.scores + WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL + PRIMARY KEY ((game, year, month, day), score, user) + WITH CLUSTERING ORDER BY (score DESC)""".format(self.keyspace_name) + + create_mv_monthlyhigh = """CREATE MATERIALIZED VIEW {0}.monthlyhigh AS + SELECT * FROM {0}.scores + WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL + PRIMARY KEY ((game, year, month), score, user, day) + WITH CLUSTERING ORDER BY (score DESC)""".format(self.keyspace_name) + + create_mv_filtereduserhigh = """CREATE MATERIALIZED VIEW {0}.filtereduserhigh AS + SELECT * FROM {0}.scores + WHERE user in ('jbellis', 'pcmanus') AND game IS NOT NULL AND score IS NOT NULL AND year is NOT NULL AND day is not NULL and month IS NOT NULL + PRIMARY KEY (game, score, user, year, month, day) + WITH CLUSTERING ORDER BY (score DESC)""".format(self.keyspace_name) + + self.session.execute(create_mv_alltime) + self.session.execute(create_mv_dailyhigh) + self.session.execute(create_mv_monthlyhigh) + self.session.execute(create_mv_filtereduserhigh) + + self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.alltimehigh".format(self.keyspace_name)) + self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.dailyhigh".format(self.keyspace_name)) + self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.monthlyhigh".format(self.keyspace_name)) + self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.filtereduserhigh".format(self.keyspace_name)) + + prepared_insert = self.session.prepare("""INSERT INTO {0}.scores (user, game, year, month, day, score) VALUES (?, ?, ? ,? ,?, ?)""".format(self.keyspace_name)) + + bound = prepared_insert.bind(('pcmanus', 'Coup', 2015, 5, 1, 4000)) + self.session.execute(bound) + bound = prepared_insert.bind(('jbellis', 'Coup', 2015, 5, 3, 1750)) + self.session.execute(bound) + bound = prepared_insert.bind(('yukim', 'Coup', 2015, 5, 3, 2250)) + self.session.execute(bound) + bound = prepared_insert.bind(('tjake', 'Coup', 2015, 5, 3, 500)) + self.session.execute(bound) + bound = prepared_insert.bind(('iamaleksey', 'Coup', 2015, 6, 1, 2500)) + self.session.execute(bound) + bound = prepared_insert.bind(('tjake', 'Coup', 2015, 6, 2, 1000)) + self.session.execute(bound) + bound = prepared_insert.bind(('pcmanus', 'Coup', 2015, 6, 2, 2000)) + self.session.execute(bound) + bound = prepared_insert.bind(('jmckenzie', 'Coup', 2015, 6, 9, 2700)) + self.session.execute(bound) + bound = prepared_insert.bind(('jbellis', 'Coup', 2015, 6, 20, 3500)) + self.session.execute(bound) + bound = prepared_insert.bind(('jbellis', 'Checkers', 2015, 6, 20, 1200)) + self.session.execute(bound) + bound = prepared_insert.bind(('jbellis', 'Chess', 2015, 6, 21, 3500)) + self.session.execute(bound) + bound = prepared_insert.bind(('pcmanus', 'Chess', 2015, 1, 25, 3200)) + self.session.execute(bound) + + # Test simple statement and alltime high filtering + query_statement = SimpleStatement("SELECT * FROM {0}.alltimehigh WHERE game='Coup'".format(self.keyspace_name), + consistency_level=ConsistencyLevel.QUORUM) + results = self.session.execute(query_statement) + self.assertEqual(results[0].game, 'Coup') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 5) + self.assertEqual(results[0].day, 1) + self.assertEqual(results[0].score, 4000) + self.assertEqual(results[0].user, "pcmanus") + + # Test prepared statement and daily high filtering + prepared_query = self.session.prepare("SELECT * FROM {0}.dailyhigh WHERE game=? AND year=? AND month=? and day=?".format(self.keyspace_name)) + bound_query = prepared_query.bind(("Coup", 2015, 6, 2)) + results = self.session.execute(bound_query) + self.assertEqual(results[0].game, 'Coup') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 6) + self.assertEqual(results[0].day, 2) + self.assertEqual(results[0].score, 2000) + self.assertEqual(results[0].user, "pcmanus") + + self.assertEqual(results[1].game, 'Coup') + self.assertEqual(results[1].year, 2015) + self.assertEqual(results[1].month, 6) + self.assertEqual(results[1].day, 2) + self.assertEqual(results[1].score, 1000) + self.assertEqual(results[1].user, "tjake") + + # Test montly high range queries + prepared_query = self.session.prepare("SELECT * FROM {0}.monthlyhigh WHERE game=? AND year=? AND month=? and score >= ? and score <= ?".format(self.keyspace_name)) + bound_query = prepared_query.bind(("Coup", 2015, 6, 2500, 3500)) + results = self.session.execute(bound_query) + self.assertEqual(results[0].game, 'Coup') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 6) + self.assertEqual(results[0].day, 20) + self.assertEqual(results[0].score, 3500) + self.assertEqual(results[0].user, "jbellis") + + self.assertEqual(results[1].game, 'Coup') + self.assertEqual(results[1].year, 2015) + self.assertEqual(results[1].month, 6) + self.assertEqual(results[1].day, 9) + self.assertEqual(results[1].score, 2700) + self.assertEqual(results[1].user, "jmckenzie") + + self.assertEqual(results[2].game, 'Coup') + self.assertEqual(results[2].year, 2015) + self.assertEqual(results[2].month, 6) + self.assertEqual(results[2].day, 1) + self.assertEqual(results[2].score, 2500) + self.assertEqual(results[2].user, "iamaleksey") + + # Test filtered user high scores + query_statement = SimpleStatement("SELECT * FROM {0}.filtereduserhigh WHERE game='Chess'".format(self.keyspace_name), + consistency_level=ConsistencyLevel.QUORUM) + results = self.session.execute(query_statement) + self.assertEqual(results[0].game, 'Chess') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 6) + self.assertEqual(results[0].day, 21) + self.assertEqual(results[0].score, 3500) + self.assertEqual(results[0].user, "jbellis") + + self.assertEqual(results[1].game, 'Chess') + self.assertEqual(results[1].year, 2015) + self.assertEqual(results[1].month, 1) + self.assertEqual(results[1].day, 25) + self.assertEqual(results[1].score, 3200) + self.assertEqual(results[1].user, "pcmanus") + + +class UnicodeQueryTest(BasicSharedKeyspaceUnitTestCase): + + def setUp(self): + ddl = ''' + CREATE TABLE {0}.{1} ( + k int PRIMARY KEY, + v text )'''.format(self.keyspace_name, self.function_table_name) + self.session.execute(ddl) + + def tearDown(self): + self.session.execute("DROP TABLE {0}.{1}".format(self.keyspace_name,self.function_table_name)) + + def test_unicode(self): + """ + Test to validate that unicode query strings are handled appropriately by various query types + + @since 3.0.0 + @jira_ticket PYTHON-334 + @expected_result no unicode exceptions are thrown + + @test_category query + """ + + unicode_text = u'Fran\u00E7ois' + batch = BatchStatement(BatchType.LOGGED) + batch.add(u"INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format(self.keyspace_name, self.function_table_name), (0, unicode_text)) + self.session.execute(batch) + self.session.execute(u"INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format(self.keyspace_name, self.function_table_name), (0, unicode_text)) + prepared = self.session.prepare(u"INSERT INTO {0}.{1} (k, v) VALUES (?, ?)".format(self.keyspace_name, self.function_table_name)) + bound = prepared.bind((1, unicode_text)) + self.session.execute(bound) + + +class BaseKeyspaceTests(): + @classmethod + def setUpClass(cls): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.session = cls.cluster.connect(wait_for_all_pools=True) + cls.ks_name = cls.__name__.lower() + + cls.alternative_ks = "alternative_keyspace" + cls.table_name = "table_query_keyspace_tests" + + ddl = """CREATE KEYSPACE {0} WITH replication = + {{'class': 'SimpleStrategy', + 'replication_factor': '{1}'}}""".format(cls.ks_name, 1) + cls.session.execute(ddl) + + ddl = """CREATE KEYSPACE {0} WITH replication = + {{'class': 'SimpleStrategy', + 'replication_factor': '{1}'}}""".format(cls.alternative_ks, 1) + cls.session.execute(ddl) + + ddl = ''' + CREATE TABLE {0}.{1} ( + k int PRIMARY KEY, + v int )'''.format(cls.ks_name, cls.table_name) + cls.session.execute(ddl) + ddl = ''' + CREATE TABLE {0}.{1} ( + k int PRIMARY KEY, + v int )'''.format(cls.alternative_ks, cls.table_name) + cls.session.execute(ddl) + + cls.session.execute("INSERT INTO {}.{} (k, v) VALUES (1, 1)".format(cls.ks_name, cls.table_name)) + cls.session.execute("INSERT INTO {}.{} (k, v) VALUES (2, 2)".format(cls.alternative_ks, cls.table_name)) + + @classmethod + def tearDownClass(cls): + ddl = "DROP KEYSPACE {}".format(cls.alternative_ks) + cls.session.execute(ddl) + ddl = "DROP KEYSPACE {}".format(cls.ks_name) + cls.session.execute(ddl) + cls.cluster.shutdown() + +class QueryKeyspaceTests(BaseKeyspaceTests): + + def test_setting_keyspace(self): + """ + Test the basic functionality of PYTHON-678, the keyspace can be set + independently of the query and read the results + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + self._check_set_keyspace_in_statement(self.session) + + def test_setting_keyspace_and_session(self): + """ + Test we can still send the keyspace independently even the session + connects to a keyspace when it's created + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + cluster = Cluster(protocol_version=ProtocolVersion.V5, allow_beta_protocol_version=True) + session = cluster.connect(self.alternative_ks) + self.addCleanup(cluster.shutdown) + + self._check_set_keyspace_in_statement(session) + + def test_setting_keyspace_and_session_after_created(self): + """ + Test we can still send the keyspace independently even the session + connects to a different keyspace after being created + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + cluster = Cluster(protocol_version=ProtocolVersion.V5, allow_beta_protocol_version=True) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + + session.set_keyspace(self.alternative_ks) + self._check_set_keyspace_in_statement(session) + + def test_setting_keyspace_and_same_session(self): + """ + Test we can still send the keyspace independently even if the session + is connected to the sent keyspace + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + cluster = Cluster(protocol_version=ProtocolVersion.V5, allow_beta_protocol_version=True) + session = cluster.connect(self.ks_name) + self.addCleanup(cluster.shutdown) + + self._check_set_keyspace_in_statement(session) + + +@greaterthanorequalcass40 +class SimpleWithKeyspaceTests(QueryKeyspaceTests, unittest.TestCase): + @unittest.skip + def test_lower_protocol(self): + cluster = Cluster(protocol_version=ProtocolVersion.V4) + session = cluster.connect(self.ks_name) + self.addCleanup(cluster.shutdown) + + simple_stmt = SimpleStatement("SELECT * from {}".format(self.table_name), keyspace=self.ks_name) + # This raises cassandra.cluster.NoHostAvailable: ('Unable to complete the operation against + # any hosts', {: UnsupportedOperation('Keyspaces may only be + # set on queries with protocol version 5 or higher. Consider setting Cluster.protocol_version to 5.',), + # : ConnectionException('Host has been marked down or removed',), + # : ConnectionException('Host has been marked down or removed',)}) + with self.assertRaises(NoHostAvailable): + session.execute(simple_stmt) + + def _check_set_keyspace_in_statement(self, session): + simple_stmt = SimpleStatement("SELECT * from {}".format(self.table_name), keyspace=self.ks_name) + results = session.execute(simple_stmt) + self.assertEqual(results[0], (1, 1)) + + simple_stmt = SimpleStatement("SELECT * from {}".format(self.table_name)) + simple_stmt.keyspace = self.ks_name + results = session.execute(simple_stmt) + self.assertEqual(results[0], (1, 1)) + + +@greaterthanorequalcass40 +class BatchWithKeyspaceTests(QueryKeyspaceTests, unittest.TestCase): + def _check_set_keyspace_in_statement(self, session): + batch_stmt = BatchStatement() + for i in range(10): + batch_stmt.add("INSERT INTO {} (k, v) VALUES (%s, %s)".format(self.table_name), (i, i)) + + batch_stmt.keyspace = self.ks_name + session.execute(batch_stmt) + self.confirm_results() + + def confirm_results(self): + keys = set() + values = set() + # Assuming the test data is inserted at default CL.ONE, we need ALL here to guarantee we see + # everything inserted + results = self.session.execute(SimpleStatement("SELECT * FROM {}.{}".format(self.ks_name, self.table_name), + consistency_level=ConsistencyLevel.ALL)) + for result in results: + keys.add(result.k) + values.add(result.v) + + self.assertEqual(set(range(10)), keys, msg=results) + self.assertEqual(set(range(10)), values, msg=results) + + +@greaterthanorequalcass40 +class PreparedWithKeyspaceTests(BaseKeyspaceTests, unittest.TestCase): + + def setUp(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, allow_beta_protocol_version=True) + self.session = self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + def test_prepared_with_keyspace_explicit(self): + """ + Test the basic functionality of PYTHON-678, the keyspace can be set + independently of the query and read the results + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + query = "SELECT * from {} WHERE k = ?".format(self.table_name) + prepared_statement = self.session.prepare(query, keyspace=self.ks_name) + + results = self.session.execute(prepared_statement, (1, )) + self.assertEqual(results[0], (1, 1)) + + prepared_statement_alternative = self.session.prepare(query, keyspace=self.alternative_ks) + + self.assertNotEqual(prepared_statement.query_id, prepared_statement_alternative.query_id) + + results = self.session.execute(prepared_statement_alternative, (2,)) + self.assertEqual(results[0], (2, 2)) + + def test_reprepare_after_host_is_down(self): + """ + Test that Cluster._prepare_all_queries is called and the + when a node comes up and the queries succeed later + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + mock_handler = MockLoggingHandler() + logger = logging.getLogger(cluster.__name__) + logger.addHandler(mock_handler) + get_node(1).stop(wait=True, gently=True, wait_other_notice=True) + + only_first = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(["127.0.0.1"])) + self.cluster.add_execution_profile("only_first", only_first) + + query = "SELECT v from {} WHERE k = ?".format(self.table_name) + prepared_statement = self.session.prepare(query, keyspace=self.ks_name) + prepared_statement_alternative = self.session.prepare(query, keyspace=self.alternative_ks) + + get_node(1).start(wait_for_binary_proto=True, wait_other_notice=True) + + # We wait for cluster._prepare_all_queries to be called + time.sleep(5) + self.assertEqual(1, mock_handler.get_message_count('debug', 'Preparing all known prepared statements')) + results = self.session.execute(prepared_statement, (1,), execution_profile="only_first") + self.assertEqual(results[0], (1, )) + + results = self.session.execute(prepared_statement_alternative, (2,), execution_profile="only_first") + self.assertEqual(results[0], (2, )) + + def test_prepared_not_found(self): + """ + Test to if a query fails on a node that didn't have + the query prepared, it is re-prepared as expected and then + the query is executed + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + cluster = Cluster(protocol_version=PROTOCOL_VERSION, allow_beta_protocol_version=True) + session = self.cluster.connect("system") + self.addCleanup(cluster.shutdown) + + cluster.prepare_on_all_hosts = False + query = "SELECT k from {} WHERE k = ?".format(self.table_name) + prepared_statement = session.prepare(query, keyspace=self.ks_name) + + for _ in range(10): + results = session.execute(prepared_statement, (1, )) + self.assertEqual(results[0], (1,)) + + def test_prepared_in_query_keyspace(self): + """ + Test to the the keyspace can be set in the query + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the results are retrieved correctly + + @test_category query + """ + cluster = Cluster(protocol_version=PROTOCOL_VERSION, allow_beta_protocol_version=True) + session = self.cluster.connect() + self.addCleanup(cluster.shutdown) + + query = "SELECT k from {}.{} WHERE k = ?".format(self.ks_name, self.table_name) + prepared_statement = session.prepare(query) + results = session.execute(prepared_statement, (1,)) + self.assertEqual(results[0], (1,)) + + query = "SELECT k from {}.{} WHERE k = ?".format(self.alternative_ks, self.table_name) + prepared_statement = session.prepare(query) + results = session.execute(prepared_statement, (2,)) + self.assertEqual(results[0], (2,)) + + def test_prepared_in_query_keyspace_and_explicit(self): + """ + Test to the the keyspace set explicitly is ignored if it is + specified as well in the query + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the keyspace set explicitly is ignored and + the results are retrieved correctly + + @test_category query + """ + query = "SELECT k from {}.{} WHERE k = ?".format(self.ks_name, self.table_name) + prepared_statement = self.session.prepare(query, keyspace="system") + results = self.session.execute(prepared_statement, (1,)) + self.assertEqual(results[0], (1,)) + + query = "SELECT k from {}.{} WHERE k = ?".format(self.ks_name, self.table_name) + prepared_statement = self.session.prepare(query, keyspace=self.alternative_ks) + results = self.session.execute(prepared_statement, (1,)) + self.assertEqual(results[0], (1,)) diff --git a/tests/integration/standard/test_query_paging.py b/tests/integration/standard/test_query_paging.py new file mode 100644 index 0000000..dfe9f70 --- /dev/null +++ b/tests/integration/standard/test_query_paging.py @@ -0,0 +1,413 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.integration import use_singledc, PROTOCOL_VERSION + +import logging +log = logging.getLogger(__name__) + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from itertools import cycle, count +from six.moves import range +from threading import Event + +from cassandra.cluster import Cluster +from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args +from cassandra.policies import HostDistance +from cassandra.query import SimpleStatement + + +def setup_module(): + use_singledc() + + +class QueryPagingTests(unittest.TestCase): + + def setUp(self): + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest( + "Protocol 2.0+ is required for Paging state, currently testing against %r" + % (PROTOCOL_VERSION,)) + + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + if PROTOCOL_VERSION < 3: + self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) + self.session = self.cluster.connect(wait_for_all_pools=True) + self.session.execute("TRUNCATE test3rf.test") + + def tearDown(self): + self.cluster.shutdown() + + def test_paging(self): + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + execute_concurrent(self.session, list(statements_and_params)) + + prepared = self.session.prepare("SELECT * FROM test3rf.test") + + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): + self.session.default_fetch_size = fetch_size + self.assertEqual(100, len(list(self.session.execute("SELECT * FROM test3rf.test")))) + + statement = SimpleStatement("SELECT * FROM test3rf.test") + self.assertEqual(100, len(list(self.session.execute(statement)))) + + self.assertEqual(100, len(list(self.session.execute(prepared)))) + + def test_paging_state(self): + """ + Test to validate paging state api + @since 3.7.0 + @jira_ticket PYTHON-200 + @expected_result paging state should returned should be accurate, and allow for queries to be resumed. + + @test_category queries + """ + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + execute_concurrent(self.session, list(statements_and_params)) + + list_all_results = [] + self.session.default_fetch_size = 3 + + result_set = self.session.execute("SELECT * FROM test3rf.test") + while(result_set.has_more_pages): + for row in result_set.current_rows: + self.assertNotIn(row, list_all_results) + list_all_results.extend(result_set.current_rows) + page_state = result_set.paging_state + result_set = self.session.execute("SELECT * FROM test3rf.test", paging_state=page_state) + + if(len(result_set.current_rows) > 0): + list_all_results.append(result_set.current_rows) + self.assertEqual(len(list_all_results), 100) + + def test_paging_verify_writes(self): + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + execute_concurrent(self.session, statements_and_params) + + prepared = self.session.prepare("SELECT * FROM test3rf.test") + + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): + self.session.default_fetch_size = fetch_size + results = self.session.execute("SELECT * FROM test3rf.test") + result_array = set() + result_set = set() + for result in results: + result_array.add(result.k) + result_set.add(result.v) + + self.assertEqual(set(range(100)), result_array) + self.assertEqual(set([0]), result_set) + + statement = SimpleStatement("SELECT * FROM test3rf.test") + results = self.session.execute(statement) + result_array = set() + result_set = set() + for result in results: + result_array.add(result.k) + result_set.add(result.v) + + self.assertEqual(set(range(100)), result_array) + self.assertEqual(set([0]), result_set) + + results = self.session.execute(prepared) + result_array = set() + result_set = set() + for result in results: + result_array.add(result.k) + result_set.add(result.v) + + self.assertEqual(set(range(100)), result_array) + self.assertEqual(set([0]), result_set) + + def test_paging_verify_with_composite_keys(self): + ddl = ''' + CREATE TABLE test3rf.test_paging_verify_2 ( + k1 int, + k2 int, + v int, + PRIMARY KEY(k1, k2) + )''' + self.session.execute(ddl) + + statements_and_params = zip(cycle(["INSERT INTO test3rf.test_paging_verify_2 " + "(k1, k2, v) VALUES (0, %s, %s)"]), + [(i, i + 1) for i in range(100)]) + execute_concurrent(self.session, statements_and_params) + + prepared = self.session.prepare("SELECT * FROM test3rf.test_paging_verify_2") + + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): + self.session.default_fetch_size = fetch_size + results = self.session.execute("SELECT * FROM test3rf.test_paging_verify_2") + result_array = [] + value_array = [] + for result in results: + result_array.append(result.k2) + value_array.append(result.v) + + self.assertSequenceEqual(range(100), result_array) + self.assertSequenceEqual(range(1, 101), value_array) + + statement = SimpleStatement("SELECT * FROM test3rf.test_paging_verify_2") + results = self.session.execute(statement) + result_array = [] + value_array = [] + for result in results: + result_array.append(result.k2) + value_array.append(result.v) + + self.assertSequenceEqual(range(100), result_array) + self.assertSequenceEqual(range(1, 101), value_array) + + results = self.session.execute(prepared) + result_array = [] + value_array = [] + for result in results: + result_array.append(result.k2) + value_array.append(result.v) + + self.assertSequenceEqual(range(100), result_array) + self.assertSequenceEqual(range(1, 101), value_array) + + def test_async_paging(self): + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + execute_concurrent(self.session, list(statements_and_params)) + + prepared = self.session.prepare("SELECT * FROM test3rf.test") + + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): + self.session.default_fetch_size = fetch_size + self.assertEqual(100, len(list(self.session.execute_async("SELECT * FROM test3rf.test").result()))) + + statement = SimpleStatement("SELECT * FROM test3rf.test") + self.assertEqual(100, len(list(self.session.execute_async(statement).result()))) + + self.assertEqual(100, len(list(self.session.execute_async(prepared).result()))) + + def test_async_paging_verify_writes(self): + ddl = ''' + CREATE TABLE test3rf.test_async_paging_verify ( + k1 int, + k2 int, + v int, + PRIMARY KEY(k1, k2) + )''' + self.session.execute(ddl) + + statements_and_params = zip(cycle(["INSERT INTO test3rf.test_async_paging_verify " + "(k1, k2, v) VALUES (0, %s, %s)"]), + [(i, i + 1) for i in range(100)]) + execute_concurrent(self.session, statements_and_params) + + prepared = self.session.prepare("SELECT * FROM test3rf.test_async_paging_verify") + + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): + self.session.default_fetch_size = fetch_size + results = self.session.execute_async("SELECT * FROM test3rf.test_async_paging_verify").result() + result_array = [] + value_array = [] + for result in results: + result_array.append(result.k2) + value_array.append(result.v) + + self.assertSequenceEqual(range(100), result_array) + self.assertSequenceEqual(range(1, 101), value_array) + + statement = SimpleStatement("SELECT * FROM test3rf.test_async_paging_verify") + results = self.session.execute_async(statement).result() + result_array = [] + value_array = [] + for result in results: + result_array.append(result.k2) + value_array.append(result.v) + + self.assertSequenceEqual(range(100), result_array) + self.assertSequenceEqual(range(1, 101), value_array) + + results = self.session.execute_async(prepared).result() + result_array = [] + value_array = [] + for result in results: + result_array.append(result.k2) + value_array.append(result.v) + + self.assertSequenceEqual(range(100), result_array) + self.assertSequenceEqual(range(1, 101), value_array) + + def test_paging_callbacks(self): + """ + Test to validate callback api + @since 3.9.0 + @jira_ticket PYTHON-733 + @expected_result callbacks shouldn't be called twice per message + and the fetch_size should be handled in a transparent way to the user + + @test_category queries + """ + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + execute_concurrent(self.session, list(statements_and_params)) + + prepared = self.session.prepare("SELECT * FROM test3rf.test") + + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): + self.session.default_fetch_size = fetch_size + future = self.session.execute_async("SELECT * FROM test3rf.test", timeout=20) + + event = Event() + counter = count() + number_of_calls = count() + + def handle_page(rows, future, counter, number_of_calls): + next(number_of_calls) + for row in rows: + next(counter) + + if future.has_more_pages: + future.start_fetching_next_page() + else: + event.set() + + def handle_error(err): + event.set() + self.fail(err) + + future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls), + errback=handle_error) + event.wait() + self.assertEqual(next(number_of_calls), 100 // fetch_size + 1) + self.assertEqual(next(counter), 100) + + # simple statement + future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"), timeout=20) + event.clear() + counter = count() + number_of_calls = count() + + future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls), + errback=handle_error) + event.wait() + self.assertEqual(next(number_of_calls), 100 // fetch_size + 1) + self.assertEqual(next(counter), 100) + + # prepared statement + future = self.session.execute_async(prepared, timeout=20) + event.clear() + counter = count() + number_of_calls = count() + + future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls), + errback=handle_error) + event.wait() + self.assertEqual(next(number_of_calls), 100 // fetch_size + 1) + self.assertEqual(next(counter), 100) + + def test_concurrent_with_paging(self): + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + execute_concurrent(self.session, list(statements_and_params)) + + prepared = self.session.prepare("SELECT * FROM test3rf.test") + + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): + self.session.default_fetch_size = fetch_size + results = execute_concurrent_with_args(self.session, prepared, [None] * 10) + self.assertEqual(10, len(results)) + for (success, result) in results: + self.assertTrue(success) + self.assertEqual(100, len(list(result))) + + def test_fetch_size(self): + """ + Ensure per-statement fetch_sizes override the default fetch size. + """ + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + execute_concurrent(self.session, list(statements_and_params)) + + prepared = self.session.prepare("SELECT * FROM test3rf.test") + + self.session.default_fetch_size = 10 + result = self.session.execute(prepared, []) + self.assertTrue(result.has_more_pages) + + self.session.default_fetch_size = 2000 + result = self.session.execute(prepared, []) + self.assertFalse(result.has_more_pages) + + self.session.default_fetch_size = None + result = self.session.execute(prepared, []) + self.assertFalse(result.has_more_pages) + + self.session.default_fetch_size = 10 + + prepared.fetch_size = 2000 + result = self.session.execute(prepared, []) + self.assertFalse(result.has_more_pages) + + prepared.fetch_size = None + result = self.session.execute(prepared, []) + self.assertFalse(result.has_more_pages) + + prepared.fetch_size = 10 + result = self.session.execute(prepared, []) + self.assertTrue(result.has_more_pages) + + prepared.fetch_size = 2000 + bound = prepared.bind([]) + result = self.session.execute(bound, []) + self.assertFalse(result.has_more_pages) + + prepared.fetch_size = None + bound = prepared.bind([]) + result = self.session.execute(bound, []) + self.assertFalse(result.has_more_pages) + + prepared.fetch_size = 10 + bound = prepared.bind([]) + result = self.session.execute(bound, []) + self.assertTrue(result.has_more_pages) + + bound.fetch_size = 2000 + result = self.session.execute(bound, []) + self.assertFalse(result.has_more_pages) + + bound.fetch_size = None + result = self.session.execute(bound, []) + self.assertFalse(result.has_more_pages) + + bound.fetch_size = 10 + result = self.session.execute(bound, []) + self.assertTrue(result.has_more_pages) + + s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None) + result = self.session.execute(s, []) + self.assertFalse(result.has_more_pages) + + s = SimpleStatement("SELECT * FROM test3rf.test") + result = self.session.execute(s, []) + self.assertTrue(result.has_more_pages) + + s = SimpleStatement("SELECT * FROM test3rf.test") + s.fetch_size = None + result = self.session.execute(s, []) + self.assertFalse(result.has_more_pages) diff --git a/tests/integration/standard/test_routing.py b/tests/integration/standard/test_routing.py new file mode 100644 index 0000000..bf4c787 --- /dev/null +++ b/tests/integration/standard/test_routing.py @@ -0,0 +1,97 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from uuid import uuid1 + +import logging +log = logging.getLogger(__name__) + +from cassandra.cluster import Cluster + +from tests.integration import use_singledc, PROTOCOL_VERSION + + +def setup_module(): + use_singledc() + + +class RoutingTests(unittest.TestCase): + + @property + def cfname(self): + return self._testMethodName.lower() + + @classmethod + def setup_class(cls): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.session = cls.cluster.connect('test1rf') + + @classmethod + def teardown_class(cls): + cls.cluster.shutdown() + + def insert_select_token(self, insert, select, key_values): + s = self.session + + bound = insert.bind(key_values) + s.execute(bound) + + my_token = s.cluster.metadata.token_map.token_class.from_key(bound.routing_key) + + cass_token = s.execute(select, key_values)[0][0] + token = s.cluster.metadata.token_map.token_class(cass_token) + self.assertEqual(my_token, token) + + def create_prepare(self, key_types): + s = self.session + table_name = "%s_%s" % (self.cfname, '_'.join(key_types)) + key_count = len(key_types) + key_decl = ', '.join(['k%d %s' % (i, t) for i, t in enumerate(key_types)]) + primary_key = ', '.join(['k%d' % i for i in range(key_count)]) + s.execute("CREATE TABLE %s (%s, v int, PRIMARY KEY ((%s)))" % + (table_name, key_decl, primary_key)) + + parameter_places = ', '.join(['?'] * key_count) + insert = s.prepare("INSERT INTO %s (%s, v) VALUES (%s, 1)" % + (table_name, primary_key, parameter_places)) + + where_clause = ' AND '.join(['k%d = ?' % i for i in range(key_count)]) + select = s.prepare("SELECT token(%s) FROM %s WHERE %s" % + (primary_key, table_name, where_clause)) + + return insert, select + + def test_singular_key(self): + # string + insert, select = self.create_prepare(('text',)) + self.insert_select_token(insert, select, ('some text value',)) + + # non-string + insert, select = self.create_prepare(('bigint',)) + self.insert_select_token(insert, select, (12390890177098123,)) + + def test_composite(self): + # double bool + insert, select = self.create_prepare(('double', 'boolean')) + self.insert_select_token(insert, select, (3.1459, True)) + self.insert_select_token(insert, select, (1.21e9, False)) + + # uuid string int + insert, select = self.create_prepare(('timeuuid', 'varchar', 'int')) + self.insert_select_token(insert, select, (uuid1(), 'asdf', 400)) + self.insert_select_token(insert, select, (uuid1(), 'fdsa', -1)) diff --git a/tests/integration/standard/test_row_factories.py b/tests/integration/standard/test_row_factories.py new file mode 100644 index 0000000..df709c3 --- /dev/null +++ b/tests/integration/standard/test_row_factories.py @@ -0,0 +1,259 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCaseWFunctionTable, BasicSharedKeyspaceUnitTestCase, execute_until_pass + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.cluster import Cluster, ResultSet +from cassandra.query import tuple_factory, named_tuple_factory, dict_factory, ordered_dict_factory +from cassandra.util import OrderedDict + + +def setup_module(): + use_singledc() + + +class NameTupleFactory(BasicSharedKeyspaceUnitTestCase): + + def setUp(self): + super(NameTupleFactory, self).setUp() + self.session.row_factory = named_tuple_factory + ddl = ''' + CREATE TABLE {0}.{1} ( + k int PRIMARY KEY, + v1 text, + v2 text, + v3 text)'''.format(self.ks_name, self.function_table_name) + self.session.execute(ddl) + execute_until_pass(self.session, ddl) + + def test_sanitizing(self): + """ + Test to ensure that same named results are surfaced in the NamedTupleFactory + + Creates a table with a few different text fields. Inserts a few values in that table. + It then fetches the values and confirms that despite all be being selected as the same name + they are propagated in the result set differently. + + @since 3.3 + @jira_ticket PYTHON-467 + @expected_result duplicate named results have unique row names. + + @test_category queries + """ + + for x in range(5): + insert1 = ''' + INSERT INTO {0}.{1} + ( k , v1, v2, v3 ) + VALUES + ( 1 , 'v1{2}', 'v2{2}','v3{2}' ) + '''.format(self.keyspace_name, self.function_table_name, str(x)) + self.session.execute(insert1) + + query = "SELECT v1 AS duplicate, v2 AS duplicate, v3 AS duplicate from {0}.{1}".format(self.ks_name, self.function_table_name) + rs = self.session.execute(query) + row = rs[0] + self.assertTrue(hasattr(row, 'duplicate')) + self.assertTrue(hasattr(row, 'duplicate_')) + self.assertTrue(hasattr(row, 'duplicate__')) + + +class RowFactoryTests(BasicSharedKeyspaceUnitTestCaseWFunctionTable): + """ + Test different row_factories and access code + """ + def setUp(self): + super(RowFactoryTests, self).setUp() + self.insert1 = ''' + INSERT INTO {0}.{1} + ( k , v ) + VALUES + ( 1 , 1 ) + '''.format(self.keyspace_name, self.function_table_name) + + self.insert2 = ''' + INSERT INTO {0}.{1} + ( k , v ) + VALUES + ( 2 , 2 ) + '''.format(self.keyspace_name, self.function_table_name) + + self.select = ''' + SELECT * FROM {0}.{1} + '''.format(self.keyspace_name, self.function_table_name) + + def tearDown(self): + self.drop_function_table() + + def test_tuple_factory(self): + session = self.session + session.row_factory = tuple_factory + + session.execute(self.insert1) + session.execute(self.insert2) + + result = session.execute(self.select) + + self.assertIsInstance(result, ResultSet) + self.assertIsInstance(result[0], tuple) + + for row in result: + self.assertEqual(row[0], row[1]) + + self.assertEqual(result[0][0], result[0][1]) + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], result[1][1]) + self.assertEqual(result[1][0], 2) + + def test_named_tuple_factory(self): + session = self.session + session.row_factory = named_tuple_factory + + session.execute(self.insert1) + session.execute(self.insert2) + + result = session.execute(self.select) + + self.assertIsInstance(result, ResultSet) + result = list(result) + + for row in result: + self.assertEqual(row.k, row.v) + + self.assertEqual(result[0].k, result[0].v) + self.assertEqual(result[0].k, 1) + self.assertEqual(result[1].k, result[1].v) + self.assertEqual(result[1].k, 2) + + def test_dict_factory(self): + session = self.session + session.row_factory = dict_factory + + session.execute(self.insert1) + session.execute(self.insert2) + + result = session.execute(self.select) + + self.assertIsInstance(result, ResultSet) + self.assertIsInstance(result[0], dict) + + for row in result: + self.assertEqual(row['k'], row['v']) + + self.assertEqual(result[0]['k'], result[0]['v']) + self.assertEqual(result[0]['k'], 1) + self.assertEqual(result[1]['k'], result[1]['v']) + self.assertEqual(result[1]['k'], 2) + + def test_ordered_dict_factory(self): + session = self.session + session.row_factory = ordered_dict_factory + + session.execute(self.insert1) + session.execute(self.insert2) + + result = session.execute(self.select) + + self.assertIsInstance(result, ResultSet) + self.assertIsInstance(result[0], OrderedDict) + + for row in result: + self.assertEqual(row['k'], row['v']) + + self.assertEqual(result[0]['k'], result[0]['v']) + self.assertEqual(result[0]['k'], 1) + self.assertEqual(result[1]['k'], result[1]['v']) + self.assertEqual(result[1]['k'], 2) + + def test_generator_row_factory(self): + """ + Test that ResultSet.one() works with a row_factory that contains a generator. + + @since 3.16 + @jira_ticket PYTHON-1026 + @expected_result one() returns the first row + + @test_category queries + """ + def generator_row_factory(column_names, rows): + return _gen_row_factory(rows) + + def _gen_row_factory(rows): + for r in rows: + yield r + + session = self.session + session.row_factory = generator_row_factory + + session.execute(self.insert1) + result = session.execute(self.select) + self.assertIsInstance(result, ResultSet) + first_row = result.one() + self.assertEqual(first_row[0], first_row[1]) + + +class NamedTupleFactoryAndNumericColNamesTests(unittest.TestCase): + """ + Test for PYTHON-122: Improve Error Handling/Reporting for named_tuple_factory and Numeric Column Names + """ + @classmethod + def setup_class(cls): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.session = cls.cluster.connect() + cls._cass_version, cls._cql_version = get_server_versions() + ddl = ''' + CREATE TABLE test1rf.table_num_col ( key blob PRIMARY KEY, "626972746864617465" blob ) + WITH COMPACT STORAGE''' + cls.session.execute(ddl) + + @classmethod + def teardown_class(cls): + cls.session.execute("DROP TABLE test1rf.table_num_col") + cls.cluster.shutdown() + + def test_no_exception_on_select(self): + """ + no exception on SELECT for numeric column name + """ + try: + self.session.execute('SELECT * FROM test1rf.table_num_col') + except ValueError as e: + self.fail("Unexpected ValueError exception: %s" % e.message) + + def test_can_select_using_alias(self): + """ + can SELECT "" AS aliases + """ + if self._cass_version < (2, 0, 0): + raise unittest.SkipTest("Alias in SELECT not supported before 2.0") + + try: + self.session.execute('SELECT key, "626972746864617465" AS my_col from test1rf.table_num_col') + except ValueError as e: + self.fail("Unexpected ValueError exception: %s" % e.message) + + def test_can_select_with_dict_factory(self): + """ + can SELECT numeric column using dict_factory + """ + self.session.row_factory = dict_factory + try: + self.session.execute('SELECT * FROM test1rf.table_num_col') + except ValueError as e: + self.fail("Unexpected ValueError exception: %s" % e.message) diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py new file mode 100644 index 0000000..b49ee06 --- /dev/null +++ b/tests/integration/standard/test_types.py @@ -0,0 +1,948 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from datetime import datetime +import math +import six + +import cassandra +from cassandra import InvalidRequest +from cassandra.cluster import Cluster +from cassandra.concurrent import execute_concurrent_with_args +from cassandra.cqltypes import Int32Type, EMPTY +from cassandra.query import dict_factory, ordered_dict_factory +from cassandra.util import sortedset, Duration +from tests.unit.cython.utils import cythontest + +from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1, \ + BasicSharedKeyspaceUnitTestCase, greaterthancass21, lessthancass30, greaterthanorequalcass3_10 +from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, PRIMITIVE_DATATYPES_KEYS, \ + get_sample, get_all_samples, get_collection_sample + + +def setup_module(): + use_singledc() + update_datatypes() + + +class TypeTests(BasicSharedKeyspaceUnitTestCase): + + @classmethod + def setUpClass(cls): + # cls._cass_version, cls. = get_server_versions() + super(TypeTests, cls).setUpClass() + cls.session.set_keyspace(cls.ks_name) + + def test_can_insert_blob_type_as_string(self): + """ + Tests that byte strings in Python maps to blob type in Cassandra + """ + s = self.session + + s.execute("CREATE TABLE blobstring (a ascii PRIMARY KEY, b blob)") + + params = ['key1', b'blobbyblob'] + query = "INSERT INTO blobstring (a, b) VALUES (%s, %s)" + + # In python2, with Cassandra > 2.0, we don't treat the 'byte str' type as a blob, so we'll encode it + # as a string literal and have the following failure. + if six.PY2 and self.cql_version >= (3, 1, 0): + # Blob values can't be specified using string notation in CQL 3.1.0 and + # above which is used by default in Cassandra 2.0. + if self.cass_version >= (2, 1, 0): + msg = r'.*Invalid STRING constant \(.*?\) for "b" of type blob.*' + else: + msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*' + self.assertRaisesRegexp(InvalidRequest, msg, s.execute, query, params) + return + + # In python2, with Cassandra < 2.0, we can manually encode the 'byte str' type as hex for insertion in a blob. + if six.PY2: + cass_params = [params[0], params[1].encode('hex')] + s.execute(query, cass_params) + # In python 3, the 'bytes' type is treated as a blob, so we can correctly encode it with hex notation. + else: + s.execute(query, params) + + results = s.execute("SELECT * FROM blobstring")[0] + for expected, actual in zip(params, results): + self.assertEqual(expected, actual) + + def test_can_insert_blob_type_as_bytearray(self): + """ + Tests that blob type in Cassandra maps to bytearray in Python + """ + s = self.session + + s.execute("CREATE TABLE blobbytes (a ascii PRIMARY KEY, b blob)") + + params = ['key1', bytearray(b'blob1')] + s.execute("INSERT INTO blobbytes (a, b) VALUES (%s, %s)", params) + + results = s.execute("SELECT * FROM blobbytes")[0] + for expected, actual in zip(params, results): + self.assertEqual(expected, actual) + + @unittest.skipIf(not hasattr(cassandra, 'deserializers'), "Cython required for to test DesBytesTypeArray deserializer") + def test_des_bytes_type_array(self): + """ + Simple test to ensure the DesBytesTypeByteArray deserializer functionally works + + @since 3.1 + @jira_ticket PYTHON-503 + @expected_result byte array should be deserialized appropriately. + + @test_category queries:custom_payload + """ + original = None + try: + + original = cassandra.deserializers.DesBytesType + cassandra.deserializers.DesBytesType = cassandra.deserializers.DesBytesTypeByteArray + s = self.session + + s.execute("CREATE TABLE blobbytes2 (a ascii PRIMARY KEY, b blob)") + + params = ['key1', bytearray(b'blob1')] + s.execute("INSERT INTO blobbytes2 (a, b) VALUES (%s, %s)", params) + + results = s.execute("SELECT * FROM blobbytes2")[0] + for expected, actual in zip(params, results): + self.assertEqual(expected, actual) + finally: + if original is not None: + cassandra.deserializers.DesBytesType=original + + def test_can_insert_primitive_datatypes(self): + """ + Test insertion of all datatype primitives + """ + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name) + + # create table + alpha_type_list = ["zz int PRIMARY KEY"] + col_names = ["zz"] + start_index = ord('a') + for i, datatype in enumerate(PRIMITIVE_DATATYPES): + alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype)) + col_names.append(chr(start_index + i)) + + s.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list))) + + # create the input + params = [0] + for datatype in PRIMITIVE_DATATYPES: + params.append((get_sample(datatype))) + + # insert into table as a simple statement + columns_string = ', '.join(col_names) + placeholders = ', '.join(["%s"] * len(col_names)) + s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params) + + # verify data + results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0] + for expected, actual in zip(params, results): + self.assertEqual(actual, expected) + + # try the same thing sending one insert at the time + s.execute("TRUNCATE alltypes;") + for i, datatype in enumerate(PRIMITIVE_DATATYPES): + single_col_name = chr(start_index + i) + single_col_names = ["zz", single_col_name] + placeholders = ','.join(["%s"] * len(single_col_names)) + single_columns_string = ', '.join(single_col_names) + for j, data_sample in enumerate(get_all_samples(datatype)): + key = i + 1000 * j + single_params = (key, data_sample) + s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(single_columns_string, placeholders), + single_params) + # verify data + result = s.execute("SELECT {0} FROM alltypes WHERE zz=%s".format(single_columns_string), (key,))[0][1] + compare_value = data_sample + if six.PY3: + import ipaddress + if isinstance(data_sample, ipaddress.IPv4Address) or isinstance(data_sample, ipaddress.IPv6Address): + compare_value = str(data_sample) + self.assertEqual(result, compare_value) + + # try the same thing with a prepared statement + placeholders = ','.join(["?"] * len(col_names)) + s.execute("TRUNCATE alltypes;") + insert = s.prepare("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders)) + s.execute(insert.bind(params)) + + # verify data + results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0] + for expected, actual in zip(params, results): + self.assertEqual(actual, expected) + + # verify data with prepared statement query + select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string)) + results = s.execute(select.bind([0]))[0] + for expected, actual in zip(params, results): + self.assertEqual(actual, expected) + + # verify data with with prepared statement, use dictionary with no explicit columns + s.row_factory = ordered_dict_factory + select = s.prepare("SELECT * FROM alltypes") + results = s.execute(select)[0] + + for expected, actual in zip(params, results.values()): + self.assertEqual(actual, expected) + + c.shutdown() + + def test_can_insert_collection_datatypes(self): + """ + Test insertion of all collection types + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name) + # use tuple encoding, to convert native python tuple into raw CQL + s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple + + # create table + alpha_type_list = ["zz int PRIMARY KEY"] + col_names = ["zz"] + start_index = ord('a') + for i, collection_type in enumerate(COLLECTION_TYPES): + for j, datatype in enumerate(PRIMITIVE_DATATYPES_KEYS): + if collection_type == "map": + type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j), + collection_type, datatype) + elif collection_type == "tuple": + type_string = "{0}_{1} frozen<{2}<{3}>>".format(chr(start_index + i), chr(start_index + j), + collection_type, datatype) + else: + type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j), + collection_type, datatype) + alpha_type_list.append(type_string) + col_names.append("{0}_{1}".format(chr(start_index + i), chr(start_index + j))) + + s.execute("CREATE TABLE allcoltypes ({0})".format(', '.join(alpha_type_list))) + columns_string = ', '.join(col_names) + + # create the input for simple statement + params = [0] + for collection_type in COLLECTION_TYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: + params.append((get_collection_sample(collection_type, datatype))) + + # insert into table as a simple statement + placeholders = ', '.join(["%s"] * len(col_names)) + s.execute("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params) + + # verify data + results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string))[0] + for expected, actual in zip(params, results): + self.assertEqual(actual, expected) + + # create the input for prepared statement + params = [0] + for collection_type in COLLECTION_TYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: + params.append((get_collection_sample(collection_type, datatype))) + + # try the same thing with a prepared statement + placeholders = ','.join(["?"] * len(col_names)) + insert = s.prepare("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders)) + s.execute(insert.bind(params)) + + # verify data + results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string))[0] + for expected, actual in zip(params, results): + self.assertEqual(actual, expected) + + # verify data with prepared statement query + select = s.prepare("SELECT {0} FROM allcoltypes WHERE zz=?".format(columns_string)) + results = s.execute(select.bind([0]))[0] + for expected, actual in zip(params, results): + self.assertEqual(actual, expected) + + # verify data with with prepared statement, use dictionary with no explicit columns + s.row_factory = ordered_dict_factory + select = s.prepare("SELECT * FROM allcoltypes") + results = s.execute(select)[0] + + for expected, actual in zip(params, results.values()): + self.assertEqual(actual, expected) + + c.shutdown() + + def test_can_insert_empty_strings_and_nulls(self): + """ + Test insertion of empty strings and null values + """ + s = self.session + + # create table + alpha_type_list = ["zz int PRIMARY KEY"] + col_names = [] + string_types = set(('ascii', 'text', 'varchar')) + string_columns = set(('')) + # this is just a list of types to try with empty strings + non_string_types = PRIMITIVE_DATATYPES - string_types - set(('blob', 'date', 'inet', 'time', 'timestamp')) + non_string_columns = set() + start_index = ord('a') + for i, datatype in enumerate(PRIMITIVE_DATATYPES): + col_name = chr(start_index + i) + alpha_type_list.append("{0} {1}".format(col_name, datatype)) + col_names.append(col_name) + if datatype in non_string_types: + non_string_columns.add(col_name) + if datatype in string_types: + string_columns.add(col_name) + + execute_until_pass(s, "CREATE TABLE all_empty ({0})".format(', '.join(alpha_type_list))) + + # verify all types initially null with simple statement + columns_string = ','.join(col_names) + s.execute("INSERT INTO all_empty (zz) VALUES (2)") + results = s.execute("SELECT {0} FROM all_empty WHERE zz=2".format(columns_string))[0] + self.assertTrue(all(x is None for x in results)) + + # verify all types initially null with prepared statement + select = s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)) + results = s.execute(select.bind([2]))[0] + self.assertTrue(all(x is None for x in results)) + + # insert empty strings for string-like fields + expected_values = dict((col, '') for col in string_columns) + columns_string = ','.join(string_columns) + placeholders = ','.join(["%s"] * len(string_columns)) + s.execute("INSERT INTO all_empty (zz, {0}) VALUES (3, {1})".format(columns_string, placeholders), expected_values.values()) + + # verify string types empty with simple statement + results = s.execute("SELECT {0} FROM all_empty WHERE zz=3".format(columns_string))[0] + for expected, actual in zip(expected_values.values(), results): + self.assertEqual(actual, expected) + + # verify string types empty with prepared statement + results = s.execute(s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)), [3])[0] + for expected, actual in zip(expected_values.values(), results): + self.assertEqual(actual, expected) + + # non-string types shouldn't accept empty strings + for col in non_string_columns: + query = "INSERT INTO all_empty (zz, {0}) VALUES (4, %s)".format(col) + with self.assertRaises(InvalidRequest): + s.execute(query, ['']) + + insert = s.prepare("INSERT INTO all_empty (zz, {0}) VALUES (4, ?)".format(col)) + with self.assertRaises(TypeError): + s.execute(insert, ['']) + + # verify that Nones can be inserted and overwrites existing data + # create the input + params = [] + for datatype in PRIMITIVE_DATATYPES: + params.append((get_sample(datatype))) + + # insert the data + columns_string = ','.join(col_names) + placeholders = ','.join(["%s"] * len(col_names)) + simple_insert = "INSERT INTO all_empty (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders) + s.execute(simple_insert, params) + + # then insert None, which should null them out + null_values = [None] * len(col_names) + s.execute(simple_insert, null_values) + + # check via simple statement + query = "SELECT {0} FROM all_empty WHERE zz=5".format(columns_string) + results = s.execute(query)[0] + for col in results: + self.assertEqual(None, col) + + # check via prepared statement + select = s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)) + results = s.execute(select.bind([5]))[0] + for col in results: + self.assertEqual(None, col) + + # do the same thing again, but use a prepared statement to insert the nulls + s.execute(simple_insert, params) + + placeholders = ','.join(["?"] * len(col_names)) + insert = s.prepare("INSERT INTO all_empty (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders)) + s.execute(insert, null_values) + + results = s.execute(query)[0] + for col in results: + self.assertEqual(None, col) + + results = s.execute(select.bind([5]))[0] + for col in results: + self.assertEqual(None, col) + + def test_can_insert_empty_values_for_int32(self): + """ + Ensure Int32Type supports empty values + """ + s = self.session + + execute_until_pass(s, "CREATE TABLE empty_values (a text PRIMARY KEY, b int)") + execute_until_pass(s, "INSERT INTO empty_values (a, b) VALUES ('a', blobAsInt(0x))") + try: + Int32Type.support_empty_values = True + results = execute_until_pass(s, "SELECT b FROM empty_values WHERE a='a'")[0] + self.assertIs(EMPTY, results.b) + finally: + Int32Type.support_empty_values = False + + def test_timezone_aware_datetimes_are_timestamps(self): + """ + Ensure timezone-aware datetimes are converted to timestamps correctly + """ + + try: + import pytz + except ImportError as exc: + raise unittest.SkipTest('pytz is not available: %r' % (exc,)) + + dt = datetime(1997, 8, 29, 11, 14) + eastern_tz = pytz.timezone('US/Eastern') + eastern_tz.localize(dt) + + s = self.session + + s.execute("CREATE TABLE tz_aware (a ascii PRIMARY KEY, b timestamp)") + + # test non-prepared statement + s.execute("INSERT INTO tz_aware (a, b) VALUES ('key1', %s)", [dt]) + result = s.execute("SELECT b FROM tz_aware WHERE a='key1'")[0].b + self.assertEqual(dt.utctimetuple(), result.utctimetuple()) + + # test prepared statement + insert = s.prepare("INSERT INTO tz_aware (a, b) VALUES ('key2', ?)") + s.execute(insert.bind([dt])) + result = s.execute("SELECT b FROM tz_aware WHERE a='key2'")[0].b + self.assertEqual(dt.utctimetuple(), result.utctimetuple()) + + def test_can_insert_tuples(self): + """ + Basic test of tuple functionality + """ + + if self.cass_version < (2, 1, 0): + raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name) + + # use this encoder in order to insert tuples + s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple + + s.execute("CREATE TABLE tuple_type (a int PRIMARY KEY, b frozen>)") + + # test non-prepared statement + complete = ('foo', 123, True) + s.execute("INSERT INTO tuple_type (a, b) VALUES (0, %s)", parameters=(complete,)) + result = s.execute("SELECT b FROM tuple_type WHERE a=0")[0] + self.assertEqual(complete, result.b) + + partial = ('bar', 456) + partial_result = partial + (None,) + s.execute("INSERT INTO tuple_type (a, b) VALUES (1, %s)", parameters=(partial,)) + result = s.execute("SELECT b FROM tuple_type WHERE a=1")[0] + self.assertEqual(partial_result, result.b) + + # test single value tuples + subpartial = ('zoo',) + subpartial_result = subpartial + (None, None) + s.execute("INSERT INTO tuple_type (a, b) VALUES (2, %s)", parameters=(subpartial,)) + result = s.execute("SELECT b FROM tuple_type WHERE a=2")[0] + self.assertEqual(subpartial_result, result.b) + + # test prepared statement + prepared = s.prepare("INSERT INTO tuple_type (a, b) VALUES (?, ?)") + s.execute(prepared, parameters=(3, complete)) + s.execute(prepared, parameters=(4, partial)) + s.execute(prepared, parameters=(5, subpartial)) + + # extra items in the tuple should result in an error + self.assertRaises(ValueError, s.execute, prepared, parameters=(0, (1, 2, 3, 4, 5, 6))) + + prepared = s.prepare("SELECT b FROM tuple_type WHERE a=?") + self.assertEqual(complete, s.execute(prepared, (3,))[0].b) + self.assertEqual(partial_result, s.execute(prepared, (4,))[0].b) + self.assertEqual(subpartial_result, s.execute(prepared, (5,))[0].b) + + c.shutdown() + + def test_can_insert_tuples_with_varying_lengths(self): + """ + Test tuple types of lengths of 1, 2, 3, and 384 to ensure edge cases work + as expected. + """ + + if self.cass_version < (2, 1, 0): + raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name) + + # set the row_factory to dict_factory for programmatic access + # set the encoder for tuples for the ability to write tuples + s.row_factory = dict_factory + s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple + + # programmatically create the table with tuples of said sizes + lengths = (1, 2, 3, 384) + value_schema = [] + for i in lengths: + value_schema += [' v_%s frozen>' % (i, ', '.join(['int'] * i))] + s.execute("CREATE TABLE tuple_lengths (k int PRIMARY KEY, %s)" % (', '.join(value_schema),)) + + # insert tuples into same key using different columns + # and verify the results + for i in lengths: + # ensure tuples of larger sizes throw an error + created_tuple = tuple(range(0, i + 1)) + self.assertRaises(InvalidRequest, s.execute, "INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + + # ensure tuples of proper sizes are written and read correctly + created_tuple = tuple(range(0, i)) + + s.execute("INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + + result = s.execute("SELECT v_%s FROM tuple_lengths WHERE k=0", (i,))[0] + self.assertEqual(tuple(created_tuple), result['v_%s' % i]) + c.shutdown() + + def test_can_insert_tuples_all_primitive_datatypes(self): + """ + Ensure tuple subtypes are appropriately handled. + """ + + if self.cass_version < (2, 1, 0): + raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name) + s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple + + s.execute("CREATE TABLE tuple_primitive (" + "k int PRIMARY KEY, " + "v frozen>)" % ','.join(PRIMITIVE_DATATYPES)) + + values = [] + type_count = len(PRIMITIVE_DATATYPES) + for i, data_type in enumerate(PRIMITIVE_DATATYPES): + # create tuples to be written and ensure they match with the expected response + # responses have trailing None values for every element that has not been written + values.append(get_sample(data_type)) + expected = tuple(values + [None] * (type_count - len(values))) + s.execute("INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, tuple(values))) + result = s.execute("SELECT v FROM tuple_primitive WHERE k=%s", (i,))[0] + self.assertEqual(result.v, expected) + c.shutdown() + + def test_can_insert_tuples_all_collection_datatypes(self): + """ + Ensure tuple subtypes are appropriately handled for maps, sets, and lists. + """ + + if self.cass_version < (2, 1, 0): + raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name) + + # set the row_factory to dict_factory for programmatic access + # set the encoder for tuples for the ability to write tuples + s.row_factory = dict_factory + s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple + + values = [] + + # create list values + for datatype in PRIMITIVE_DATATYPES_KEYS: + values.append('v_{0} frozen>>'.format(len(values), datatype)) + + # create set values + for datatype in PRIMITIVE_DATATYPES_KEYS: + values.append('v_{0} frozen>>'.format(len(values), datatype)) + + # create map values + for datatype in PRIMITIVE_DATATYPES_KEYS: + datatype_1 = datatype_2 = datatype + if datatype == 'blob': + # unhashable type: 'bytearray' + datatype_1 = 'ascii' + values.append('v_{0} frozen>>'.format(len(values), datatype_1, datatype_2)) + + # make sure we're testing all non primitive data types in the future + if set(COLLECTION_TYPES) != set(['tuple', 'list', 'map', 'set']): + raise NotImplemented('Missing datatype not implemented: {}'.format( + set(COLLECTION_TYPES) - set(['tuple', 'list', 'map', 'set']) + )) + + # create table + s.execute("CREATE TABLE tuple_non_primative (" + "k int PRIMARY KEY, " + "%s)" % ', '.join(values)) + + i = 0 + # test tuple> + for datatype in PRIMITIVE_DATATYPES_KEYS: + created_tuple = tuple([[get_sample(datatype)]]) + s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + + result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,))[0] + self.assertEqual(created_tuple, result['v_%s' % i]) + i += 1 + + # test tuple> + for datatype in PRIMITIVE_DATATYPES_KEYS: + created_tuple = tuple([sortedset([get_sample(datatype)])]) + s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + + result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,))[0] + self.assertEqual(created_tuple, result['v_%s' % i]) + i += 1 + + # test tuple> + for datatype in PRIMITIVE_DATATYPES_KEYS: + if datatype == 'blob': + # unhashable type: 'bytearray' + created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}]) + else: + created_tuple = tuple([{get_sample(datatype): get_sample(datatype)}]) + + s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + + result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,))[0] + self.assertEqual(created_tuple, result['v_%s' % i]) + i += 1 + c.shutdown() + + def nested_tuples_schema_helper(self, depth): + """ + Helper method for creating nested tuple schema + """ + + if depth == 0: + return 'int' + else: + return 'tuple<%s>' % self.nested_tuples_schema_helper(depth - 1) + + def nested_tuples_creator_helper(self, depth): + """ + Helper method for creating nested tuples + """ + + if depth == 0: + return 303 + else: + return (self.nested_tuples_creator_helper(depth - 1), ) + + def test_can_insert_nested_tuples(self): + """ + Ensure nested are appropriately handled. + """ + + if self.cass_version < (2, 1, 0): + raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name) + + # set the row_factory to dict_factory for programmatic access + # set the encoder for tuples for the ability to write tuples + s.row_factory = dict_factory + s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple + + # create a table with multiple sizes of nested tuples + s.execute("CREATE TABLE nested_tuples (" + "k int PRIMARY KEY, " + "v_1 frozen<%s>," + "v_2 frozen<%s>," + "v_3 frozen<%s>," + "v_32 frozen<%s>" + ")" % (self.nested_tuples_schema_helper(1), + self.nested_tuples_schema_helper(2), + self.nested_tuples_schema_helper(3), + self.nested_tuples_schema_helper(32))) + + for i in (1, 2, 3, 32): + # create tuple + created_tuple = self.nested_tuples_creator_helper(i) + + # write tuple + s.execute("INSERT INTO nested_tuples (k, v_%s) VALUES (%s, %s)", (i, i, created_tuple)) + + # verify tuple was written and read correctly + result = s.execute("SELECT v_%s FROM nested_tuples WHERE k=%s", (i, i))[0] + self.assertEqual(created_tuple, result['v_%s' % i]) + c.shutdown() + + def test_can_insert_tuples_with_nulls(self): + """ + Test tuples with null and empty string fields. + """ + + if self.cass_version < (2, 1, 0): + raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") + + s = self.session + + s.execute("CREATE TABLE tuples_nulls (k int PRIMARY KEY, t frozen>)") + + insert = s.prepare("INSERT INTO tuples_nulls (k, t) VALUES (0, ?)") + s.execute(insert, [(None, None, None, None)]) + + result = s.execute("SELECT * FROM tuples_nulls WHERE k=0") + self.assertEqual((None, None, None, None), result[0].t) + + read = s.prepare("SELECT * FROM tuples_nulls WHERE k=0") + self.assertEqual((None, None, None, None), s.execute(read)[0].t) + + # also test empty strings where compatible + s.execute(insert, [('', None, None, b'')]) + result = s.execute("SELECT * FROM tuples_nulls WHERE k=0") + self.assertEqual(('', None, None, b''), result[0].t) + self.assertEqual(('', None, None, b''), s.execute(read)[0].t) + + def test_can_insert_unicode_query_string(self): + """ + Test to ensure unicode strings can be used in a query + """ + s = self.session + s.execute(u"SELECT * FROM system.local WHERE key = 'ef\u2052ef'") + s.execute(u"SELECT * FROM system.local WHERE key = %s", (u"fe\u2051fe",)) + + def test_can_read_composite_type(self): + """ + Test to ensure that CompositeTypes can be used in a query + """ + s = self.session + + s.execute(""" + CREATE TABLE composites ( + a int PRIMARY KEY, + b 'org.apache.cassandra.db.marshal.CompositeType(AsciiType, Int32Type)' + )""") + + # CompositeType string literals are split on ':' chars + s.execute("INSERT INTO composites (a, b) VALUES (0, 'abc:123')") + result = s.execute("SELECT * FROM composites WHERE a = 0")[0] + self.assertEqual(0, result.a) + self.assertEqual(('abc', 123), result.b) + + # CompositeType values can omit elements at the end + s.execute("INSERT INTO composites (a, b) VALUES (0, 'abc')") + result = s.execute("SELECT * FROM composites WHERE a = 0")[0] + self.assertEqual(0, result.a) + self.assertEqual(('abc',), result.b) + + @notprotocolv1 + def test_special_float_cql_encoding(self): + """ + Test to insure that Infinity -Infinity and NaN are supported by the python driver. + + @since 3.0.0 + @jira_ticket PYTHON-282 + @expected_result nan, inf and -inf can be inserted and selected correctly. + + @test_category data_types + """ + s = self.session + + s.execute(""" + CREATE TABLE float_cql_encoding ( + f float PRIMARY KEY, + d double + )""") + items = (float('nan'), float('inf'), float('-inf')) + + def verify_insert_select(ins_statement, sel_statement): + execute_concurrent_with_args(s, ins_statement, ((f, f) for f in items)) + for f in items: + row = s.execute(sel_statement, (f,))[0] + if math.isnan(f): + self.assertTrue(math.isnan(row.f)) + self.assertTrue(math.isnan(row.d)) + else: + self.assertEqual(row.f, f) + self.assertEqual(row.d, f) + + # cql encoding + verify_insert_select('INSERT INTO float_cql_encoding (f, d) VALUES (%s, %s)', + 'SELECT * FROM float_cql_encoding WHERE f=%s') + + s.execute("TRUNCATE float_cql_encoding") + + # prepared binding + verify_insert_select(s.prepare('INSERT INTO float_cql_encoding (f, d) VALUES (?, ?)'), + s.prepare('SELECT * FROM float_cql_encoding WHERE f=?')) + + @cythontest + def test_cython_decimal(self): + """ + Test to validate that decimal deserialization works correctly in with our cython extensions + + @since 3.0.0 + @jira_ticket PYTHON-212 + @expected_result no exceptions are thrown, decimal is decoded correctly + + @test_category data_types serialization + """ + + self.session.execute("CREATE TABLE {0} (dc decimal PRIMARY KEY)".format(self.function_table_name)) + try: + self.session.execute("INSERT INTO {0} (dc) VALUES (-1.08430792318105707)".format(self.function_table_name)) + results = self.session.execute("SELECT * FROM {0}".format(self.function_table_name)) + self.assertTrue(str(results[0].dc) == '-1.08430792318105707') + finally: + self.session.execute("DROP TABLE {0}".format(self.function_table_name)) + + @greaterthanorequalcass3_10 + def test_smoke_duration_values(self): + """ + Test to write several Duration values to the database and verify + they can be read correctly. The verify than an exception is arisen + if the value is too big + + @since 3.10 + @jira_ticket PYTHON-747 + @expected_result the read value in C* matches the written one + + @test_category data_types serialization + """ + self.session.execute(""" + CREATE TABLE duration_smoke (k int primary key, v duration) + """) + self.addCleanup(self.session.execute, "DROP TABLE duration_smoke") + + prepared = self.session.prepare(""" + INSERT INTO duration_smoke (k, v) + VALUES (?, ?) + """) + + nanosecond_smoke_values = [0, -1, 1, 100, 1000, 1000000, 1000000000, + 10000000000000,-9223372036854775807, 9223372036854775807, + int("7FFFFFFFFFFFFFFF", 16), int("-7FFFFFFFFFFFFFFF", 16)] + month_day_smoke_values = [0, -1, 1, 100, 1000, 1000000, 1000000000, + int("7FFFFFFF", 16), int("-7FFFFFFF", 16)] + + for nanosecond_value in nanosecond_smoke_values: + for month_day_value in month_day_smoke_values: + + # Must have the same sign + if (month_day_value <= 0) != (nanosecond_value <= 0): + continue + + self.session.execute(prepared, (1, Duration(month_day_value, month_day_value, nanosecond_value))) + results = self.session.execute("SELECT * FROM duration_smoke") + + v = results[0][1] + self.assertEqual(Duration(month_day_value, month_day_value, nanosecond_value), v, + "Error encoding value {0},{0},{1}".format(month_day_value, nanosecond_value)) + + self.assertRaises(ValueError, self.session.execute, prepared, + (1, Duration(0, 0, int("8FFFFFFFFFFFFFF0", 16)))) + self.assertRaises(ValueError, self.session.execute, prepared, + (1, Duration(0, int("8FFFFFFFFFFFFFF0", 16), 0))) + self.assertRaises(ValueError, self.session.execute, prepared, + (1, Duration(int("8FFFFFFFFFFFFFF0", 16), 0, 0))) + +class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase): + + @greaterthancass21 + @lessthancass30 + def test_nested_types_with_protocol_version(self): + """ + Test to validate that nested type serialization works on various protocol versions. Provided + the version of cassandra is greater the 2.1.3 we would expect to nested to types to work at all protocol versions. + + @since 3.0.0 + @jira_ticket PYTHON-215 + @expected_result no exceptions are thrown + + @test_category data_types serialization + """ + ddl = '''CREATE TABLE {0}.t ( + k int PRIMARY KEY, + v list>>)'''.format(self.keyspace_name) + + self.session.execute(ddl) + ddl = '''CREATE TABLE {0}.u ( + k int PRIMARY KEY, + v set>>)'''.format(self.keyspace_name) + self.session.execute(ddl) + ddl = '''CREATE TABLE {0}.v ( + k int PRIMARY KEY, + v map>, frozen>>, + v1 frozen>)'''.format(self.keyspace_name) + self.session.execute(ddl) + + self.session.execute("CREATE TYPE {0}.typ (v0 frozen>>>, v1 frozen>)".format(self.keyspace_name)) + + ddl = '''CREATE TABLE {0}.w ( + k int PRIMARY KEY, + v frozen)'''.format(self.keyspace_name) + + self.session.execute(ddl) + + for pvi in range(1, 5): + self.run_inserts_at_version(pvi) + for pvr in range(1, 5): + self.read_inserts_at_level(pvr) + + def read_inserts_at_level(self, proto_ver): + session = Cluster(protocol_version=proto_ver).connect(self.keyspace_name) + try: + results = session.execute('select * from t')[0] + self.assertEqual("[SortedSet([1, 2]), SortedSet([3, 5])]", str(results.v)) + + results = session.execute('select * from u')[0] + self.assertEqual("SortedSet([[1, 2], [3, 5]])", str(results.v)) + + results = session.execute('select * from v')[0] + self.assertEqual("{SortedSet([1, 2]): [1, 2, 3], SortedSet([3, 5]): [4, 5, 6]}", str(results.v)) + + results = session.execute('select * from w')[0] + self.assertEqual("typ(v0=OrderedMapSerializedKey([(1, [1, 2, 3]), (2, [4, 5, 6])]), v1=[7, 8, 9])", str(results.v)) + + finally: + session.cluster.shutdown() + + def run_inserts_at_version(self, proto_ver): + session = Cluster(protocol_version=proto_ver).connect(self.keyspace_name) + try: + p = session.prepare('insert into t (k, v) values (?, ?)') + session.execute(p, (0, [{1, 2}, {3, 5}])) + + p = session.prepare('insert into u (k, v) values (?, ?)') + session.execute(p, (0, {(1, 2), (3, 5)})) + + p = session.prepare('insert into v (k, v, v1) values (?, ?, ?)') + session.execute(p, (0, {(1, 2): [1, 2, 3], (3, 5): [4, 5, 6]}, (123, 'four'))) + + p = session.prepare('insert into w (k, v) values (?, ?)') + session.execute(p, (0, ({1: [1, 2, 3], 2: [4, 5, 6]}, [7, 8, 9]))) + + finally: + session.cluster.shutdown() + + + diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py new file mode 100644 index 0000000..514c562 --- /dev/null +++ b/tests/integration/standard/test_udts.py @@ -0,0 +1,760 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from collections import namedtuple +from functools import partial +import six + +from cassandra import InvalidRequest +from cassandra.cluster import Cluster, UserTypeDoesNotExist +from cassandra.query import dict_factory +from cassandra.util import OrderedMap + +from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass, BasicSegregatedKeyspaceUnitTestCase, \ + greaterthancass20, greaterthanorequalcass36, lessthancass30 +from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, PRIMITIVE_DATATYPES_KEYS, \ + COLLECTION_TYPES, get_sample, get_collection_sample + +nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's']) +nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u']) + + +def setup_module(): + use_singledc() + update_datatypes() + + +@greaterthancass20 +class UDTTests(BasicSegregatedKeyspaceUnitTestCase): + + @property + def table_name(self): + return self._testMethodName.lower() + + def setUp(self): + super(UDTTests, self).setUp() + self.session.set_keyspace(self.keyspace_name) + + @greaterthanorequalcass36 + def test_non_frozen_udts(self): + """ + Test to ensure that non frozen udt's work with C* >3.6. + + @since 3.7.0 + @jira_ticket PYTHON-498 + @expected_result Non frozen UDT's are supported + + @test_category data_types, udt + """ + self.session.execute("USE {0}".format(self.keyspace_name)) + self.session.execute("CREATE TYPE user (state text, has_corn boolean)") + self.session.execute("CREATE TABLE {0} (a int PRIMARY KEY, b user)".format(self.function_table_name)) + User = namedtuple('user', ('state', 'has_corn')) + self.cluster.register_user_type(self.keyspace_name, "user", User) + self.session.execute("INSERT INTO {0} (a, b) VALUES (%s, %s)".format(self.function_table_name), (0, User("Nebraska", True))) + self.session.execute("UPDATE {0} SET b.has_corn = False where a = 0".format(self.function_table_name)) + result = self.session.execute("SELECT * FROM {0}".format(self.function_table_name)) + self.assertFalse(result[0].b.has_corn) + table_sql = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].as_cql_query() + self.assertNotIn("", table_sql) + + def test_can_insert_unprepared_registered_udts(self): + """ + Test the insertion of unprepared, registered UDTs + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + + s.execute("CREATE TYPE user (age int, name text)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + User = namedtuple('user', ('age', 'name')) + c.register_user_type(self.keyspace_name, "user", User) + + s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob'))) + result = s.execute("SELECT b FROM mytable WHERE a=0") + row = result[0] + self.assertEqual(42, row.b.age) + self.assertEqual('bob', row.b.name) + self.assertTrue(type(row.b) is User) + + # use the same UDT name in a different keyspace + s.execute(""" + CREATE KEYSPACE udt_test_unprepared_registered2 + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_unprepared_registered2") + s.execute("CREATE TYPE user (state text, is_cool boolean)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + User = namedtuple('user', ('state', 'is_cool')) + c.register_user_type("udt_test_unprepared_registered2", "user", User) + + s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User('Texas', True))) + result = s.execute("SELECT b FROM mytable WHERE a=0") + row = result[0] + self.assertEqual('Texas', row.b.state) + self.assertEqual(True, row.b.is_cool) + self.assertTrue(type(row.b) is User) + + s.execute("DROP KEYSPACE udt_test_unprepared_registered2") + + c.shutdown() + + def test_can_register_udt_before_connecting(self): + """ + Test the registration of UDTs before session creation + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(wait_for_all_pools=True) + + s.execute(""" + CREATE KEYSPACE udt_test_register_before_connecting + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_register_before_connecting") + s.execute("CREATE TYPE user (age int, name text)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + s.execute(""" + CREATE KEYSPACE udt_test_register_before_connecting2 + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_register_before_connecting2") + s.execute("CREATE TYPE user (state text, is_cool boolean)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + # now that types are defined, shutdown and re-create Cluster + c.shutdown() + c = Cluster(protocol_version=PROTOCOL_VERSION) + + User1 = namedtuple('user', ('age', 'name')) + User2 = namedtuple('user', ('state', 'is_cool')) + + c.register_user_type("udt_test_register_before_connecting", "user", User1) + c.register_user_type("udt_test_register_before_connecting2", "user", User2) + + s = c.connect(wait_for_all_pools=True) + + s.set_keyspace("udt_test_register_before_connecting") + s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob'))) + result = s.execute("SELECT b FROM mytable WHERE a=0") + row = result[0] + self.assertEqual(42, row.b.age) + self.assertEqual('bob', row.b.name) + self.assertTrue(type(row.b) is User1) + + # use the same UDT name in a different keyspace + s.set_keyspace("udt_test_register_before_connecting2") + s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User2('Texas', True))) + result = s.execute("SELECT b FROM mytable WHERE a=0") + row = result[0] + self.assertEqual('Texas', row.b.state) + self.assertEqual(True, row.b.is_cool) + self.assertTrue(type(row.b) is User2) + + s.execute("DROP KEYSPACE udt_test_register_before_connecting") + s.execute("DROP KEYSPACE udt_test_register_before_connecting2") + + c.shutdown() + + def test_can_insert_prepared_unregistered_udts(self): + """ + Test the insertion of prepared, unregistered UDTs + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + + s.execute("CREATE TYPE user (age int, name text)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + User = namedtuple('user', ('age', 'name')) + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") + s.execute(insert, (0, User(42, 'bob'))) + + select = s.prepare("SELECT b FROM mytable WHERE a=?") + result = s.execute(select, (0,)) + row = result[0] + self.assertEqual(42, row.b.age) + self.assertEqual('bob', row.b.name) + + # use the same UDT name in a different keyspace + s.execute(""" + CREATE KEYSPACE udt_test_prepared_unregistered2 + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_prepared_unregistered2") + s.execute("CREATE TYPE user (state text, is_cool boolean)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + User = namedtuple('user', ('state', 'is_cool')) + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") + s.execute(insert, (0, User('Texas', True))) + + select = s.prepare("SELECT b FROM mytable WHERE a=?") + result = s.execute(select, (0,)) + row = result[0] + self.assertEqual('Texas', row.b.state) + self.assertEqual(True, row.b.is_cool) + + s.execute("DROP KEYSPACE udt_test_prepared_unregistered2") + + c.shutdown() + + def test_can_insert_prepared_registered_udts(self): + """ + Test the insertion of prepared, registered UDTs + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + + s.execute("CREATE TYPE user (age int, name text)") + User = namedtuple('user', ('age', 'name')) + c.register_user_type(self.keyspace_name, "user", User) + + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") + s.execute(insert, (0, User(42, 'bob'))) + + select = s.prepare("SELECT b FROM mytable WHERE a=?") + result = s.execute(select, (0,)) + row = result[0] + self.assertEqual(42, row.b.age) + self.assertEqual('bob', row.b.name) + self.assertTrue(type(row.b) is User) + + # use the same UDT name in a different keyspace + s.execute(""" + CREATE KEYSPACE udt_test_prepared_registered2 + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_prepared_registered2") + s.execute("CREATE TYPE user (state text, is_cool boolean)") + User = namedtuple('user', ('state', 'is_cool')) + c.register_user_type("udt_test_prepared_registered2", "user", User) + + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") + s.execute(insert, (0, User('Texas', True))) + + select = s.prepare("SELECT b FROM mytable WHERE a=?") + result = s.execute(select, (0,)) + row = result[0] + self.assertEqual('Texas', row.b.state) + self.assertEqual(True, row.b.is_cool) + self.assertTrue(type(row.b) is User) + + s.execute("DROP KEYSPACE udt_test_prepared_registered2") + + c.shutdown() + + def test_can_insert_udts_with_nulls(self): + """ + Test the insertion of UDTs with null and empty string fields + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + + s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)") + User = namedtuple('user', ('a', 'b', 'c', 'd')) + c.register_user_type(self.keyspace_name, "user", User) + + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (0, ?)") + s.execute(insert, [User(None, None, None, None)]) + + results = s.execute("SELECT b FROM mytable WHERE a=0") + self.assertEqual((None, None, None, None), results[0].b) + + select = s.prepare("SELECT b FROM mytable WHERE a=0") + self.assertEqual((None, None, None, None), s.execute(select)[0].b) + + # also test empty strings + s.execute(insert, [User('', None, None, six.binary_type())]) + results = s.execute("SELECT b FROM mytable WHERE a=0") + self.assertEqual(('', None, None, six.binary_type()), results[0].b) + + c.shutdown() + + def test_can_insert_udts_with_varying_lengths(self): + """ + Test for ensuring extra-lengthy udts are properly inserted + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + + MAX_TEST_LENGTH = 254 + + # create the seed udt, increase timeout to avoid the query failure on slow systems + s.execute("CREATE TYPE lengthy_udt ({0})" + .format(', '.join(['v_{0} int'.format(i) + for i in range(MAX_TEST_LENGTH)]))) + + # create a table with multiple sizes of nested udts + # no need for all nested types, only a spot checked few and the largest one + s.execute("CREATE TABLE mytable (" + "k int PRIMARY KEY, " + "v frozen)") + + # create and register the seed udt type + udt = namedtuple('lengthy_udt', tuple(['v_{0}'.format(i) for i in range(MAX_TEST_LENGTH)])) + c.register_user_type(self.keyspace_name, "lengthy_udt", udt) + + # verify inserts and reads + for i in (0, 1, 2, 3, MAX_TEST_LENGTH): + # create udt + params = [j for j in range(i)] + [None for j in range(MAX_TEST_LENGTH - i)] + created_udt = udt(*params) + + # write udt + s.execute("INSERT INTO mytable (k, v) VALUES (0, %s)", (created_udt,)) + + # verify udt was written and read correctly, increase timeout to avoid the query failure on slow systems + result = s.execute("SELECT v FROM mytable WHERE k=0")[0] + self.assertEqual(created_udt, result.v) + + c.shutdown() + + def nested_udt_schema_helper(self, session, MAX_NESTING_DEPTH): + # create the seed udt + execute_until_pass(session, "CREATE TYPE depth_0 (age int, name text)") + + # create the nested udts + for i in range(MAX_NESTING_DEPTH): + execute_until_pass(session, "CREATE TYPE depth_{0} (value frozen)".format(i + 1, i)) + + # create a table with multiple sizes of nested udts + # no need for all nested types, only a spot checked few and the largest one + execute_until_pass(session, "CREATE TABLE mytable (" + "k int PRIMARY KEY, " + "v_0 frozen, " + "v_1 frozen, " + "v_2 frozen, " + "v_3 frozen, " + "v_{0} frozen)".format(MAX_NESTING_DEPTH)) + + def nested_udt_creation_helper(self, udts, i): + if i == 0: + return udts[0](42, 'Bob') + else: + return udts[i](self.nested_udt_creation_helper(udts, i - 1)) + + def nested_udt_verification_helper(self, session, MAX_NESTING_DEPTH, udts): + for i in (0, 1, 2, 3, MAX_NESTING_DEPTH): + # create udt + udt = self.nested_udt_creation_helper(udts, i) + + # write udt via simple statement + session.execute("INSERT INTO mytable (k, v_%s) VALUES (0, %s)", [i, udt]) + + # verify udt was written and read correctly + result = session.execute("SELECT v_{0} FROM mytable WHERE k=0".format(i))[0] + self.assertEqual(udt, result["v_{0}".format(i)]) + + # write udt via prepared statement + insert = session.prepare("INSERT INTO mytable (k, v_{0}) VALUES (1, ?)".format(i)) + session.execute(insert, [udt]) + + # verify udt was written and read correctly + result = session.execute("SELECT v_{0} FROM mytable WHERE k=1".format(i))[0] + self.assertEqual(udt, result["v_{0}".format(i)]) + + def test_can_insert_nested_registered_udts(self): + """ + Test for ensuring nested registered udts are properly inserted + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + s.row_factory = dict_factory + + MAX_NESTING_DEPTH = 16 + + # create the schema + self.nested_udt_schema_helper(s, MAX_NESTING_DEPTH) + + # create and register the seed udt type + udts = [] + udt = namedtuple('depth_0', ('age', 'name')) + udts.append(udt) + c.register_user_type(self.keyspace_name, "depth_0", udts[0]) + + # create and register the nested udt types + for i in range(MAX_NESTING_DEPTH): + udt = namedtuple('depth_{0}'.format(i + 1), ('value')) + udts.append(udt) + c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1]) + + # insert udts and verify inserts with reads + self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts) + + c.shutdown() + + def test_can_insert_nested_unregistered_udts(self): + """ + Test for ensuring nested unregistered udts are properly inserted + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + s.row_factory = dict_factory + + MAX_NESTING_DEPTH = 16 + + # create the schema + self.nested_udt_schema_helper(s, MAX_NESTING_DEPTH) + + # create the seed udt type + udts = [] + udt = namedtuple('depth_0', ('age', 'name')) + udts.append(udt) + + # create the nested udt types + for i in range(MAX_NESTING_DEPTH): + udt = namedtuple('depth_{0}'.format(i + 1), ('value')) + udts.append(udt) + + # insert udts via prepared statements and verify inserts with reads + for i in (0, 1, 2, 3, MAX_NESTING_DEPTH): + # create udt + udt = self.nested_udt_creation_helper(udts, i) + + # write udt + insert = s.prepare("INSERT INTO mytable (k, v_{0}) VALUES (0, ?)".format(i)) + s.execute(insert, [udt]) + + # verify udt was written and read correctly + result = s.execute("SELECT v_{0} FROM mytable WHERE k=0".format(i))[0] + self.assertEqual(udt, result["v_{0}".format(i)]) + + c.shutdown() + + def test_can_insert_nested_registered_udts_with_different_namedtuples(self): + """ + Test for ensuring nested udts are inserted correctly when the + created namedtuples are use names that are different the cql type. + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + s.row_factory = dict_factory + + MAX_NESTING_DEPTH = 16 + + # create the schema + self.nested_udt_schema_helper(s, MAX_NESTING_DEPTH) + + # create and register the seed udt type + udts = [] + udt = namedtuple('level_0', ('age', 'name')) + udts.append(udt) + c.register_user_type(self.keyspace_name, "depth_0", udts[0]) + + # create and register the nested udt types + for i in range(MAX_NESTING_DEPTH): + udt = namedtuple('level_{0}'.format(i + 1), ('value')) + udts.append(udt) + c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1]) + + # insert udts and verify inserts with reads + self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts) + + c.shutdown() + + def test_raise_error_on_nonexisting_udts(self): + """ + Test for ensuring that an error is raised for operating on a nonexisting udt or an invalid keyspace + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + User = namedtuple('user', ('age', 'name')) + + with self.assertRaises(UserTypeDoesNotExist): + c.register_user_type("some_bad_keyspace", "user", User) + + with self.assertRaises(UserTypeDoesNotExist): + c.register_user_type("system", "user", User) + + with self.assertRaises(InvalidRequest): + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + c.shutdown() + + def test_can_insert_udt_all_datatypes(self): + """ + Test for inserting various types of PRIMITIVE_DATATYPES into UDT's + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + + # create UDT + alpha_type_list = [] + start_index = ord('a') + for i, datatype in enumerate(PRIMITIVE_DATATYPES): + alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype)) + + s.execute(""" + CREATE TYPE alldatatypes ({0}) + """.format(', '.join(alpha_type_list)) + ) + + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + # register UDT + alphabet_list = [] + for i in range(ord('a'), ord('a') + len(PRIMITIVE_DATATYPES)): + alphabet_list.append('{0}'.format(chr(i))) + Alldatatypes = namedtuple("alldatatypes", alphabet_list) + c.register_user_type(self.keyspace_name, "alldatatypes", Alldatatypes) + + # insert UDT data + params = [] + for datatype in PRIMITIVE_DATATYPES: + params.append((get_sample(datatype))) + + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") + s.execute(insert, (0, Alldatatypes(*params))) + + # retrieve and verify data + results = s.execute("SELECT * FROM mytable") + + row = results[0].b + for expected, actual in zip(params, row): + self.assertEqual(expected, actual) + + c.shutdown() + + def test_can_insert_udt_all_collection_datatypes(self): + """ + Test for inserting various types of COLLECTION_TYPES into UDT's + """ + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + + # create UDT + alpha_type_list = [] + start_index = ord('a') + for i, collection_type in enumerate(COLLECTION_TYPES): + for j, datatype in enumerate(PRIMITIVE_DATATYPES_KEYS): + if collection_type == "map": + type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j), + collection_type, datatype) + elif collection_type == "tuple": + type_string = "{0}_{1} frozen<{2}<{3}>>".format(chr(start_index + i), chr(start_index + j), + collection_type, datatype) + else: + type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j), + collection_type, datatype) + alpha_type_list.append(type_string) + + s.execute(""" + CREATE TYPE alldatatypes ({0}) + """.format(', '.join(alpha_type_list)) + ) + + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") + + # register UDT + alphabet_list = [] + for i in range(ord('a'), ord('a') + len(COLLECTION_TYPES)): + for j in range(ord('a'), ord('a') + len(PRIMITIVE_DATATYPES_KEYS)): + alphabet_list.append('{0}_{1}'.format(chr(i), chr(j))) + + Alldatatypes = namedtuple("alldatatypes", alphabet_list) + c.register_user_type(self.keyspace_name, "alldatatypes", Alldatatypes) + + # insert UDT data + params = [] + for collection_type in COLLECTION_TYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: + params.append((get_collection_sample(collection_type, datatype))) + + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") + s.execute(insert, (0, Alldatatypes(*params))) + + # retrieve and verify data + results = s.execute("SELECT * FROM mytable") + + row = results[0].b + for expected, actual in zip(params, row): + self.assertEqual(expected, actual) + + c.shutdown() + + def insert_select_column(self, session, table_name, column_name, value): + insert = session.prepare("INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name)) + session.execute(insert, (0, value)) + result = session.execute("SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,))[0][0] + self.assertEqual(result, value) + + def test_can_insert_nested_collections(self): + """ + Test for inserting various types of nested COLLECTION_TYPES into tables and UDTs + """ + + if self.cass_version < (2, 1, 3): + raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3") + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect(self.keyspace_name, wait_for_all_pools=True) + s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple + + name = self._testMethodName + + s.execute(""" + CREATE TYPE %s ( + m frozen>, + t tuple, + l frozen>, + s frozen> + )""" % name) + s.execute(""" + CREATE TYPE %s_nested ( + m frozen>, + t tuple, + l frozen>, + s frozen>, + u frozen<%s> + )""" % (name, name)) + s.execute(""" + CREATE TABLE %s ( + k int PRIMARY KEY, + map_map map>, frozen>>, + map_set map>, frozen>>, + map_list map>, frozen>>, + map_tuple map>, frozen>>, + map_udt map, frozen<%s>>, + )""" % (name, name, name)) + + validate = partial(self.insert_select_column, s, name) + validate('map_map', OrderedMap([({1: 1, 2: 2}, {3: 3, 4: 4}), ({5: 5, 6: 6}, {7: 7, 8: 8})])) + validate('map_set', OrderedMap([(set((1, 2)), set((3, 4))), (set((5, 6)), set((7, 8)))])) + validate('map_list', OrderedMap([([1, 2], [3, 4]), ([5, 6], [7, 8])])) + validate('map_tuple', OrderedMap([((1, 2), (3,)), ((4, 5), (6,))])) + + value = nested_collection_udt({1: 'v1', 2: 'v2'}, (3, 'v3'), [4, 5, 6, 7], set((8, 9, 10))) + key = nested_collection_udt_nested(value.m, value.t, value.l, value.s, value) + key2 = nested_collection_udt_nested({3: 'v3'}, value.t, value.l, value.s, value) + validate('map_udt', OrderedMap([(key, value), (key2, value)])) + + c.shutdown() + + def test_non_alphanum_identifiers(self): + """ + PYTHON-413 + """ + s = self.session + non_alphanum_name = 'test.field@#$%@%#!' + type_name = 'type2' + s.execute('CREATE TYPE "%s" ("%s" text)' % (non_alphanum_name, non_alphanum_name)) + s.execute('CREATE TYPE %s ("%s" text)' % (type_name, non_alphanum_name)) + # table with types as map keys to make sure the tuple lookup works + s.execute('CREATE TABLE %s (k int PRIMARY KEY, non_alphanum_type_map map, int>, alphanum_type_map map, int>)' % (self.table_name, non_alphanum_name, type_name)) + s.execute('INSERT INTO %s (k, non_alphanum_type_map, alphanum_type_map) VALUES (%s, {{"%s": \'nonalphanum\'}: 0}, {{"%s": \'alphanum\'}: 1})' % (self.table_name, 0, non_alphanum_name, non_alphanum_name)) + row = s.execute('SELECT * FROM %s' % (self.table_name,))[0] + + k, v = row.non_alphanum_type_map.popitem() + self.assertEqual(v, 0) + self.assertEqual(k.__class__, tuple) + self.assertEqual(k[0], 'nonalphanum') + + k, v = row.alphanum_type_map.popitem() + self.assertEqual(v, 1) + self.assertNotEqual(k.__class__, tuple) # should be the namedtuple type + self.assertEqual(k[0], 'alphanum') + self.assertEqual(k.field_0_, 'alphanum') # named tuple with positional field name + + @lessthancass30 + def test_type_alteration(self): + s = self.session + type_name = "type_name" + self.assertNotIn(type_name, s.cluster.metadata.keyspaces['udttests'].user_types) + s.execute('CREATE TYPE %s (v0 int)' % (type_name,)) + self.assertIn(type_name, s.cluster.metadata.keyspaces['udttests'].user_types) + + s.execute('CREATE TABLE %s (k int PRIMARY KEY, v frozen<%s>)' % (self.table_name, type_name)) + s.execute('INSERT INTO %s (k, v) VALUES (0, {v0 : 1})' % (self.table_name,)) + + s.cluster.register_user_type('udttests', type_name, dict) + + val = s.execute('SELECT v FROM %s' % self.table_name)[0][0] + self.assertEqual(val['v0'], 1) + + # add field + s.execute('ALTER TYPE %s ADD v1 text' % (type_name,)) + val = s.execute('SELECT v FROM %s' % self.table_name)[0][0] + self.assertEqual(val['v0'], 1) + self.assertIsNone(val['v1']) + s.execute("INSERT INTO %s (k, v) VALUES (0, {v0 : 2, v1 : 'sometext'})" % (self.table_name,)) + val = s.execute('SELECT v FROM %s' % self.table_name)[0][0] + self.assertEqual(val['v0'], 2) + self.assertEqual(val['v1'], 'sometext') + + # alter field type + s.execute('ALTER TYPE %s ALTER v1 TYPE blob' % (type_name,)) + s.execute("INSERT INTO %s (k, v) VALUES (0, {v0 : 3, v1 : 0xdeadbeef})" % (self.table_name,)) + val = s.execute('SELECT v FROM %s' % self.table_name)[0][0] + self.assertEqual(val['v0'], 3) + self.assertEqual(val['v1'], six.b('\xde\xad\xbe\xef')) + + @lessthancass30 + def test_alter_udt(self): + """ + Test to ensure that altered UDT's are properly surfaced without needing to restart the underlying session. + + @since 3.0.0 + @jira_ticket PYTHON-226 + @expected_result UDT's will reflect added columns without a session restart. + + @test_category data_types, udt + """ + + # Create udt ensure it has the proper column names. + self.session.set_keyspace(self.keyspace_name) + self.session.execute("CREATE TYPE typetoalter (a int)") + typetoalter = namedtuple('typetoalter', ('a')) + self.session.execute("CREATE TABLE {0} (pk int primary key, typetoalter frozen)".format(self.function_table_name)) + insert_statement = self.session.prepare("INSERT INTO {0} (pk, typetoalter) VALUES (?, ?)".format(self.function_table_name)) + self.session.execute(insert_statement, [1, typetoalter(1)]) + results = self.session.execute("SELECT * from {0}".format(self.function_table_name)) + for result in results: + self.assertTrue(hasattr(result.typetoalter, 'a')) + self.assertFalse(hasattr(result.typetoalter, 'b')) + + # Alter UDT and ensure the alter is honored in results + self.session.execute("ALTER TYPE typetoalter add b int") + typetoalter = namedtuple('typetoalter', ('a', 'b')) + self.session.execute(insert_statement, [2, typetoalter(2, 2)]) + results = self.session.execute("SELECT * from {0}".format(self.function_table_name)) + for result in results: + self.assertTrue(hasattr(result.typetoalter, 'a')) + self.assertTrue(hasattr(result.typetoalter, 'b')) + diff --git a/tests/integration/standard/utils.py b/tests/integration/standard/utils.py new file mode 100644 index 0000000..917b3a7 --- /dev/null +++ b/tests/integration/standard/utils.py @@ -0,0 +1,58 @@ +""" +Helper module to populate a dummy Cassandra tables with data. +""" + +from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample + + +def create_table_with_all_types(table_name, session, N): + """ + Method that given a table_name and session construct a table that contains + all possible primitive types. + + :param table_name: Name of table to create + :param session: session to use for table creation + :param N: the number of items to insert into the table + + :return: a list of column names + """ + # create table + alpha_type_list = ["primkey int PRIMARY KEY"] + col_names = ["primkey"] + start_index = ord('a') + for i, datatype in enumerate(PRIMITIVE_DATATYPES): + alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype)) + col_names.append(chr(start_index + i)) + + session.execute("CREATE TABLE {0} ({1})".format( + table_name, ', '.join(alpha_type_list)), timeout=120) + + # create the input + + for key in range(N): + params = get_all_primitive_params(key) + + # insert into table as a simple statement + columns_string = ', '.join(col_names) + placeholders = ', '.join(["%s"] * len(col_names)) + session.execute("INSERT INTO {0} ({1}) VALUES ({2})".format( + table_name, columns_string, placeholders), params, timeout=120) + return col_names + + +def get_all_primitive_params(key): + """ + Simple utility method used to give back a list of all possible primitive data sample types. + """ + params = [key] + for datatype in PRIMITIVE_DATATYPES: + # Also test for empty strings + if key == 1 and datatype == 'ascii': + params.append('') + else: + params.append(get_sample(datatype)) + return params + + +def get_primitive_datatypes(): + return ['int'] + list(PRIMITIVE_DATATYPES) diff --git a/tests/integration/upgrade/__init__.py b/tests/integration/upgrade/__init__.py new file mode 100644 index 0000000..d2b9076 --- /dev/null +++ b/tests/integration/upgrade/__init__.py @@ -0,0 +1,189 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests.integration import CCM_KWARGS, use_cluster, remove_cluster, MockLoggingHandler +from tests.integration import setup_keyspace + +from cassandra.cluster import Cluster +from cassandra import cluster + +from collections import namedtuple +from functools import wraps +import logging +from threading import Thread, Event +from ccmlib.node import TimeoutError +import time +import logging + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +def setup_module(): + remove_cluster() + + +UPGRADE_CLUSTER_NAME = "upgrade_cluster" +UpgradePath = namedtuple('UpgradePath', ('name', 'starting_version', 'upgrade_version', 'configuration_options')) + +log = logging.getLogger(__name__) + + +class upgrade_paths(object): + """ + Decorator used to specify the upgrade paths for a particular method + """ + def __init__(self, paths): + self.paths = paths + + def __call__(self, method): + @wraps(method) + def wrapper(*args, **kwargs): + for path in self.paths: + self_from_decorated = args[0] + log.debug('setting up {path}'.format(path=path)) + self_from_decorated.UPGRADE_PATH = path + self_from_decorated._upgrade_step_setup() + method(*args, **kwargs) + log.debug('tearing down {path}'.format(path=path)) + self_from_decorated._upgrade_step_teardown() + return wrapper + + +class UpgradeBase(unittest.TestCase): + """ + Base class for the upgrade tests. The _setup method + will clean the environment and start the appropriate C* version according + to the upgrade path. The upgrade can be done in a different thread using the + start_upgrade upgrade_method (this would be the most realistic scenario) + or node by node, waiting for the upgrade to happen, using _upgrade_one_node method + """ + UPGRADE_PATH = None + start_cluster = True + set_keyspace = True + + @classmethod + def setUpClass(cls): + cls.logger_handler = MockLoggingHandler() + logger = logging.getLogger(cluster.__name__) + logger.addHandler(cls.logger_handler) + + def _upgrade_step_setup(self): + """ + This is not the regular _setUp method because it will be called from + the decorator instead of letting nose handle it. + This setup method will start a cluster with the right version according + to the variable UPGRADE_PATH. + """ + remove_cluster() + self.cluster = use_cluster(UPGRADE_CLUSTER_NAME + self.UPGRADE_PATH.name, [3], + ccm_options=self.UPGRADE_PATH.starting_version, set_keyspace=self.set_keyspace, + configuration_options=self.UPGRADE_PATH.configuration_options) + self.nodes = self.cluster.nodelist() + self.last_node_upgraded = None + self.upgrade_done = Event() + self.upgrade_thread = None + + if self.start_cluster: + setup_keyspace() + + self.cluster_driver = Cluster() + self.session = self.cluster_driver.connect() + self.logger_handler.reset() + + def _upgrade_step_teardown(self): + """ + special tearDown method called by the decorator after the method has ended + """ + if self.upgrade_thread: + self.upgrade_thread.join(timeout=5) + self.upgrade_thread = None + + if self.start_cluster: + self.cluster_driver.shutdown() + + def start_upgrade(self, time_node_upgrade): + """ + Starts the upgrade in a different thread + """ + log.debug('Starting upgrade in new thread') + self.upgrade_thread = Thread(target=self._upgrade, args=(time_node_upgrade,)) + self.upgrade_thread.start() + + def _upgrade(self, time_node_upgrade): + """ + Starts the upgrade in the same thread + """ + start_time = time.time() + for node in self.nodes: + self.upgrade_node(node) + end_time = time.time() + time_to_upgrade = end_time - start_time + if time_node_upgrade > time_to_upgrade: + time.sleep(time_node_upgrade - time_to_upgrade) + self.upgrade_done.set() + + def is_upgraded(self): + """ + Returns True if the upgrade has finished and False otherwise + """ + return self.upgrade_done.is_set() + + def wait_for_upgrade(self, timeout=None): + """ + Waits until the upgrade has completed + """ + self.upgrade_done.wait(timeout=timeout) + + def upgrade_node(self, node): + """ + Upgrades only one node. Return True if the upgrade + has finished and False otherwise + """ + node.drain() + node.stop(gently=True) + + node.set_install_dir(**self.UPGRADE_PATH.upgrade_version) + + # There must be a cleaner way of doing this, but it's necessary here + # to call the private method from cluster __update_topology_files + self.cluster._Cluster__update_topology_files() + try: + node.start(wait_for_binary_proto=True, wait_other_notice=True) + except TimeoutError: + self.fail("Error starting C* node while upgrading") + + return True + + +class UpgradeBaseAuth(UpgradeBase): + """ + Base class of authentication test, the authentication parameters for + C* still have to be specified within the upgrade path variable + """ + start_cluster = False + set_keyspace = False + + + def _upgrade_step_setup(self): + """ + We sleep here for the same reason as we do in test_authentication.py: + there seems to be some race, with some versions of C* taking longer to + get the auth (and default user) setup. Sleep here to give it a chance + """ + super(UpgradeBaseAuth, self)._upgrade_step_setup() + time.sleep(10) diff --git a/tests/integration/upgrade/test_upgrade.py b/tests/integration/upgrade/test_upgrade.py new file mode 100644 index 0000000..7fa88a9 --- /dev/null +++ b/tests/integration/upgrade/test_upgrade.py @@ -0,0 +1,279 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from itertools import count + +from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider +from cassandra.cluster import ConsistencyLevel, Cluster, DriverException, ExecutionProfile +from cassandra.policies import ConstantSpeculativeExecutionPolicy +from tests.integration.upgrade import UpgradeBase, UpgradeBaseAuth, UpgradePath, upgrade_paths + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +two_to_three_path = upgrade_paths([ + UpgradePath("2.2.9-3.11", {"version": "2.2.9"}, {"version": "3.11.4"}, {}), +]) +class UpgradeTests(UpgradeBase): + @two_to_three_path + def test_can_write(self): + """ + Verify that the driver will keep querying C* even if there is a host down while being + upgraded and that all the writes will eventually succeed + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result all the writes succeed + + @test_category upgrade + """ + self.start_upgrade(0) + + c = count() + while not self.is_upgraded(): + self.session.execute("INSERT INTO test3rf.test(k, v) VALUES (%s, 0)", (next(c), )) + time.sleep(0.0001) + + self.session.default_consistency_level = ConsistencyLevel.ALL + total_number_of_inserted = self.session.execute("SELECT COUNT(*) from test3rf.test")[0][0] + self.assertEqual(total_number_of_inserted, next(c)) + + self.session.default_consistency_level = ConsistencyLevel.LOCAL_ONE + self.assertEqual(self.logger_handler.get_message_count("error", ""), 0) + + @two_to_three_path + def test_can_connect(self): + """ + Verify that the driver can connect to all the nodes + despite some nodes being in different versions + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result the driver connects successfully and can execute queries against + all the hosts + + @test_category upgrade + """ + def connect_and_shutdown(): + cluster = Cluster() + session = cluster.connect(wait_for_all_pools=True) + queried_hosts = set() + for _ in range(10): + results = session.execute("SELECT * from system.local") + self.assertGreater(len(results.current_rows), 0) + self.assertEqual(len(results.response_future.attempted_hosts), 1) + queried_hosts.add(results.response_future.attempted_hosts[0]) + self.assertEqual(len(queried_hosts), 3) + cluster.shutdown() + + connect_and_shutdown() + for node in self.nodes: + self.upgrade_node(node) + connect_and_shutdown() + + connect_and_shutdown() + + +class UpgradeTestsMetadata(UpgradeBase): + @two_to_three_path + def test_can_write(self): + """ + Verify that the driver will keep querying C* even if there is a host down while being + upgraded and that all the writes will eventually succeed + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result all the writes succeed + + @test_category upgrade + """ + self.start_upgrade(0) + + c = count() + while not self.is_upgraded(): + self.session.execute("INSERT INTO test3rf.test(k, v) VALUES (%s, 0)", (next(c),)) + time.sleep(0.0001) + + self.session.default_consistency_level = ConsistencyLevel.ALL + total_number_of_inserted = self.session.execute("SELECT COUNT(*) from test3rf.test")[0][0] + self.assertEqual(total_number_of_inserted, next(c)) + + self.session.default_consistency_level = ConsistencyLevel.LOCAL_ONE + self.assertEqual(self.logger_handler.get_message_count("error", ""), 0) + + @two_to_three_path + def test_schema_metadata_gets_refreshed(self): + """ + Verify that the driver fails to update the metadata while connected against + different versions of nodes. This won't succeed because each node will report a + different schema version + + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result the driver raises DriverException when updating the schema + metadata while upgrading + all the hosts + + @test_category metadata + """ + original_meta = self.cluster_driver.metadata.keyspaces + number_of_nodes = len(self.cluster.nodelist()) + nodes = self.nodes + for node in nodes[1:]: + self.upgrade_node(node) + # Wait for the control connection to reconnect + time.sleep(20) + + with self.assertRaises(DriverException): + self.cluster_driver.refresh_schema_metadata(max_schema_agreement_wait=10) + + self.upgrade_node(nodes[0]) + # Wait for the control connection to reconnect + time.sleep(20) + self.cluster_driver.refresh_schema_metadata(max_schema_agreement_wait=40) + self.assertNotEqual(original_meta, self.cluster_driver.metadata.keyspaces) + + @two_to_three_path + def test_schema_nodes_gets_refreshed(self): + """ + Verify that the driver token map and node list gets rebuild correctly while upgrading. + The token map and the node list should be the same after each node upgrade + + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result the token map and the node list stays consistent with each node upgrade + metadata while upgrading + all the hosts + + @test_category metadata + """ + for node in self.nodes: + token_map = self.cluster_driver.metadata.token_map + self.upgrade_node(node) + # Wait for the control connection to reconnect + time.sleep(20) + + self.cluster_driver.refresh_nodes(force_token_rebuild=True) + self._assert_same_token_map(token_map, self.cluster_driver.metadata.token_map) + + def _assert_same_token_map(self, original, new): + self.assertIsNot(original, new) + self.assertEqual(original.tokens_to_hosts_by_ks, new.tokens_to_hosts_by_ks) + self.assertEqual(original.token_to_host_owner, new.token_to_host_owner) + self.assertEqual(original.ring, new.ring) + + +two_to_three_with_auth_path = upgrade_paths([ + UpgradePath("2.2.9-3.11-auth", {"version": "2.2.9"}, {"version": "3.11.4"}, + {'authenticator': 'PasswordAuthenticator', + 'authorizer': 'CassandraAuthorizer'}), +]) +class UpgradeTestsAuthentication(UpgradeBaseAuth): + @two_to_three_with_auth_path + def test_can_connect_auth_plain(self): + """ + Verify that the driver can connect despite some nodes being in different versions + with plain authentication + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result the driver connects successfully and can execute queries against + all the hosts + + @test_category upgrade + """ + auth_provider = PlainTextAuthProvider( + username="cassandra", + password="cassandra" + ) + self.connect_and_shutdown(auth_provider) + for node in self.nodes: + self.upgrade_node(node) + self.connect_and_shutdown(auth_provider) + + self.connect_and_shutdown(auth_provider) + + @two_to_three_with_auth_path + def test_can_connect_auth_sasl(self): + """ + Verify that the driver can connect despite some nodes being in different versions + with ssl authentication + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result the driver connects successfully and can execute queries against + all the hosts + + @test_category upgrade + """ + sasl_kwargs = {'service': 'cassandra', + 'mechanism': 'PLAIN', + 'qops': ['auth'], + 'username': 'cassandra', + 'password': 'cassandra'} + auth_provider = SaslAuthProvider(**sasl_kwargs) + self.connect_and_shutdown(auth_provider) + for node in self.nodes: + self.upgrade_node(node) + self.connect_and_shutdown(auth_provider) + + self.connect_and_shutdown(auth_provider) + + def connect_and_shutdown(self, auth_provider): + cluster = Cluster(idle_heartbeat_interval=0, + auth_provider=auth_provider) + session = cluster.connect(wait_for_all_pools=True) + queried_hosts = set() + for _ in range(10): + results = session.execute("SELECT * from system.local") + self.assertGreater(len(results.current_rows), 0) + self.assertEqual(len(results.response_future.attempted_hosts), 1) + queried_hosts.add(results.response_future.attempted_hosts[0]) + self.assertEqual(len(queried_hosts), 3) + cluster.shutdown() + + +class UpgradeTestsPolicies(UpgradeBase): + @two_to_three_path + def test_can_write_speculative(self): + """ + Verify that the driver will keep querying C* even if there is a host down while being + upgraded and that all the writes will eventually succeed using the ConstantSpeculativeExecutionPolicy + policy + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result all the writes succeed + + @test_category upgrade + """ + spec_ep_rr = ExecutionProfile(speculative_execution_policy=ConstantSpeculativeExecutionPolicy(.5, 10), + request_timeout=12) + cluster = Cluster() + self.addCleanup(cluster.shutdown) + cluster.add_execution_profile("spec_ep_rr", spec_ep_rr) + + session = cluster.connect() + + self.start_upgrade(0) + + c = count() + while not self.is_upgraded(): + session.execute("INSERT INTO test3rf.test(k, v) VALUES (%s, 0)", (next(c),), + execution_profile='spec_ep_rr') + time.sleep(0.0001) + + session.default_consistency_level = ConsistencyLevel.ALL + total_number_of_inserted = session.execute("SELECT COUNT(*) from test3rf.test")[0][0] + self.assertEqual(total_number_of_inserted, next(c)) + + self.assertEqual(self.logger_handler.get_message_count("error", ""), 0) diff --git a/tests/integration/util.py b/tests/integration/util.py new file mode 100644 index 0000000..a2ce9d5 --- /dev/null +++ b/tests/integration/util.py @@ -0,0 +1,109 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.integration import PROTOCOL_VERSION +from functools import wraps +import time + + +def assert_quiescent_pool_state(test_case, cluster, wait=None): + """ + Checking the quiescent pool state checks that none of the requests ids have + been lost. However, the callback corresponding to a request_id is called + before the request_id is returned back to the pool, therefore + + session.execute("SELECT * from system.local") + assert_quiescent_pool_state(self, session.cluster) + + (with no wait) might fail because when execute comes back the request_id + hasn't yet been returned to the pool, therefore the wait. + """ + if wait is not None: + time.sleep(wait) + + for session in cluster.sessions: + pool_states = session.get_pool_state().values() + test_case.assertTrue(pool_states) + + for state in pool_states: + test_case.assertFalse(state['shutdown']) + test_case.assertGreater(state['open_count'], 0) + test_case.assertTrue(all((i == 0 for i in state['in_flights']))) + + for holder in cluster.get_connection_holders(): + for connection in holder.get_connections(): + # all ids are unique + req_ids = connection.request_ids + test_case.assertEqual(len(req_ids), len(set(req_ids))) + test_case.assertEqual(connection.highest_request_id, len(req_ids) - 1) + test_case.assertEqual(connection.highest_request_id, max(req_ids)) + if PROTOCOL_VERSION < 3: + test_case.assertEqual(connection.highest_request_id, connection.max_request_id) + + +def wait_until(condition, delay, max_attempts): + """ + Executes a function at regular intervals while the condition + is false and the amount of attempts < maxAttempts. + :param condition: a function + :param delay: the delay in second + :param max_attempts: the maximum number of attempts. So the timeout + of this function is delay*max_attempts + """ + attempt = 0 + while not condition() and attempt < max_attempts: + attempt += 1 + time.sleep(delay) + + if attempt >= max_attempts: + raise Exception("Condition is still False after {} attempts.".format(max_attempts)) + + +def wait_until_not_raised(condition, delay, max_attempts): + """ + Executes a function at regular intervals while the condition + doesn't raise an exception and the amount of attempts < maxAttempts. + :param condition: a function + :param delay: the delay in second + :param max_attempts: the maximum number of attemps. So the timeout + of this function will be delay*max_attempts + """ + def wrapped_condition(): + try: + condition() + except: + return False + + return True + + attempt = 0 + while attempt < (max_attempts-1): + attempt += 1 + if wrapped_condition(): + return + + time.sleep(delay) + + # last attempt, let the exception raise + condition() + + +def late(seconds=1): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + time.sleep(seconds) + func(*args, **kwargs) + return wrapper + return decorator diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..386372e --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,14 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/cqlengine/__init__.py b/tests/unit/cqlengine/__init__.py new file mode 100644 index 0000000..386372e --- /dev/null +++ b/tests/unit/cqlengine/__init__.py @@ -0,0 +1,14 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/cqlengine/test_columns.py b/tests/unit/cqlengine/test_columns.py new file mode 100644 index 0000000..bcb174a --- /dev/null +++ b/tests/unit/cqlengine/test_columns.py @@ -0,0 +1,71 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.cqlengine.columns import Column + + +class ColumnTest(unittest.TestCase): + + def test_comparisons(self): + c0 = Column() + c1 = Column() + self.assertEqual(c1.position - c0.position, 1) + + # __ne__ + self.assertNotEqual(c0, c1) + self.assertNotEqual(c0, object()) + + # __eq__ + self.assertEqual(c0, c0) + self.assertFalse(c0 == object()) + + # __lt__ + self.assertLess(c0, c1) + try: + c0 < object() # this raises for Python 3 + except TypeError: + pass + + # __le__ + self.assertLessEqual(c0, c1) + self.assertLessEqual(c0, c0) + try: + c0 <= object() # this raises for Python 3 + except TypeError: + pass + + # __gt__ + self.assertGreater(c1, c0) + try: + c1 > object() # this raises for Python 3 + except TypeError: + pass + + # __ge__ + self.assertGreaterEqual(c1, c0) + self.assertGreaterEqual(c1, c1) + try: + c1 >= object() # this raises for Python 3 + except TypeError: + pass + + def test_hash(self): + c0 = Column() + self.assertEqual(id(c0), c0.__hash__()) + diff --git a/tests/unit/cqlengine/test_connection.py b/tests/unit/cqlengine/test_connection.py new file mode 100644 index 0000000..9f8e500 --- /dev/null +++ b/tests/unit/cqlengine/test_connection.py @@ -0,0 +1,64 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.cluster import _ConfigMode +from cassandra.cqlengine import connection +from cassandra.query import dict_factory + +from mock import Mock + + +class ConnectionTest(unittest.TestCase): + + no_registered_connection_msg = "doesn't exist in the registry" + + def setUp(self): + super(ConnectionTest, self).setUp() + self.assertFalse( + connection._connections, + 'Test precondition not met: connections are registered: {cs}'.format(cs=connection._connections) + ) + + def test_set_session_without_existing_connection(self): + """ + Users can set the default session without having a default connection set. + """ + mock_cluster = Mock( + _config_mode=_ConfigMode.LEGACY, + ) + mock_session = Mock( + row_factory=dict_factory, + encoder=Mock(mapping={}), + cluster=mock_cluster, + ) + connection.set_session(mock_session) + + def test_get_session_fails_without_existing_connection(self): + """ + Users can't get the default session without having a default connection set. + """ + with self.assertRaisesRegexp(connection.CQLEngineException, self.no_registered_connection_msg): + connection.get_session(connection=None) + + def test_get_cluster_fails_without_existing_connection(self): + """ + Users can't get the default cluster without having a default connection set. + """ + with self.assertRaisesRegexp(connection.CQLEngineException, self.no_registered_connection_msg): + connection.get_cluster(connection=None) diff --git a/tests/unit/cqlengine/test_udt.py b/tests/unit/cqlengine/test_udt.py new file mode 100644 index 0000000..ebe1139 --- /dev/null +++ b/tests/unit/cqlengine/test_udt.py @@ -0,0 +1,41 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.cqlengine import columns +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.usertype import UserType + + +class UDTTest(unittest.TestCase): + + def test_initialization_without_existing_connection(self): + """ + Test that users can define models with UDTs without initializing + connections. + + Written to reproduce PYTHON-649. + """ + + class Value(UserType): + t = columns.Text() + + class DummyUDT(Model): + __keyspace__ = 'ks' + primary_key = columns.Integer(primary_key=True) + value = columns.UserDefinedType(Value) diff --git a/tests/unit/cython/__init__.py b/tests/unit/cython/__init__.py new file mode 100644 index 0000000..386372e --- /dev/null +++ b/tests/unit/cython/__init__.py @@ -0,0 +1,14 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/cython/test_bytesio.py b/tests/unit/cython/test_bytesio.py new file mode 100644 index 0000000..a156fc1 --- /dev/null +++ b/tests/unit/cython/test_bytesio.py @@ -0,0 +1,35 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.unit.cython.utils import cyimport, cythontest +bytesio_testhelper = cyimport('tests.unit.cython.bytesio_testhelper') + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +class BytesIOTest(unittest.TestCase): + """Test Cython BytesIO proxy""" + + @cythontest + def test_reading(self): + bytesio_testhelper.test_read1(self.assertEqual, self.assertRaises) + bytesio_testhelper.test_read2(self.assertEqual, self.assertRaises) + bytesio_testhelper.test_read3(self.assertEqual, self.assertRaises) + + @cythontest + def test_reading_error(self): + bytesio_testhelper.test_read_eof(self.assertEqual, self.assertRaises) diff --git a/tests/unit/cython/test_types.py b/tests/unit/cython/test_types.py new file mode 100644 index 0000000..a0d2138 --- /dev/null +++ b/tests/unit/cython/test_types.py @@ -0,0 +1,32 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.unit.cython.utils import cyimport, cythontest +types_testhelper = cyimport('tests.unit.cython.types_testhelper') + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +class TypesTest(unittest.TestCase): + + @cythontest + def test_datetype(self): + types_testhelper.test_datetype(self.assertEqual) + + @cythontest + def test_date_side_by_side(self): + types_testhelper.test_date_side_by_side(self.assertEqual) diff --git a/tests/unit/cython/test_utils.py b/tests/unit/cython/test_utils.py new file mode 100644 index 0000000..dc8745e --- /dev/null +++ b/tests/unit/cython/test_utils.py @@ -0,0 +1,29 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.unit.cython.utils import cyimport, cythontest +utils_testhelper = cyimport('tests.unit.cython.utils_testhelper') + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +class UtilsTest(unittest.TestCase): + """Test Cython Utils functions""" + + @cythontest + def test_datetime_from_timestamp(self): + utils_testhelper.test_datetime_from_timestamp(self.assertEqual) \ No newline at end of file diff --git a/tests/unit/cython/utils.py b/tests/unit/cython/utils.py new file mode 100644 index 0000000..7f8be22 --- /dev/null +++ b/tests/unit/cython/utils.py @@ -0,0 +1,43 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY +try: + from tests import VERIFY_CYTHON +except ImportError: + VERIFY_CYTHON = False + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +def cyimport(import_path): + """ + Import a Cython module if available, otherwise return None + (and skip any relevant tests). + """ + if HAVE_CYTHON: + import pyximport + py_importer, pyx_importer = pyximport.install() + mod = __import__(import_path, fromlist=[True]) + pyximport.uninstall(py_importer, pyx_importer) + return mod + + +# @cythontest +# def test_something(self): ... +cythontest = unittest.skipUnless((HAVE_CYTHON or VERIFY_CYTHON) or VERIFY_CYTHON, 'Cython is not available') +notcython = unittest.skipIf(HAVE_CYTHON, 'Cython not supported') +numpytest = unittest.skipUnless((HAVE_CYTHON and HAVE_NUMPY) or VERIFY_CYTHON, 'NumPy is not available') diff --git a/tests/unit/io/__init__.py b/tests/unit/io/__init__.py new file mode 100644 index 0000000..386372e --- /dev/null +++ b/tests/unit/io/__init__.py @@ -0,0 +1,14 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/io/eventlet_utils.py b/tests/unit/io/eventlet_utils.py new file mode 100644 index 0000000..785856b --- /dev/null +++ b/tests/unit/io/eventlet_utils.py @@ -0,0 +1,48 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import select +import socket +try: + import thread + import Queue + import __builtin__ + #For python3 compatibility +except ImportError: + import _thread as thread + import queue as Queue + import builtins as __builtin__ + +import threading +import ssl +import time +import eventlet +from imp import reload + +def eventlet_un_patch_all(): + """ + A method to unpatch eventlet monkey patching used for the reactor tests + """ + + # These are the modules that are loaded by eventlet we reload them all + modules_to_unpatch = [os, select, socket, thread, time, Queue, threading, ssl, __builtin__] + for to_unpatch in modules_to_unpatch: + reload(to_unpatch) + +def restore_saved_module(module): + reload(module) + del eventlet.patcher.already_patched[module.__name__] + diff --git a/tests/unit/io/gevent_utils.py b/tests/unit/io/gevent_utils.py new file mode 100644 index 0000000..a341fd9 --- /dev/null +++ b/tests/unit/io/gevent_utils.py @@ -0,0 +1,56 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from gevent import monkey + + +def gevent_un_patch_all(): + """ + A method to unpatch gevent libraries. These are unloaded + in the same order that gevent monkey patch loads theirs. + Order cannot be arbitrary. This is used in the unit tests to + un monkey patch gevent + """ + restore_saved_module("os") + restore_saved_module("time") + restore_saved_module("thread") + restore_saved_module("threading") + restore_saved_module("_threading_local") + restore_saved_module("stdin") + restore_saved_module("stdout") + restore_saved_module("socket") + restore_saved_module("select") + restore_saved_module("ssl") + restore_saved_module("subprocess") + + +def restore_saved_module(module): + """ + gevent monkey patch keeps a list of all patched modules. + This will restore the original ones + :param module: to unpatch + :return: + """ + + # Check the saved attributes in geven monkey patch + if not (module in monkey.saved): + return + _module = __import__(module) + + # If it exist unpatch it + for attr in monkey.saved[module]: + if hasattr(_module, attr): + setattr(_module, attr, monkey.saved[module][attr]) + diff --git a/tests/unit/io/test_asyncioreactor.py b/tests/unit/io/test_asyncioreactor.py new file mode 100644 index 0000000..be3c2bc --- /dev/null +++ b/tests/unit/io/test_asyncioreactor.py @@ -0,0 +1,76 @@ +try: + from cassandra.io.asyncioreactor import AsyncioConnection + import asynctest + ASYNCIO_AVAILABLE = True +except (ImportError, SyntaxError): + AsyncioConnection = None + ASYNCIO_AVAILABLE = False + +from tests import is_monkey_patched, connection_class +from tests.unit.io.utils import TimerCallback, TimerTestMixin + +from mock import patch + +import unittest +import time + +skip_me = (is_monkey_patched() or + (not ASYNCIO_AVAILABLE) or + (connection_class is not AsyncioConnection)) + + +@unittest.skipIf(is_monkey_patched(), 'runtime is monkey patched for another reactor') +@unittest.skipIf(connection_class is not AsyncioConnection, + 'not running asyncio tests; current connection_class is {}'.format(connection_class)) +@unittest.skipUnless(ASYNCIO_AVAILABLE, "asyncio is not available for this runtime") +class AsyncioTimerTests(TimerTestMixin, unittest.TestCase): + + @classmethod + def setUpClass(cls): + if skip_me: + return + cls.connection_class = AsyncioConnection + AsyncioConnection.initialize_reactor() + + @classmethod + def tearDownClass(cls): + if skip_me: + return + if ASYNCIO_AVAILABLE and AsyncioConnection._loop: + AsyncioConnection._loop.stop() + + @property + def create_timer(self): + return self.connection.create_timer + + @property + def _timers(self): + raise RuntimeError('no TimerManager for AsyncioConnection') + + def setUp(self): + if skip_me: + return + socket_patcher = patch('socket.socket') + self.addCleanup(socket_patcher.stop) + socket_patcher.start() + + old_selector = AsyncioConnection._loop._selector + AsyncioConnection._loop._selector = asynctest.TestSelector() + + def reset_selector(): + AsyncioConnection._loop._selector = old_selector + + self.addCleanup(reset_selector) + + super(AsyncioTimerTests, self).setUp() + + def test_timer_cancellation(self): + # Various lists for tracking callback stage + timeout = .1 + callback = TimerCallback(timeout) + timer = self.create_timer(timeout, callback.invoke) + timer.cancel() + # Release context allow for timer thread to run. + time.sleep(.2) + # Assert that the cancellation was honored + self.assertFalse(callback.was_invoked()) diff --git a/tests/unit/io/test_asyncorereactor.py b/tests/unit/io/test_asyncorereactor.py new file mode 100644 index 0000000..7e55059 --- /dev/null +++ b/tests/unit/io/test_asyncorereactor.py @@ -0,0 +1,85 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from mock import patch +import socket +import cassandra.io.asyncorereactor +from cassandra.io.asyncorereactor import AsyncoreConnection +from tests import is_monkey_patched +from tests.unit.io.utils import ReactorTestMixin, TimerTestMixin, noop_if_monkey_patched + + +class AsyncorePatcher(unittest.TestCase): + + @classmethod + @noop_if_monkey_patched + def setUpClass(cls): + if is_monkey_patched(): + return + AsyncoreConnection.initialize_reactor() + + socket_patcher = patch('socket.socket', spec=socket.socket) + channel_patcher = patch( + 'cassandra.io.asyncorereactor.AsyncoreConnection.add_channel', + new=(lambda *args, **kwargs: None) + ) + + cls.mock_socket = socket_patcher.start() + cls.mock_socket.connect_ex.return_value = 0 + cls.mock_socket.getsockopt.return_value = 0 + cls.mock_socket.fileno.return_value = 100 + + channel_patcher.start() + + cls.patchers = (socket_patcher, channel_patcher) + + @classmethod + @noop_if_monkey_patched + def tearDownClass(cls): + for p in cls.patchers: + try: + p.stop() + except: + pass + + +class AsyncoreConnectionTest(ReactorTestMixin, AsyncorePatcher): + + connection_class = AsyncoreConnection + socket_attr_name = 'socket' + + def setUp(self): + if is_monkey_patched(): + raise unittest.SkipTest("Can't test asyncore with monkey patching") + + +class TestAsyncoreTimer(TimerTestMixin, AsyncorePatcher): + connection_class = AsyncoreConnection + + @property + def create_timer(self): + return self.connection.create_timer + + @property + def _timers(self): + return cassandra.io.asyncorereactor._global_loop._timers + + def setUp(self): + if is_monkey_patched(): + raise unittest.SkipTest("Can't test asyncore with monkey patching") + super(TestAsyncoreTimer, self).setUp() diff --git a/tests/unit/io/test_eventletreactor.py b/tests/unit/io/test_eventletreactor.py new file mode 100644 index 0000000..ce828cd --- /dev/null +++ b/tests/unit/io/test_eventletreactor.py @@ -0,0 +1,79 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from tests.unit.io.utils import TimerTestMixin +from tests import notpypy, EVENT_LOOP_MANAGER + +from eventlet import monkey_patch +from mock import patch + +try: + from cassandra.io.eventletreactor import EventletConnection +except ImportError: + EventletConnection = None # noqa + +skip_condition = EventletConnection is None or EVENT_LOOP_MANAGER != "eventlet" +# There are some issues with some versions of pypy and eventlet +@notpypy +@unittest.skipIf(skip_condition, "Skipping the eventlet tests because it's not installed") +class EventletTimerTest(TimerTestMixin, unittest.TestCase): + + connection_class = EventletConnection + + @classmethod + def setUpClass(cls): + # This is run even though the class is skipped, so we need + # to make sure no monkey patching is happening + if skip_condition: + return + + # This is being added temporarily due to a bug in eventlet: + # https://github.com/eventlet/eventlet/issues/401 + import eventlet + eventlet.sleep() + monkey_patch() + # cls.connection_class = EventletConnection + + EventletConnection.initialize_reactor() + assert EventletConnection._timers is not None + + def setUp(self): + socket_patcher = patch('eventlet.green.socket.socket') + self.addCleanup(socket_patcher.stop) + socket_patcher.start() + + super(EventletTimerTest, self).setUp() + + recv_patcher = patch.object(self.connection._socket, + 'recv', + return_value=b'') + self.addCleanup(recv_patcher.stop) + recv_patcher.start() + + @property + def create_timer(self): + return self.connection.create_timer + + @property + def _timers(self): + return self.connection._timers + + # There is no unpatching because there is not a clear way + # of doing it reliably diff --git a/tests/unit/io/test_geventreactor.py b/tests/unit/io/test_geventreactor.py new file mode 100644 index 0000000..ec64ce3 --- /dev/null +++ b/tests/unit/io/test_geventreactor.py @@ -0,0 +1,68 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +from tests.unit.io.utils import TimerTestMixin +from tests import EVENT_LOOP_MANAGER +try: + from cassandra.io.geventreactor import GeventConnection + import gevent.monkey +except ImportError: + GeventConnection = None # noqa + +from mock import patch + + +skip_condition = GeventConnection is None or EVENT_LOOP_MANAGER != "gevent" +@unittest.skipIf(skip_condition, "Skipping the gevent tests because it's not installed") +class GeventTimerTest(TimerTestMixin, unittest.TestCase): + + connection_class = GeventConnection + + @classmethod + def setUpClass(cls): + # This is run even though the class is skipped, so we need + # to make sure no monkey patching is happening + if skip_condition: + return + # There is no unpatching because there is not a clear way + # of doing it reliably + gevent.monkey.patch_all() + GeventConnection.initialize_reactor() + + def setUp(self): + socket_patcher = patch('gevent.socket.socket') + self.addCleanup(socket_patcher.stop) + socket_patcher.start() + + super(GeventTimerTest, self).setUp() + + recv_patcher = patch.object(self.connection._socket, + 'recv', + return_value=b'') + self.addCleanup(recv_patcher.stop) + recv_patcher.start() + + @property + def create_timer(self): + return self.connection.create_timer + + @property + def _timers(self): + return self.connection._timers diff --git a/tests/unit/io/test_libevreactor.py b/tests/unit/io/test_libevreactor.py new file mode 100644 index 0000000..a02458e --- /dev/null +++ b/tests/unit/io/test_libevreactor.py @@ -0,0 +1,144 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from mock import patch, Mock +import weakref +import socket + +from tests import is_monkey_patched +from tests.unit.io.utils import ReactorTestMixin, TimerTestMixin, noop_if_monkey_patched + + +try: + from cassandra.io.libevreactor import _cleanup as libev__cleanup + from cassandra.io.libevreactor import LibevConnection +except ImportError: + LibevConnection = None # noqa + + +class LibevConnectionTest(ReactorTestMixin, unittest.TestCase): + + connection_class = LibevConnection + socket_attr_name = '_socket' + null_handle_function_args = None, 0 + + def setUp(self): + if is_monkey_patched(): + raise unittest.SkipTest("Can't test libev with monkey patching") + if LibevConnection is None: + raise unittest.SkipTest('libev does not appear to be installed correctly') + LibevConnection.initialize_reactor() + + # we patch here rather than as a decorator so that the Mixin can avoid + # specifying patch args to test methods + patchers = [patch(obj) for obj in + ('socket.socket', + 'cassandra.io.libevwrapper.IO', + 'cassandra.io.libevreactor.LibevLoop.maybe_start' + )] + for p in patchers: + self.addCleanup(p.stop) + for p in patchers: + p.start() + + def test_watchers_are_finished(self): + """ + Test for asserting that watchers are closed in LibevConnection + + This test simulates a process termination without calling cluster.shutdown(), which would trigger + _global_loop._cleanup. It will check the watchers have been closed + Finally it will restore the LibevConnection reactor so it doesn't affect + the rest of the tests + + @since 3.10 + @jira_ticket PYTHON-747 + @expected_result the watchers are closed + + @test_category connection + """ + from cassandra.io.libevreactor import _global_loop + with patch.object(_global_loop, "_thread"),\ + patch.object(_global_loop, "notify"): + + self.make_connection() + + # We have to make a copy because the connections shouldn't + # be alive when we verify them + live_connections = set(_global_loop._live_conns) + + # This simulates the process ending without cluster.shutdown() + # being called, then with atexit _cleanup for libevreactor would + # be called + libev__cleanup(_global_loop) + for conn in live_connections: + self.assertTrue(conn._write_watcher.stop.mock_calls) + self.assertTrue(conn._read_watcher.stop.mock_calls) + + _global_loop._shutdown = False + + +class LibevTimerPatcher(unittest.TestCase): + + @classmethod + @noop_if_monkey_patched + def setUpClass(cls): + if LibevConnection is None: + raise unittest.SkipTest('libev does not appear to be installed correctly') + cls.patchers = [ + patch('socket.socket', spec=socket.socket), + patch('cassandra.io.libevwrapper.IO') + ] + for p in cls.patchers: + p.start() + + @classmethod + @noop_if_monkey_patched + def tearDownClass(cls): + for p in cls.patchers: + try: + p.stop() + except: + pass + + +class LibevTimerTest(TimerTestMixin, LibevTimerPatcher): + connection_class = LibevConnection + + @property + def create_timer(self): + return self.connection.create_timer + + @property + def _timers(self): + from cassandra.io.libevreactor import _global_loop + return _global_loop._timers + + def make_connection(self): + c = LibevConnection('1.2.3.4', cql_version='3.0.1') + c._socket_impl = Mock() + c._socket.return_value.send.side_effect = lambda x: len(x) + return c + + def setUp(self): + if is_monkey_patched(): + raise unittest.SkipTest("Can't test libev with monkey patching.") + if LibevConnection is None: + raise unittest.SkipTest('libev does not appear to be installed correctly') + + LibevConnection.initialize_reactor() + super(LibevTimerTest, self).setUp() diff --git a/tests/unit/io/test_twistedreactor.py b/tests/unit/io/test_twistedreactor.py new file mode 100644 index 0000000..f0a1d73 --- /dev/null +++ b/tests/unit/io/test_twistedreactor.py @@ -0,0 +1,226 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest +from mock import Mock, patch + +from cassandra.connection import DefaultEndPoint + +try: + from twisted.test import proto_helpers + from twisted.python.failure import Failure + from cassandra.io import twistedreactor + from cassandra.io.twistedreactor import TwistedConnection +except ImportError: + twistedreactor = TwistedConnection = None # NOQA + + +from cassandra.connection import _Frame + +from tests.unit.io.utils import TimerTestMixin + +class TestTwistedTimer(TimerTestMixin, unittest.TestCase): + """ + Simple test class that is used to validate that the TimerManager, and timer + classes function appropriately with the twisted infrastructure + """ + + connection_class = TwistedConnection + + @property + def create_timer(self): + return self.connection.create_timer + + @property + def _timers(self): + return self.connection._loop._timers + + def setUp(self): + if twistedreactor is None: + raise unittest.SkipTest("Twisted libraries not available") + twistedreactor.TwistedConnection.initialize_reactor() + super(TestTwistedTimer, self).setUp() + + +class TestTwistedProtocol(unittest.TestCase): + + def setUp(self): + if twistedreactor is None: + raise unittest.SkipTest("Twisted libraries not available") + twistedreactor.TwistedConnection.initialize_reactor() + self.tr = proto_helpers.StringTransportWithDisconnection() + self.tr.connector = Mock() + self.mock_connection = Mock() + self.tr.connector.factory = twistedreactor.TwistedConnectionClientFactory( + self.mock_connection) + self.obj_ut = twistedreactor.TwistedConnectionProtocol() + self.tr.protocol = self.obj_ut + + def tearDown(self): + pass + + def test_makeConnection(self): + """ + Verify that the protocol class notifies the connection + object that a successful connection was made. + """ + self.obj_ut.makeConnection(self.tr) + self.assertTrue(self.mock_connection.client_connection_made.called) + + def test_receiving_data(self): + """ + Verify that the dataReceived() callback writes the data to + the connection object's buffer and calls handle_read(). + """ + self.obj_ut.makeConnection(self.tr) + self.obj_ut.dataReceived('foobar') + self.assertTrue(self.mock_connection.handle_read.called) + self.mock_connection._iobuf.write.assert_called_with("foobar") + + +class TestTwistedClientFactory(unittest.TestCase): + def setUp(self): + if twistedreactor is None: + raise unittest.SkipTest("Twisted libraries not available") + twistedreactor.TwistedConnection.initialize_reactor() + self.mock_connection = Mock() + self.obj_ut = twistedreactor.TwistedConnectionClientFactory( + self.mock_connection) + + def test_client_connection_failed(self): + """ + Verify that connection failed causes the connection object to close. + """ + exc = Exception('a test') + self.obj_ut.clientConnectionFailed(None, Failure(exc)) + self.mock_connection.defunct.assert_called_with(exc) + + def test_client_connection_lost(self): + """ + Verify that connection lost causes the connection object to close. + """ + exc = Exception('a test') + self.obj_ut.clientConnectionLost(None, Failure(exc)) + self.mock_connection.defunct.assert_called_with(exc) + + +class TestTwistedConnection(unittest.TestCase): + def setUp(self): + if twistedreactor is None: + raise unittest.SkipTest("Twisted libraries not available") + twistedreactor.TwistedConnection.initialize_reactor() + self.reactor_cft_patcher = patch( + 'twisted.internet.reactor.callFromThread') + self.reactor_run_patcher = patch('twisted.internet.reactor.run') + self.mock_reactor_cft = self.reactor_cft_patcher.start() + self.mock_reactor_run = self.reactor_run_patcher.start() + self.obj_ut = twistedreactor.TwistedConnection(DefaultEndPoint('1.2.3.4'), + cql_version='3.0.1') + + def tearDown(self): + self.reactor_cft_patcher.stop() + self.reactor_run_patcher.stop() + + def test_connection_initialization(self): + """ + Verify that __init__() works correctly. + """ + self.mock_reactor_cft.assert_called_with(self.obj_ut.add_connection) + self.obj_ut._loop._cleanup() + self.mock_reactor_run.assert_called_with(installSignalHandlers=False) + + @patch('twisted.internet.reactor.connectTCP') + def test_add_connection(self, mock_connectTCP): + """ + Verify that add_connection() gives us a valid twisted connector. + """ + self.obj_ut.add_connection() + self.assertTrue(self.obj_ut.connector is not None) + self.assertTrue(mock_connectTCP.called) + + def test_client_connection_made(self): + """ + Verifiy that _send_options_message() is called in + client_connection_made() + """ + self.obj_ut._send_options_message = Mock() + self.obj_ut.client_connection_made(Mock()) + self.obj_ut._send_options_message.assert_called_with() + + @patch('twisted.internet.reactor.connectTCP') + def test_close(self, mock_connectTCP): + """ + Verify that close() disconnects the connector and errors callbacks. + """ + self.obj_ut.error_all_requests = Mock() + self.obj_ut.add_connection() + self.obj_ut.is_closed = False + self.obj_ut.close() + + self.assertTrue(self.obj_ut.connected_event.is_set()) + self.assertTrue(self.obj_ut.error_all_requests.called) + + def test_handle_read__incomplete(self): + """ + Verify that handle_read() processes incomplete messages properly. + """ + self.obj_ut.process_msg = Mock() + self.assertEqual(self.obj_ut._iobuf.getvalue(), b'') # buf starts empty + # incomplete header + self.obj_ut._iobuf.write(b'\x84\x00\x00\x00\x00') + self.obj_ut.handle_read() + self.assertEqual(self.obj_ut._iobuf.getvalue(), b'\x84\x00\x00\x00\x00') + + # full header, but incomplete body + self.obj_ut._iobuf.write(b'\x00\x00\x00\x15') + self.obj_ut.handle_read() + self.assertEqual(self.obj_ut._iobuf.getvalue(), + b'\x84\x00\x00\x00\x00\x00\x00\x00\x15') + self.assertEqual(self.obj_ut._current_frame.end_pos, 30) + + # verify we never attempted to process the incomplete message + self.assertFalse(self.obj_ut.process_msg.called) + + def test_handle_read__fullmessage(self): + """ + Verify that handle_read() processes complete messages properly. + """ + self.obj_ut.process_msg = Mock() + self.assertEqual(self.obj_ut._iobuf.getvalue(), b'') # buf starts empty + + # write a complete message, plus 'NEXT' (to simulate next message) + # assumes protocol v3+ as default Connection.protocol_version + body = b'this is the drum roll' + extra = b'NEXT' + self.obj_ut._iobuf.write( + b'\x84\x01\x00\x02\x03\x00\x00\x00\x15' + body + extra) + self.obj_ut.handle_read() + self.assertEqual(self.obj_ut._iobuf.getvalue(), extra) + self.obj_ut.process_msg.assert_called_with( + _Frame(version=4, flags=1, stream=2, opcode=3, body_offset=9, end_pos=9 + len(body)), body) + + @patch('twisted.internet.reactor.connectTCP') + def test_push(self, mock_connectTCP): + """ + Verifiy that push() calls transport.write(data). + """ + self.obj_ut.add_connection() + transport_mock = Mock() + self.obj_ut.transport = transport_mock + self.obj_ut.push('123 pickup') + self.mock_reactor_cft.assert_called_with( + transport_mock.write, '123 pickup') diff --git a/tests/unit/io/utils.py b/tests/unit/io/utils.py new file mode 100644 index 0000000..2856b9d --- /dev/null +++ b/tests/unit/io/utils.py @@ -0,0 +1,516 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.connection import (ConnectionException, ProtocolError, + HEADER_DIRECTION_TO_CLIENT) +from cassandra.marshal import int32_pack, uint8_pack, uint32_pack +from cassandra.protocol import (write_stringmultimap, write_int, write_string, + SupportedMessage, ReadyMessage, ServerError) +from cassandra.connection import DefaultEndPoint + +from tests import is_monkey_patched + +import io +import random +from functools import wraps +from itertools import cycle +import six +from six import binary_type, BytesIO +from mock import Mock + +import errno +import logging +import math +import os +from socket import error as socket_error +import ssl + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import time + + +log = logging.getLogger(__name__) + + +class TimerCallback(object): + + invoked = False + created_time = 0 + invoked_time = 0 + expected_wait = 0 + + def __init__(self, expected_wait): + self.invoked = False + self.created_time = time.time() + self.expected_wait = expected_wait + + def invoke(self): + self.invoked_time = time.time() + self.invoked = True + + def was_invoked(self): + return self.invoked + + def get_wait_time(self): + elapsed_time = self.invoked_time - self.created_time + return elapsed_time + + def wait_match_excepted(self): + if self.expected_wait - .01 <= self.get_wait_time() <= self.expected_wait + .01: + return True + return False + + +def get_timeout(gross_time, start, end, precision, split_range): + """ + A way to generate varying timeouts based on ranges + :param gross_time: Some integer between start and end + :param start: the start value of the range + :param end: the end value of the range + :param precision: the precision to use to generate the timeout. + :param split_range: generate values from both ends + :return: a timeout value to use + """ + if split_range: + top_num = float(end) / precision + bottom_num = float(start) / precision + if gross_time % 2 == 0: + timeout = top_num - float(gross_time) / precision + else: + timeout = bottom_num + float(gross_time) / precision + + else: + timeout = float(gross_time) / precision + + return timeout + + +def submit_and_wait_for_completion(unit_test, create_timer, start, end, increment, precision, split_range=False): + """ + This will submit a number of timers to the provided connection. It will then ensure that the corresponding + callback is invoked in the appropriate amount of time. + :param unit_test: Invoking unit tests + :param connection: Connection to create the timer on. + :param start: Lower bound of range. + :param end: Upper bound of the time range + :param increment: +1, or -1 + :param precision: 100 for centisecond, 1000 for milliseconds + :param split_range: True to split the range between incrementing and decrementing. + """ + + # Various lists for tracking callback as completed or pending + pending_callbacks = [] + completed_callbacks = [] + + # submit timers with various timeouts + for gross_time in range(start, end, increment): + timeout = get_timeout(gross_time, start, end, precision, split_range) + callback = TimerCallback(timeout) + create_timer(timeout, callback.invoke) + pending_callbacks.append(callback) + + # wait for all the callbacks associated with the timers to be invoked + while len(pending_callbacks) is not 0: + for callback in pending_callbacks: + if callback.was_invoked(): + pending_callbacks.remove(callback) + completed_callbacks.append(callback) + time.sleep(.1) + + # ensure they are all called back in a timely fashion + for callback in completed_callbacks: + unit_test.assertAlmostEqual(callback.expected_wait, callback.get_wait_time(), delta=.15) + + +def noop_if_monkey_patched(f): + if is_monkey_patched(): + @wraps(f) + def noop(*args, **kwargs): + return + return noop + + return f + + +class TimerTestMixin(object): + + connection_class = connection = None + # replace with property returning the connection's create_timer and _timers + create_timer = _timers = None + + def setUp(self): + self.connection = self.connection_class( + DefaultEndPoint("127.0.0.1"), + connect_timeout=5 + ) + + def tearDown(self): + self.connection.close() + + def test_multi_timer_validation(self): + """ + Verify that timer timeouts are honored appropriately + """ + # Tests timers submitted in order at various timeouts + submit_and_wait_for_completion(self, self.create_timer, 0, 100, 1, 100) + # Tests timers submitted in reverse order at various timeouts + submit_and_wait_for_completion(self, self.create_timer, 100, 0, -1, 100) + # Tests timers submitted in varying order at various timeouts + submit_and_wait_for_completion(self, self.create_timer, 0, 100, 1, 100, True), + + def test_timer_cancellation(self): + """ + Verify that timer cancellation is honored + """ + + # Various lists for tracking callback stage + timeout = .1 + callback = TimerCallback(timeout) + timer = self.create_timer(timeout, callback.invoke) + timer.cancel() + # Release context allow for timer thread to run. + time.sleep(.2) + timer_manager = self._timers + # Assert that the cancellation was honored + self.assertFalse(timer_manager._queue) + self.assertFalse(timer_manager._new_timers) + self.assertFalse(callback.was_invoked()) + + +class ReactorTestMixin(object): + + connection_class = socket_attr_name = None + null_handle_function_args = () + + def get_socket(self, connection): + return getattr(connection, self.socket_attr_name) + + def set_socket(self, connection, obj): + return setattr(connection, self.socket_attr_name, obj) + + def make_header_prefix(self, message_class, version=2, stream_id=0): + return binary_type().join(map(uint8_pack, [ + 0xff & (HEADER_DIRECTION_TO_CLIENT | version), + 0, # flags (compression) + stream_id, + message_class.opcode # opcode + ])) + + def make_connection(self): + c = self.connection_class(DefaultEndPoint('1.2.3.4'), cql_version='3.0.1', connect_timeout=5) + mocket = Mock() + mocket.send.side_effect = lambda x: len(x) + self.set_socket(c, mocket) + return c + + def make_options_body(self): + options_buf = BytesIO() + write_stringmultimap(options_buf, { + 'CQL_VERSION': ['3.0.1'], + 'COMPRESSION': [] + }) + return options_buf.getvalue() + + def make_error_body(self, code, msg): + buf = BytesIO() + write_int(buf, code) + write_string(buf, msg) + return buf.getvalue() + + def make_msg(self, header, body=binary_type()): + return header + uint32_pack(len(body)) + body + + def test_successful_connection(self): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write(*self.null_handle_function_args) + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + self.get_socket(c).recv.return_value = self.make_msg(header, options) + c.handle_read(*self.null_handle_function_args) + + # let it write out a StartupMessage + c.handle_write(*self.null_handle_function_args) + + header = self.make_header_prefix(ReadyMessage, stream_id=1) + self.get_socket(c).recv.return_value = self.make_msg(header) + c.handle_read(*self.null_handle_function_args) + + self.assertTrue(c.connected_event.is_set()) + return c + + def test_eagain_on_buffer_size(self): + self._check_error_recovery_on_buffer_size(errno.EAGAIN) + + def test_ewouldblock_on_buffer_size(self): + self._check_error_recovery_on_buffer_size(errno.EWOULDBLOCK) + + def test_sslwantread_on_buffer_size(self): + self._check_error_recovery_on_buffer_size( + ssl.SSL_ERROR_WANT_READ, + error_class=ssl.SSLError) + + def test_sslwantwrite_on_buffer_size(self): + self._check_error_recovery_on_buffer_size( + ssl.SSL_ERROR_WANT_WRITE, + error_class=ssl.SSLError) + + def _check_error_recovery_on_buffer_size(self, error_code, error_class=socket_error): + c = self.test_successful_connection() + + # current data, used by the recv side_effect + message_chunks = None + + def recv_side_effect(*args): + response = message_chunks.pop(0) + if isinstance(response, error_class): + raise response + else: + return response + + # setup + self.get_socket(c).recv.side_effect = recv_side_effect + c.process_io_buffer = Mock() + + def chunk(size): + return six.b('a') * size + + buf_size = c.in_buffer_size + + # List of messages to test. A message = (chunks, expected_read_size) + messages = [ + ([chunk(200)], 200), + ([chunk(200), chunk(200)], 200), # first chunk < in_buffer_size, process the message + ([chunk(buf_size), error_class(error_code)], buf_size), + ([chunk(buf_size), chunk(buf_size), error_class(error_code)], buf_size*2), + ([chunk(buf_size), chunk(buf_size), chunk(10)], (buf_size*2) + 10), + ([chunk(buf_size), chunk(buf_size), error_class(error_code), chunk(10)], buf_size*2), + ([error_class(error_code), chunk(buf_size)], 0) + ] + + for message, expected_size in messages: + message_chunks = message + c._iobuf = io.BytesIO() + c.process_io_buffer.reset_mock() + c.handle_read(*self.null_handle_function_args) + c._iobuf.seek(0, os.SEEK_END) + + # Ensure the message size is the good one and that the + # message has been processed if it is non-empty + self.assertEqual(c._iobuf.tell(), expected_size) + if expected_size == 0: + c.process_io_buffer.assert_not_called() + else: + c.process_io_buffer.assert_called_once_with() + + def test_protocol_error(self): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write(*self.null_handle_function_args) + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage, version=0xa4) + options = self.make_options_body() + self.get_socket(c).recv.return_value = self.make_msg(header, options) + c.handle_read(*self.null_handle_function_args) + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertTrue(c.connected_event.is_set()) + self.assertIsInstance(c.last_error, ProtocolError) + + def test_error_message_on_startup(self): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write(*self.null_handle_function_args) + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + self.get_socket(c).recv.return_value = self.make_msg(header, options) + c.handle_read(*self.null_handle_function_args) + + # let it write out a StartupMessage + c.handle_write(*self.null_handle_function_args) + + header = self.make_header_prefix(ServerError, stream_id=1) + body = self.make_error_body(ServerError.error_code, ServerError.summary) + self.get_socket(c).recv.return_value = self.make_msg(header, body) + c.handle_read(*self.null_handle_function_args) + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertIsInstance(c.last_error, ConnectionException) + self.assertTrue(c.connected_event.is_set()) + + def test_socket_error_on_write(self): + c = self.make_connection() + + # make the OptionsMessage write fail + self.get_socket(c).send.side_effect = socket_error(errno.EIO, "bad stuff!") + c.handle_write(*self.null_handle_function_args) + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertIsInstance(c.last_error, socket_error) + self.assertTrue(c.connected_event.is_set()) + + def test_blocking_on_write(self): + c = self.make_connection() + + # make the OptionsMessage write block + self.get_socket(c).send.side_effect = socket_error(errno.EAGAIN, + "socket busy") + c.handle_write(*self.null_handle_function_args) + + self.assertFalse(c.is_defunct) + + # try again with normal behavior + self.get_socket(c).send.side_effect = lambda x: len(x) + c.handle_write(*self.null_handle_function_args) + self.assertFalse(c.is_defunct) + self.assertTrue(self.get_socket(c).send.call_args is not None) + + def test_partial_send(self): + c = self.make_connection() + + # only write the first four bytes of the OptionsMessage + write_size = 4 + self.get_socket(c).send.side_effect = None + self.get_socket(c).send.return_value = write_size + c.handle_write(*self.null_handle_function_args) + + msg_size = 9 # v3+ frame header + expected_writes = int(math.ceil(float(msg_size) / write_size)) + size_mod = msg_size % write_size + last_write_size = size_mod if size_mod else write_size + self.assertFalse(c.is_defunct) + self.assertEqual(expected_writes, self.get_socket(c).send.call_count) + self.assertEqual(last_write_size, + len(self.get_socket(c).send.call_args[0][0])) + + def test_socket_error_on_read(self): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write(*self.null_handle_function_args) + + # read in a SupportedMessage response + self.get_socket(c).recv.side_effect = socket_error(errno.EIO, + "busy socket") + c.handle_read(*self.null_handle_function_args) + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertIsInstance(c.last_error, socket_error) + self.assertTrue(c.connected_event.is_set()) + + def test_partial_header_read(self): + c = self.make_connection() + + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + message = self.make_msg(header, options) + + self.get_socket(c).recv.return_value = message[0:1] + c.handle_read(*self.null_handle_function_args) + self.assertEqual(c._iobuf.getvalue(), message[0:1]) + + self.get_socket(c).recv.return_value = message[1:] + c.handle_read(*self.null_handle_function_args) + self.assertEqual(six.binary_type(), c._iobuf.getvalue()) + + # let it write out a StartupMessage + c.handle_write(*self.null_handle_function_args) + + header = self.make_header_prefix(ReadyMessage, stream_id=1) + self.get_socket(c).recv.return_value = self.make_msg(header) + c.handle_read(*self.null_handle_function_args) + + self.assertTrue(c.connected_event.is_set()) + self.assertFalse(c.is_defunct) + + def test_partial_message_read(self): + c = self.make_connection() + + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + message = self.make_msg(header, options) + + # read in the first nine bytes + self.get_socket(c).recv.return_value = message[:9] + c.handle_read(*self.null_handle_function_args) + self.assertEqual(c._iobuf.getvalue(), message[:9]) + + # ... then read in the rest + self.get_socket(c).recv.return_value = message[9:] + c.handle_read(*self.null_handle_function_args) + self.assertEqual(six.binary_type(), c._iobuf.getvalue()) + + # let it write out a StartupMessage + c.handle_write(*self.null_handle_function_args) + + header = self.make_header_prefix(ReadyMessage, stream_id=1) + self.get_socket(c).recv.return_value = self.make_msg(header) + c.handle_read(*self.null_handle_function_args) + + self.assertTrue(c.connected_event.is_set()) + self.assertFalse(c.is_defunct) + + def test_mixed_message_and_buffer_sizes(self): + """ + Validate that all messages are processed with different scenarios: + + - various message sizes + - various socket buffer sizes + - random non-fatal errors raised + """ + c = self.make_connection() + c.process_io_buffer = Mock() + + errors = cycle([ + ssl.SSLError(ssl.SSL_ERROR_WANT_READ), + ssl.SSLError(ssl.SSL_ERROR_WANT_WRITE), + socket_error(errno.EWOULDBLOCK), + socket_error(errno.EAGAIN) + ]) + + for buffer_size in [512, 1024, 2048, 4096, 8192]: + c.in_buffer_size = buffer_size + + for i in range(1, 15): + c.process_io_buffer.reset_mock() + c._iobuf = io.BytesIO() + message = io.BytesIO(six.b('a') * (2**i)) + + def recv_side_effect(*args): + if random.randint(1,10) % 3 == 0: + raise next(errors) + return message.read(args[0]) + + self.get_socket(c).recv.side_effect = recv_side_effect + c.handle_read(*self.null_handle_function_args) + if c._iobuf.tell(): + c.process_io_buffer.assert_called_once() + else: + c.process_io_buffer.assert_not_called() diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py new file mode 100644 index 0000000..fd46731 --- /dev/null +++ b/tests/unit/test_cluster.py @@ -0,0 +1,518 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import logging +import six + +from mock import patch, Mock + +from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ + InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException +from cassandra.cluster import _Scheduler, Session, Cluster, _NOT_SET, default_lbp_factory, \ + ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT, NoHostAvailable +from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, \ + DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy +from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory +from cassandra.pool import Host +from tests.unit.utils import mock_session_pools +from tests import connection_class + + +log = logging.getLogger(__name__) + +class ExceptionTypeTest(unittest.TestCase): + + def test_exception_types(self): + """ + PYTHON-443 + Sanity check to ensure we don't unintentionally change class hierarchy of exception types + """ + self.assertTrue(issubclass(Unavailable, DriverException)) + self.assertTrue(issubclass(Unavailable, RequestExecutionException)) + + self.assertTrue(issubclass(ReadTimeout, DriverException)) + self.assertTrue(issubclass(ReadTimeout, RequestExecutionException)) + self.assertTrue(issubclass(ReadTimeout, Timeout)) + + self.assertTrue(issubclass(WriteTimeout, DriverException)) + self.assertTrue(issubclass(WriteTimeout, RequestExecutionException)) + self.assertTrue(issubclass(WriteTimeout, Timeout)) + + self.assertTrue(issubclass(CoordinationFailure, DriverException)) + self.assertTrue(issubclass(CoordinationFailure, RequestExecutionException)) + + self.assertTrue(issubclass(ReadFailure, DriverException)) + self.assertTrue(issubclass(ReadFailure, RequestExecutionException)) + self.assertTrue(issubclass(ReadFailure, CoordinationFailure)) + + self.assertTrue(issubclass(WriteFailure, DriverException)) + self.assertTrue(issubclass(WriteFailure, RequestExecutionException)) + self.assertTrue(issubclass(WriteFailure, CoordinationFailure)) + + self.assertTrue(issubclass(FunctionFailure, DriverException)) + self.assertTrue(issubclass(FunctionFailure, RequestExecutionException)) + + self.assertTrue(issubclass(RequestValidationException, DriverException)) + + self.assertTrue(issubclass(ConfigurationException, DriverException)) + self.assertTrue(issubclass(ConfigurationException, RequestValidationException)) + + self.assertTrue(issubclass(AlreadyExists, DriverException)) + self.assertTrue(issubclass(AlreadyExists, RequestValidationException)) + self.assertTrue(issubclass(AlreadyExists, ConfigurationException)) + + self.assertTrue(issubclass(InvalidRequest, DriverException)) + self.assertTrue(issubclass(InvalidRequest, RequestValidationException)) + + self.assertTrue(issubclass(Unauthorized, DriverException)) + self.assertTrue(issubclass(Unauthorized, RequestValidationException)) + + self.assertTrue(issubclass(AuthenticationFailed, DriverException)) + + self.assertTrue(issubclass(OperationTimedOut, DriverException)) + + self.assertTrue(issubclass(UnsupportedOperation, DriverException)) + + +class ClusterTest(unittest.TestCase): + + def test_invalid_contact_point_types(self): + with self.assertRaises(ValueError): + Cluster(contact_points=[None], protocol_version=4, connect_timeout=1) + with self.assertRaises(TypeError): + Cluster(contact_points="not a sequence", protocol_version=4, connect_timeout=1) + + def test_requests_in_flight_threshold(self): + d = HostDistance.LOCAL + mn = 3 + mx = 5 + c = Cluster(protocol_version=2) + c.set_min_requests_per_connection(d, mn) + c.set_max_requests_per_connection(d, mx) + # min underflow, max, overflow + for n in (-1, mx, 127): + self.assertRaises(ValueError, c.set_min_requests_per_connection, d, n) + # max underflow, under min, overflow + for n in (0, mn, 128): + self.assertRaises(ValueError, c.set_max_requests_per_connection, d, n) + + +class SchedulerTest(unittest.TestCase): + # TODO: this suite could be expanded; for now just adding a test covering a ticket + + @patch('time.time', return_value=3) # always queue at same time + @patch('cassandra.cluster._Scheduler.run') # don't actually run the thread + def test_event_delay_timing(self, *_): + """ + Schedule something with a time collision to make sure the heap comparison works + + PYTHON-473 + """ + sched = _Scheduler(None) + sched.schedule(0, lambda: None) + sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()"t + + +class SessionTest(unittest.TestCase): + def setUp(self): + if connection_class is None: + raise unittest.SkipTest('libev does not appear to be installed correctly') + connection_class.initialize_reactor() + + # TODO: this suite could be expanded; for now just adding a test covering a PR + @mock_session_pools + def test_default_serial_consistency_level(self, *_): + """ + Make sure default_serial_consistency_level passes through to a query message. + Also make sure Statement.serial_consistency_level overrides the default. + + PR #510 + """ + s = Session(Cluster(protocol_version=4), [Host("127.0.0.1", SimpleConvictionPolicy)]) + + # default is None + self.assertIsNone(s.default_serial_consistency_level) + + # Should fail + with self.assertRaises(ValueError): + s.default_serial_consistency_level = ConsistencyLevel.ANY + with self.assertRaises(ValueError): + s.default_serial_consistency_level = 1001 + + for cl in (None, ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL): + s.default_serial_consistency_level = cl + + # default is passed through + f = s.execute_async(query='') + self.assertEqual(f.message.serial_consistency_level, cl) + + # any non-None statement setting takes precedence + for cl_override in (ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL): + f = s.execute_async(SimpleStatement(query_string='', serial_consistency_level=cl_override)) + self.assertEqual(s.default_serial_consistency_level, cl) + self.assertEqual(f.message.serial_consistency_level, cl_override) + + +class ExecutionProfileTest(unittest.TestCase): + def setUp(self): + if connection_class is None: + raise unittest.SkipTest('libev does not appear to be installed correctly') + connection_class.initialize_reactor() + + def _verify_response_future_profile(self, rf, prof): + self.assertEqual(rf._load_balancer, prof.load_balancing_policy) + self.assertEqual(rf._retry_policy, prof.retry_policy) + self.assertEqual(rf.message.consistency_level, prof.consistency_level) + self.assertEqual(rf.message.serial_consistency_level, prof.serial_consistency_level) + self.assertEqual(rf.timeout, prof.request_timeout) + self.assertEqual(rf.row_factory, prof.row_factory) + + @mock_session_pools + def test_default_exec_parameters(self): + cluster = Cluster() + self.assertEqual(cluster._config_mode, _ConfigMode.UNCOMMITTED) + self.assertEqual(cluster.load_balancing_policy.__class__, default_lbp_factory().__class__) + self.assertEqual(cluster.default_retry_policy.__class__, RetryPolicy) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + self.assertEqual(session.default_timeout, 10.0) + self.assertEqual(session.default_consistency_level, ConsistencyLevel.LOCAL_ONE) + self.assertEqual(session.default_serial_consistency_level, None) + self.assertEqual(session.row_factory, named_tuple_factory) + + @mock_session_pools + def test_default_legacy(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) + self.assertEqual(cluster._config_mode, _ConfigMode.LEGACY) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session.default_timeout = 3.7 + session.default_consistency_level = ConsistencyLevel.ALL + session.default_serial_consistency_level = ConsistencyLevel.SERIAL + rf = session.execute_async("query") + expected_profile = ExecutionProfile(cluster.load_balancing_policy, cluster.default_retry_policy, + session.default_consistency_level, session.default_serial_consistency_level, + session.default_timeout, session.row_factory) + self._verify_response_future_profile(rf, expected_profile) + + @mock_session_pools + def test_default_profile(self): + non_default_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) + cluster = Cluster(execution_profiles={'non-default': non_default_profile}) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + + self.assertEqual(cluster._config_mode, _ConfigMode.PROFILES) + + default_profile = session.get_execution_profile(EXEC_PROFILE_DEFAULT) + rf = session.execute_async("query") + self._verify_response_future_profile(rf, default_profile) + + rf = session.execute_async("query", execution_profile='non-default') + self._verify_response_future_profile(rf, non_default_profile) + + for name, ep in six.iteritems(cluster.profile_manager.profiles): + self.assertEqual(ep, session.get_execution_profile(name)) + + # invalid ep + with self.assertRaises(ValueError): + session.get_execution_profile('non-existent') + + def test_serial_consistency_level_validation(self): + # should pass + ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=ConsistencyLevel.SERIAL) + ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL) + + # should not pass + with self.assertRaises(ValueError): + ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=ConsistencyLevel.ANY) + with self.assertRaises(ValueError): + ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=42) + + @mock_session_pools + def test_statement_params_override_legacy(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) + self.assertEqual(cluster._config_mode, _ConfigMode.LEGACY) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + + ss = SimpleStatement("query", retry_policy=DowngradingConsistencyRetryPolicy(), + consistency_level=ConsistencyLevel.ALL, serial_consistency_level=ConsistencyLevel.SERIAL) + my_timeout = 1.1234 + + self.assertNotEqual(ss.retry_policy.__class__, cluster.default_retry_policy) + self.assertNotEqual(ss.consistency_level, session.default_consistency_level) + self.assertNotEqual(ss._serial_consistency_level, session.default_serial_consistency_level) + self.assertNotEqual(my_timeout, session.default_timeout) + + rf = session.execute_async(ss, timeout=my_timeout) + expected_profile = ExecutionProfile(load_balancing_policy=cluster.load_balancing_policy, retry_policy=ss.retry_policy, + request_timeout=my_timeout, consistency_level=ss.consistency_level, + serial_consistency_level=ss._serial_consistency_level) + self._verify_response_future_profile(rf, expected_profile) + + @mock_session_pools + def test_statement_params_override_profile(self): + non_default_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) + cluster = Cluster(execution_profiles={'non-default': non_default_profile}) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + + self.assertEqual(cluster._config_mode, _ConfigMode.PROFILES) + + rf = session.execute_async("query", execution_profile='non-default') + + ss = SimpleStatement("query", retry_policy=DowngradingConsistencyRetryPolicy(), + consistency_level=ConsistencyLevel.ALL, serial_consistency_level=ConsistencyLevel.SERIAL) + my_timeout = 1.1234 + + self.assertNotEqual(ss.retry_policy.__class__, rf._load_balancer.__class__) + self.assertNotEqual(ss.consistency_level, rf.message.consistency_level) + self.assertNotEqual(ss._serial_consistency_level, rf.message.serial_consistency_level) + self.assertNotEqual(my_timeout, rf.timeout) + + rf = session.execute_async(ss, timeout=my_timeout, execution_profile='non-default') + expected_profile = ExecutionProfile(non_default_profile.load_balancing_policy, ss.retry_policy, + ss.consistency_level, ss._serial_consistency_level, my_timeout, non_default_profile.row_factory) + self._verify_response_future_profile(rf, expected_profile) + + @mock_session_pools + def test_no_profile_with_legacy(self): + # don't construct with both + self.assertRaises(ValueError, Cluster, load_balancing_policy=RoundRobinPolicy(), execution_profiles={'a': ExecutionProfile()}) + self.assertRaises(ValueError, Cluster, default_retry_policy=DowngradingConsistencyRetryPolicy(), execution_profiles={'a': ExecutionProfile()}) + self.assertRaises(ValueError, Cluster, load_balancing_policy=RoundRobinPolicy(), + default_retry_policy=DowngradingConsistencyRetryPolicy(), execution_profiles={'a': ExecutionProfile()}) + + # can't add after + cluster = Cluster(load_balancing_policy=RoundRobinPolicy()) + self.assertRaises(ValueError, cluster.add_execution_profile, 'name', ExecutionProfile()) + + # session settings lock out profiles + cluster = Cluster() + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + for attr, value in (('default_timeout', 1), + ('default_consistency_level', ConsistencyLevel.ANY), + ('default_serial_consistency_level', ConsistencyLevel.SERIAL), + ('row_factory', tuple_factory)): + cluster._config_mode = _ConfigMode.UNCOMMITTED + setattr(session, attr, value) + self.assertRaises(ValueError, cluster.add_execution_profile, 'name' + attr, ExecutionProfile()) + + # don't accept profile + self.assertRaises(ValueError, session.execute_async, "query", execution_profile='some name here') + + @mock_session_pools + def test_no_legacy_with_profile(self): + cluster_init = Cluster(execution_profiles={'name': ExecutionProfile()}) + cluster_add = Cluster() + cluster_add.add_execution_profile('name', ExecutionProfile()) + # for clusters with profiles added either way... + for cluster in (cluster_init, cluster_init): + # don't allow legacy parameters set + for attr, value in (('default_retry_policy', RetryPolicy()), + ('load_balancing_policy', default_lbp_factory())): + self.assertRaises(ValueError, setattr, cluster, attr, value) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + for attr, value in (('default_timeout', 1), + ('default_consistency_level', ConsistencyLevel.ANY), + ('default_serial_consistency_level', ConsistencyLevel.SERIAL), + ('row_factory', tuple_factory)): + self.assertRaises(ValueError, setattr, session, attr, value) + + @mock_session_pools + def test_profile_name_value(self): + + internalized_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) + cluster = Cluster(execution_profiles={'by-name': internalized_profile}) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + self.assertEqual(cluster._config_mode, _ConfigMode.PROFILES) + + rf = session.execute_async("query", execution_profile='by-name') + self._verify_response_future_profile(rf, internalized_profile) + + by_value = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) + rf = session.execute_async("query", execution_profile=by_value) + self._verify_response_future_profile(rf, by_value) + + @mock_session_pools + def test_exec_profile_clone(self): + + cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(), 'one': ExecutionProfile()}) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + + profile_attrs = {'request_timeout': 1, + 'consistency_level': ConsistencyLevel.ANY, + 'serial_consistency_level': ConsistencyLevel.SERIAL, + 'row_factory': tuple_factory, + 'retry_policy': RetryPolicy(), + 'load_balancing_policy': default_lbp_factory()} + reference_attributes = ('retry_policy', 'load_balancing_policy') + + # default and one named + for profile in (EXEC_PROFILE_DEFAULT, 'one'): + active = session.get_execution_profile(profile) + clone = session.execution_profile_clone_update(profile) + self.assertIsNot(clone, active) + + all_updated = session.execution_profile_clone_update(clone, **profile_attrs) + self.assertIsNot(all_updated, clone) + for attr, value in profile_attrs.items(): + self.assertEqual(getattr(clone, attr), getattr(active, attr)) + if attr in reference_attributes: + self.assertIs(getattr(clone, attr), getattr(active, attr)) + self.assertNotEqual(getattr(all_updated, attr), getattr(active, attr)) + + # cannot clone nonexistent profile + self.assertRaises(ValueError, session.execution_profile_clone_update, 'DOES NOT EXIST', **profile_attrs) + + def test_no_profiles_same_name(self): + # can override default in init + cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(), 'one': ExecutionProfile()}) + + # cannot update default + self.assertRaises(ValueError, cluster.add_execution_profile, EXEC_PROFILE_DEFAULT, ExecutionProfile()) + + # cannot update named init + self.assertRaises(ValueError, cluster.add_execution_profile, 'one', ExecutionProfile()) + + # can add new name + cluster.add_execution_profile('two', ExecutionProfile()) + + # cannot add a profile added dynamically + self.assertRaises(ValueError, cluster.add_execution_profile, 'two', ExecutionProfile()) + + def test_warning_on_no_lbp_with_contact_points_legacy_mode(self): + """ + Test that users are warned when they instantiate a Cluster object in + legacy mode with contact points but no load-balancing policy. + + @since 3.12.0 + @jira_ticket PYTHON-812 + @expected_result logs + + @test_category configuration + """ + self._check_warning_on_no_lbp_with_contact_points( + cluster_kwargs={'contact_points': ['127.0.0.1']} + ) + + def test_warning_on_no_lbp_with_contact_points_profile_mode(self): + """ + Test that users are warned when they instantiate a Cluster object in + execution profile mode with contact points but no load-balancing + policy. + + @since 3.12.0 + @jira_ticket PYTHON-812 + @expected_result logs + + @test_category configuration + """ + self._check_warning_on_no_lbp_with_contact_points(cluster_kwargs={ + 'contact_points': ['127.0.0.1'], + 'execution_profiles': {EXEC_PROFILE_DEFAULT: ExecutionProfile()} + }) + + @mock_session_pools + def _check_warning_on_no_lbp_with_contact_points(self, cluster_kwargs): + with patch('cassandra.cluster.log') as patched_logger: + Cluster(**cluster_kwargs) + patched_logger.warning.assert_called_once() + warning_message = patched_logger.warning.call_args[0][0] + self.assertIn('please specify a load-balancing policy', warning_message) + self.assertIn("contact_points = ['127.0.0.1']", warning_message) + + def test_no_warning_on_contact_points_with_lbp_legacy_mode(self): + """ + Test that users aren't warned when they instantiate a Cluster object + with contact points and a load-balancing policy in legacy mode. + + @since 3.12.0 + @jira_ticket PYTHON-812 + @expected_result no logs + + @test_category configuration + """ + self._check_no_warning_on_contact_points_with_lbp({ + 'contact_points': ['127.0.0.1'], + 'load_balancing_policy': object() + }) + + def test_no_warning_on_contact_points_with_lbp_profiles_mode(self): + """ + Test that users aren't warned when they instantiate a Cluster object + with contact points and a load-balancing policy in execution profile + mode. + + @since 3.12.0 + @jira_ticket PYTHON-812 + @expected_result no logs + + @test_category configuration + """ + ep_with_lbp = ExecutionProfile(load_balancing_policy=object()) + self._check_no_warning_on_contact_points_with_lbp(cluster_kwargs={ + 'contact_points': ['127.0.0.1'], + 'execution_profiles': { + EXEC_PROFILE_DEFAULT: ep_with_lbp + } + }) + + @mock_session_pools + def _check_no_warning_on_contact_points_with_lbp(self, cluster_kwargs): + """ + Test that users aren't warned when they instantiate a Cluster object + with contact points and a load-balancing policy. + + @since 3.12.0 + @jira_ticket PYTHON-812 + @expected_result no logs + + @test_category configuration + """ + with patch('cassandra.cluster.log') as patched_logger: + Cluster(**cluster_kwargs) + patched_logger.warning.assert_not_called() + + @mock_session_pools + def test_warning_adding_no_lbp_ep_to_cluster_with_contact_points(self): + ep_with_lbp = ExecutionProfile(load_balancing_policy=object()) + cluster = Cluster( + contact_points=['127.0.0.1'], + execution_profiles={EXEC_PROFILE_DEFAULT: ep_with_lbp}) + with patch('cassandra.cluster.log') as patched_logger: + cluster.add_execution_profile( + name='no_lbp', + profile=ExecutionProfile() + ) + + patched_logger.warning.assert_called_once() + warning_message = patched_logger.warning.call_args[0][0] + self.assertIn('no_lbp', warning_message) + self.assertIn('trying to add', warning_message) + self.assertIn('please specify a load-balancing policy', warning_message) + + @mock_session_pools + def test_no_warning_adding_lbp_ep_to_cluster_with_contact_points(self): + ep_with_lbp = ExecutionProfile(load_balancing_policy=object()) + cluster = Cluster( + contact_points=['127.0.0.1'], + execution_profiles={EXEC_PROFILE_DEFAULT: ep_with_lbp}) + with patch('cassandra.cluster.log') as patched_logger: + cluster.add_execution_profile( + name='with_lbp', + profile=ExecutionProfile(load_balancing_policy=Mock(name='lbp')) + ) + + patched_logger.warning.assert_not_called() diff --git a/tests/unit/test_concurrent.py b/tests/unit/test_concurrent.py new file mode 100644 index 0000000..cc6c12c --- /dev/null +++ b/tests/unit/test_concurrent.py @@ -0,0 +1,260 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from itertools import cycle +from mock import Mock +import time +import threading +from six.moves.queue import PriorityQueue +import sys +import platform + +from cassandra.cluster import Cluster, Session +from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args +from cassandra.pool import Host +from cassandra.policies import SimpleConvictionPolicy +from tests.unit.utils import mock_session_pools + + +class MockResponseResponseFuture(): + """ + This is a mock ResponseFuture. It is used to allow us to hook into the underlying session + and invoke callback with various timing. + """ + + _query_trace = None + _col_names = None + _col_types = None + + # a list pending callbacks, these will be prioritized in reverse or normal orderd + pending_callbacks = PriorityQueue() + + def __init__(self, reverse): + + # if this is true invoke callback in the reverse order then what they were insert + self.reverse = reverse + # hardcoded to avoid paging logic + self.has_more_pages = False + + if(reverse): + self.priority = 100 + else: + self.priority = 0 + + def add_callback(self, fn, *args, **kwargs): + """ + This is used to add a callback our pending list of callbacks. + If reverse is specified we will invoke the callback in the opposite order that we added it + """ + time_added = time.time() + self.pending_callbacks.put((self.priority, (fn, args, kwargs, time_added))) + if not reversed: + self.priority += 1 + else: + self.priority -= 1 + + def add_callbacks(self, callback, errback, + callback_args=(), callback_kwargs=None, + errback_args=(), errback_kwargs=None): + + self.add_callback(callback, *callback_args, **(callback_kwargs or {})) + + def get_next_callback(self): + return self.pending_callbacks.get() + + def has_next_callback(self): + return not self.pending_callbacks.empty() + + def has_more_pages(self): + return False + + def clear_callbacks(self): + return + + +class TimedCallableInvoker(threading.Thread): + """ + This is a local thread which is runs and invokes all the callbacks on the pending callback queue. + The slowdown flag can used to invoke random slowdowns in our simulate queries. + """ + def __init__(self, handler, slowdown=False): + super(TimedCallableInvoker, self).__init__() + self.slowdown = slowdown + self._stopper = threading.Event() + self.handler = handler + + def stop(self): + self._stopper.set() + + def stopped(self): + return self._stopper.isSet() + + def run(self): + while(not self.stopped()): + if(self.handler.has_next_callback()): + pending_callback = self.handler.get_next_callback() + priority_num = pending_callback[0] + if (priority_num % 10) == 0 and self.slowdown: + self._stopper.wait(.1) + callback_args = pending_callback[1] + fn, args, kwargs, time_added = callback_args + fn([time_added], *args, **kwargs) + self._stopper.wait(.001) + return + +class ConcurrencyTest((unittest.TestCase)): + + def test_results_ordering_forward(self): + """ + This tests the ordering of our various concurrent generator class ConcurrentExecutorListResults + when queries complete in the order they were executed. + """ + self.insert_and_validate_list_results(False, False) + + def test_results_ordering_reverse(self): + """ + This tests the ordering of our various concurrent generator class ConcurrentExecutorListResults + when queries complete in the reverse order they were executed. + """ + self.insert_and_validate_list_results(True, False) + + def test_results_ordering_forward_slowdown(self): + """ + This tests the ordering of our various concurrent generator class ConcurrentExecutorListResults + when queries complete in the order they were executed, with slow queries mixed in. + """ + self.insert_and_validate_list_results(False, True) + + def test_results_ordering_reverse_slowdown(self): + """ + This tests the ordering of our various concurrent generator class ConcurrentExecutorListResults + when queries complete in the reverse order they were executed, with slow queries mixed in. + """ + self.insert_and_validate_list_results(True, True) + + def test_results_ordering_forward_generator(self): + """ + This tests the ordering of our various concurrent generator class ConcurrentExecutorGenResults + when queries complete in the order they were executed. + """ + self.insert_and_validate_list_generator(False, False) + + def test_results_ordering_reverse_generator(self): + """ + This tests the ordering of our various concurrent generator class ConcurrentExecutorGenResults + when queries complete in the reverse order they were executed. + """ + self.insert_and_validate_list_generator(True, False) + + def test_results_ordering_forward_generator_slowdown(self): + """ + This tests the ordering of our various concurrent generator class ConcurrentExecutorGenResults + when queries complete in the order they were executed, with slow queries mixed in. + """ + self.insert_and_validate_list_generator(False, True) + + def test_results_ordering_reverse_generator_slowdown(self): + """ + This tests the ordering of our various concurrent generator class ConcurrentExecutorGenResults + when queries complete in the reverse order they were executed, with slow queries mixed in. + """ + self.insert_and_validate_list_generator(True, True) + + def insert_and_validate_list_results(self, reverse, slowdown): + """ + This utility method will execute submit various statements for execution using the ConcurrentExecutorListResults, + then invoke a separate thread to execute the callback associated with the futures registered + for those statements. The parameters will toggle various timing, and ordering changes. + Finally it will validate that the results were returned in the order they were submitted + :param reverse: Execute the callbacks in the opposite order that they were submitted + :param slowdown: Cause intermittent queries to perform slowly + """ + our_handler = MockResponseResponseFuture(reverse=reverse) + mock_session = Mock() + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + mock_session.execute_async.return_value = our_handler + + t = TimedCallableInvoker(our_handler, slowdown=slowdown) + t.start() + results = execute_concurrent(mock_session, statements_and_params) + + while(not our_handler.pending_callbacks.empty()): + time.sleep(.01) + t.stop() + self.validate_result_ordering(results) + + def insert_and_validate_list_generator(self, reverse, slowdown): + """ + This utility method will execute submit various statements for execution using the ConcurrentExecutorGenResults, + then invoke a separate thread to execute the callback associated with the futures registered + for those statements. The parameters will toggle various timing, and ordering changes. + Finally it will validate that the results were returned in the order they were submitted + :param reverse: Execute the callbacks in the opposite order that they were submitted + :param slowdown: Cause intermittent queries to perform slowly + """ + our_handler = MockResponseResponseFuture(reverse=reverse) + mock_session = Mock() + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + mock_session.execute_async.return_value = our_handler + + t = TimedCallableInvoker(our_handler, slowdown=slowdown) + t.start() + try: + results = execute_concurrent(mock_session, statements_and_params, results_generator=True) + self.validate_result_ordering(results) + finally: + t.stop() + + def validate_result_ordering(self, results): + """ + This method will validate that the timestamps returned from the result are in order. This indicates that the + results were returned in the order they were submitted for execution + :param results: + """ + last_time_added = 0 + for success, result in results: + self.assertTrue(success) + current_time_added = list(result)[0] + + #Windows clock granularity makes this equal most of the times + if "Windows" in platform.system(): + self.assertLessEqual(last_time_added, current_time_added) + else: + self.assertLess(last_time_added, current_time_added) + last_time_added = current_time_added + + @mock_session_pools + def test_recursion_limited(self): + """ + Verify that recursion is controlled when raise_on_first_error=False and something is wrong with the query. + + PYTHON-585 + """ + max_recursion = sys.getrecursionlimit() + s = Session(Cluster(), [Host("127.0.0.1", SimpleConvictionPolicy)]) + self.assertRaises(TypeError, execute_concurrent_with_args, s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=True) + + results = execute_concurrent_with_args(s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=False) # previously + self.assertEqual(len(results), max_recursion) + for r in results: + self.assertFalse(r[0]) + self.assertIsInstance(r[1], TypeError) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py new file mode 100644 index 0000000..fccf854 --- /dev/null +++ b/tests/unit/test_connection.py @@ -0,0 +1,489 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from mock import Mock, ANY, call, patch +import six +from six import BytesIO +import time +from threading import Lock + +from cassandra import OperationTimedOut +from cassandra.cluster import Cluster +from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, + locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, + ConnectionException, DefaultEndPoint) +from cassandra.marshal import uint8_pack, uint32_pack, int32_pack +from cassandra.protocol import (write_stringmultimap, write_int, write_string, + SupportedMessage, ProtocolHandler) + + +class ConnectionTest(unittest.TestCase): + + def make_connection(self): + c = Connection(DefaultEndPoint('1.2.3.4')) + c._socket = Mock() + c._socket.send.side_effect = lambda x: len(x) + return c + + def make_header_prefix(self, message_class, version=Connection.protocol_version, stream_id=0): + if Connection.protocol_version < 3: + return six.binary_type().join(map(uint8_pack, [ + 0xff & (HEADER_DIRECTION_TO_CLIENT | version), + 0, # flags (compression) + stream_id, + message_class.opcode # opcode + ])) + else: + return six.binary_type().join(map(uint8_pack, [ + 0xff & (HEADER_DIRECTION_TO_CLIENT | version), + 0, # flags (compression) + 0, # MSB for v3+ stream + stream_id, + message_class.opcode # opcode + ])) + + def make_options_body(self): + options_buf = BytesIO() + write_stringmultimap(options_buf, { + 'CQL_VERSION': ['3.0.1'], + 'COMPRESSION': [] + }) + return options_buf.getvalue() + + def make_error_body(self, code, msg): + buf = BytesIO() + write_int(buf, code) + write_string(buf, msg) + return buf.getvalue() + + def make_msg(self, header, body=""): + return header + uint32_pack(len(body)) + body + + def test_connection_endpoint(self): + endpoint = DefaultEndPoint('1.2.3.4') + c = Connection(endpoint) + self.assertEqual(c.endpoint, endpoint) + self.assertEqual(c.endpoint.address, endpoint.address) + + c = Connection(host=endpoint) # kwarg + self.assertEqual(c.endpoint, endpoint) + self.assertEqual(c.endpoint.address, endpoint.address) + + c = Connection('10.0.0.1') + endpoint = DefaultEndPoint('10.0.0.1') + self.assertEqual(c.endpoint, endpoint) + self.assertEqual(c.endpoint.address, endpoint.address) + + def test_bad_protocol_version(self, *args): + c = self.make_connection() + c._requests = Mock() + c.defunct = Mock() + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage, version=0x7f) + options = self.make_options_body() + message = self.make_msg(header, options) + c._iobuf = BytesIO() + c._iobuf.write(message) + c.process_io_buffer() + + # make sure it errored correctly + c.defunct.assert_called_once_with(ANY) + args, kwargs = c.defunct.call_args + self.assertIsInstance(args[0], ProtocolError) + + def test_negative_body_length(self, *args): + c = self.make_connection() + c._requests = Mock() + c.defunct = Mock() + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage) + message = header + int32_pack(-13) + c._iobuf = BytesIO() + c._iobuf.write(message) + c.process_io_buffer() + + # make sure it errored correctly + c.defunct.assert_called_once_with(ANY) + args, kwargs = c.defunct.call_args + self.assertIsInstance(args[0], ProtocolError) + + def test_unsupported_cql_version(self, *args): + c = self.make_connection() + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} + c.defunct = Mock() + c.cql_version = "3.0.3" + + options_buf = BytesIO() + write_stringmultimap(options_buf, { + 'CQL_VERSION': ['7.8.9'], + 'COMPRESSION': [] + }) + options = options_buf.getvalue() + + c.process_msg(_Frame(version=4, flags=0, stream=0, opcode=SupportedMessage.opcode, body_offset=9, end_pos=9 + len(options)), options) + + # make sure it errored correctly + c.defunct.assert_called_once_with(ANY) + args, kwargs = c.defunct.call_args + self.assertIsInstance(args[0], ProtocolError) + + def test_prefer_lz4_compression(self, *args): + c = self.make_connection() + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} + c.defunct = Mock() + c.cql_version = "3.0.3" + + locally_supported_compressions.pop('lz4', None) + locally_supported_compressions.pop('snappy', None) + locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress') + locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress') + + # read in a SupportedMessage response + options_buf = BytesIO() + write_stringmultimap(options_buf, { + 'CQL_VERSION': ['3.0.3'], + 'COMPRESSION': ['snappy', 'lz4'] + }) + options = options_buf.getvalue() + + c.process_msg(_Frame(version=4, flags=0, stream=0, opcode=SupportedMessage.opcode, body_offset=9, end_pos=9 + len(options)), options) + + self.assertEqual(c.decompressor, locally_supported_compressions['lz4'][1]) + + def test_requested_compression_not_available(self, *args): + c = self.make_connection() + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} + c.defunct = Mock() + # request lz4 compression + c.compression = "lz4" + + locally_supported_compressions.pop('lz4', None) + locally_supported_compressions.pop('snappy', None) + locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress') + locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress') + + # the server only supports snappy + options_buf = BytesIO() + write_stringmultimap(options_buf, { + 'CQL_VERSION': ['3.0.3'], + 'COMPRESSION': ['snappy'] + }) + options = options_buf.getvalue() + + c.process_msg(_Frame(version=4, flags=0, stream=0, opcode=SupportedMessage.opcode, body_offset=9, end_pos=9 + len(options)), options) + + # make sure it errored correctly + c.defunct.assert_called_once_with(ANY) + args, kwargs = c.defunct.call_args + self.assertIsInstance(args[0], ProtocolError) + + def test_use_requested_compression(self, *args): + c = self.make_connection() + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} + c.defunct = Mock() + # request snappy compression + c.compression = "snappy" + + locally_supported_compressions.pop('lz4', None) + locally_supported_compressions.pop('snappy', None) + locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress') + locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress') + + # the server only supports snappy + options_buf = BytesIO() + write_stringmultimap(options_buf, { + 'CQL_VERSION': ['3.0.3'], + 'COMPRESSION': ['snappy', 'lz4'] + }) + options = options_buf.getvalue() + + c.process_msg(_Frame(version=4, flags=0, stream=0, opcode=SupportedMessage.opcode, body_offset=9, end_pos=9 + len(options)), options) + + self.assertEqual(c.decompressor, locally_supported_compressions['snappy'][1]) + + def test_disable_compression(self, *args): + c = self.make_connection() + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} + c.defunct = Mock() + # disable compression + c.compression = False + + locally_supported_compressions.pop('lz4', None) + locally_supported_compressions.pop('snappy', None) + locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress') + locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress') + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage) + + # the server only supports snappy + options_buf = BytesIO() + write_stringmultimap(options_buf, { + 'CQL_VERSION': ['3.0.3'], + 'COMPRESSION': ['snappy', 'lz4'] + }) + options = options_buf.getvalue() + + message = self.make_msg(header, options) + c.process_msg(message, len(message) - 8) + + self.assertEqual(c.decompressor, None) + + def test_not_implemented(self): + """ + Ensure the following methods throw NIE's. If not, come back and test them. + """ + c = self.make_connection() + self.assertRaises(NotImplementedError, c.close) + + def test_set_keyspace_blocking(self): + c = self.make_connection() + + self.assertEqual(c.keyspace, None) + c.set_keyspace_blocking(None) + self.assertEqual(c.keyspace, None) + + c.keyspace = 'ks' + c.set_keyspace_blocking('ks') + self.assertEqual(c.keyspace, 'ks') + + def test_set_connection_class(self): + cluster = Cluster(connection_class='test') + self.assertEqual('test', cluster.connection_class) + + +@patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped') +class ConnectionHeartbeatTest(unittest.TestCase): + + @staticmethod + def make_get_holders(len): + holders = [] + for _ in range(len): + holder = Mock() + holder.get_connections = Mock(return_value=[]) + holders.append(holder) + get_holders = Mock(return_value=holders) + return get_holders + + def run_heartbeat(self, get_holders_fun, count=2, interval=0.05, timeout=0.05): + ch = ConnectionHeartbeat(interval, get_holders_fun, timeout=timeout) + time.sleep(interval * count) + ch.stop() + self.assertTrue(get_holders_fun.call_count) + + def test_empty_connections(self, *args): + count = 3 + get_holders = self.make_get_holders(1) + + self.run_heartbeat(get_holders, count) + + self.assertGreaterEqual(get_holders.call_count, count - 1) # lower bound to account for thread spinup time + self.assertLessEqual(get_holders.call_count, count) + holder = get_holders.return_value[0] + holder.get_connections.assert_has_calls([call()] * get_holders.call_count) + + def test_idle_non_idle(self, *args): + request_id = 999 + + # connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) + def send_msg(msg, req_id, msg_callback): + msg_callback(SupportedMessage([], {})) + + idle_connection = Mock(spec=Connection, host='localhost', + max_request_id=127, + lock=Lock(), + in_flight=0, is_idle=True, + is_defunct=False, is_closed=False, + get_request_id=lambda: request_id, + send_msg=Mock(side_effect=send_msg)) + non_idle_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=False) + + get_holders = self.make_get_holders(1) + holder = get_holders.return_value[0] + holder.get_connections.return_value.append(idle_connection) + holder.get_connections.return_value.append(non_idle_connection) + + self.run_heartbeat(get_holders) + + holder.get_connections.assert_has_calls([call()] * get_holders.call_count) + self.assertEqual(idle_connection.in_flight, 0) + self.assertEqual(non_idle_connection.in_flight, 0) + + idle_connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) + self.assertEqual(non_idle_connection.send_msg.call_count, 0) + + def test_closed_defunct(self, *args): + get_holders = self.make_get_holders(1) + closed_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=False, is_closed=True) + defunct_connection = Mock(spec=Connection, in_flight=0, is_idle=False, is_defunct=True, is_closed=False) + holder = get_holders.return_value[0] + holder.get_connections.return_value.append(closed_connection) + holder.get_connections.return_value.append(defunct_connection) + + self.run_heartbeat(get_holders) + + holder.get_connections.assert_has_calls([call()] * get_holders.call_count) + self.assertEqual(closed_connection.in_flight, 0) + self.assertEqual(defunct_connection.in_flight, 0) + self.assertEqual(closed_connection.send_msg.call_count, 0) + self.assertEqual(defunct_connection.send_msg.call_count, 0) + + def test_no_req_ids(self, *args): + in_flight = 3 + + get_holders = self.make_get_holders(1) + max_connection = Mock(spec=Connection, host='localhost', + lock=Lock(), + max_request_id=in_flight - 1, in_flight=in_flight, + is_idle=True, is_defunct=False, is_closed=False) + holder = get_holders.return_value[0] + holder.get_connections.return_value.append(max_connection) + + self.run_heartbeat(get_holders) + + holder.get_connections.assert_has_calls([call()] * get_holders.call_count) + self.assertEqual(max_connection.in_flight, in_flight) + self.assertEqual(max_connection.send_msg.call_count, 0) + self.assertEqual(max_connection.send_msg.call_count, 0) + max_connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) + holder.return_connection.assert_has_calls( + [call(max_connection)] * get_holders.call_count) + + def test_unexpected_response(self, *args): + request_id = 999 + + get_holders = self.make_get_holders(1) + + def send_msg(msg, req_id, msg_callback): + msg_callback(object()) + + connection = Mock(spec=Connection, host='localhost', + max_request_id=127, + lock=Lock(), + in_flight=0, is_idle=True, + is_defunct=False, is_closed=False, + get_request_id=lambda: request_id, + send_msg=Mock(side_effect=send_msg)) + holder = get_holders.return_value[0] + holder.get_connections.return_value.append(connection) + + self.run_heartbeat(get_holders) + + self.assertEqual(connection.in_flight, get_holders.call_count) + connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) + connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) + exc = connection.defunct.call_args_list[0][0][0] + self.assertIsInstance(exc, ConnectionException) + self.assertRegexpMatches(exc.args[0], r'^Received unexpected response to OptionsMessage.*') + holder.return_connection.assert_has_calls( + [call(connection)] * get_holders.call_count) + + def test_timeout(self, *args): + request_id = 999 + + get_holders = self.make_get_holders(1) + + def send_msg(msg, req_id, msg_callback): + pass + + # we used endpoint=X here because it's a mock and we need connection.endpoint to be set + connection = Mock(spec=Connection, endpoint=DefaultEndPoint('localhost'), + max_request_id=127, + lock=Lock(), + in_flight=0, is_idle=True, + is_defunct=False, is_closed=False, + get_request_id=lambda: request_id, + send_msg=Mock(side_effect=send_msg)) + holder = get_holders.return_value[0] + holder.get_connections.return_value.append(connection) + + self.run_heartbeat(get_holders) + + self.assertEqual(connection.in_flight, get_holders.call_count) + connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) + connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) + exc = connection.defunct.call_args_list[0][0][0] + self.assertIsInstance(exc, OperationTimedOut) + self.assertEqual(exc.errors, 'Connection heartbeat timeout after 0.05 seconds') + self.assertEqual(exc.last_host, DefaultEndPoint('localhost')) + holder.return_connection.assert_has_calls( + [call(connection)] * get_holders.call_count) + + +class TimerTest(unittest.TestCase): + + def test_timer_collision(self): + # simple test demonstrating #466 + # same timeout, comparison will defer to the Timer object itself + t1 = Timer(0, lambda: None) + t2 = Timer(0, lambda: None) + t2.end = t1.end + + tm = TimerManager() + tm.add_timer(t1) + tm.add_timer(t2) + # Prior to #466: "TypeError: unorderable types: Timer() < Timer()" + tm.service_timeouts() + + +class DefaultEndPointTest(unittest.TestCase): + + def test_default_endpoint_properties(self): + endpoint = DefaultEndPoint('10.0.0.1') + self.assertEqual(endpoint.address, '10.0.0.1') + self.assertEqual(endpoint.port, 9042) + self.assertEqual(str(endpoint), '10.0.0.1:9042') + + endpoint = DefaultEndPoint('10.0.0.1', 8888) + self.assertEqual(endpoint.address, '10.0.0.1') + self.assertEqual(endpoint.port, 8888) + self.assertEqual(str(endpoint), '10.0.0.1:8888') + + def test_endpoint_equality(self): + self.assertEqual( + DefaultEndPoint('10.0.0.1'), + DefaultEndPoint('10.0.0.1') + ) + + self.assertEqual( + DefaultEndPoint('10.0.0.1'), + DefaultEndPoint('10.0.0.1', 9042) + ) + + self.assertNotEqual( + DefaultEndPoint('10.0.0.1'), + DefaultEndPoint('10.0.0.2') + ) + + self.assertNotEqual( + DefaultEndPoint('10.0.0.1'), + DefaultEndPoint('10.0.0.1', 0000) + ) + + def test_endpoint_resolve(self): + self.assertEqual( + DefaultEndPoint('10.0.0.1').resolve(), + ('10.0.0.1', 9042) + ) + + self.assertEqual( + DefaultEndPoint('10.0.0.1', 3232).resolve(), + ('10.0.0.1', 3232) + ) diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py new file mode 100644 index 0000000..e76fbd2 --- /dev/null +++ b/tests/unit/test_control_connection.py @@ -0,0 +1,519 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import six + +from concurrent.futures import ThreadPoolExecutor +from mock import Mock, ANY, call + +from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType +from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS +from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile +from cassandra.pool import Host +from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory +from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, + ConstantReconnectionPolicy, IdentityTranslator) + +PEER_IP = "foobar" + + +class MockMetadata(object): + + def __init__(self): + self.hosts = { + DefaultEndPoint("192.168.1.0"): Host(DefaultEndPoint("192.168.1.0"), SimpleConvictionPolicy), + DefaultEndPoint("192.168.1.1"): Host(DefaultEndPoint("192.168.1.1"), SimpleConvictionPolicy), + DefaultEndPoint("192.168.1.2"): Host(DefaultEndPoint("192.168.1.2"), SimpleConvictionPolicy) + } + for host in self.hosts.values(): + host.set_up() + + self.cluster_name = None + self.partitioner = None + self.token_map = {} + + def get_host(self, endpoint_or_address): + if not isinstance(endpoint_or_address, EndPoint): + for host in six.itervalues(self.hosts): + if host.address == endpoint_or_address: + return host + else: + return self.hosts.get(endpoint_or_address) + + def all_hosts(self): + return self.hosts.values() + + def rebuild_token_map(self, partitioner, token_map): + self.partitioner = partitioner + self.token_map = token_map + + +class MockCluster(object): + + max_schema_agreement_wait = 5 + profile_manager = ProfileManager() + reconnection_policy = ConstantReconnectionPolicy(2) + address_translator = IdentityTranslator() + down_host = None + contact_points = [] + is_shutdown = False + + def __init__(self): + self.metadata = MockMetadata() + self.added_hosts = [] + self.removed_hosts = [] + self.scheduler = Mock(spec=_Scheduler) + self.executor = Mock(spec=ThreadPoolExecutor) + self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(RoundRobinPolicy()) + self.endpoint_factory = DefaultEndPointFactory().configure(self) + + def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True): + host = Host(endpoint, SimpleConvictionPolicy, datacenter, rack) + self.added_hosts.append(host) + return host + + def remove_host(self, host): + self.removed_hosts.append(host) + + def on_up(self, host): + pass + + def on_down(self, host, is_host_addition): + self.down_host = host + + +class MockConnection(object): + + is_defunct = False + + def __init__(self): + self.endpoint = DefaultEndPoint("192.168.1.0") + self.local_results = [ + ["schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens"], + [["a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["0", "100", "200"]]] + ] + + self.peer_results = [ + ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens"], + [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"]], + ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"]]] + ] + local_response = ResultMessage( + kind=RESULT_KIND_ROWS, results=self.local_results) + peer_response = ResultMessage( + kind=RESULT_KIND_ROWS, results=self.peer_results) + + self.wait_for_responses = Mock(return_value=(peer_response, local_response)) + + +class FakeTime(object): + + def __init__(self): + self.clock = 0 + + def time(self): + return self.clock + + def sleep(self, amount): + self.clock += amount + + +class ControlConnectionTest(unittest.TestCase): + + def setUp(self): + self.cluster = MockCluster() + self.connection = MockConnection() + self.time = FakeTime() + + self.control_connection = ControlConnection(self.cluster, 1, 0, 0, 0) + self.control_connection._connection = self.connection + self.control_connection._time = self.time + + def _get_matching_schema_preloaded_results(self): + local_results = [ + ["schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens"], + [["a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["0", "100", "200"]]] + ] + local_response = ResultMessage(kind=RESULT_KIND_ROWS, results=local_results) + + peer_results = [ + ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens"], + [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"]], + ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"]]] + ] + peer_response = ResultMessage(kind=RESULT_KIND_ROWS, results=peer_results) + + return (peer_response, local_response) + + def _get_nonmatching_schema_preloaded_results(self): + local_results = [ + ["schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens"], + [["a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["0", "100", "200"]]] + ] + local_response = ResultMessage(kind=RESULT_KIND_ROWS, results=local_results) + + peer_results = [ + ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens"], + [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"]], + ["192.168.1.2", "10.0.0.2", "b", "dc1", "rack1", ["2", "102", "202"]]] + ] + peer_response = ResultMessage(kind=RESULT_KIND_ROWS, results=peer_results) + + return (peer_response, local_response) + + def test_wait_for_schema_agreement(self): + """ + Basic test with all schema versions agreeing + """ + self.assertTrue(self.control_connection.wait_for_schema_agreement()) + # the control connection should not have slept at all + self.assertEqual(self.time.clock, 0) + + def test_wait_for_schema_agreement_uses_preloaded_results_if_given(self): + """ + wait_for_schema_agreement uses preloaded results if given for shared table queries + """ + preloaded_results = self._get_matching_schema_preloaded_results() + + self.assertTrue(self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results)) + # the control connection should not have slept at all + self.assertEqual(self.time.clock, 0) + # the connection should not have made any queries if given preloaded results + self.assertEqual(self.connection.wait_for_responses.call_count, 0) + + def test_wait_for_schema_agreement_falls_back_to_querying_if_schemas_dont_match_preloaded_result(self): + """ + wait_for_schema_agreement requery if schema does not match using preloaded results + """ + preloaded_results = self._get_nonmatching_schema_preloaded_results() + + self.assertTrue(self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results)) + # the control connection should not have slept at all + self.assertEqual(self.time.clock, 0) + self.assertEqual(self.connection.wait_for_responses.call_count, 1) + + def test_wait_for_schema_agreement_fails(self): + """ + Make sure the control connection sleeps and retries + """ + # change the schema version on one node + self.connection.peer_results[1][1][2] = 'b' + self.assertFalse(self.control_connection.wait_for_schema_agreement()) + # the control connection should have slept until it hit the limit + self.assertGreaterEqual(self.time.clock, self.cluster.max_schema_agreement_wait) + + def test_wait_for_schema_agreement_skipping(self): + """ + If rpc_address or schema_version isn't set, the host should be skipped + """ + # an entry with no schema_version + self.connection.peer_results[1].append( + ["192.168.1.3", "10.0.0.3", None, "dc1", "rack1", ["3", "103", "203"]] + ) + # an entry with a different schema_version and no rpc_address + self.connection.peer_results[1].append( + [None, None, "b", "dc1", "rack1", ["4", "104", "204"]] + ) + + # change the schema version on one of the existing entries + self.connection.peer_results[1][1][3] = 'c' + self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.1')).is_up = False + + self.assertTrue(self.control_connection.wait_for_schema_agreement()) + self.assertEqual(self.time.clock, 0) + + def test_wait_for_schema_agreement_rpc_lookup(self): + """ + If the rpc_address is 0.0.0.0, the "peer" column should be used instead. + """ + self.connection.peer_results[1].append( + ["0.0.0.0", PEER_IP, "b", "dc1", "rack1", ["3", "103", "203"]] + ) + host = Host(DefaultEndPoint("0.0.0.0"), SimpleConvictionPolicy) + self.cluster.metadata.hosts[DefaultEndPoint("foobar")] = host + host.is_up = False + + # even though the new host has a different schema version, it's + # marked as down, so the control connection shouldn't care + self.assertTrue(self.control_connection.wait_for_schema_agreement()) + self.assertEqual(self.time.clock, 0) + + # but once we mark it up, the control connection will care + host.is_up = True + self.assertFalse(self.control_connection.wait_for_schema_agreement()) + self.assertGreaterEqual(self.time.clock, self.cluster.max_schema_agreement_wait) + + def test_refresh_nodes_and_tokens(self): + self.control_connection.refresh_node_list_and_token_map() + meta = self.cluster.metadata + self.assertEqual(meta.partitioner, 'Murmur3Partitioner') + self.assertEqual(meta.cluster_name, 'foocluster') + + # check token map + self.assertEqual(sorted(meta.all_hosts()), sorted(meta.token_map.keys())) + for token_list in meta.token_map.values(): + self.assertEqual(3, len(token_list)) + + # check datacenter/rack + for host in meta.all_hosts(): + self.assertEqual(host.datacenter, "dc1") + self.assertEqual(host.rack, "rack1") + + self.assertEqual(self.connection.wait_for_responses.call_count, 1) + + def test_refresh_nodes_and_tokens_uses_preloaded_results_if_given(self): + """ + refresh_nodes_and_tokens uses preloaded results if given for shared table queries + """ + preloaded_results = self._get_matching_schema_preloaded_results() + + self.control_connection._refresh_node_list_and_token_map(self.connection, preloaded_results=preloaded_results) + meta = self.cluster.metadata + self.assertEqual(meta.partitioner, 'Murmur3Partitioner') + self.assertEqual(meta.cluster_name, 'foocluster') + + # check token map + self.assertEqual(sorted(meta.all_hosts()), sorted(meta.token_map.keys())) + for token_list in meta.token_map.values(): + self.assertEqual(3, len(token_list)) + + # check datacenter/rack + for host in meta.all_hosts(): + self.assertEqual(host.datacenter, "dc1") + self.assertEqual(host.rack, "rack1") + + # the connection should not have made any queries if given preloaded results + self.assertEqual(self.connection.wait_for_responses.call_count, 0) + + def test_refresh_nodes_and_tokens_no_partitioner(self): + """ + Test handling of an unknown partitioner. + """ + # set the partitioner column to None + self.connection.local_results[1][0][4] = None + self.control_connection.refresh_node_list_and_token_map() + meta = self.cluster.metadata + self.assertEqual(meta.partitioner, None) + self.assertEqual(meta.token_map, {}) + + def test_refresh_nodes_and_tokens_add_host(self): + self.connection.peer_results[1].append( + ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", ["3", "103", "203"]] + ) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + self.control_connection.refresh_node_list_and_token_map() + self.assertEqual(1, len(self.cluster.added_hosts)) + self.assertEqual(self.cluster.added_hosts[0].address, "192.168.1.3") + self.assertEqual(self.cluster.added_hosts[0].datacenter, "dc1") + self.assertEqual(self.cluster.added_hosts[0].rack, "rack1") + + def test_refresh_nodes_and_tokens_remove_host(self): + del self.connection.peer_results[1][1] + self.control_connection.refresh_node_list_and_token_map() + self.assertEqual(1, len(self.cluster.removed_hosts)) + self.assertEqual(self.cluster.removed_hosts[0].address, "192.168.1.2") + + def test_refresh_nodes_and_tokens_timeout(self): + + def bad_wait_for_responses(*args, **kwargs): + self.assertEqual(kwargs['timeout'], self.control_connection._timeout) + raise OperationTimedOut() + + self.connection.wait_for_responses = bad_wait_for_responses + self.control_connection.refresh_node_list_and_token_map() + self.cluster.executor.submit.assert_called_with(self.control_connection._reconnect) + + def test_refresh_schema_timeout(self): + + def bad_wait_for_responses(*args, **kwargs): + self.time.sleep(kwargs['timeout']) + raise OperationTimedOut() + + self.connection.wait_for_responses = Mock(side_effect=bad_wait_for_responses) + self.control_connection.refresh_schema() + self.assertEqual(self.connection.wait_for_responses.call_count, self.cluster.max_schema_agreement_wait / self.control_connection._timeout) + self.assertEqual(self.connection.wait_for_responses.call_args[1]['timeout'], self.control_connection._timeout) + + def test_handle_topology_change(self): + event = { + 'change_type': 'NEW_NODE', + 'address': ('1.2.3.4', 9000) + } + self.cluster.scheduler.reset_mock() + self.control_connection._handle_topology_change(event) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection._refresh_nodes_if_not_up, None) + + event = { + 'change_type': 'REMOVED_NODE', + 'address': ('1.2.3.4', 9000) + } + self.cluster.scheduler.reset_mock() + self.control_connection._handle_topology_change(event) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.remove_host, None) + + event = { + 'change_type': 'MOVED_NODE', + 'address': ('1.2.3.4', 9000) + } + self.cluster.scheduler.reset_mock() + self.control_connection._handle_topology_change(event) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection._refresh_nodes_if_not_up, None) + + def test_handle_status_change(self): + event = { + 'change_type': 'UP', + 'address': ('1.2.3.4', 9000) + } + self.cluster.scheduler.reset_mock() + self.control_connection._handle_status_change(event) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map) + + # do the same with a known Host + event = { + 'change_type': 'UP', + 'address': ('192.168.1.0', 9000) + } + self.cluster.scheduler.reset_mock() + self.control_connection._handle_status_change(event) + host = self.cluster.metadata.hosts[DefaultEndPoint('192.168.1.0')] + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.on_up, host) + + self.cluster.scheduler.schedule.reset_mock() + event = { + 'change_type': 'DOWN', + 'address': ('1.2.3.4', 9000) + } + self.control_connection._handle_status_change(event) + self.assertFalse(self.cluster.scheduler.schedule.called) + + # do the same with a known Host + event = { + 'change_type': 'DOWN', + 'address': ('192.168.1.0', 9000) + } + self.control_connection._handle_status_change(event) + host = self.cluster.metadata.hosts[DefaultEndPoint('192.168.1.0')] + self.assertIs(host, self.cluster.down_host) + + def test_handle_schema_change(self): + + change_types = [getattr(SchemaChangeType, attr) for attr in vars(SchemaChangeType) if attr[0] != '_'] + for change_type in change_types: + event = { + 'target_type': SchemaTargetType.TABLE, + 'change_type': change_type, + 'keyspace': 'ks1', + 'table': 'table1' + } + self.cluster.scheduler.reset_mock() + self.control_connection._handle_schema_change(event) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_schema, **event) + + self.cluster.scheduler.reset_mock() + event['target_type'] = SchemaTargetType.KEYSPACE + del event['table'] + self.control_connection._handle_schema_change(event) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_schema, **event) + + def test_refresh_disabled(self): + cluster = MockCluster() + + schema_event = { + 'target_type': SchemaTargetType.TABLE, + 'change_type': SchemaChangeType.CREATED, + 'keyspace': 'ks1', + 'table': 'table1' + } + + status_event = { + 'change_type': 'UP', + 'address': ('1.2.3.4', 9000) + } + + topo_event = { + 'change_type': 'MOVED_NODE', + 'address': ('1.2.3.4', 9000) + } + + cc_no_schema_refresh = ControlConnection(cluster, 1, -1, 0, 0) + cluster.scheduler.reset_mock() + + # no call on schema refresh + cc_no_schema_refresh._handle_schema_change(schema_event) + self.assertFalse(cluster.scheduler.schedule.called) + self.assertFalse(cluster.scheduler.schedule_unique.called) + + # topo and status changes as normal + cc_no_schema_refresh._handle_status_change(status_event) + cc_no_schema_refresh._handle_topology_change(topo_event) + cluster.scheduler.schedule_unique.assert_has_calls([call(ANY, cc_no_schema_refresh.refresh_node_list_and_token_map), + call(ANY, cc_no_schema_refresh._refresh_nodes_if_not_up, None)]) + + cc_no_topo_refresh = ControlConnection(cluster, 1, 0, -1, 0) + cluster.scheduler.reset_mock() + + # no call on topo refresh + cc_no_topo_refresh._handle_topology_change(topo_event) + self.assertFalse(cluster.scheduler.schedule.called) + self.assertFalse(cluster.scheduler.schedule_unique.called) + + # schema and status change refresh as normal + cc_no_topo_refresh._handle_status_change(status_event) + cc_no_topo_refresh._handle_schema_change(schema_event) + cluster.scheduler.schedule_unique.assert_has_calls([call(ANY, cc_no_topo_refresh.refresh_node_list_and_token_map), + call(0.0, cc_no_topo_refresh.refresh_schema, + **schema_event)]) + + +class EventTimingTest(unittest.TestCase): + """ + A simple test to validate that event scheduling happens in order + Added for PYTHON-358 + """ + def setUp(self): + self.cluster = MockCluster() + self.connection = MockConnection() + self.time = FakeTime() + + # Use 2 for the schema_event_refresh_window which is what we would normally default to. + self.control_connection = ControlConnection(self.cluster, 1, 2, 0, 0) + self.control_connection._connection = self.connection + self.control_connection._time = self.time + + def test_event_delay_timing(self): + """ + Submits a wide array of events make sure that each is scheduled to occur in the order they were received + """ + prior_delay = 0 + for _ in range(100): + for change_type in ('CREATED', 'DROPPED', 'UPDATED'): + event = { + 'change_type': change_type, + 'keyspace': '1', + 'table': 'table1' + } + # This is to increment the fake time, we don't actually sleep here. + self.time.sleep(.001) + self.cluster.scheduler.reset_mock() + self.control_connection._handle_schema_change(event) + self.cluster.scheduler.mock_calls + # Grabs the delay parameter from the scheduler invocation + current_delay = self.cluster.scheduler.mock_calls[0][1][0] + self.assertLess(prior_delay, current_delay) + prior_delay = current_delay diff --git a/tests/unit/test_exception.py b/tests/unit/test_exception.py new file mode 100644 index 0000000..3a082f7 --- /dev/null +++ b/tests/unit/test_exception.py @@ -0,0 +1,56 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest + +from cassandra import Unavailable, Timeout, ConsistencyLevel +import re + + +class ConsistencyExceptionTest(unittest.TestCase): + """ + Verify Cassandra Exception string representation + """ + + def extract_consistency(self, msg): + """ + Given message that has 'consistency': 'value', extract consistency value as a string + :param msg: message with consistency value + :return: String representing consistency value + """ + match = re.search("'consistency':\s+'([\w\s]+)'", msg) + return match and match.group(1) + + def test_timeout_consistency(self): + """ + Verify that Timeout exception object translates consistency from input value to correct output string + """ + consistency_str = self.extract_consistency(repr(Timeout("Timeout Message", consistency=None))) + self.assertEqual(consistency_str, 'Not Set') + for c in ConsistencyLevel.value_to_name.keys(): + consistency_str = self.extract_consistency(repr(Timeout("Timeout Message", consistency=c))) + self.assertEqual(consistency_str, ConsistencyLevel.value_to_name[c]) + + def test_unavailable_consistency(self): + """ + Verify that Unavailable exception object translates consistency from input value to correct output string + """ + consistency_str = self.extract_consistency(repr(Unavailable("Unavailable Message", consistency=None))) + self.assertEqual(consistency_str, 'Not Set') + for c in ConsistencyLevel.value_to_name.keys(): + consistency_str = self.extract_consistency(repr(Unavailable("Timeout Message", consistency=c))) + self.assertEqual(consistency_str, ConsistencyLevel.value_to_name[c]) diff --git a/tests/unit/test_marshalling.py b/tests/unit/test_marshalling.py new file mode 100644 index 0000000..c2363e0 --- /dev/null +++ b/tests/unit/test_marshalling.py @@ -0,0 +1,148 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from cassandra import ProtocolVersion + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import platform +from datetime import datetime, date +from decimal import Decimal +from uuid import UUID + +from cassandra.cqltypes import lookup_casstype, DecimalType, UTF8Type, DateType +from cassandra.util import OrderedMapSerializedKey, sortedset, Time, Date + +marshalled_value_pairs = ( + # binary form, type, python native type + (b'lorem ipsum dolor sit amet', 'AsciiType', 'lorem ipsum dolor sit amet'), + (b'', 'AsciiType', ''), + (b'\x01', 'BooleanType', True), + (b'\x00', 'BooleanType', False), + (b'', 'BooleanType', None), + (b'\xff\xfe\xfd\xfc\xfb', 'BytesType', b'\xff\xfe\xfd\xfc\xfb'), + (b'', 'BytesType', b''), + (b'\x7f\xff\xff\xff\xff\xff\xff\xff', 'CounterColumnType', 9223372036854775807), + (b'\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808), + (b'', 'CounterColumnType', None), + (b'\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)), + (b'\x00\x00\x01P\xc5~L\x00', 'DateType', datetime(2015, 11, 2)), + (b'', 'DateType', None), + (b'\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')), + (b'\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')), + (b'\x00\x00\x00\x14\x00\xfa\xce', 'DecimalType', Decimal('0.00000000000000064206')), + (b'\x00\x00\x00\x14\xff\x052', 'DecimalType', Decimal('-0.00000000000000064206')), + (b'\xff\xff\xff\x9c\x00\xfa\xce', 'DecimalType', Decimal('64206e100')), + (b'', 'DecimalType', None), + (b'@\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', 19432.125), + (b'\xc0\xd2\xfa\x08\x00\x00\x00\x00', 'DoubleType', -19432.125), + (b'\x7f\xef\x00\x00\x00\x00\x00\x00', 'DoubleType', 1.7415152243978685e+308), + (b'', 'DoubleType', None), + (b'F\x97\xd0@', 'FloatType', 19432.125), + (b'\xc6\x97\xd0@', 'FloatType', -19432.125), + (b'\xc6\x97\xd0@', 'FloatType', -19432.125), + (b'\x7f\x7f\x00\x00', 'FloatType', 338953138925153547590470800371487866880.0), + (b'', 'FloatType', None), + (b'\x7f\x50\x00\x00', 'Int32Type', 2135949312), + (b'\xff\xfd\xcb\x91', 'Int32Type', -144495), + (b'', 'Int32Type', None), + (b'f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15', 'IntegerType', 123456789123456789123456789), + (b'', 'IntegerType', None), + (b'\x7f\xff\xff\xff\xff\xff\xff\xff', 'LongType', 9223372036854775807), + (b'\x80\x00\x00\x00\x00\x00\x00\x00', 'LongType', -9223372036854775808), + (b'', 'LongType', None), + (b'', 'InetAddressType', None), + (b'A46\xa9', 'InetAddressType', '65.52.54.169'), + (b'*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'), + (b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'), + (b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000), + (b'', 'UTF8Type', u''), + (b'\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), + (b'I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')), + (b'', 'UUIDType', None), + (b'', 'MapType(AsciiType, BooleanType)', None), + (b'', 'ListType(FloatType)', None), + (b'', 'SetType(LongType)', None), + (b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMapSerializedKey(DecimalType, 0)), + (b'\x00\x00', 'ListType(FloatType)', []), + (b'\x00\x00', 'SetType(IntegerType)', sortedset()), + (b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]), + (b'\x80\x00\x00\x01', 'SimpleDateType', Date(1)), + (b'\x7f\xff\xff\xff', 'SimpleDateType', Date('1969-12-31')), + (b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', Time(1)), + (b'\x7f', 'ByteType', 127), + (b'\x80', 'ByteType', -128), + (b'\x7f\xff', 'ShortType', 32767), + (b'\x80\x00', 'ShortType', -32768) +) + +ordered_map_value = OrderedMapSerializedKey(UTF8Type, 2) +ordered_map_value._insert(u'\u307fbob', 199) +ordered_map_value._insert(u'', -1) +ordered_map_value._insert(u'\\', 0) + +# these following entries work for me right now, but they're dependent on +# vagaries of internal python ordering for unordered types +marshalled_value_pairs_unsafe = ( + (b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_map_value), + (b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])), + (b'\x00', 'IntegerType', 0), +) + +if platform.python_implementation() == 'CPython': + # Only run tests for entries which depend on internal python ordering under + # CPython + marshalled_value_pairs += marshalled_value_pairs_unsafe + + +class UnmarshalTest(unittest.TestCase): + def test_unmarshalling(self): + for serializedval, valtype, nativeval in marshalled_value_pairs: + unmarshaller = lookup_casstype(valtype) + whatwegot = unmarshaller.from_binary(serializedval, 1) + self.assertEqual(whatwegot, nativeval, + msg='Unmarshaller for %s (%s) failed: unmarshal(%r) got %r instead of %r' + % (valtype, unmarshaller, serializedval, whatwegot, nativeval)) + self.assertEqual(type(whatwegot), type(nativeval), + msg='Unmarshaller for %s (%s) gave wrong type (%s instead of %s)' + % (valtype, unmarshaller, type(whatwegot), type(nativeval))) + + def test_marshalling(self): + for serializedval, valtype, nativeval in marshalled_value_pairs: + marshaller = lookup_casstype(valtype) + whatwegot = marshaller.to_binary(nativeval, 1) + self.assertEqual(whatwegot, serializedval, + msg='Marshaller for %s (%s) failed: marshal(%r) got %r instead of %r' + % (valtype, marshaller, nativeval, whatwegot, serializedval)) + self.assertEqual(type(whatwegot), type(serializedval), + msg='Marshaller for %s (%s) gave wrong type (%s instead of %s)' + % (valtype, marshaller, type(whatwegot), type(serializedval))) + + def test_date(self): + # separate test because it will deserialize as datetime + self.assertEqual(DateType.from_binary(DateType.to_binary(date(2015, 11, 2), 1), 1), datetime(2015, 11, 2)) + + def test_decimal(self): + # testing implicit numeric conversion + # int, tuple(sign, digits, exp), float + converted_types = (10001, (0, (1, 0, 0, 0, 0, 1), -3), 100.1, -87.629798) + + for proto_ver in range(1, ProtocolVersion.MAX_SUPPORTED + 1): + for n in converted_types: + expected = Decimal(n) + self.assertEqual(DecimalType.from_binary(DecimalType.to_binary(n, proto_ver), proto_ver), expected) diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py new file mode 100644 index 0000000..49b2627 --- /dev/null +++ b/tests/unit/test_metadata.py @@ -0,0 +1,622 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from binascii import unhexlify +import logging +from mock import Mock +import os +import six +import timeit + +import cassandra +from cassandra.cqltypes import strip_frozen +from cassandra.marshal import uint16_unpack, uint16_pack +from cassandra.metadata import (Murmur3Token, MD5Token, + BytesToken, ReplicationStrategy, + NetworkTopologyStrategy, SimpleStrategy, + LocalStrategy, protect_name, + protect_names, protect_value, is_valid_name, + UserType, KeyspaceMetadata, get_schema_parser, + _UnknownStrategy, ColumnMetadata, TableMetadata, + IndexMetadata, Function, Aggregate, + Metadata, TokenMap) +from cassandra.policies import SimpleConvictionPolicy +from cassandra.pool import Host + + +log = logging.getLogger(__name__) + + +class StrategiesTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + "Hook method for setting up class fixture before running tests in the class." + if not hasattr(cls, 'assertItemsEqual'): + cls.assertItemsEqual = cls.assertCountEqual + + def test_replication_strategy(self): + """ + Basic code coverage testing that ensures different ReplicationStrategies + can be initiated using parameters correctly. + """ + + rs = ReplicationStrategy() + + self.assertEqual(rs.create('OldNetworkTopologyStrategy', None), _UnknownStrategy('OldNetworkTopologyStrategy', None)) + fake_options_map = {'options': 'map'} + uks = rs.create('OldNetworkTopologyStrategy', fake_options_map) + self.assertEqual(uks, _UnknownStrategy('OldNetworkTopologyStrategy', fake_options_map)) + self.assertEqual(uks.make_token_replica_map({}, []), {}) + + fake_options_map = {'dc1': '3'} + self.assertIsInstance(rs.create('NetworkTopologyStrategy', fake_options_map), NetworkTopologyStrategy) + self.assertEqual(rs.create('NetworkTopologyStrategy', fake_options_map).dc_replication_factors, + NetworkTopologyStrategy(fake_options_map).dc_replication_factors) + + fake_options_map = {'options': 'map'} + self.assertIsNone(rs.create('SimpleStrategy', fake_options_map)) + + fake_options_map = {'options': 'map'} + self.assertIsInstance(rs.create('LocalStrategy', fake_options_map), LocalStrategy) + + fake_options_map = {'options': 'map', 'replication_factor': 3} + self.assertIsInstance(rs.create('SimpleStrategy', fake_options_map), SimpleStrategy) + self.assertEqual(rs.create('SimpleStrategy', fake_options_map).replication_factor, + SimpleStrategy(fake_options_map).replication_factor) + + self.assertEqual(rs.create('xxxxxxxx', fake_options_map), _UnknownStrategy('xxxxxxxx', fake_options_map)) + + self.assertRaises(NotImplementedError, rs.make_token_replica_map, None, None) + self.assertRaises(NotImplementedError, rs.export_for_schema) + + def test_nts_make_token_replica_map(self): + token_to_host_owner = {} + + dc1_1 = Host('dc1.1', SimpleConvictionPolicy) + dc1_2 = Host('dc1.2', SimpleConvictionPolicy) + dc1_3 = Host('dc1.3', SimpleConvictionPolicy) + for host in (dc1_1, dc1_2, dc1_3): + host.set_location_info('dc1', 'rack1') + token_to_host_owner[MD5Token(0)] = dc1_1 + token_to_host_owner[MD5Token(100)] = dc1_2 + token_to_host_owner[MD5Token(200)] = dc1_3 + + dc2_1 = Host('dc2.1', SimpleConvictionPolicy) + dc2_2 = Host('dc2.2', SimpleConvictionPolicy) + dc2_1.set_location_info('dc2', 'rack1') + dc2_2.set_location_info('dc2', 'rack1') + token_to_host_owner[MD5Token(1)] = dc2_1 + token_to_host_owner[MD5Token(101)] = dc2_2 + + dc3_1 = Host('dc3.1', SimpleConvictionPolicy) + dc3_1.set_location_info('dc3', 'rack3') + token_to_host_owner[MD5Token(2)] = dc3_1 + + ring = [MD5Token(0), + MD5Token(1), + MD5Token(2), + MD5Token(100), + MD5Token(101), + MD5Token(200)] + + nts = NetworkTopologyStrategy({'dc1': 2, 'dc2': 2, 'dc3': 1}) + replica_map = nts.make_token_replica_map(token_to_host_owner, ring) + + self.assertItemsEqual(replica_map[MD5Token(0)], (dc1_1, dc1_2, dc2_1, dc2_2, dc3_1)) + + def test_nts_token_performance(self): + """ + Tests to ensure that when rf exceeds the number of nodes available, that we dont' + needlessly iterate trying to construct tokens for nodes that don't exist. + + @since 3.7 + @jira_ticket PYTHON-379 + @expected_result timing with 1500 rf should be same/similar to 3rf if we have 3 nodes + + @test_category metadata + """ + + token_to_host_owner = {} + ring = [] + dc1hostnum = 3 + current_token = 0 + vnodes_per_host = 500 + for i in range(dc1hostnum): + + host = Host('dc1.{0}'.format(i), SimpleConvictionPolicy) + host.set_location_info('dc1', "rack1") + for vnode_num in range(vnodes_per_host): + md5_token = MD5Token(current_token+vnode_num) + token_to_host_owner[md5_token] = host + ring.append(md5_token) + current_token += 1000 + + nts = NetworkTopologyStrategy({'dc1': 3}) + start_time = timeit.default_timer() + nts.make_token_replica_map(token_to_host_owner, ring) + elapsed_base = timeit.default_timer() - start_time + + nts = NetworkTopologyStrategy({'dc1': 1500}) + start_time = timeit.default_timer() + nts.make_token_replica_map(token_to_host_owner, ring) + elapsed_bad = timeit.default_timer() - start_time + difference = elapsed_bad - elapsed_base + self.assertTrue(difference < 1 and difference > -1) + + def test_nts_make_token_replica_map_multi_rack(self): + token_to_host_owner = {} + + # (A) not enough distinct racks, first skipped is used + dc1_1 = Host('dc1.1', SimpleConvictionPolicy) + dc1_2 = Host('dc1.2', SimpleConvictionPolicy) + dc1_3 = Host('dc1.3', SimpleConvictionPolicy) + dc1_4 = Host('dc1.4', SimpleConvictionPolicy) + dc1_1.set_location_info('dc1', 'rack1') + dc1_2.set_location_info('dc1', 'rack1') + dc1_3.set_location_info('dc1', 'rack2') + dc1_4.set_location_info('dc1', 'rack2') + token_to_host_owner[MD5Token(0)] = dc1_1 + token_to_host_owner[MD5Token(100)] = dc1_2 + token_to_host_owner[MD5Token(200)] = dc1_3 + token_to_host_owner[MD5Token(300)] = dc1_4 + + # (B) distinct racks, but not contiguous + dc2_1 = Host('dc2.1', SimpleConvictionPolicy) + dc2_2 = Host('dc2.2', SimpleConvictionPolicy) + dc2_3 = Host('dc2.3', SimpleConvictionPolicy) + dc2_1.set_location_info('dc2', 'rack1') + dc2_2.set_location_info('dc2', 'rack1') + dc2_3.set_location_info('dc2', 'rack2') + token_to_host_owner[MD5Token(1)] = dc2_1 + token_to_host_owner[MD5Token(101)] = dc2_2 + token_to_host_owner[MD5Token(201)] = dc2_3 + + ring = [MD5Token(0), + MD5Token(1), + MD5Token(100), + MD5Token(101), + MD5Token(200), + MD5Token(201), + MD5Token(300)] + + nts = NetworkTopologyStrategy({'dc1': 3, 'dc2': 2}) + replica_map = nts.make_token_replica_map(token_to_host_owner, ring) + + token_replicas = replica_map[MD5Token(0)] + self.assertItemsEqual(token_replicas, (dc1_1, dc1_2, dc1_3, dc2_1, dc2_3)) + + def test_nts_make_token_replica_map_empty_dc(self): + host = Host('1', SimpleConvictionPolicy) + host.set_location_info('dc1', 'rack1') + token_to_host_owner = {MD5Token(0): host} + ring = [MD5Token(0)] + nts = NetworkTopologyStrategy({'dc1': 1, 'dc2': 0}) + + replica_map = nts.make_token_replica_map(token_to_host_owner, ring) + self.assertEqual(set(replica_map[MD5Token(0)]), set([host])) + + def test_nts_export_for_schema(self): + strategy = NetworkTopologyStrategy({'dc1': '1', 'dc2': '2'}) + self.assertEqual("{'class': 'NetworkTopologyStrategy', 'dc1': '1', 'dc2': '2'}", + strategy.export_for_schema()) + + def test_simple_strategy_make_token_replica_map(self): + host1 = Host('1', SimpleConvictionPolicy) + host2 = Host('2', SimpleConvictionPolicy) + host3 = Host('3', SimpleConvictionPolicy) + token_to_host_owner = { + MD5Token(0): host1, + MD5Token(100): host2, + MD5Token(200): host3 + } + ring = [MD5Token(0), MD5Token(100), MD5Token(200)] + + rf1_replicas = SimpleStrategy({'replication_factor': '1'}).make_token_replica_map(token_to_host_owner, ring) + self.assertItemsEqual(rf1_replicas[MD5Token(0)], [host1]) + self.assertItemsEqual(rf1_replicas[MD5Token(100)], [host2]) + self.assertItemsEqual(rf1_replicas[MD5Token(200)], [host3]) + + rf2_replicas = SimpleStrategy({'replication_factor': '2'}).make_token_replica_map(token_to_host_owner, ring) + self.assertItemsEqual(rf2_replicas[MD5Token(0)], [host1, host2]) + self.assertItemsEqual(rf2_replicas[MD5Token(100)], [host2, host3]) + self.assertItemsEqual(rf2_replicas[MD5Token(200)], [host3, host1]) + + rf3_replicas = SimpleStrategy({'replication_factor': '3'}).make_token_replica_map(token_to_host_owner, ring) + self.assertItemsEqual(rf3_replicas[MD5Token(0)], [host1, host2, host3]) + self.assertItemsEqual(rf3_replicas[MD5Token(100)], [host2, host3, host1]) + self.assertItemsEqual(rf3_replicas[MD5Token(200)], [host3, host1, host2]) + + def test_ss_equals(self): + self.assertNotEqual(SimpleStrategy({'replication_factor': '1'}), NetworkTopologyStrategy({'dc1': 2})) + + +class NameEscapingTest(unittest.TestCase): + + def test_protect_name(self): + """ + Test cassandra.metadata.protect_name output + """ + self.assertEqual(protect_name('tests'), 'tests') + self.assertEqual(protect_name('test\'s'), '"test\'s"') + self.assertEqual(protect_name('test\'s'), "\"test's\"") + self.assertEqual(protect_name('tests ?!@#$%^&*()'), '"tests ?!@#$%^&*()"') + self.assertEqual(protect_name('1'), '"1"') + self.assertEqual(protect_name('1test'), '"1test"') + + def test_protect_names(self): + """ + Test cassandra.metadata.protect_names output + """ + self.assertEqual(protect_names(['tests']), ['tests']) + self.assertEqual(protect_names( + [ + 'tests', + 'test\'s', + 'tests ?!@#$%^&*()', + '1' + ]), + [ + 'tests', + "\"test's\"", + '"tests ?!@#$%^&*()"', + '"1"' + ]) + + def test_protect_value(self): + """ + Test cassandra.metadata.protect_value output + """ + self.assertEqual(protect_value(True), "true") + self.assertEqual(protect_value(False), "false") + self.assertEqual(protect_value(3.14), '3.14') + self.assertEqual(protect_value(3), '3') + self.assertEqual(protect_value('test'), "'test'") + self.assertEqual(protect_value('test\'s'), "'test''s'") + self.assertEqual(protect_value(None), 'NULL') + + def test_is_valid_name(self): + """ + Test cassandra.metadata.is_valid_name output + """ + self.assertEqual(is_valid_name(None), False) + self.assertEqual(is_valid_name('test'), True) + self.assertEqual(is_valid_name('Test'), False) + self.assertEqual(is_valid_name('t_____1'), True) + self.assertEqual(is_valid_name('test1'), True) + self.assertEqual(is_valid_name('1test1'), False) + + invalid_keywords = cassandra.metadata.cql_keywords - cassandra.metadata.cql_keywords_unreserved + for keyword in invalid_keywords: + self.assertEqual(is_valid_name(keyword), False) + + +class GetReplicasTest(unittest.TestCase): + def _get_replicas(self, token_klass): + tokens = [token_klass(i) for i in range(0, (2 ** 127 - 1), 2 ** 125)] + hosts = [Host("ip%d" % i, SimpleConvictionPolicy) for i in range(len(tokens))] + token_to_primary_replica = dict(zip(tokens, hosts)) + keyspace = KeyspaceMetadata("ks", True, "SimpleStrategy", {"replication_factor": "1"}) + metadata = Mock(spec=Metadata, keyspaces={'ks': keyspace}) + token_map = TokenMap(token_klass, token_to_primary_replica, tokens, metadata) + + # tokens match node tokens exactly + for token, expected_host in zip(tokens, hosts): + replicas = token_map.get_replicas("ks", token) + self.assertEqual(set(replicas), {expected_host}) + + # shift the tokens back by one + for token, expected_host in zip(tokens, hosts): + replicas = token_map.get_replicas("ks", token_klass(token.value - 1)) + self.assertEqual(set(replicas), {expected_host}) + + # shift the tokens forward by one + for i, token in enumerate(tokens): + replicas = token_map.get_replicas("ks", token_klass(token.value + 1)) + expected_host = hosts[(i + 1) % len(hosts)] + self.assertEqual(set(replicas), {expected_host}) + + def test_murmur3_tokens(self): + self._get_replicas(Murmur3Token) + + def test_md5_tokens(self): + self._get_replicas(MD5Token) + + def test_bytes_tokens(self): + self._get_replicas(BytesToken) + + +class Murmur3TokensTest(unittest.TestCase): + + def test_murmur3_init(self): + murmur3_token = Murmur3Token(cassandra.metadata.MIN_LONG - 1) + self.assertEqual(str(murmur3_token), '') + + def test_python_vs_c(self): + from cassandra.murmur3 import _murmur3 as mm3_python + try: + from cassandra.cmurmur3 import murmur3 as mm3_c + + iterations = 100 + for _ in range(iterations): + for len in range(0, 32): # zero to one block plus full range of tail lengths + key = os.urandom(len) + self.assertEqual(mm3_python(key), mm3_c(key)) + + except ImportError: + raise unittest.SkipTest('The cmurmur3 extension is not available') + + def test_murmur3_python(self): + from cassandra.murmur3 import _murmur3 + self._verify_hash(_murmur3) + + def test_murmur3_c(self): + try: + from cassandra.cmurmur3 import murmur3 + self._verify_hash(murmur3) + except ImportError: + raise unittest.SkipTest('The cmurmur3 extension is not available') + + def _verify_hash(self, fn): + self.assertEqual(fn(six.b('123')), -7468325962851647638) + self.assertEqual(fn(b'\x00\xff\x10\xfa\x99' * 10), 5837342703291459765) + self.assertEqual(fn(b'\xfe' * 8), -8927430733708461935) + self.assertEqual(fn(b'\x10' * 8), 1446172840243228796) + self.assertEqual(fn(six.b(str(cassandra.metadata.MAX_LONG))), 7162290910810015547) + + +class MD5TokensTest(unittest.TestCase): + + def test_md5_tokens(self): + md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1) + self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808) + self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639) + self.assertEqual(str(md5_token), '' % -9223372036854775809) + + +class BytesTokensTest(unittest.TestCase): + + def test_bytes_tokens(self): + bytes_token = BytesToken(unhexlify(six.b('01'))) + self.assertEqual(bytes_token.value, six.b('\x01')) + self.assertEqual(str(bytes_token), "" % bytes_token.value) + self.assertEqual(bytes_token.hash_fn('123'), '123') + self.assertEqual(bytes_token.hash_fn(123), 123) + self.assertEqual(bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)), str(cassandra.metadata.MAX_LONG)) + + def test_from_string(self): + from_unicode = BytesToken.from_string(six.text_type('0123456789abcdef')) + from_bin = BytesToken.from_string(six.b('0123456789abcdef')) + self.assertEqual(from_unicode, from_bin) + self.assertIsInstance(from_unicode.value, six.binary_type) + self.assertIsInstance(from_bin.value, six.binary_type) + + def test_comparison(self): + tok = BytesToken.from_string(six.text_type('0123456789abcdef')) + token_high_order = uint16_unpack(tok.value[0:2]) + self.assertLess(BytesToken(uint16_pack(token_high_order - 1)), tok) + self.assertGreater(BytesToken(uint16_pack(token_high_order + 1)), tok) + + def test_comparison_unicode(self): + value = six.b('\'_-()"\xc2\xac') + t0 = BytesToken(value) + t1 = BytesToken.from_string('00') + self.assertGreater(t0, t1) + self.assertFalse(t0 < t1) + + +class KeyspaceMetadataTest(unittest.TestCase): + + def test_export_as_string_user_types(self): + keyspace_name = 'test' + keyspace = KeyspaceMetadata(keyspace_name, True, 'SimpleStrategy', dict(replication_factor=3)) + keyspace.user_types['a'] = UserType(keyspace_name, 'a', ['one', 'two'], ['c', 'int']) + keyspace.user_types['b'] = UserType(keyspace_name, 'b', ['one', 'two', 'three'], ['d', 'int', 'a']) + keyspace.user_types['c'] = UserType(keyspace_name, 'c', ['one'], ['int']) + keyspace.user_types['d'] = UserType(keyspace_name, 'd', ['one'], ['c']) + + self.assertEqual("""CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} AND durable_writes = true; + +CREATE TYPE test.c ( + one int +); + +CREATE TYPE test.a ( + one c, + two int +); + +CREATE TYPE test.d ( + one c +); + +CREATE TYPE test.b ( + one d, + two int, + three a +);""", keyspace.export_as_string()) + + +class UserTypesTest(unittest.TestCase): + + def test_as_cql_query(self): + field_types = ['varint', 'ascii', 'frozen>'] + udt = UserType("ks1", "mytype", ["a", "b", "c"], field_types) + self.assertEqual("CREATE TYPE ks1.mytype (a varint, b ascii, c frozen>)", udt.as_cql_query(formatted=False)) + + self.assertEqual("""CREATE TYPE ks1.mytype ( + a varint, + b ascii, + c frozen> +);""", udt.export_as_string()) + + def test_as_cql_query_name_escaping(self): + udt = UserType("MyKeyspace", "MyType", ["AbA", "keyspace"], ['ascii', 'ascii']) + self.assertEqual('CREATE TYPE "MyKeyspace"."MyType" ("AbA" ascii, "keyspace" ascii)', udt.as_cql_query(formatted=False)) + + +class UserDefinedFunctionTest(unittest.TestCase): + def test_as_cql_query_removes_frozen(self): + func = Function("ks1", "myfunction", ["frozen>"], ["a"], "int", "java", "return 0;", True) + expected_result = ( + "CREATE FUNCTION ks1.myfunction(a tuple) " + "CALLED ON NULL INPUT " + "RETURNS int " + "LANGUAGE java " + "AS $$return 0;$$" + ) + self.assertEqual(expected_result, func.as_cql_query(formatted=False)) + + +class UserDefinedAggregateTest(unittest.TestCase): + def test_as_cql_query_removes_frozen(self): + aggregate = Aggregate("ks1", "myaggregate", ["frozen>"], "statefunc", "frozen>", "finalfunc", "(0)", "tuple") + expected_result = ( + "CREATE AGGREGATE ks1.myaggregate(tuple) " + "SFUNC statefunc " + "STYPE tuple " + "FINALFUNC finalfunc " + "INITCOND (0)" + ) + self.assertEqual(expected_result, aggregate.as_cql_query(formatted=False)) + + +class IndexTest(unittest.TestCase): + + def test_build_index_as_cql(self): + column_meta = Mock() + column_meta.name = 'column_name_here' + column_meta.table.name = 'table_name_here' + column_meta.table.keyspace_name = 'keyspace_name_here' + column_meta.table.columns = {column_meta.name: column_meta} + parser = get_schema_parser(Mock(), '2.1.0', 0.1) + + row = {'index_name': 'index_name_here', 'index_type': 'index_type_here'} + index_meta = parser._build_index_metadata(column_meta, row) + self.assertEqual(index_meta.as_cql_query(), + 'CREATE INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here)') + + row['index_options'] = '{ "class_name": "class_name_here" }' + row['index_type'] = 'CUSTOM' + index_meta = parser._build_index_metadata(column_meta, row) + self.assertEqual(index_meta.as_cql_query(), + "CREATE CUSTOM INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here) USING 'class_name_here'") + + +class UnicodeIdentifiersTests(unittest.TestCase): + """ + Exercise cql generation with unicode characters. Keyspace, Table, and Index names + cannot have special chars because C* names files by those identifiers, but they are + tested anyway. + + Looking for encoding errors like PYTHON-447 + """ + + name = six.text_type(b'\'_-()"\xc2\xac'.decode('utf-8')) + + def test_keyspace_name(self): + km = KeyspaceMetadata(self.name, False, 'SimpleStrategy', {'replication_factor': 1}) + km.export_as_string() + + def test_table_name(self): + tm = TableMetadata(self.name, self.name) + tm.export_as_string() + + def test_column_name_single_partition(self): + tm = TableMetadata('ks', 'table') + cm = ColumnMetadata(tm, self.name, u'int') + tm.columns[cm.name] = cm + tm.partition_key.append(cm) + tm.export_as_string() + + def test_column_name_single_partition_single_clustering(self): + tm = TableMetadata('ks', 'table') + cm = ColumnMetadata(tm, self.name, u'int') + tm.columns[cm.name] = cm + tm.partition_key.append(cm) + cm = ColumnMetadata(tm, self.name + 'x', u'int') + tm.columns[cm.name] = cm + tm.clustering_key.append(cm) + tm.export_as_string() + + def test_column_name_multiple_partition(self): + tm = TableMetadata('ks', 'table') + cm = ColumnMetadata(tm, self.name, u'int') + tm.columns[cm.name] = cm + tm.partition_key.append(cm) + cm = ColumnMetadata(tm, self.name + 'x', u'int') + tm.columns[cm.name] = cm + tm.partition_key.append(cm) + tm.export_as_string() + + def test_index(self): + im = IndexMetadata(self.name, self.name, self.name, kind='', index_options={'target': self.name}) + log.debug(im.export_as_string()) + im = IndexMetadata(self.name, self.name, self.name, kind='CUSTOM', index_options={'target': self.name, 'class_name': 'Class'}) + log.debug(im.export_as_string()) + # PYTHON-1008 + im = IndexMetadata(self.name, self.name, self.name, kind='CUSTOM', index_options={'target': self.name, 'class_name': 'Class', 'delimiter': self.name}) + log.debug(im.export_as_string()) + + def test_function(self): + fm = Function(self.name, self.name, (u'int', u'int'), (u'x', u'y'), u'int', u'language', self.name, False) + fm.export_as_string() + + def test_aggregate(self): + am = Aggregate(self.name, self.name, (u'text',), self.name, u'text', self.name, self.name, u'text') + am.export_as_string() + + def test_user_type(self): + um = UserType(self.name, self.name, [self.name, self.name], [u'int', u'text']) + um.export_as_string() + + +class HostsTests(unittest.TestCase): + def test_iterate_all_hosts_and_modify(self): + """ + PYTHON-572 + """ + metadata = Metadata() + metadata.add_or_return_host(Host('dc1.1', SimpleConvictionPolicy)) + metadata.add_or_return_host(Host('dc1.2', SimpleConvictionPolicy)) + + self.assertEqual(len(metadata.all_hosts()), 2) + + for host in metadata.all_hosts(): # this would previously raise in Py3 + metadata.remove_host(host) + + self.assertEqual(len(metadata.all_hosts()), 0) + + +class MetadataHelpersTest(unittest.TestCase): + """ For any helper functions that need unit tests """ + def test_strip_frozen(self): + self.longMessage = True + + argument_to_expected_results = [ + ('int', 'int'), + ('tuple', 'tuple'), + (r'map<"!@#$%^&*()[]\ frozen >>>", int>', r'map<"!@#$%^&*()[]\ frozen >>>", int>'), # A valid UDT name + ('frozen>', 'tuple'), + (r'frozen>>", int>>', r'map<"!@#$%^&*()[]\ frozen >>>", int>'), + ('frozen>, int>>, frozen>>>>>', + 'map, int>, map>>'), + ] + for argument, expected_result in argument_to_expected_results: + result = strip_frozen(argument) + self.assertEqual(result, expected_result, "strip_frozen() arg: {}".format(argument)) diff --git a/tests/unit/test_orderedmap.py b/tests/unit/test_orderedmap.py new file mode 100644 index 0000000..f2baab4 --- /dev/null +++ b/tests/unit/test_orderedmap.py @@ -0,0 +1,186 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.util import OrderedMap, OrderedMapSerializedKey +from cassandra.cqltypes import EMPTY, UTF8Type, lookup_casstype +import six + +class OrderedMapTest(unittest.TestCase): + def test_init(self): + a = OrderedMap(zip(['one', 'three', 'two'], [1, 3, 2])) + b = OrderedMap([('one', 1), ('three', 3), ('two', 2)]) + c = OrderedMap(a) + builtin = {'one': 1, 'two': 2, 'three': 3} + self.assertEqual(a, b) + self.assertEqual(a, c) + self.assertEqual(a, builtin) + self.assertEqual(OrderedMap([(1, 1), (1, 2)]), {1: 2}) + + d = OrderedMap({'': 3}, key1='v1', key2='v2') + self.assertEqual(d[''], 3) + self.assertEqual(d['key1'], 'v1') + self.assertEqual(d['key2'], 'v2') + + with self.assertRaises(TypeError): + OrderedMap('too', 'many', 'args') + + def test_contains(self): + keys = ['first', 'middle', 'last'] + + om = OrderedMap() + + om = OrderedMap(zip(keys, range(len(keys)))) + + for k in keys: + self.assertTrue(k in om) + self.assertFalse(k not in om) + + self.assertTrue('notthere' not in om) + self.assertFalse('notthere' in om) + + def test_keys(self): + keys = ['first', 'middle', 'last'] + om = OrderedMap(zip(keys, range(len(keys)))) + + self.assertListEqual(list(om.keys()), keys) + + def test_values(self): + keys = ['first', 'middle', 'last'] + values = list(range(len(keys))) + om = OrderedMap(zip(keys, values)) + + self.assertListEqual(list(om.values()), values) + + def test_items(self): + keys = ['first', 'middle', 'last'] + items = list(zip(keys, range(len(keys)))) + om = OrderedMap(items) + + self.assertListEqual(list(om.items()), items) + + def test_get(self): + keys = ['first', 'middle', 'last'] + om = OrderedMap(zip(keys, range(len(keys)))) + + for v, k in enumerate(keys): + self.assertEqual(om.get(k), v) + + self.assertEqual(om.get('notthere', 'default'), 'default') + self.assertIsNone(om.get('notthere')) + + def test_equal(self): + d1 = {'one': 1} + d12 = {'one': 1, 'two': 2} + om1 = OrderedMap({'one': 1}) + om12 = OrderedMap([('one', 1), ('two', 2)]) + om21 = OrderedMap([('two', 2), ('one', 1)]) + + self.assertEqual(om1, d1) + self.assertEqual(om12, d12) + self.assertEqual(om21, d12) + self.assertNotEqual(om1, om12) + self.assertNotEqual(om12, om1) + self.assertNotEqual(om12, om21) + self.assertNotEqual(om1, d12) + self.assertNotEqual(om12, d1) + self.assertNotEqual(om1, EMPTY) + + self.assertFalse(OrderedMap([('three', 3), ('four', 4)]) == d12) + + def test_getitem(self): + keys = ['first', 'middle', 'last'] + om = OrderedMap(zip(keys, range(len(keys)))) + + for v, k in enumerate(keys): + self.assertEqual(om[k], v) + + with self.assertRaises(KeyError): + om['notthere'] + + def test_iter(self): + keys = ['first', 'middle', 'last'] + values = list(range(len(keys))) + items = list(zip(keys, values)) + om = OrderedMap(items) + + itr = iter(om) + self.assertEqual(sum([1 for _ in itr]), len(keys)) + self.assertRaises(StopIteration, six.next, itr) + + self.assertEqual(list(iter(om)), keys) + self.assertEqual(list(six.iteritems(om)), items) + self.assertEqual(list(six.itervalues(om)), values) + + def test_len(self): + self.assertEqual(len(OrderedMap()), 0) + self.assertEqual(len(OrderedMap([(1, 1)])), 1) + + def test_mutable_keys(self): + d = {'1': 1} + s = set([1, 2, 3]) + om = OrderedMap([(d, 'dict'), (s, 'set')]) + + def test_strings(self): + # changes in 3.x + d = {'map': 'inner'} + s = set([1, 2, 3]) + self.assertEqual(repr(OrderedMap([('two', 2), ('one', 1), (d, 'value'), (s, 'another')])), + "OrderedMap([('two', 2), ('one', 1), (%r, 'value'), (%r, 'another')])" % (d, s)) + + self.assertEqual(str(OrderedMap([('two', 2), ('one', 1), (d, 'value'), (s, 'another')])), + "{'two': 2, 'one': 1, %r: 'value', %r: 'another'}" % (d, s)) + + def test_popitem(self): + item = (1, 2) + om = OrderedMap((item,)) + self.assertEqual(om.popitem(), item) + self.assertRaises(KeyError, om.popitem) + + def test_delitem(self): + om = OrderedMap({1: 1, 2: 2}) + + self.assertRaises(KeyError, om.__delitem__, 3) + + del om[1] + self.assertEqual(om, {2: 2}) + del om[2] + self.assertFalse(om) + + self.assertRaises(KeyError, om.__delitem__, 1) + + +class OrderedMapSerializedKeyTest(unittest.TestCase): + def test_init(self): + om = OrderedMapSerializedKey(UTF8Type, 2) + self.assertEqual(om, {}) + + def test_normalized_lookup(self): + key_type = lookup_casstype('MapType(UTF8Type, Int32Type)') + protocol_version = 3 + om = OrderedMapSerializedKey(key_type, protocol_version) + key_ascii = {'one': 1} + key_unicode = {u'two': 2} + om._insert_unchecked(key_ascii, key_type.serialize(key_ascii, protocol_version), object()) + om._insert_unchecked(key_unicode, key_type.serialize(key_unicode, protocol_version), object()) + + # type lookup is normalized by key_type + # PYTHON-231 + self.assertIs(om[{'one': 1}], om[{u'one': 1}]) + self.assertIs(om[{'two': 2}], om[{u'two': 2}]) + self.assertIsNot(om[{'one': 1}], om[{'two': 2}]) diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py new file mode 100644 index 0000000..9c91679 --- /dev/null +++ b/tests/unit/test_parameter_binding.py @@ -0,0 +1,223 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.encoder import Encoder +from cassandra.protocol import ColumnMetadata +from cassandra.query import (bind_params, ValueSequence, PreparedStatement, + BoundStatement, UNSET_VALUE) +from cassandra.cqltypes import Int32Type +from cassandra.util import OrderedDict + +from six.moves import xrange +import six + + +class ParamBindingTest(unittest.TestCase): + + def test_bind_sequence(self): + result = bind_params("%s %s %s", (1, "a", 2.0), Encoder()) + self.assertEqual(result, "1 'a' 2.0") + + def test_bind_map(self): + result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), Encoder()) + self.assertEqual(result, "1 'a' 2.0") + + def test_sequence_param(self): + result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), Encoder()) + self.assertEqual(result, "(1, 'a', 2.0)") + + def test_generator_param(self): + result = bind_params("%s", ((i for i in xrange(3)),), Encoder()) + self.assertEqual(result, "[0, 1, 2]") + + def test_none_param(self): + result = bind_params("%s", (None,), Encoder()) + self.assertEqual(result, "NULL") + + def test_list_collection(self): + result = bind_params("%s", (['a', 'b', 'c'],), Encoder()) + self.assertEqual(result, "['a', 'b', 'c']") + + def test_set_collection(self): + result = bind_params("%s", (set(['a', 'b']),), Encoder()) + self.assertIn(result, ("{'a', 'b'}", "{'b', 'a'}")) + + def test_map_collection(self): + vals = OrderedDict() + vals['a'] = 'a' + vals['b'] = 'b' + vals['c'] = 'c' + result = bind_params("%s", (vals,), Encoder()) + self.assertEqual(result, "{'a': 'a', 'b': 'b', 'c': 'c'}") + + def test_quote_escaping(self): + result = bind_params("%s", ("""'ef''ef"ef""ef'""",), Encoder()) + self.assertEqual(result, """'''ef''''ef"ef""ef'''""") + + def test_float_precision(self): + f = 3.4028234663852886e+38 + self.assertEqual(float(bind_params("%s", (f,), Encoder())), f) + + +class BoundStatementTestV1(unittest.TestCase): + + protocol_version = 1 + + @classmethod + def setUpClass(cls): + column_metadata = [ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type), + ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type), + ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type), + ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)] + cls.prepared = PreparedStatement(column_metadata=column_metadata, + query_id=None, + routing_key_indexes=[1, 0], + query=None, + keyspace='keyspace', + protocol_version=cls.protocol_version, result_metadata=None, + result_metadata_id=None) + cls.bound = BoundStatement(prepared_statement=cls.prepared) + + def test_invalid_argument_type(self): + values = (0, 0, 0, 'string not int') + try: + self.bound.bind(values) + except TypeError as e: + self.assertIn('v0', str(e)) + self.assertIn('Int32Type', str(e)) + self.assertIn('str', str(e)) + else: + self.fail('Passed invalid type but exception was not thrown') + + values = (['1', '2'], 0, 0, 0) + + try: + self.bound.bind(values) + except TypeError as e: + self.assertIn('rk0', str(e)) + self.assertIn('Int32Type', str(e)) + self.assertIn('list', str(e)) + else: + self.fail('Passed invalid type but exception was not thrown') + + def test_inherit_fetch_size(self): + keyspace = 'keyspace1' + column_family = 'cf1' + + column_metadata = [ + ColumnMetadata(keyspace, column_family, 'foo1', Int32Type), + ColumnMetadata(keyspace, column_family, 'foo2', Int32Type) + ] + + prepared_statement = PreparedStatement(column_metadata=column_metadata, + query_id=None, + routing_key_indexes=[], + query=None, + keyspace=keyspace, + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None) + prepared_statement.fetch_size = 1234 + bound_statement = BoundStatement(prepared_statement=prepared_statement) + self.assertEqual(1234, bound_statement.fetch_size) + + def test_too_few_parameters_for_routing_key(self): + self.assertRaises(ValueError, self.prepared.bind, (1,)) + + bound = self.prepared.bind((1, 2)) + self.assertEqual(bound.keyspace, 'keyspace') + + def test_dict_missing_routing_key(self): + self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0}) + self.assertRaises(KeyError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0}) + + def test_missing_value(self): + self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0}) + + def test_extra_value(self): + self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123}) # okay to have extra keys in dict + self.assertEqual(self.bound.values, [six.b('\x00') * 4] * 4) # four encoded zeros + self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, 0, 123)) + + def test_values_none(self): + # should have values + self.assertRaises(ValueError, self.bound.bind, None) + + # prepared statement with no values + prepared_statement = PreparedStatement(column_metadata=[], + query_id=None, + routing_key_indexes=[], + query=None, + keyspace='whatever', + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None) + bound = prepared_statement.bind(None) + self.assertListEqual(bound.values, []) + + def test_bind_none(self): + self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None}) + self.assertEqual(self.bound.values[-1], None) + + old_values = self.bound.values + self.bound.bind((0, 0, 0, None)) + self.assertIsNot(self.bound.values, old_values) + self.assertEqual(self.bound.values[-1], None) + + def test_unset_value(self): + self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) + self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, UNSET_VALUE)) + + +class BoundStatementTestV2(BoundStatementTestV1): + protocol_version = 2 + + +class BoundStatementTestV3(BoundStatementTestV1): + protocol_version = 3 + + +class BoundStatementTestV4(BoundStatementTestV1): + protocol_version = 4 + + def test_dict_missing_routing_key(self): + # in v4 it implicitly binds UNSET_VALUE for missing items, + # UNSET_VALUE is ValueError for routing keys + self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0}) + self.assertRaises(ValueError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0}) + + def test_missing_value(self): + # in v4 missing values are UNSET_VALUE + self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0}) + self.assertEqual(self.bound.values[-1], UNSET_VALUE) + + old_values = self.bound.values + self.bound.bind((0, 0, 0)) + self.assertIsNot(self.bound.values, old_values) + self.assertEqual(self.bound.values[-1], UNSET_VALUE) + + def test_unset_value(self): + self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) + self.assertEqual(self.bound.values[-1], UNSET_VALUE) + + self.bound.bind((0, 0, 0, UNSET_VALUE)) + self.assertEqual(self.bound.values[-1], UNSET_VALUE) + +class BoundStatementTestV5(BoundStatementTestV4): + protocol_version = 5 diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py new file mode 100644 index 0000000..15fa316 --- /dev/null +++ b/tests/unit/test_policies.py @@ -0,0 +1,1500 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from itertools import islice, cycle +from mock import Mock, patch, call +from random import randint +import six +from six.moves._thread import LockType +import sys +import struct +from threading import Thread + +from cassandra import ConsistencyLevel +from cassandra.cluster import Cluster +from cassandra.metadata import Metadata +from cassandra.policies import (RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, + TokenAwarePolicy, SimpleConvictionPolicy, + HostDistance, ExponentialReconnectionPolicy, + RetryPolicy, WriteType, + DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy, + LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy, + IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy) +from cassandra.pool import Host +from cassandra.connection import DefaultEndPoint +from cassandra.query import Statement + +from six.moves import xrange + + +class LoadBalancingPolicyTest(unittest.TestCase): + def test_non_implemented(self): + """ + Code coverage for interface-style base class + """ + + policy = LoadBalancingPolicy() + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host.set_location_info("dc1", "rack1") + + self.assertRaises(NotImplementedError, policy.distance, host) + self.assertRaises(NotImplementedError, policy.populate, None, host) + self.assertRaises(NotImplementedError, policy.make_query_plan) + self.assertRaises(NotImplementedError, policy.on_up, host) + self.assertRaises(NotImplementedError, policy.on_down, host) + self.assertRaises(NotImplementedError, policy.on_add, host) + self.assertRaises(NotImplementedError, policy.on_remove, host) + + def test_instance_check(self): + self.assertRaises(TypeError, Cluster, load_balancing_policy=RoundRobinPolicy) + + +class RoundRobinPolicyTest(unittest.TestCase): + + def test_basic(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), hosts) + + def test_multiple_query_plans(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + for i in xrange(20): + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), hosts) + + def test_single_host(self): + policy = RoundRobinPolicy() + policy.populate(None, [0]) + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, [0]) + + def test_status_updates(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + policy.on_down(0) + policy.on_remove(1) + policy.on_up(4) + policy.on_add(5) + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), [2, 3, 4, 5]) + + def test_thread_safety(self): + hosts = range(100) + policy = RoundRobinPolicy() + policy.populate(None, hosts) + + def check_query_plan(): + for i in range(100): + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), list(hosts)) + + threads = [Thread(target=check_query_plan) for i in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + def test_thread_safety_during_modification(self): + hosts = range(100) + policy = RoundRobinPolicy() + policy.populate(None, hosts) + + errors = [] + + def check_query_plan(): + try: + for i in xrange(100): + list(policy.make_query_plan()) + except Exception as exc: + errors.append(exc) + + def host_up(): + for i in xrange(1000): + policy.on_up(randint(0, 99)) + + def host_down(): + for i in xrange(1000): + policy.on_down(randint(0, 99)) + + threads = [] + for i in range(5): + threads.append(Thread(target=check_query_plan)) + threads.append(Thread(target=host_up)) + threads.append(Thread(target=host_down)) + + # make the GIL switch after every instruction, maximizing + # the chance of race conditions + check = six.PY2 or '__pypy__' in sys.builtin_module_names + if check: + original_interval = sys.getcheckinterval() + else: + original_interval = sys.getswitchinterval() + + try: + if check: + sys.setcheckinterval(0) + else: + sys.setswitchinterval(0.0001) + map(lambda t: t.start(), threads) + map(lambda t: t.join(), threads) + finally: + if check: + sys.setcheckinterval(original_interval) + else: + sys.setswitchinterval(original_interval) + + if errors: + self.fail("Saw errors: %s" % (errors,)) + + def test_no_live_nodes(self): + """ + Ensure query plan for a downed cluster will execute without errors + """ + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + + for i in range(4): + policy.on_down(i) + + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, []) + + +class DCAwareRoundRobinPolicyTest(unittest.TestCase): + + def test_no_remote(self): + hosts = [] + for i in range(4): + h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) + h.set_location_info("dc1", "rack1") + hosts.append(h) + + policy = DCAwareRoundRobinPolicy("dc1") + policy.populate(None, hosts) + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), sorted(hosts)) + + def test_with_remotes(self): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc2", "rack1") + + local_hosts = set(h for h in hosts if h.datacenter == "dc1") + remote_hosts = set(h for h in hosts if h.datacenter != "dc1") + + # allow all of the remote hosts to be used + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=2) + policy.populate(Mock(), hosts) + qplan = list(policy.make_query_plan()) + self.assertEqual(set(qplan[:2]), local_hosts) + self.assertEqual(set(qplan[2:]), remote_hosts) + + # allow only one of the remote hosts to be used + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy.populate(Mock(), hosts) + qplan = list(policy.make_query_plan()) + self.assertEqual(set(qplan[:2]), local_hosts) + + used_remotes = set(qplan[2:]) + self.assertEqual(1, len(used_remotes)) + self.assertIn(qplan[2], remote_hosts) + + # allow no remote hosts to be used + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0) + policy.populate(Mock(), hosts) + qplan = list(policy.make_query_plan()) + self.assertEqual(2, len(qplan)) + self.assertEqual(local_hosts, set(qplan)) + + def test_get_distance(self): + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host.set_location_info("dc1", "rack1") + policy.populate(Mock(), [host]) + + self.assertEqual(policy.distance(host), HostDistance.LOCAL) + + # used_hosts_per_remote_dc is set to 0, so ignore it + remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) + remote_host.set_location_info("dc2", "rack1") + self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) + + # dc2 isn't registered in the policy's live_hosts dict + policy.used_hosts_per_remote_dc = 1 + self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) + + # make sure the policy has both dcs registered + policy.populate(Mock(), [host, remote_host]) + self.assertEqual(policy.distance(remote_host), HostDistance.REMOTE) + + # since used_hosts_per_remote_dc is set to 1, only the first + # remote host in dc2 will be REMOTE, the rest are IGNORED + second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy) + second_remote_host.set_location_info("dc2", "rack1") + policy.populate(Mock(), [host, remote_host, second_remote_host]) + distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) + self.assertEqual(distances, set([HostDistance.REMOTE, HostDistance.IGNORED])) + + def test_status_updates(self): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc2", "rack1") + + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy.populate(Mock(), hosts) + policy.on_down(hosts[0]) + policy.on_remove(hosts[2]) + + new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_local_host.set_location_info("dc1", "rack1") + policy.on_up(new_local_host) + + new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + new_remote_host.set_location_info("dc9000", "rack1") + policy.on_add(new_remote_host) + + # we now have two local hosts and two remote hosts in separate dcs + qplan = list(policy.make_query_plan()) + self.assertEqual(set(qplan[:2]), set([hosts[1], new_local_host])) + self.assertEqual(set(qplan[2:]), set([hosts[3], new_remote_host])) + + # since we have hosts in dc9000, the distance shouldn't be IGNORED + self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE) + + policy.on_down(new_local_host) + policy.on_down(hosts[1]) + qplan = list(policy.make_query_plan()) + self.assertEqual(set(qplan), set([hosts[3], new_remote_host])) + + policy.on_down(new_remote_host) + policy.on_down(hosts[3]) + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, []) + + def test_modification_during_generation(self): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc2", "rack1") + + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=3) + policy.populate(Mock(), hosts) + + # The general concept here is to change thee internal state of the + # policy during plan generation. In this case we use a grey-box + # approach that changes specific things during known phases of the + # generator. + + new_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_host.set_location_info("dc1", "rack1") + + # new local before iteration + plan = policy.make_query_plan() + policy.on_up(new_host) + # local list is not bound yet, so we get to see that one + self.assertEqual(len(list(plan)), 3 + 2) + + # remove local before iteration + plan = policy.make_query_plan() + policy.on_down(new_host) + # local list is not bound yet, so we don't see it + self.assertEqual(len(list(plan)), 2 + 2) + + # new local after starting iteration + plan = policy.make_query_plan() + next(plan) + policy.on_up(new_host) + # local list was is bound, and one consumed, so we only see the other original + self.assertEqual(len(list(plan)), 1 + 2) + + # remove local after traversing available + plan = policy.make_query_plan() + for _ in range(3): + next(plan) + policy.on_down(new_host) + # we should be past the local list + self.assertEqual(len(list(plan)), 0 + 2) + + # REMOTES CHANGE + new_host.set_location_info("dc2", "rack1") + + # new remote after traversing local, but not starting remote + plan = policy.make_query_plan() + for _ in range(2): + next(plan) + policy.on_up(new_host) + # list is updated before we get to it + self.assertEqual(len(list(plan)), 0 + 3) + + # remove remote after traversing local, but not starting remote + plan = policy.make_query_plan() + for _ in range(2): + next(plan) + policy.on_down(new_host) + # list is updated before we get to it + self.assertEqual(len(list(plan)), 0 + 2) + + # new remote after traversing local, and starting remote + plan = policy.make_query_plan() + for _ in range(3): + next(plan) + policy.on_up(new_host) + # slice is already made, and we've consumed one + self.assertEqual(len(list(plan)), 0 + 1) + + # remove remote after traversing local, and starting remote + plan = policy.make_query_plan() + for _ in range(3): + next(plan) + policy.on_down(new_host) + # slice is created with all present, and we've consumed one + self.assertEqual(len(list(plan)), 0 + 2) + + # local DC disappears after finishing it, but not starting remote + plan = policy.make_query_plan() + for _ in range(2): + next(plan) + policy.on_down(hosts[0]) + policy.on_down(hosts[1]) + # dict traversal starts as normal + self.assertEqual(len(list(plan)), 0 + 2) + policy.on_up(hosts[0]) + policy.on_up(hosts[1]) + + # PYTHON-297 addresses the following cases, where DCs come and go + # during generation + # local DC disappears after finishing it, and starting remote + plan = policy.make_query_plan() + for _ in range(3): + next(plan) + policy.on_down(hosts[0]) + policy.on_down(hosts[1]) + # dict traversal has begun and consumed one + self.assertEqual(len(list(plan)), 0 + 1) + policy.on_up(hosts[0]) + policy.on_up(hosts[1]) + + # remote DC disappears after finishing local, but not starting remote + plan = policy.make_query_plan() + for _ in range(2): + next(plan) + policy.on_down(hosts[2]) + policy.on_down(hosts[3]) + # nothing left + self.assertEqual(len(list(plan)), 0 + 0) + policy.on_up(hosts[2]) + policy.on_up(hosts[3]) + + # remote DC disappears while traversing it + plan = policy.make_query_plan() + for _ in range(3): + next(plan) + policy.on_down(hosts[2]) + policy.on_down(hosts[3]) + # we continue with remainder of original list + self.assertEqual(len(list(plan)), 0 + 1) + policy.on_up(hosts[2]) + policy.on_up(hosts[3]) + + another_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + another_host.set_location_info("dc3", "rack1") + new_host.set_location_info("dc3", "rack1") + + # new DC while traversing remote + plan = policy.make_query_plan() + for _ in range(3): + next(plan) + policy.on_up(new_host) + policy.on_up(another_host) + # we continue with remainder of original list + self.assertEqual(len(list(plan)), 0 + 1) + + # remote DC disappears after finishing it + plan = policy.make_query_plan() + for _ in range(3): + next(plan) + last_host_in_this_dc = next(plan) + if last_host_in_this_dc in (new_host, another_host): + down_hosts = [new_host, another_host] + else: + down_hosts = hosts[2:] + for h in down_hosts: + policy.on_down(h) + # the last DC has two + self.assertEqual(len(list(plan)), 0 + 2) + + def test_no_live_nodes(self): + """ + Ensure query plan for a downed cluster will execute without errors + """ + + hosts = [] + for i in range(4): + h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) + h.set_location_info("dc1", "rack1") + hosts.append(h) + + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy.populate(Mock(), hosts) + + for host in hosts: + policy.on_down(host) + + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, []) + + def test_no_nodes(self): + """ + Ensure query plan for an empty cluster will execute without errors + """ + + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy.populate(None, []) + + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, []) + + def test_default_dc(self): + host_local = Host(DefaultEndPoint(1), SimpleConvictionPolicy, 'local') + host_remote = Host(DefaultEndPoint(2), SimpleConvictionPolicy, 'remote') + host_none = Host(DefaultEndPoint(1), SimpleConvictionPolicy) + + # contact point is '1' + cluster = Mock(endpoints_resolved=[DefaultEndPoint(1)]) + + # contact DC first + policy = DCAwareRoundRobinPolicy() + policy.populate(cluster, [host_none]) + self.assertFalse(policy.local_dc) + policy.on_add(host_local) + policy.on_add(host_remote) + self.assertNotEqual(policy.local_dc, host_remote.datacenter) + self.assertEqual(policy.local_dc, host_local.datacenter) + + # contact DC second + policy = DCAwareRoundRobinPolicy() + policy.populate(cluster, [host_none]) + self.assertFalse(policy.local_dc) + policy.on_add(host_remote) + policy.on_add(host_local) + self.assertNotEqual(policy.local_dc, host_remote.datacenter) + self.assertEqual(policy.local_dc, host_local.datacenter) + + # no DC + policy = DCAwareRoundRobinPolicy() + policy.populate(cluster, [host_none]) + self.assertFalse(policy.local_dc) + policy.on_add(host_none) + self.assertFalse(policy.local_dc) + + # only other DC + policy = DCAwareRoundRobinPolicy() + policy.populate(cluster, [host_none]) + self.assertFalse(policy.local_dc) + policy.on_add(host_remote) + self.assertFalse(policy.local_dc) + + +class TokenAwarePolicyTest(unittest.TestCase): + + def test_wrap_round_robin(self): + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + for host in hosts: + host.set_up() + + def get_replicas(keyspace, packed_key): + index = struct.unpack('>i', packed_key)[0] + return list(islice(cycle(hosts), index, index + 2)) + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy(RoundRobinPolicy()) + policy.populate(cluster, hosts) + + for i in range(4): + query = Statement(routing_key=struct.pack('>i', i), keyspace='keyspace_name') + qplan = list(policy.make_query_plan(None, query)) + + replicas = get_replicas(None, struct.pack('>i', i)) + other = set(h for h in hosts if h not in replicas) + self.assertEqual(replicas, qplan[:2]) + self.assertEqual(other, set(qplan[2:])) + + # Should use the secondary policy + for i in range(4): + qplan = list(policy.make_query_plan()) + + self.assertEqual(set(qplan), set(hosts)) + + def test_wrap_dc_aware(self): + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + for host in hosts: + host.set_up() + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc2", "rack1") + + def get_replicas(keyspace, packed_key): + index = struct.unpack('>i', packed_key)[0] + # return one node from each DC + if index % 2 == 0: + return [hosts[0], hosts[2]] + else: + return [hosts[1], hosts[3]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1)) + policy.populate(cluster, hosts) + + for i in range(4): + query = Statement(routing_key=struct.pack('>i', i), keyspace='keyspace_name') + qplan = list(policy.make_query_plan(None, query)) + replicas = get_replicas(None, struct.pack('>i', i)) + + # first should be the only local replica + self.assertIn(qplan[0], replicas) + self.assertEqual(qplan[0].datacenter, "dc1") + + # then the local non-replica + self.assertNotIn(qplan[1], replicas) + self.assertEqual(qplan[1].datacenter, "dc1") + + # then one of the remotes (used_hosts_per_remote_dc is 1, so we + # shouldn't see two remotes) + self.assertEqual(qplan[2].datacenter, "dc2") + self.assertEqual(3, len(qplan)) + + class FakeCluster: + def __init__(self): + self.metadata = Mock(spec=Metadata) + + def test_get_distance(self): + """ + Same test as DCAwareRoundRobinPolicyTest.test_get_distance() + Except a FakeCluster is needed for the metadata variable and + policy.child_policy is needed to change child policy settings + """ + + policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0)) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host.set_location_info("dc1", "rack1") + + policy.populate(self.FakeCluster(), [host]) + + self.assertEqual(policy.distance(host), HostDistance.LOCAL) + + # used_hosts_per_remote_dc is set to 0, so ignore it + remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) + remote_host.set_location_info("dc2", "rack1") + self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) + + # dc2 isn't registered in the policy's live_hosts dict + policy._child_policy.used_hosts_per_remote_dc = 1 + self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) + + # make sure the policy has both dcs registered + policy.populate(self.FakeCluster(), [host, remote_host]) + self.assertEqual(policy.distance(remote_host), HostDistance.REMOTE) + + # since used_hosts_per_remote_dc is set to 1, only the first + # remote host in dc2 will be REMOTE, the rest are IGNORED + second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy) + second_remote_host.set_location_info("dc2", "rack1") + policy.populate(self.FakeCluster(), [host, remote_host, second_remote_host]) + distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) + self.assertEqual(distances, set([HostDistance.REMOTE, HostDistance.IGNORED])) + + def test_status_updates(self): + """ + Same test as DCAwareRoundRobinPolicyTest.test_status_updates() + """ + + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc2", "rack1") + + policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1)) + policy.populate(self.FakeCluster(), hosts) + policy.on_down(hosts[0]) + policy.on_remove(hosts[2]) + + new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_local_host.set_location_info("dc1", "rack1") + policy.on_up(new_local_host) + + new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + new_remote_host.set_location_info("dc9000", "rack1") + policy.on_add(new_remote_host) + + # we now have two local hosts and two remote hosts in separate dcs + qplan = list(policy.make_query_plan()) + self.assertEqual(set(qplan[:2]), set([hosts[1], new_local_host])) + self.assertEqual(set(qplan[2:]), set([hosts[3], new_remote_host])) + + # since we have hosts in dc9000, the distance shouldn't be IGNORED + self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE) + + policy.on_down(new_local_host) + policy.on_down(hosts[1]) + qplan = list(policy.make_query_plan()) + self.assertEqual(set(qplan), set([hosts[3], new_remote_host])) + + policy.on_down(new_remote_host) + policy.on_down(hosts[3]) + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, []) + + def test_statement_keyspace(self): + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + for host in hosts: + host.set_up() + + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + replicas = hosts[2:] + cluster.metadata.get_replicas.return_value = replicas + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy) + policy.populate(cluster, hosts) + + # no keyspace, child policy is called + keyspace = None + routing_key = 'routing_key' + query = Statement(routing_key=routing_key) + qplan = list(policy.make_query_plan(keyspace, query)) + self.assertEqual(hosts, qplan) + self.assertEqual(cluster.metadata.get_replicas.call_count, 0) + child_policy.make_query_plan.assert_called_once_with(keyspace, query) + + # working keyspace, no statement + cluster.metadata.get_replicas.reset_mock() + keyspace = 'working_keyspace' + routing_key = 'routing_key' + query = Statement(routing_key=routing_key) + qplan = list(policy.make_query_plan(keyspace, query)) + self.assertEqual(replicas + hosts[:2], qplan) + cluster.metadata.get_replicas.assert_called_with(keyspace, routing_key) + + # statement keyspace, no working + cluster.metadata.get_replicas.reset_mock() + working_keyspace = None + statement_keyspace = 'statement_keyspace' + routing_key = 'routing_key' + query = Statement(routing_key=routing_key, keyspace=statement_keyspace) + qplan = list(policy.make_query_plan(working_keyspace, query)) + self.assertEqual(replicas + hosts[:2], qplan) + cluster.metadata.get_replicas.assert_called_with(statement_keyspace, routing_key) + + # both keyspaces set, statement keyspace used for routing + cluster.metadata.get_replicas.reset_mock() + working_keyspace = 'working_keyspace' + statement_keyspace = 'statement_keyspace' + routing_key = 'routing_key' + query = Statement(routing_key=routing_key, keyspace=statement_keyspace) + qplan = list(policy.make_query_plan(working_keyspace, query)) + self.assertEqual(replicas + hosts[:2], qplan) + cluster.metadata.get_replicas.assert_called_with(statement_keyspace, routing_key) + + def test_shuffles_if_given_keyspace_and_routing_key(self): + """ + Test to validate the hosts are shuffled when `shuffle_replicas` is truthy + @since 3.8 + @jira_ticket PYTHON-676 + @expected_result shuffle should be called, because the keyspace and the + routing key are set + + @test_category policy + """ + self._assert_shuffle(keyspace='keyspace', routing_key='routing_key') + + def test_no_shuffle_if_given_no_keyspace(self): + """ + Test to validate the hosts are not shuffled when no keyspace is provided + @since 3.8 + @jira_ticket PYTHON-676 + @expected_result shuffle should be called, because keyspace is None + + @test_category policy + """ + self._assert_shuffle(keyspace=None, routing_key='routing_key') + + def test_no_shuffle_if_given_no_routing_key(self): + """ + Test to validate the hosts are not shuffled when no routing_key is provided + @since 3.8 + @jira_ticket PYTHON-676 + @expected_result shuffle should be called, because routing_key is None + + @test_category policy + """ + self._assert_shuffle(keyspace='keyspace', routing_key=None) + + @patch('cassandra.policies.shuffle') + def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + for host in hosts: + host.set_up() + + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + replicas = hosts[2:] + cluster.metadata.get_replicas.return_value = replicas + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) + policy.populate(cluster, hosts) + + cluster.metadata.get_replicas.reset_mock() + child_policy.make_query_plan.reset_mock() + query = Statement(routing_key=routing_key) + qplan = list(policy.make_query_plan(keyspace, query)) + if keyspace is None or routing_key is None: + self.assertEqual(hosts, qplan) + self.assertEqual(cluster.metadata.get_replicas.call_count, 0) + child_policy.make_query_plan.assert_called_once_with(keyspace, query) + self.assertEqual(patched_shuffle.call_count, 0) + else: + self.assertEqual(set(replicas), set(qplan[:2])) + self.assertEqual(hosts[:2], qplan[2:]) + child_policy.make_query_plan.assert_called_once_with(keyspace, query) + self.assertEqual(patched_shuffle.call_count, 1) + + +class ConvictionPolicyTest(unittest.TestCase): + def test_not_implemented(self): + """ + Code coverage for interface-style base class + """ + + conviction_policy = ConvictionPolicy(1) + self.assertRaises(NotImplementedError, conviction_policy.add_failure, 1) + self.assertRaises(NotImplementedError, conviction_policy.reset) + + +class SimpleConvictionPolicyTest(unittest.TestCase): + def test_basic_responses(self): + """ + Code coverage for SimpleConvictionPolicy + """ + + conviction_policy = SimpleConvictionPolicy(1) + self.assertEqual(conviction_policy.add_failure(1), True) + self.assertEqual(conviction_policy.reset(), None) + + +class ReconnectionPolicyTest(unittest.TestCase): + def test_basic_responses(self): + """ + Code coverage for interface-style base class + """ + + policy = ReconnectionPolicy() + self.assertRaises(NotImplementedError, policy.new_schedule) + + +class ConstantReconnectionPolicyTest(unittest.TestCase): + + def test_bad_vals(self): + """ + Test initialization values + """ + + self.assertRaises(ValueError, ConstantReconnectionPolicy, -1, 0) + + def test_schedule(self): + """ + Test ConstantReconnectionPolicy schedule + """ + + delay = 2 + max_attempts = 100 + policy = ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts) + schedule = list(policy.new_schedule()) + self.assertEqual(len(schedule), max_attempts) + for i, delay in enumerate(schedule): + self.assertEqual(delay, delay) + + def test_schedule_negative_max_attempts(self): + """ + Test how negative max_attempts are handled + """ + + delay = 2 + max_attempts = -100 + + try: + ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts) + self.fail('max_attempts should throw ValueError when negative') + except ValueError: + pass + + def test_schedule_infinite_attempts(self): + delay = 2 + max_attempts = None + crp = ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts) + # this is infinite. we'll just verify one more than default + for _, d in zip(range(65), crp.new_schedule()): + self.assertEqual(d, delay) + + +class ExponentialReconnectionPolicyTest(unittest.TestCase): + + def _assert_between(self, value, min, max): + self.assertTrue(min <= value <= max) + + def test_bad_vals(self): + self.assertRaises(ValueError, ExponentialReconnectionPolicy, -1, 0) + self.assertRaises(ValueError, ExponentialReconnectionPolicy, 0, -1) + self.assertRaises(ValueError, ExponentialReconnectionPolicy, 9000, 1) + self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2, -1) + + def test_schedule_no_max(self): + base_delay = 2.0 + max_delay = 100.0 + test_iter = 10000 + policy = ExponentialReconnectionPolicy(base_delay=base_delay, max_delay=max_delay, max_attempts=None) + sched_slice = list(islice(policy.new_schedule(), 0, test_iter)) + self._assert_between(sched_slice[0], base_delay*0.85, base_delay*1.15) + self._assert_between(sched_slice[-1], max_delay*0.85, max_delay*1.15) + self.assertEqual(len(sched_slice), test_iter) + + def test_schedule_with_max(self): + base_delay = 2.0 + max_delay = 100.0 + max_attempts = 64 + policy = ExponentialReconnectionPolicy(base_delay=base_delay, max_delay=max_delay, max_attempts=max_attempts) + schedule = list(policy.new_schedule()) + self.assertEqual(len(schedule), max_attempts) + for i, delay in enumerate(schedule): + if i == 0: + self._assert_between(delay, base_delay*0.85, base_delay*1.15) + elif i < 6: + value = base_delay * (2 ** i) + self._assert_between(delay, value*85/100, value*1.15) + else: + self._assert_between(delay, max_delay*85/100, max_delay*1.15) + + def test_schedule_exactly_one_attempt(self): + base_delay = 2.0 + max_delay = 100.0 + max_attempts = 1 + policy = ExponentialReconnectionPolicy( + base_delay=base_delay, max_delay=max_delay, max_attempts=max_attempts + ) + self.assertEqual(len(list(policy.new_schedule())), 1) + + def test_schedule_overflow(self): + """ + Test to verify an OverflowError is handled correctly + in the ExponentialReconnectionPolicy + @since 3.10 + @jira_ticket PYTHON-707 + @expected_result all numbers should be less than sys.float_info.max + since that's the biggest max we can possibly have as that argument must be a float. + Note that is possible for a float to be inf. + + @test_category policy + """ + + # This should lead to overflow + # Note that this may not happen in the fist iterations + # as sys.float_info.max * 2 = inf + base_delay = sys.float_info.max - 1 + max_delay = sys.float_info.max + max_attempts = 2**12 + policy = ExponentialReconnectionPolicy(base_delay=base_delay, max_delay=max_delay, max_attempts=max_attempts) + schedule = list(policy.new_schedule()) + for number in schedule: + self.assertLessEqual(number, sys.float_info.max) + + def test_schedule_with_jitter(self): + """ + Test to verify jitter is added properly and is always between -/+ 15%. + + @since 3.18 + @jira_ticket PYTHON-1065 + """ + for i in range(100): + base_delay = float(randint(2, 5)) + max_delay = (base_delay - 1) * 100.0 + ep = ExponentialReconnectionPolicy(base_delay, max_delay, max_attempts=64) + schedule = ep.new_schedule() + for i in range(64): + exp_delay = min(base_delay * (2 ** i), max_delay) + min_jitter_delay = max(base_delay, exp_delay*85/100) + max_jitter_delay = min(max_delay, exp_delay*115/100) + delay = next(schedule) + self._assert_between(delay, min_jitter_delay, max_jitter_delay) + + +ONE = ConsistencyLevel.ONE + + +class RetryPolicyTest(unittest.TestCase): + + def test_read_timeout(self): + policy = RetryPolicy() + + # if this is the second or greater attempt, rethrow + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=1, received_responses=2, + data_retrieved=True, retry_num=1) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + # if we didn't get enough responses, rethrow + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=2, received_responses=1, + data_retrieved=True, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + # if we got enough responses, but also got a data response, rethrow + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=2, received_responses=2, + data_retrieved=True, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + # we got enough responses but no data response, so retry + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=2, received_responses=2, + data_retrieved=False, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, ONE) + + def test_write_timeout(self): + policy = RetryPolicy() + + # if this is the second or greater attempt, rethrow + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=WriteType.SIMPLE, + required_responses=1, received_responses=2, retry_num=1) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + # if it's not a BATCH_LOG write, don't retry it + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=WriteType.SIMPLE, + required_responses=1, received_responses=2, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + # retry BATCH_LOG writes regardless of received responses + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=WriteType.BATCH_LOG, + required_responses=10000, received_responses=1, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, ONE) + + def test_unavailable(self): + """ + Use the same tests for test_write_timeout, but ensure they only RETHROW + """ + policy = RetryPolicy() + + retry, consistency = policy.on_unavailable( + query=None, consistency=ONE, + required_replicas=1, alive_replicas=2, retry_num=1) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + retry, consistency = policy.on_unavailable( + query=None, consistency=ONE, + required_replicas=1, alive_replicas=2, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY_NEXT_HOST) + self.assertEqual(consistency, None) + + retry, consistency = policy.on_unavailable( + query=None, consistency=ONE, + required_replicas=10000, alive_replicas=1, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY_NEXT_HOST) + self.assertEqual(consistency, None) + + +class FallthroughRetryPolicyTest(unittest.TestCase): + + """ + Use the same tests for test_write_timeout, but ensure they only RETHROW + """ + + def test_read_timeout(self): + policy = FallthroughRetryPolicy() + + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=1, received_responses=2, + data_retrieved=True, retry_num=1) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=2, received_responses=1, + data_retrieved=True, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=2, received_responses=2, + data_retrieved=True, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=2, received_responses=2, + data_retrieved=False, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + def test_write_timeout(self): + policy = FallthroughRetryPolicy() + + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=WriteType.SIMPLE, + required_responses=1, received_responses=2, retry_num=1) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=WriteType.SIMPLE, + required_responses=1, received_responses=2, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=WriteType.BATCH_LOG, + required_responses=10000, received_responses=1, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + def test_unavailable(self): + policy = FallthroughRetryPolicy() + + retry, consistency = policy.on_unavailable( + query=None, consistency=ONE, + required_replicas=1, alive_replicas=2, retry_num=1) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + retry, consistency = policy.on_unavailable( + query=None, consistency=ONE, + required_replicas=1, alive_replicas=2, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + retry, consistency = policy.on_unavailable( + query=None, consistency=ONE, + required_replicas=10000, alive_replicas=1, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + +class DowngradingConsistencyRetryPolicyTest(unittest.TestCase): + + def test_read_timeout(self): + policy = DowngradingConsistencyRetryPolicy() + + # if this is the second or greater attempt, rethrow + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=1, received_responses=2, + data_retrieved=True, retry_num=1) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + # if we didn't get enough responses, retry at a lower consistency + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=4, received_responses=3, + data_retrieved=True, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, ConsistencyLevel.THREE) + + # if we didn't get enough responses, retry at a lower consistency + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=3, received_responses=2, + data_retrieved=True, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, ConsistencyLevel.TWO) + + # retry consistency level goes down based on the # of recv'd responses + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=3, received_responses=1, + data_retrieved=True, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, ConsistencyLevel.ONE) + + # if we got no responses, rethrow + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=3, received_responses=0, + data_retrieved=True, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + # if we got enough response but no data, retry + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=3, received_responses=3, + data_retrieved=False, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, ONE) + + # if we got enough responses, but also got a data response, rethrow + retry, consistency = policy.on_read_timeout( + query=None, consistency=ONE, required_responses=2, received_responses=2, + data_retrieved=True, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + def test_write_timeout(self): + policy = DowngradingConsistencyRetryPolicy() + + # if this is the second or greater attempt, rethrow + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=WriteType.SIMPLE, + required_responses=1, received_responses=2, retry_num=1) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + for write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): + # ignore failures if at least one response (replica persisted) + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=write_type, + required_responses=1, received_responses=2, retry_num=0) + self.assertEqual(retry, RetryPolicy.IGNORE) + # retrhow if we can't be sure we have a replica + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=write_type, + required_responses=1, received_responses=0, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + + # downgrade consistency level on unlogged batch writes + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=WriteType.UNLOGGED_BATCH, + required_responses=3, received_responses=1, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, ConsistencyLevel.ONE) + + # retry batch log writes at the same consistency level + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=WriteType.BATCH_LOG, + required_responses=3, received_responses=1, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, ONE) + + # timeout on an unknown write_type + retry, consistency = policy.on_write_timeout( + query=None, consistency=ONE, write_type=None, + required_responses=1, received_responses=2, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + def test_unavailable(self): + policy = DowngradingConsistencyRetryPolicy() + + # if this is the second or greater attempt, rethrow + retry, consistency = policy.on_unavailable( + query=None, consistency=ONE, required_replicas=3, alive_replicas=1, retry_num=1) + self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(consistency, None) + + # downgrade consistency on unavailable exceptions + retry, consistency = policy.on_unavailable( + query=None, consistency=ONE, required_replicas=3, alive_replicas=1, retry_num=0) + self.assertEqual(retry, RetryPolicy.RETRY) + self.assertEqual(consistency, ConsistencyLevel.ONE) + + +class WhiteListRoundRobinPolicyTest(unittest.TestCase): + + def test_hosts_with_hostname(self): + hosts = ['localhost'] + policy = WhiteListRoundRobinPolicy(hosts) + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy) + policy.populate(None, [host]) + + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), [host]) + + self.assertEqual(policy.distance(host), HostDistance.LOCAL) + + +class AddressTranslatorTest(unittest.TestCase): + + def test_identity_translator(self): + IdentityTranslator() + + @patch('socket.getfqdn', return_value='localhost') + def test_ec2_multi_region_translator(self, *_): + ec2t = EC2MultiRegionTranslator() + addr = '127.0.0.1' + translated = ec2t.translate(addr) + self.assertIsNot(translated, addr) # verifies that the resolver path is followed + self.assertEqual(translated, addr) # and that it resolves to the same address + + +class HostFilterPolicyInitTest(unittest.TestCase): + + def setUp(self): + self.child_policy, self.predicate = (Mock(name='child_policy'), + Mock(name='predicate')) + + def _check_init(self, hfp): + self.assertIs(hfp._child_policy, self.child_policy) + self.assertIsInstance(hfp._hosts_lock, LockType) + + # we can't use a simple assertIs because we wrap the function + arg0, arg1 = Mock(name='arg0'), Mock(name='arg1') + hfp.predicate(arg0) + hfp.predicate(arg1) + self.predicate.assert_has_calls([call(arg0), call(arg1)]) + + def test_init_arg_order(self): + self._check_init(HostFilterPolicy(self.child_policy, self.predicate)) + + def test_init_kwargs(self): + self._check_init(HostFilterPolicy( + predicate=self.predicate, child_policy=self.child_policy + )) + + def test_immutable_predicate(self): + expected_message_regex = "can't set attribute" + hfp = HostFilterPolicy(child_policy=Mock(name='child_policy'), + predicate=Mock(name='predicate')) + with self.assertRaisesRegexp(AttributeError, expected_message_regex): + hfp.predicate = object() + + +class HostFilterPolicyDeferralTest(unittest.TestCase): + + def setUp(self): + self.passthrough_hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy'), + predicate=Mock(name='passthrough_predicate', + return_value=True) + ) + self.filterall_hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy'), + predicate=Mock(name='filterall_predicate', + return_value=False) + ) + + def _check_host_triggered_method(self, policy, name): + arg, kwarg = Mock(name='arg'), Mock(name='kwarg') + method, child_policy_method = (getattr(policy, name), + getattr(policy._child_policy, name)) + + result = method(arg, kw=kwarg) + + # method calls the child policy's method... + child_policy_method.assert_called_once_with(arg, kw=kwarg) + # and returns its return value + self.assertIs(result, child_policy_method.return_value) + + def test_defer_on_up_to_child_policy(self): + self._check_host_triggered_method(self.passthrough_hfp, 'on_up') + + def test_defer_on_down_to_child_policy(self): + self._check_host_triggered_method(self.passthrough_hfp, 'on_down') + + def test_defer_on_add_to_child_policy(self): + self._check_host_triggered_method(self.passthrough_hfp, 'on_add') + + def test_defer_on_remove_to_child_policy(self): + self._check_host_triggered_method(self.passthrough_hfp, 'on_remove') + + def test_filtered_host_on_up_doesnt_call_child_policy(self): + self._check_host_triggered_method(self.filterall_hfp, 'on_up') + + def test_filtered_host_on_down_doesnt_call_child_policy(self): + self._check_host_triggered_method(self.filterall_hfp, 'on_down') + + def test_filtered_host_on_add_doesnt_call_child_policy(self): + self._check_host_triggered_method(self.filterall_hfp, 'on_add') + + def test_filtered_host_on_remove_doesnt_call_child_policy(self): + self._check_host_triggered_method(self.filterall_hfp, 'on_remove') + + def _check_check_supported_deferral(self, policy): + policy.check_supported() + policy._child_policy.check_supported.assert_called_once() + + def test_check_supported_defers_to_child(self): + self._check_check_supported_deferral(self.passthrough_hfp) + + def test_check_supported_defers_to_child_when_predicate_filtered(self): + self._check_check_supported_deferral(self.filterall_hfp) + + +class HostFilterPolicyDistanceTest(unittest.TestCase): + + def setUp(self): + self.hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy', distance=Mock(name='distance')), + predicate=lambda host: host.address == 'acceptme' + ) + self.ignored_host = Host(DefaultEndPoint('ignoreme'), conviction_policy_factory=Mock()) + self.accepted_host = Host(DefaultEndPoint('acceptme'), conviction_policy_factory=Mock()) + + def test_ignored_with_filter(self): + self.assertEqual(self.hfp.distance(self.ignored_host), + HostDistance.IGNORED) + self.assertNotEqual(self.hfp.distance(self.accepted_host), + HostDistance.IGNORED) + + def test_accepted_filter_defers_to_child_policy(self): + self.hfp._child_policy.distance.side_effect = distances = Mock(), Mock() + + # getting the distance for an ignored host shouldn't affect subsequent results + self.hfp.distance(self.ignored_host) + # first call of _child_policy with count() side effect + self.assertEqual(self.hfp.distance(self.accepted_host), distances[0]) + # second call of _child_policy with count() side effect + self.assertEqual(self.hfp.distance(self.accepted_host), distances[1]) + + +class HostFilterPolicyPopulateTest(unittest.TestCase): + + def test_populate_deferred_to_child(self): + hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy'), + predicate=lambda host: True + ) + mock_cluster, hosts = (Mock(name='cluster'), + ['host1', 'host2', 'host3']) + hfp.populate(mock_cluster, hosts) + hfp._child_policy.populate.assert_called_once_with( + cluster=mock_cluster, + hosts=hosts + ) + + def test_child_is_populated_with_filtered_hosts(self): + hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy'), + predicate=lambda host: False + ) + mock_cluster, hosts = (Mock(name='cluster'), + ['acceptme0', 'acceptme1']) + hfp.populate(mock_cluster, hosts) + hfp._child_policy.populate.assert_called_once() + self.assertEqual( + hfp._child_policy.populate.call_args[1]['hosts'], + ['acceptme0', 'acceptme1'] + ) + + +class HostFilterPolicyQueryPlanTest(unittest.TestCase): + + def test_query_plan_deferred_to_child(self): + child_policy = Mock( + name='child_policy', + make_query_plan=Mock( + return_value=[object(), object(), object()] + ) + ) + hfp = HostFilterPolicy( + child_policy=child_policy, + predicate=lambda host: True + ) + working_keyspace, query = (Mock(name='working_keyspace'), + Mock(name='query')) + qp = list(hfp.make_query_plan(working_keyspace=working_keyspace, + query=query)) + hfp._child_policy.make_query_plan.assert_called_once_with( + working_keyspace=working_keyspace, + query=query + ) + self.assertEqual(qp, hfp._child_policy.make_query_plan.return_value) + + def test_wrap_token_aware(self): + cluster = Mock(spec=Cluster) + hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] + for host in hosts: + host.set_up() + + def get_replicas(keyspace, packed_key): + return hosts[:2] + + cluster.metadata.get_replicas.side_effect = get_replicas + + child_policy = TokenAwarePolicy(RoundRobinPolicy()) + + hfp = HostFilterPolicy( + child_policy=child_policy, + predicate=lambda host: host.address != "127.0.0.1" and host.address != "127.0.0.4" + ) + hfp.populate(cluster, hosts) + + # We don't allow randomness for ordering the replicas in RoundRobin + hfp._child_policy._child_policy._position = 0 + + + mocked_query = Mock() + query_plan = hfp.make_query_plan("keyspace", mocked_query) + # First the not filtered replica, and then the rest of the allowed hosts ordered + query_plan = list(query_plan) + self.assertEqual(query_plan[0], Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy)) + self.assertEqual(set(query_plan[1:]),{Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy), + Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy)}) + + def test_create_whitelist(self): + cluster = Mock(spec=Cluster) + hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] + for host in hosts: + host.set_up() + + child_policy = RoundRobinPolicy() + + hfp = HostFilterPolicy( + child_policy=child_policy, + predicate=lambda host: host.address == "127.0.0.1" or host.address == "127.0.0.4" + ) + hfp.populate(cluster, hosts) + + # We don't allow randomness for ordering the replicas in RoundRobin + hfp._child_policy._position = 0 + + mocked_query = Mock() + query_plan = hfp.make_query_plan("keyspace", mocked_query) + # Only the filtered replicas should be allowed + self.assertEqual(set(query_plan), {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy), + Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)}) + diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py new file mode 100644 index 0000000..21223d2 --- /dev/null +++ b/tests/unit/test_protocol.py @@ -0,0 +1,174 @@ +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from mock import Mock +from cassandra import ProtocolVersion, UnsupportedOperation +from cassandra.protocol import (PrepareMessage, QueryMessage, ExecuteMessage, + BatchMessage) +from cassandra.query import SimpleStatement, BatchType + +class MessageTest(unittest.TestCase): + + def test_prepare_message(self): + """ + Test to check the appropriate calls are made + + @since 3.9 + @jira_ticket PYTHON-713 + @expected_result the values are correctly written + + @test_category connection + """ + message = PrepareMessage("a") + io = Mock() + + message.send_body(io, 4) + self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',)]) + + io.reset_mock() + message.send_body(io, 5) + + self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x00\x00\x00',)]) + + def test_execute_message(self): + message = ExecuteMessage('1', [], 4) + io = Mock() + + message.send_body(io, 4) + self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)]) + + io.reset_mock() + message.result_metadata_id = 'foo' + message.send_body(io, 5) + + self._check_calls(io, [(b'\x00\x01',), (b'1',), + (b'\x00\x03',), (b'foo',), + (b'\x00\x04',), + (b'\x00\x00\x00\x01',), (b'\x00\x00',)]) + + def test_query_message(self): + """ + Test to check the appropriate calls are made + + @since 3.9 + @jira_ticket PYTHON-713 + @expected_result the values are correctly written + + @test_category connection + """ + message = QueryMessage("a", 3) + io = Mock() + + message.send_body(io, 4) + self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)]) + + io.reset_mock() + message.send_body(io, 5) + self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)]) + + def _check_calls(self, io, expected): + self.assertEqual( + tuple(c[1] for c in io.write.mock_calls), + tuple(expected) + ) + + def test_prepare_flag(self): + """ + Test to check the prepare flag is properly set, This should only happen for V5 at the moment. + + @since 3.9 + @jira_ticket PYTHON-713 + @expected_result the values are correctly written + + @test_category connection + """ + message = PrepareMessage("a") + io = Mock() + for version in ProtocolVersion.SUPPORTED_VERSIONS: + message.send_body(io, version) + if ProtocolVersion.uses_prepare_flags(version): + self.assertEqual(len(io.write.mock_calls), 3) + else: + self.assertEqual(len(io.write.mock_calls), 2) + io.reset_mock() + + def test_prepare_flag_with_keyspace(self): + message = PrepareMessage("a", keyspace='ks') + io = Mock() + + for version in ProtocolVersion.SUPPORTED_VERSIONS: + if ProtocolVersion.uses_keyspace_flag(version): + message.send_body(io, version) + self._check_calls(io, [ + (b'\x00\x00\x00\x01',), + (b'a',), + (b'\x00\x00\x00\x01',), + (b'\x00\x02',), + (b'ks',), + ]) + else: + with self.assertRaises(UnsupportedOperation): + message.send_body(io, version) + io.reset_mock() + + def test_keyspace_flag_raises_before_v5(self): + keyspace_message = QueryMessage('a', consistency_level=3, keyspace='ks') + io = Mock(name='io') + + with self.assertRaisesRegexp(UnsupportedOperation, 'Keyspaces.*set'): + keyspace_message.send_body(io, protocol_version=4) + io.assert_not_called() + + def test_keyspace_written_with_length(self): + io = Mock(name='io') + base_expected = [ + (b'\x00\x00\x00\x01',), + (b'a',), + (b'\x00\x03',), + (b'\x00\x00\x00\x80',), # options w/ keyspace flag + ] + + QueryMessage('a', consistency_level=3, keyspace='ks').send_body( + io, protocol_version=5 + ) + self._check_calls(io, base_expected + [ + (b'\x00\x02',), # length of keyspace string + (b'ks',), + ]) + + io.reset_mock() + + QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body( + io, protocol_version=5 + ) + self._check_calls(io, base_expected + [ + (b'\x00\x08',), # length of keyspace string + (b'keyspace',), + ]) + + def test_batch_message_with_keyspace(self): + self.maxDiff = None + io = Mock(name='io') + batch = BatchMessage( + batch_type=BatchType.LOGGED, + queries=((False, 'stmt a', ('param a',)), + (False, 'stmt b', ('param b',)), + (False, 'stmt c', ('param c',)) + ), + consistency_level=3, + keyspace='ks' + ) + batch.send_body(io, protocol_version=5) + self._check_calls(io, + ((b'\x00',), (b'\x00\x03',), (b'\x00',), + (b'\x00\x00\x00\x06',), (b'stmt a',), + (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param a',), + (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt b',), + (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param b',), + (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt c',), + (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param c',), + (b'\x00\x03',), + (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) + ) diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py new file mode 100644 index 0000000..7c2bfc0 --- /dev/null +++ b/tests/unit/test_query.py @@ -0,0 +1,75 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import six + +from cassandra.query import BatchStatement, SimpleStatement + + +class BatchStatementTest(unittest.TestCase): + # TODO: this suite could be expanded; for now just adding a test covering a PR + + def test_clear(self): + keyspace = 'keyspace' + routing_key = 'routing_key' + custom_payload = {'key': six.b('value')} + + ss = SimpleStatement('whatever', keyspace=keyspace, routing_key=routing_key, custom_payload=custom_payload) + + batch = BatchStatement() + batch.add(ss) + + self.assertTrue(batch._statements_and_parameters) + self.assertEqual(batch.keyspace, keyspace) + self.assertEqual(batch.routing_key, routing_key) + self.assertEqual(batch.custom_payload, custom_payload) + + batch.clear() + self.assertFalse(batch._statements_and_parameters) + self.assertIsNone(batch.keyspace) + self.assertIsNone(batch.routing_key) + self.assertFalse(batch.custom_payload) + + batch.add(ss) + + def test_clear_empty(self): + batch = BatchStatement() + batch.clear() + self.assertFalse(batch._statements_and_parameters) + self.assertIsNone(batch.keyspace) + self.assertIsNone(batch.routing_key) + self.assertFalse(batch.custom_payload) + + batch.add('something') + + def test_add_all(self): + batch = BatchStatement() + statements = ['%s'] * 10 + parameters = [(i,) for i in range(10)] + batch.add_all(statements, parameters) + bound_statements = [t[1] for t in batch._statements_and_parameters] + str_parameters = [str(i) for i in range(10)] + self.assertEqual(bound_statements, str_parameters) + + def test_len(self): + for n in 0, 10, 100: + batch = BatchStatement() + batch.add_all(statements=['%s'] * n, + parameters=[(i,) for i in range(n)]) + self.assertEqual(len(batch), n) diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py new file mode 100644 index 0000000..87f1bd6 --- /dev/null +++ b/tests/unit/test_response_future.py @@ -0,0 +1,574 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from mock import Mock, MagicMock, ANY + +from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut +from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion +from cassandra.connection import Connection, ConnectionException +from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, + UnavailableErrorMessage, ResultMessage, QueryMessage, + OverloadedErrorMessage, IsBootstrappingErrorMessage, + PreparedQueryNotFound, PrepareMessage, + RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, + RESULT_KIND_SCHEMA_CHANGE, RESULT_KIND_PREPARED, + ProtocolHandler) +from cassandra.policies import RetryPolicy +from cassandra.pool import NoConnectionsAvailable +from cassandra.query import SimpleStatement + + +class ResponseFutureTests(unittest.TestCase): + + def make_basic_session(self): + return Mock(spec=Session, row_factory=lambda *x: list(x)) + + def make_pool(self): + pool = Mock() + pool.is_shutdown = False + pool.borrow_connection.return_value = [Mock(), Mock()] + return pool + + def make_session(self): + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] + session._pools.get.return_value = self.make_pool() + return session + + def make_response_future(self, session): + query = SimpleStatement("SELECT * FROM foo") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + return ResponseFuture(session, message, query, 1) + + def make_mock_response(self, results): + return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=results, paging_state=None, col_types=None) + + def test_result_message(self): + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] + pool = session._pools.get.return_value + pool.is_shutdown = False + + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + rf = self.make_response_future(session) + rf.send_request() + + rf.session._pools.get.assert_called_once_with('ip1') + pool.borrow_connection.assert_called_once_with(timeout=ANY) + + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + + rf._set_result(None, None, None, self.make_mock_response([{'col': 'val'}])) + result = rf.result() + self.assertEqual(result, [{'col': 'val'}]) + + def test_unknown_result_class(self): + session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + rf = self.make_response_future(session) + rf.send_request() + rf._set_result(None, None, None, object()) + self.assertRaises(ConnectionException, rf.result) + + def test_set_keyspace_result(self): + session = self.make_session() + rf = self.make_response_future(session) + rf.send_request() + + result = Mock(spec=ResultMessage, + kind=RESULT_KIND_SET_KEYSPACE, + results="keyspace1") + rf._set_result(None, None, None, result) + rf._set_keyspace_completed({}) + self.assertFalse(rf.result()) + + def test_schema_change_result(self): + session = self.make_session() + rf = self.make_response_future(session) + rf.send_request() + + event_results={'target_type': SchemaTargetType.TABLE, 'change_type': SchemaChangeType.CREATED, + 'keyspace': "keyspace1", "table": "table1"} + result = Mock(spec=ResultMessage, + kind=RESULT_KIND_SCHEMA_CHANGE, + results=event_results) + connection = Mock() + rf._set_result(None, connection, None, result) + session.submit.assert_called_once_with(ANY, ANY, rf, connection, **event_results) + + def test_other_result_message_kind(self): + session = self.make_session() + rf = self.make_response_future(session) + rf.send_request() + result = [1, 2, 3] + rf._set_result(None, None, None, Mock(spec=ResultMessage, kind=999, results=result)) + self.assertListEqual(list(rf.result()), result) + + def test_heartbeat_defunct_deadlock(self): + """ + Heartbeat defuncts all connections and clears request queues. Response future times out and even + if it has been removed from request queue, timeout exception must be thrown. Otherwise event loop + will deadlock on eventual ResponseFuture.result() call. + + PYTHON-1044 + """ + + connection = MagicMock(spec=Connection) + connection._requests = {} + + pool = Mock() + pool.is_shutdown = False + pool.borrow_connection.return_value = [connection, 1] + + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(), Mock()] + session._pools.get.return_value = pool + + query = SimpleStatement("SELECT * FROM foo") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + rf = ResponseFuture(session, message, query, 1) + rf.send_request() + + # Simulate Connection.error_all_requests() after heartbeat defuncts + connection._requests = {} + + # Simulate ResponseFuture timing out + rf._on_timeout() + self.assertRaisesRegexp(OperationTimedOut, "Connection defunct by heartbeat", rf.result) + + def test_read_timeout_error_message(self): + session = self.make_session() + query = SimpleStatement("SELECT * FROM foo") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + rf = ResponseFuture(session, message, query, 1) + rf.send_request() + + result = Mock(spec=ReadTimeoutErrorMessage, info={"data_retrieved": "", "required_responses":2, + "received_responses":1, "consistency": 1}) + rf._set_result(None, None, None, result) + + self.assertRaises(Exception, rf.result) + + def test_write_timeout_error_message(self): + session = self.make_session() + query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + rf = ResponseFuture(session, message, query, 1) + rf.send_request() + + result = Mock(spec=WriteTimeoutErrorMessage, info={"write_type": 1, "required_responses":2, + "received_responses":1, "consistency": 1}) + rf._set_result(None, None, None, result) + self.assertRaises(Exception, rf.result) + + def test_unavailable_error_message(self): + session = self.make_session() + query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") + query.retry_policy = Mock() + query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + rf = ResponseFuture(session, message, query, 1) + rf._query_retries = 1 + rf.send_request() + + result = Mock(spec=UnavailableErrorMessage, info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) + rf._set_result(None, None, None, result) + self.assertRaises(Exception, rf.result) + + def test_retry_policy_says_ignore(self): + session = self.make_session() + query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + retry_policy = Mock() + retry_policy.on_unavailable.return_value = (RetryPolicy.IGNORE, None) + rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) + rf.send_request() + + result = Mock(spec=UnavailableErrorMessage, info={}) + rf._set_result(None, None, None, result) + self.assertFalse(rf.result()) + + def test_retry_policy_says_retry(self): + session = self.make_session() + pool = session._pools.get.return_value + + query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.QUORUM) + + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + retry_policy = Mock() + retry_policy.on_unavailable.return_value = (RetryPolicy.RETRY, ConsistencyLevel.ONE) + + rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) + rf.send_request() + + rf.session._pools.get.assert_called_once_with('ip1') + pool.borrow_connection.assert_called_once_with(timeout=ANY) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + + result = Mock(spec=UnavailableErrorMessage, info={}) + host = Mock() + rf._set_result(host, None, None, result) + + session.submit.assert_called_once_with(rf._retry_task, True, host) + self.assertEqual(1, rf._query_retries) + + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 2) + + # simulate the executor running this + rf._retry_task(True, host) + + # it should try again with the same host since this was + # an UnavailableException + rf.session._pools.get.assert_called_with(host) + pool.borrow_connection.assert_called_with(timeout=ANY) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + + def test_retry_with_different_host(self): + session = self.make_session() + pool = session._pools.get.return_value + + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + rf = self.make_response_future(session) + rf.message.consistency_level = ConsistencyLevel.QUORUM + rf.send_request() + + rf.session._pools.get.assert_called_once_with('ip1') + pool.borrow_connection.assert_called_once_with(timeout=ANY) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) + + result = Mock(spec=OverloadedErrorMessage, info={}) + host = Mock() + rf._set_result(host, None, None, result) + + session.submit.assert_called_once_with(rf._retry_task, False, host) + # query_retries does get incremented for Overloaded/Bootstrapping errors (since 3.18) + self.assertEqual(1, rf._query_retries) + + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 2) + # simulate the executor running this + rf._retry_task(False, host) + + # it should try with a different host + rf.session._pools.get.assert_called_with('ip2') + pool.borrow_connection.assert_called_with(timeout=ANY) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + + # the consistency level should be the same + self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) + + def test_all_retries_fail(self): + session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + rf = self.make_response_future(session) + rf.send_request() + rf.session._pools.get.assert_called_once_with('ip1') + + result = Mock(spec=IsBootstrappingErrorMessage, info={}) + host = Mock() + rf._set_result(host, None, None, result) + + # simulate the executor running this + session.submit.assert_called_once_with(rf._retry_task, False, host) + rf._retry_task(False, host) + + # it should try with a different host + rf.session._pools.get.assert_called_with('ip2') + + result = Mock(spec=IsBootstrappingErrorMessage, info={}) + rf._set_result(host, None, None, result) + + # simulate the executor running this + session.submit.assert_called_with(rf._retry_task, False, host) + rf._retry_task(False, host) + + self.assertRaises(NoHostAvailable, rf.result) + + def test_all_pools_shutdown(self): + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] + session._pools.get.return_value.is_shutdown = True + + rf = ResponseFuture(session, Mock(), Mock(), 1) + rf.send_request() + self.assertRaises(NoHostAvailable, rf.result) + + def test_first_pool_shutdown(self): + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] + # first return a pool with is_shutdown=True, then is_shutdown=False + pool_shutdown = self.make_pool() + pool_shutdown.is_shutdown = True + pool_ok = self.make_pool() + pool_ok.is_shutdown = True + session._pools.get.side_effect = [pool_shutdown, pool_ok] + + rf = self.make_response_future(session) + rf.send_request() + + rf._set_result(None, None, None, self.make_mock_response([{'col': 'val'}])) + + result = rf.result() + self.assertEqual(result, [{'col': 'val'}]) + + def test_timeout_getting_connection_from_pool(self): + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] + + # the first pool will raise an exception on borrow_connection() + exc = NoConnectionsAvailable() + first_pool = Mock(is_shutdown=False) + first_pool.borrow_connection.side_effect = exc + + # the second pool will return a connection + second_pool = Mock(is_shutdown=False) + connection = Mock(spec=Connection) + second_pool.borrow_connection.return_value = (connection, 1) + + session._pools.get.side_effect = [first_pool, second_pool] + + rf = self.make_response_future(session) + rf.send_request() + + rf._set_result(None, None, None, self.make_mock_response([{'col': 'val'}])) + self.assertEqual(rf.result(), [{'col': 'val'}]) + + # make sure the exception is recorded correctly + self.assertEqual(rf._errors, {'ip1': exc}) + + def test_callback(self): + session = self.make_session() + rf = self.make_response_future(session) + rf.send_request() + + callback = Mock() + expected_result = [{'col': 'val'}] + arg = "positional" + kwargs = {'one': 1, 'two': 2} + rf.add_callback(callback, arg, **kwargs) + + rf._set_result(None, None, None, self.make_mock_response(expected_result)) + + result = rf.result() + self.assertEqual(result, expected_result) + + callback.assert_called_once_with(expected_result, arg, **kwargs) + + # this should get called immediately now that the result is set + rf.add_callback(self.assertEqual, [{'col': 'val'}]) + + def test_errback(self): + session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + rf = ResponseFuture(session, message, query, 1) + rf._query_retries = 1 + rf.send_request() + + rf.add_errback(self.assertIsInstance, Exception) + + result = Mock(spec=UnavailableErrorMessage, info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) + result.to_exception.return_value = Exception() + + rf._set_result(None, None, None, result) + self.assertRaises(Exception, rf.result) + + # this should get called immediately now that the error is set + rf.add_errback(self.assertIsInstance, Exception) + + def test_multiple_callbacks(self): + session = self.make_session() + rf = self.make_response_future(session) + rf.send_request() + + callback = Mock() + expected_result = [{'col': 'val'}] + arg = "positional" + kwargs = {'one': 1, 'two': 2} + rf.add_callback(callback, arg, **kwargs) + + callback2 = Mock() + arg2 = "another" + kwargs2 = {'three': 3, 'four': 4} + rf.add_callback(callback2, arg2, **kwargs2) + + rf._set_result(None, None, None, self.make_mock_response(expected_result)) + + result = rf.result() + self.assertEqual(result, expected_result) + + callback.assert_called_once_with(expected_result, arg, **kwargs) + callback2.assert_called_once_with(expected_result, arg2, **kwargs2) + + def test_multiple_errbacks(self): + session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + retry_policy = Mock() + retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) + rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) + rf.send_request() + + callback = Mock() + arg = "positional" + kwargs = {'one': 1, 'two': 2} + rf.add_errback(callback, arg, **kwargs) + + callback2 = Mock() + arg2 = "another" + kwargs2 = {'three': 3, 'four': 4} + rf.add_errback(callback2, arg2, **kwargs2) + + expected_exception = Unavailable("message", 1, 2, 3) + result = Mock(spec=UnavailableErrorMessage, info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) + result.to_exception.return_value = expected_exception + rf._set_result(None, None, None, result) + rf._event.set() + self.assertRaises(Exception, rf.result) + + callback.assert_called_once_with(expected_exception, arg, **kwargs) + callback2.assert_called_once_with(expected_exception, arg2, **kwargs2) + + def test_add_callbacks(self): + session = self.make_session() + query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + # test errback + rf = ResponseFuture(session, message, query, 1) + rf._query_retries = 1 + rf.send_request() + + rf.add_callbacks( + callback=self.assertEqual, callback_args=([{'col': 'val'}],), + errback=self.assertIsInstance, errback_args=(Exception,)) + + result = Mock(spec=UnavailableErrorMessage, + info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) + result.to_exception.return_value = Exception() + rf._set_result(None, None, None, result) + self.assertRaises(Exception, rf.result) + + # test callback + rf = ResponseFuture(session, message, query, 1) + rf.send_request() + + callback = Mock() + expected_result = [{'col': 'val'}] + arg = "positional" + kwargs = {'one': 1, 'two': 2} + rf.add_callbacks( + callback=callback, callback_args=(arg,), callback_kwargs=kwargs, + errback=self.assertIsInstance, errback_args=(Exception,)) + + rf._set_result(None, None, None, self.make_mock_response(expected_result)) + self.assertEqual(rf.result(), expected_result) + + callback.assert_called_once_with(expected_result, arg, **kwargs) + + def test_prepared_query_not_found(self): + session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + rf = self.make_response_future(session) + rf.send_request() + + session.cluster.protocol_version = ProtocolVersion.V4 + session.cluster._prepared_statements = MagicMock(dict) + prepared_statement = session.cluster._prepared_statements.__getitem__.return_value + prepared_statement.query_string = "SELECT * FROM foobar" + prepared_statement.keyspace = "FooKeyspace" + rf._connection.keyspace = "FooKeyspace" + + result = Mock(spec=PreparedQueryNotFound, info='a' * 16) + rf._set_result(None, None, None, result) + + self.assertTrue(session.submit.call_args) + args, kwargs = session.submit.call_args + self.assertEqual(rf._reprepare, args[-5]) + self.assertIsInstance(args[-4], PrepareMessage) + self.assertEqual(args[-4].query, "SELECT * FROM foobar") + + def test_prepared_query_not_found_bad_keyspace(self): + session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + + rf = self.make_response_future(session) + rf.send_request() + + session.cluster.protocol_version = ProtocolVersion.V4 + session.cluster._prepared_statements = MagicMock(dict) + prepared_statement = session.cluster._prepared_statements.__getitem__.return_value + prepared_statement.query_string = "SELECT * FROM foobar" + prepared_statement.keyspace = "FooKeyspace" + rf._connection.keyspace = "BarKeyspace" + + result = Mock(spec=PreparedQueryNotFound, info='a' * 16) + rf._set_result(None, None, None, result) + self.assertRaises(ValueError, rf.result) + + def test_repeat_orig_query_after_succesful_reprepare(self): + session = self.make_session() + rf = self.make_response_future(session) + + response = Mock(spec=ResultMessage, kind=RESULT_KIND_PREPARED) + response.results = (None, None, None, None, None) + + rf._query = Mock(return_value=True) + rf._execute_after_prepare('host', None, None, response) + rf._query.assert_called_once_with('host') + + rf.prepared_statement = Mock() + rf._query = Mock(return_value=True) + rf._execute_after_prepare('host', None, None, response) + rf._query.assert_called_once_with('host') diff --git a/tests/unit/test_resultset.py b/tests/unit/test_resultset.py new file mode 100644 index 0000000..541ef6f --- /dev/null +++ b/tests/unit/test_resultset.py @@ -0,0 +1,209 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from cassandra.query import named_tuple_factory, dict_factory, tuple_factory + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from mock import Mock, PropertyMock, patch + +from cassandra.cluster import ResultSet + + +class ResultSetTests(unittest.TestCase): + + def test_iter_non_paged(self): + expected = list(range(10)) + rs = ResultSet(Mock(has_more_pages=False), expected) + itr = iter(rs) + self.assertListEqual(list(itr), expected) + + def test_iter_paged(self): + expected = list(range(10)) + response_future = Mock(has_more_pages=True) + response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + rs = ResultSet(response_future, expected[:5]) + itr = iter(rs) + # this is brittle, depends on internal impl details. Would like to find a better way + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, False)) # after init to avoid side effects being consumed by init + self.assertListEqual(list(itr), expected) + + def test_list_non_paged(self): + # list access on RS for backwards-compatibility + expected = list(range(10)) + rs = ResultSet(Mock(has_more_pages=False), expected) + for i in range(10): + self.assertEqual(rs[i], expected[i]) + self.assertEqual(list(rs), expected) + + def test_list_paged(self): + # list access on RS for backwards-compatibility + expected = list(range(10)) + response_future = Mock(has_more_pages=True) + response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + rs = ResultSet(response_future, expected[:5]) + # this is brittle, depends on internal impl details. Would like to find a better way + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # First two True are consumed on check entering list mode + self.assertEqual(rs[9], expected[9]) + self.assertEqual(list(rs), expected) + + def test_has_more_pages(self): + response_future = Mock() + response_future.has_more_pages.side_effect = PropertyMock(side_effect=(True, False)) + rs = ResultSet(response_future, []) + type(response_future).has_more_pages = PropertyMock(side_effect=(True, False)) # after init to avoid side effects being consumed by init + self.assertTrue(rs.has_more_pages) + self.assertFalse(rs.has_more_pages) + + def test_iterate_then_index(self): + # RuntimeError if indexing with no pages + expected = list(range(10)) + rs = ResultSet(Mock(has_more_pages=False), expected) + itr = iter(rs) + # before consuming + with self.assertRaises(RuntimeError): + rs[0] + list(itr) + # after consuming + with self.assertRaises(RuntimeError): + rs[0] + + self.assertFalse(rs) + self.assertFalse(list(rs)) + + # RuntimeError if indexing during or after pages + response_future = Mock(has_more_pages=True) + response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + rs = ResultSet(response_future, expected[:5]) + type(response_future).has_more_pages = PropertyMock(side_effect=(True, False)) + itr = iter(rs) + # before consuming + with self.assertRaises(RuntimeError): + rs[0] + for row in itr: + # while consuming + with self.assertRaises(RuntimeError): + rs[0] + # after consuming + with self.assertRaises(RuntimeError): + rs[0] + self.assertFalse(rs) + self.assertFalse(list(rs)) + + def test_index_list_mode(self): + # no pages + expected = list(range(10)) + rs = ResultSet(Mock(has_more_pages=False), expected) + + # index access before iteration causes list to be materialized + self.assertEqual(rs[0], expected[0]) + + # resusable iteration + self.assertListEqual(list(rs), expected) + self.assertListEqual(list(rs), expected) + + self.assertTrue(rs) + + # pages + response_future = Mock(has_more_pages=True) + response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + rs = ResultSet(response_future, expected[:5]) + # this is brittle, depends on internal impl details. Would like to find a better way + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # First two True are consumed on check entering list mode + # index access before iteration causes list to be materialized + self.assertEqual(rs[0], expected[0]) + self.assertEqual(rs[9], expected[9]) + # resusable iteration + self.assertListEqual(list(rs), expected) + self.assertListEqual(list(rs), expected) + + self.assertTrue(rs) + + def test_eq(self): + # no pages + expected = list(range(10)) + rs = ResultSet(Mock(has_more_pages=False), expected) + + # eq before iteration causes list to be materialized + self.assertEqual(rs, expected) + + # results can be iterated or indexed once we're materialized + self.assertListEqual(list(rs), expected) + self.assertEqual(rs[9], expected[9]) + self.assertTrue(rs) + + # pages + response_future = Mock(has_more_pages=True) + response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + rs = ResultSet(response_future, expected[:5]) + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) + # eq before iteration causes list to be materialized + self.assertEqual(rs, expected) + + # results can be iterated or indexed once we're materialized + self.assertListEqual(list(rs), expected) + self.assertEqual(rs[9], expected[9]) + self.assertTrue(rs) + + def test_bool(self): + self.assertFalse(ResultSet(Mock(has_more_pages=False), [])) + self.assertTrue(ResultSet(Mock(has_more_pages=False), [1])) + + def test_was_applied(self): + # unknown row factory raises + with self.assertRaises(RuntimeError): + ResultSet(Mock(), []).was_applied + + response_future = Mock(row_factory=named_tuple_factory) + + # no row + with self.assertRaises(RuntimeError): + ResultSet(response_future, []).was_applied + + # too many rows + with self.assertRaises(RuntimeError): + ResultSet(response_future, [tuple(), tuple()]).was_applied + + # various internal row factories + for row_factory in (named_tuple_factory, tuple_factory): + for applied in (True, False): + rs = ResultSet(Mock(row_factory=row_factory), [(applied,)]) + self.assertEqual(rs.was_applied, applied) + + row_factory = dict_factory + for applied in (True, False): + rs = ResultSet(Mock(row_factory=row_factory), [{'[applied]': applied}]) + self.assertEqual(rs.was_applied, applied) + + def test_one(self): + # no pages + first, second = Mock(), Mock() + rs = ResultSet(Mock(has_more_pages=False), [first, second]) + + self.assertEqual(rs.one(), first) + + @patch('cassandra.cluster.warn') + def test_indexing_deprecation(self, mocked_warn): + # normally we'd use catch_warnings to test this, but that doesn't work + # pre-Py3.0 for some reason + first, second = Mock(), Mock() + rs = ResultSet(Mock(has_more_pages=False), [first, second]) + self.assertEqual(rs[0], first) + self.assertEqual(len(mocked_warn.mock_calls), 1) + index_warning_args = tuple(mocked_warn.mock_calls[0])[1] + self.assertIn('indexing support will be removed in 4.0', + str(index_warning_args[0])) + self.assertIs(index_warning_args[1], DeprecationWarning) diff --git a/tests/unit/test_sortedset.py b/tests/unit/test_sortedset.py new file mode 100644 index 0000000..3845c2c --- /dev/null +++ b/tests/unit/test_sortedset.py @@ -0,0 +1,403 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.util import sortedset +from cassandra.cqltypes import EMPTY + +from datetime import datetime +from itertools import permutations + +class SortedSetTest(unittest.TestCase): + def test_init(self): + input = [5, 4, 3, 2, 1, 1, 1] + expected = sorted(set(input)) + ss = sortedset(input) + self.assertEqual(len(ss), len(expected)) + self.assertEqual(list(ss), expected) + + def test_repr(self): + self.assertEqual(repr(sortedset([1, 2, 3, 4])), "SortedSet([1, 2, 3, 4])") + + def test_contains(self): + input = [5, 4, 3, 2, 1, 1, 1] + expected = sorted(set(input)) + ss = sortedset(input) + + for i in expected: + self.assertTrue(i in ss) + self.assertFalse(i not in ss) + + hi = max(expected)+1 + lo = min(expected)-1 + + self.assertFalse(hi in ss) + self.assertFalse(lo in ss) + + def test_mutable_contents(self): + ba = bytearray(b'some data here') + ss = sortedset([ba, ba]) + self.assertEqual(list(ss), [ba]) + + def test_clear(self): + ss = sortedset([1, 2, 3]) + ss.clear() + self.assertEqual(len(ss), 0) + + def test_equal(self): + s1 = set([1]) + s12 = set([1, 2]) + ss1 = sortedset(s1) + ss12 = sortedset(s12) + + self.assertEqual(ss1, s1) + self.assertEqual(ss12, s12) + self.assertEqual(ss12, s12) + self.assertEqual(ss1.__eq__(None), NotImplemented) + self.assertNotEqual(ss1, ss12) + self.assertNotEqual(ss12, ss1) + self.assertNotEqual(ss1, s12) + self.assertNotEqual(ss12, s1) + self.assertNotEqual(ss1, EMPTY) + + def test_copy(self): + class comparable(object): + def __lt__(self, other): + return id(self) < id(other) + + o = comparable() + ss = sortedset([comparable(), o]) + ss2 = ss.copy() + self.assertNotEqual(id(ss), id(ss2)) + self.assertTrue(o in ss) + self.assertTrue(o in ss2) + + def test_isdisjoint(self): + # set, ss + s12 = set([1, 2]) + s2 = set([2]) + ss1 = sortedset([1]) + ss13 = sortedset([1, 3]) + ss3 = sortedset([3]) + # s ss disjoint + self.assertTrue(s2.isdisjoint(ss1)) + self.assertTrue(s2.isdisjoint(ss13)) + # s ss not disjoint + self.assertFalse(s12.isdisjoint(ss1)) + self.assertFalse(s12.isdisjoint(ss13)) + # ss s disjoint + self.assertTrue(ss1.isdisjoint(s2)) + self.assertTrue(ss13.isdisjoint(s2)) + # ss s not disjoint + self.assertFalse(ss1.isdisjoint(s12)) + self.assertFalse(ss13.isdisjoint(s12)) + # ss ss disjoint + self.assertTrue(ss1.isdisjoint(ss3)) + self.assertTrue(ss3.isdisjoint(ss1)) + # ss ss not disjoint + self.assertFalse(ss1.isdisjoint(ss13)) + self.assertFalse(ss13.isdisjoint(ss1)) + self.assertFalse(ss3.isdisjoint(ss13)) + self.assertFalse(ss13.isdisjoint(ss3)) + + def test_issubset(self): + s12 = set([1, 2]) + ss1 = sortedset([1]) + ss13 = sortedset([1, 3]) + ss3 = sortedset([3]) + + self.assertTrue(ss1.issubset(s12)) + self.assertTrue(ss1.issubset(ss13)) + + self.assertFalse(ss1.issubset(ss3)) + self.assertFalse(ss13.issubset(ss3)) + self.assertFalse(ss13.issubset(ss1)) + self.assertFalse(ss13.issubset(s12)) + + def test_issuperset(self): + s12 = set([1, 2]) + ss1 = sortedset([1]) + ss13 = sortedset([1, 3]) + ss3 = sortedset([3]) + + self.assertTrue(s12.issuperset(ss1)) + self.assertTrue(ss13.issuperset(ss3)) + self.assertTrue(ss13.issuperset(ss13)) + + self.assertFalse(s12.issuperset(ss13)) + self.assertFalse(ss1.issuperset(ss3)) + self.assertFalse(ss1.issuperset(ss13)) + + def test_union(self): + s1 = set([1]) + ss12 = sortedset([1, 2]) + ss23 = sortedset([2, 3]) + + self.assertEqual(sortedset().union(s1), sortedset([1])) + self.assertEqual(ss12.union(s1), sortedset([1, 2])) + self.assertEqual(ss12.union(ss23), sortedset([1, 2, 3])) + self.assertEqual(ss23.union(ss12), sortedset([1, 2, 3])) + self.assertEqual(ss23.union(s1), sortedset([1, 2, 3])) + + def test_intersection(self): + s12 = set([1, 2]) + ss23 = sortedset([2, 3]) + self.assertEqual(s12.intersection(ss23), set([2])) + self.assertEqual(ss23.intersection(s12), sortedset([2])) + self.assertEqual(ss23.intersection(s12, [2], (2,)), sortedset([2])) + self.assertEqual(ss23.intersection(s12, [900], (2,)), sortedset()) + + def test_difference(self): + s1 = set([1]) + ss12 = sortedset([1, 2]) + ss23 = sortedset([2, 3]) + + self.assertEqual(sortedset().difference(s1), sortedset()) + self.assertEqual(ss12.difference(s1), sortedset([2])) + self.assertEqual(ss12.difference(ss23), sortedset([1])) + self.assertEqual(ss23.difference(ss12), sortedset([3])) + self.assertEqual(ss23.difference(s1), sortedset([2, 3])) + + def test_symmetric_difference(self): + s = set([1, 3, 5]) + ss = sortedset([2, 3, 4]) + ss2 = sortedset([5, 6, 7]) + + self.assertEqual(ss.symmetric_difference(s), sortedset([1, 2, 4, 5])) + self.assertFalse(ss.symmetric_difference(ss)) + self.assertEqual(ss.symmetric_difference(s), sortedset([1, 2, 4, 5])) + self.assertEqual(ss2.symmetric_difference(ss), sortedset([2, 3, 4, 5, 6, 7])) + + def test_pop(self): + ss = sortedset([2, 1]) + self.assertEqual(ss.pop(), 2) + self.assertEqual(ss.pop(), 1) + try: + ss.pop() + self.fail("Error not thrown") + except (KeyError, IndexError) as e: + pass + + def test_remove(self): + ss = sortedset([2, 1]) + self.assertEqual(len(ss), 2) + self.assertRaises(KeyError, ss.remove, 3) + self.assertEqual(len(ss), 2) + ss.remove(1) + self.assertEqual(len(ss), 1) + ss.remove(2) + self.assertFalse(ss) + self.assertRaises(KeyError, ss.remove, 2) + self.assertFalse(ss) + + def test_getitem(self): + ss = sortedset(range(3)) + for i in range(len(ss)): + self.assertEqual(ss[i], i) + with self.assertRaises(IndexError): + ss[len(ss)] + + def test_delitem(self): + expected = [1,2,3,4] + ss = sortedset(expected) + for i in range(len(ss)): + self.assertListEqual(list(ss), expected[i:]) + del ss[0] + with self.assertRaises(IndexError): + ss[0] + + def test_delslice(self): + expected = [1, 2, 3, 4, 5] + ss = sortedset(expected) + del ss[1:3] + self.assertListEqual(list(ss), [1, 4, 5]) + del ss[-1:] + self.assertListEqual(list(ss), [1, 4]) + del ss[1:] + self.assertListEqual(list(ss), [1]) + del ss[:] + self.assertFalse(ss) + with self.assertRaises(IndexError): + del ss[0] + + def test_reversed(self): + expected = range(10) + self.assertListEqual(list(reversed(sortedset(expected))), list(reversed(expected))) + + def test_operators(self): + + ss1 = sortedset([1]) + ss12 = sortedset([1, 2]) + # __ne__ + self.assertFalse(ss12 != ss12) + self.assertFalse(ss12 != sortedset([1, 2])) + self.assertTrue(ss12 != sortedset()) + + # __le__ + self.assertTrue(ss1 <= ss12) + self.assertTrue(ss12 <= ss12) + self.assertFalse(ss12 <= ss1) + + # __lt__ + self.assertTrue(ss1 < ss12) + self.assertFalse(ss12 < ss12) + self.assertFalse(ss12 < ss1) + + # __ge__ + self.assertFalse(ss1 >= ss12) + self.assertTrue(ss12 >= ss12) + self.assertTrue(ss12 >= ss1) + + # __gt__ + self.assertFalse(ss1 > ss12) + self.assertFalse(ss12 > ss12) + self.assertTrue(ss12 > ss1) + + # __and__ + self.assertEqual(ss1 & ss12, ss1) + self.assertEqual(ss12 & ss12, ss12) + self.assertEqual(ss12 & set(), sortedset()) + + # __iand__ + tmp = sortedset(ss12) + tmp &= ss1 + self.assertEqual(tmp, ss1) + tmp = sortedset(ss1) + tmp &= ss12 + self.assertEqual(tmp, ss1) + tmp = sortedset(ss12) + tmp &= ss12 + self.assertEqual(tmp, ss12) + tmp = sortedset(ss12) + tmp &= set() + self.assertEqual(tmp, sortedset()) + + # __rand__ + self.assertEqual(set([1]) & ss12, ss1) + + # __or__ + self.assertEqual(ss1 | ss12, ss12) + self.assertEqual(ss12 | ss12, ss12) + self.assertEqual(ss12 | set(), ss12) + self.assertEqual(sortedset() | ss1 | ss12, ss12) + + # __ior__ + tmp = sortedset(ss1) + tmp |= ss12 + self.assertEqual(tmp, ss12) + tmp = sortedset(ss12) + tmp |= ss12 + self.assertEqual(tmp, ss12) + tmp = sortedset(ss12) + tmp |= set() + self.assertEqual(tmp, ss12) + tmp = sortedset() + tmp |= ss1 + tmp |= ss12 + self.assertEqual(tmp, ss12) + + # __ror__ + self.assertEqual(set([1]) | ss12, ss12) + + # __sub__ + self.assertEqual(ss1 - ss12, set()) + self.assertEqual(ss12 - ss12, set()) + self.assertEqual(ss12 - set(), ss12) + self.assertEqual(ss12 - ss1, sortedset([2])) + + # __isub__ + tmp = sortedset(ss1) + tmp -= ss12 + self.assertEqual(tmp, set()) + tmp = sortedset(ss12) + tmp -= ss12 + self.assertEqual(tmp, set()) + tmp = sortedset(ss12) + tmp -= set() + self.assertEqual(tmp, ss12) + tmp = sortedset(ss12) + tmp -= ss1 + self.assertEqual(tmp, sortedset([2])) + + # __rsub__ + self.assertEqual(set((1,2,3)) - ss12, set((3,))) + + # __xor__ + self.assertEqual(ss1 ^ ss12, set([2])) + self.assertEqual(ss12 ^ ss1, set([2])) + self.assertEqual(ss12 ^ ss12, set()) + self.assertEqual(ss12 ^ set(), ss12) + + # __ixor__ + tmp = sortedset(ss1) + tmp ^= ss12 + self.assertEqual(tmp, set([2])) + tmp = sortedset(ss12) + tmp ^= ss1 + self.assertEqual(tmp, set([2])) + tmp = sortedset(ss12) + tmp ^= ss12 + self.assertEqual(tmp, set()) + tmp = sortedset(ss12) + tmp ^= set() + self.assertEqual(tmp, ss12) + + # __rxor__ + self.assertEqual(set([1, 2]) ^ ss1, (set([2]))) + + def test_reduce_pickle(self): + ss = sortedset((4,3,2,1)) + import pickle + s = pickle.dumps(ss) + self.assertEqual(pickle.loads(s), ss) + + def _test_uncomparable_types(self, items): + for perm in permutations(items): + ss = sortedset(perm) + s = set(perm) + self.assertEqual(s, ss) + self.assertEqual(ss, ss.union(s)) + for x in range(len(ss)): + subset = set(s) + for _ in range(x): + subset.pop() + self.assertEqual(ss.difference(subset), s.difference(subset)) + self.assertEqual(ss.intersection(subset), s.intersection(subset)) + for x in ss: + self.assertIn(x, ss) + ss.remove(x) + self.assertNotIn(x, ss) + + def test_uncomparable_types_with_tuples(self): + # PYTHON-1087 - make set handle uncomparable types + dt = datetime(2019, 5, 16) + items = (('samekey', 3, 1), + ('samekey', None, 0), + ('samekey', dt), + ("samekey", None, 2), + ("samekey", None, 1), + ('samekey', dt), + ('samekey', None, 0), + ("samekey", datetime.now())) + + self._test_uncomparable_types(items) + + def test_uncomparable_types_with_integers(self): + # PYTHON-1087 - make set handle uncomparable types + items = (None, 1, 2, 6, None, None, 92) + self._test_uncomparable_types(items) diff --git a/tests/unit/test_time_util.py b/tests/unit/test_time_util.py new file mode 100644 index 0000000..7025f15 --- /dev/null +++ b/tests/unit/test_time_util.py @@ -0,0 +1,122 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra import marshal +from cassandra import util +import calendar +import datetime +import time +import uuid + + +class TimeUtilTest(unittest.TestCase): + def test_datetime_from_timestamp(self): + self.assertEqual(util.datetime_from_timestamp(0), datetime.datetime(1970, 1, 1)) + # large negative; test PYTHON-110 workaround for windows + self.assertEqual(util.datetime_from_timestamp(-62135596800), datetime.datetime(1, 1, 1)) + self.assertEqual(util.datetime_from_timestamp(-62135596199), datetime.datetime(1, 1, 1, 0, 10, 1)) + + self.assertEqual(util.datetime_from_timestamp(253402300799), datetime.datetime(9999, 12, 31, 23, 59, 59)) + + self.assertEqual(util.datetime_from_timestamp(0.123456), datetime.datetime(1970, 1, 1, 0, 0, 0, 123456)) + + self.assertEqual(util.datetime_from_timestamp(2177403010.123456), datetime.datetime(2038, 12, 31, 10, 10, 10, 123456)) + + def test_times_from_uuid1(self): + node = uuid.getnode() + now = time.time() + u = uuid.uuid1(node, 0) + + t = util.unix_time_from_uuid1(u) + self.assertAlmostEqual(now, t, 2) + + dt = util.datetime_from_uuid1(u) + t = calendar.timegm(dt.timetuple()) + dt.microsecond / 1e6 + self.assertAlmostEqual(now, t, 2) + + def test_uuid_from_time(self): + t = time.time() + seq = 0x2aa5 + node = uuid.getnode() + u = util.uuid_from_time(t, node, seq) + # using AlmostEqual because time precision is different for + # some platforms + self.assertAlmostEqual(util.unix_time_from_uuid1(u), t, 4) + self.assertEqual(u.node, node) + self.assertEqual(u.clock_seq, seq) + + # random node + u1 = util.uuid_from_time(t, clock_seq=seq) + u2 = util.uuid_from_time(t, clock_seq=seq) + self.assertAlmostEqual(util.unix_time_from_uuid1(u1), t, 4) + self.assertAlmostEqual(util.unix_time_from_uuid1(u2), t, 4) + self.assertEqual(u.clock_seq, seq) + # not impossible, but we shouldn't get the same value twice + self.assertNotEqual(u1.node, u2.node) + + # random seq + u1 = util.uuid_from_time(t, node=node) + u2 = util.uuid_from_time(t, node=node) + self.assertAlmostEqual(util.unix_time_from_uuid1(u1), t, 4) + self.assertAlmostEqual(util.unix_time_from_uuid1(u2), t, 4) + self.assertEqual(u.node, node) + # not impossible, but we shouldn't get the same value twice + self.assertNotEqual(u1.clock_seq, u2.clock_seq) + + # node too large + with self.assertRaises(ValueError): + u = util.uuid_from_time(t, node=2 ** 48) + + # clock_seq too large + with self.assertRaises(ValueError): + u = util.uuid_from_time(t, clock_seq=0x4000) + + # construct from datetime + dt = util.datetime_from_timestamp(t) + u = util.uuid_from_time(dt, node, seq) + self.assertAlmostEqual(util.unix_time_from_uuid1(u), t, 4) + self.assertEqual(u.node, node) + self.assertEqual(u.clock_seq, seq) + +# 0 1 2 3 +# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +# | time_low | +# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +# | time_mid | time_hi_and_version | +# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +# |clk_seq_hi_res | clk_seq_low | node (0-1) | +# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +# | node (2-5) | +# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + def test_min_uuid(self): + u = util.min_uuid_from_time(0) + # cassandra does a signed comparison of the remaining bytes + for i in range(8, 16): + self.assertEqual(marshal.int8_unpack(u.bytes[i:i + 1]), -128) + + def test_max_uuid(self): + u = util.max_uuid_from_time(0) + # cassandra does a signed comparison of the remaining bytes + # the first non-time byte has the variant in it + # This byte is always negative, but should be the smallest negative + # number with high-order bits '10' + self.assertEqual(marshal.int8_unpack(u.bytes[8:9]), -65) + for i in range(9, 16): + self.assertEqual(marshal.int8_unpack(u.bytes[i:i + 1]), 127) diff --git a/tests/unit/test_timestamps.py b/tests/unit/test_timestamps.py new file mode 100644 index 0000000..bbca352 --- /dev/null +++ b/tests/unit/test_timestamps.py @@ -0,0 +1,278 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import mock + +from cassandra import timestamps +import time +from threading import Thread, Lock + +class _TimestampTestMixin(object): + + @mock.patch('cassandra.timestamps.time') + def _call_and_check_results(self, + patched_time_module, + system_time_expected_stamp_pairs, + timestamp_generator=None): + """ + For each element in an iterable of (system_time, expected_timestamp) + pairs, call a :class:`cassandra.timestamps.MonotonicTimestampGenerator` + with system_times as the underlying time.time() result, then assert + that the result is expected_timestamp. Skips the check if + expected_timestamp is None. + """ + patched_time_module.time = mock.Mock() + system_times, expected_timestamps = zip(*system_time_expected_stamp_pairs) + + patched_time_module.time.side_effect = system_times + tsg = timestamp_generator or timestamps.MonotonicTimestampGenerator() + + for expected in expected_timestamps: + actual = tsg() + if expected is not None: + self.assertEqual(actual, expected) + + # assert we patched timestamps.time.time correctly + with self.assertRaises(StopIteration): + tsg() + + +class TestTimestampGeneratorOutput(unittest.TestCase, _TimestampTestMixin): + """ + Mock time.time and test the output of MonotonicTimestampGenerator.__call__ + given different patterns of changing results. + """ + + def test_timestamps_during_and_after_same_system_time(self): + """ + Test that MonotonicTimestampGenerator's output increases by 1 when the + underlying system time is the same, then returns to normal when the + system time increases again. + + @since 3.8.0 + @expected_result Timestamps should increase monotonically over repeated system time. + @test_category timing + """ + self._call_and_check_results( + system_time_expected_stamp_pairs=( + (15.0, 15 * 1e6), + (15.0, 15 * 1e6 + 1), + (15.0, 15 * 1e6 + 2), + (15.01, 15.01 * 1e6)) + ) + + def test_timestamps_during_and_after_backwards_system_time(self): + """ + Test that MonotonicTimestampGenerator's output increases by 1 when the + underlying system time goes backward, then returns to normal when the + system time increases again. + + @since 3.8.0 + @expected_result Timestamps should increase monotonically over system time going backwards. + @test_category timing + """ + self._call_and_check_results( + system_time_expected_stamp_pairs=( + (15.0, 15 * 1e6), + (13.0, 15 * 1e6 + 1), + (14.0, 15 * 1e6 + 2), + (13.5, 15 * 1e6 + 3), + (15.01, 15.01 * 1e6)) + ) + + +class TestTimestampGeneratorLogging(unittest.TestCase): + + def setUp(self): + self.log_patcher = mock.patch('cassandra.timestamps.log') + self.addCleanup(self.log_patcher.stop) + self.patched_timestamp_log = self.log_patcher.start() + + def assertLastCallArgRegex(self, call, pattern): + last_warn_args, last_warn_kwargs = call + self.assertEqual(len(last_warn_args), 1) + self.assertEqual(len(last_warn_kwargs), 0) + self.assertRegexpMatches( + last_warn_args[0], + pattern, + ) + + def test_basic_log_content(self): + """ + Tests there are logs + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result logs + + @test_category timing + """ + tsg = timestamps.MonotonicTimestampGenerator( + warning_threshold=1e-6, + warning_interval=1e-6 + ) + #The units of _last_warn is seconds + tsg._last_warn = 12 + + tsg._next_timestamp(20, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 0) + tsg._next_timestamp(16, tsg.last) + + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + self.assertLastCallArgRegex( + self.patched_timestamp_log.warning.call_args, + r'Clock skew detected:.*\b16\b.*\b4\b.*\b20\b' + ) + + def test_disable_logging(self): + """ + Tests there are no logs when there is a clock skew if logging is disabled + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result no logs + + @test_category timing + """ + no_warn_tsg = timestamps.MonotonicTimestampGenerator(warn_on_drift=False) + + no_warn_tsg.last = 100 + no_warn_tsg._next_timestamp(99, no_warn_tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 0) + + def test_warning_threshold_respected_no_logging(self): + """ + Tests there are no logs if `warning_threshold` is not exceeded + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result no logs + + @test_category timing + """ + tsg = timestamps.MonotonicTimestampGenerator( + warning_threshold=2e-6, + ) + tsg.last, tsg._last_warn = 100, 97 + tsg._next_timestamp(98, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 0) + + def test_warning_threshold_respected_logs(self): + """ + Tests there are logs if `warning_threshold` is exceeded + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result logs + + @test_category timing + """ + tsg = timestamps.MonotonicTimestampGenerator( + warning_threshold=1e-6, + warning_interval=1e-6 + ) + tsg.last, tsg._last_warn = 100, 97 + tsg._next_timestamp(98, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + + def test_warning_interval_respected_no_logging(self): + """ + Tests there is only one log in the interval `warning_interval` + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result one log + + @test_category timing + """ + tsg = timestamps.MonotonicTimestampGenerator( + warning_threshold=1e-6, + warning_interval=2e-6 + ) + tsg.last = 100 + tsg._next_timestamp(70, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + + tsg._next_timestamp(71, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + + def test_warning_interval_respected_logs(self): + """ + Tests there are logs again if the + clock skew happens after`warning_interval` + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result logs + + @test_category timing + """ + tsg = timestamps.MonotonicTimestampGenerator( + warning_interval=1e-6, + warning_threshold=1e-6, + ) + tsg.last = 100 + tsg._next_timestamp(70, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + + tsg._next_timestamp(72, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 2) + + +class TestTimestampGeneratorMultipleThreads(unittest.TestCase): + + def test_should_generate_incrementing_timestamps_for_all_threads(self): + """ + Tests when time is "stopped", values are assigned incrementally + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result the returned values increase + + @test_category timing + """ + lock = Lock() + + def request_time(): + for _ in range(timestamp_to_generate): + timestamp = tsg() + with lock: + generated_timestamps.append(timestamp) + + tsg = timestamps.MonotonicTimestampGenerator() + fixed_time = 1 + num_threads = 5 + + timestamp_to_generate = 1000 + generated_timestamps = [] + + with mock.patch('time.time', new=mock.Mock(return_value=fixed_time)): + threads = [] + for _ in range(num_threads): + threads.append(Thread(target=request_time)) + + for t in threads: + t.start() + + for t in threads: + t.join() + + self.assertEqual(len(generated_timestamps), num_threads * timestamp_to_generate) + for i, timestamp in enumerate(sorted(generated_timestamps)): + self.assertEqual(int(i + 1e6), timestamp) diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py new file mode 100644 index 0000000..c8f3011 --- /dev/null +++ b/tests/unit/test_types.py @@ -0,0 +1,386 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from binascii import unhexlify +import datetime +import tempfile +import six +import time + +import cassandra +from cassandra.cqltypes import (BooleanType, lookup_casstype_simple, lookup_casstype, + LongType, DecimalType, SetType, cql_typename, + CassandraType, UTF8Type, parse_casstype_args, + SimpleDateType, TimeType, ByteType, ShortType, + EmptyValue, _CassandraType, DateType, int64_pack) +from cassandra.encoder import cql_quote +from cassandra.protocol import (write_string, read_longstring, write_stringmap, + read_stringmap, read_inet, write_inet, + read_string, write_longstring) +from cassandra.query import named_tuple_factory +from cassandra.pool import Host +from cassandra.policies import SimpleConvictionPolicy, ConvictionPolicy +from cassandra.util import Date, Time +from cassandra.metadata import Token + + +class TypeTests(unittest.TestCase): + + def test_lookup_casstype_simple(self): + """ + Ensure lookup_casstype_simple returns the correct classes + """ + + self.assertEqual(lookup_casstype_simple('AsciiType'), cassandra.cqltypes.AsciiType) + self.assertEqual(lookup_casstype_simple('LongType'), cassandra.cqltypes.LongType) + self.assertEqual(lookup_casstype_simple('BytesType'), cassandra.cqltypes.BytesType) + self.assertEqual(lookup_casstype_simple('BooleanType'), cassandra.cqltypes.BooleanType) + self.assertEqual(lookup_casstype_simple('CounterColumnType'), cassandra.cqltypes.CounterColumnType) + self.assertEqual(lookup_casstype_simple('DecimalType'), cassandra.cqltypes.DecimalType) + self.assertEqual(lookup_casstype_simple('DoubleType'), cassandra.cqltypes.DoubleType) + self.assertEqual(lookup_casstype_simple('FloatType'), cassandra.cqltypes.FloatType) + self.assertEqual(lookup_casstype_simple('InetAddressType'), cassandra.cqltypes.InetAddressType) + self.assertEqual(lookup_casstype_simple('Int32Type'), cassandra.cqltypes.Int32Type) + self.assertEqual(lookup_casstype_simple('UTF8Type'), cassandra.cqltypes.UTF8Type) + self.assertEqual(lookup_casstype_simple('DateType'), cassandra.cqltypes.DateType) + self.assertEqual(lookup_casstype_simple('SimpleDateType'), cassandra.cqltypes.SimpleDateType) + self.assertEqual(lookup_casstype_simple('ByteType'), cassandra.cqltypes.ByteType) + self.assertEqual(lookup_casstype_simple('ShortType'), cassandra.cqltypes.ShortType) + self.assertEqual(lookup_casstype_simple('TimeUUIDType'), cassandra.cqltypes.TimeUUIDType) + self.assertEqual(lookup_casstype_simple('TimeType'), cassandra.cqltypes.TimeType) + self.assertEqual(lookup_casstype_simple('UUIDType'), cassandra.cqltypes.UUIDType) + self.assertEqual(lookup_casstype_simple('IntegerType'), cassandra.cqltypes.IntegerType) + self.assertEqual(lookup_casstype_simple('MapType'), cassandra.cqltypes.MapType) + self.assertEqual(lookup_casstype_simple('ListType'), cassandra.cqltypes.ListType) + self.assertEqual(lookup_casstype_simple('SetType'), cassandra.cqltypes.SetType) + self.assertEqual(lookup_casstype_simple('CompositeType'), cassandra.cqltypes.CompositeType) + self.assertEqual(lookup_casstype_simple('ColumnToCollectionType'), cassandra.cqltypes.ColumnToCollectionType) + self.assertEqual(lookup_casstype_simple('ReversedType'), cassandra.cqltypes.ReversedType) + self.assertEqual(lookup_casstype_simple('DurationType'), cassandra.cqltypes.DurationType) + + self.assertEqual(str(lookup_casstype_simple('unknown')), str(cassandra.cqltypes.mkUnrecognizedType('unknown'))) + + def test_lookup_casstype(self): + """ + Ensure lookup_casstype returns the correct classes + """ + + self.assertEqual(lookup_casstype('AsciiType'), cassandra.cqltypes.AsciiType) + self.assertEqual(lookup_casstype('LongType'), cassandra.cqltypes.LongType) + self.assertEqual(lookup_casstype('BytesType'), cassandra.cqltypes.BytesType) + self.assertEqual(lookup_casstype('BooleanType'), cassandra.cqltypes.BooleanType) + self.assertEqual(lookup_casstype('CounterColumnType'), cassandra.cqltypes.CounterColumnType) + self.assertEqual(lookup_casstype('DateType'), cassandra.cqltypes.DateType) + self.assertEqual(lookup_casstype('DecimalType'), cassandra.cqltypes.DecimalType) + self.assertEqual(lookup_casstype('DoubleType'), cassandra.cqltypes.DoubleType) + self.assertEqual(lookup_casstype('FloatType'), cassandra.cqltypes.FloatType) + self.assertEqual(lookup_casstype('InetAddressType'), cassandra.cqltypes.InetAddressType) + self.assertEqual(lookup_casstype('Int32Type'), cassandra.cqltypes.Int32Type) + self.assertEqual(lookup_casstype('UTF8Type'), cassandra.cqltypes.UTF8Type) + self.assertEqual(lookup_casstype('DateType'), cassandra.cqltypes.DateType) + self.assertEqual(lookup_casstype('TimeType'), cassandra.cqltypes.TimeType) + self.assertEqual(lookup_casstype('ByteType'), cassandra.cqltypes.ByteType) + self.assertEqual(lookup_casstype('ShortType'), cassandra.cqltypes.ShortType) + self.assertEqual(lookup_casstype('TimeUUIDType'), cassandra.cqltypes.TimeUUIDType) + self.assertEqual(lookup_casstype('UUIDType'), cassandra.cqltypes.UUIDType) + self.assertEqual(lookup_casstype('IntegerType'), cassandra.cqltypes.IntegerType) + self.assertEqual(lookup_casstype('MapType'), cassandra.cqltypes.MapType) + self.assertEqual(lookup_casstype('ListType'), cassandra.cqltypes.ListType) + self.assertEqual(lookup_casstype('SetType'), cassandra.cqltypes.SetType) + self.assertEqual(lookup_casstype('CompositeType'), cassandra.cqltypes.CompositeType) + self.assertEqual(lookup_casstype('ColumnToCollectionType'), cassandra.cqltypes.ColumnToCollectionType) + self.assertEqual(lookup_casstype('ReversedType'), cassandra.cqltypes.ReversedType) + self.assertEqual(lookup_casstype('DurationType'), cassandra.cqltypes.DurationType) + + self.assertEqual(str(lookup_casstype('unknown')), str(cassandra.cqltypes.mkUnrecognizedType('unknown'))) + + self.assertRaises(ValueError, lookup_casstype, 'AsciiType~') + + def test_casstype_parameterized(self): + self.assertEqual(LongType.cass_parameterized_type_with(()), 'LongType') + self.assertEqual(LongType.cass_parameterized_type_with((), full=True), 'org.apache.cassandra.db.marshal.LongType') + self.assertEqual(SetType.cass_parameterized_type_with([DecimalType], full=True), 'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)') + + self.assertEqual(LongType.cql_parameterized_type(), 'bigint') + + subtypes = (cassandra.cqltypes.UTF8Type, cassandra.cqltypes.UTF8Type) + self.assertEqual('map', + cassandra.cqltypes.MapType.apply_parameters(subtypes).cql_parameterized_type()) + + def test_datetype_from_string(self): + # Ensure all formats can be parsed, without exception + for format in cassandra.cqltypes.cql_timestamp_formats: + date_string = str(datetime.datetime.now().strftime(format)) + cassandra.cqltypes.DateType.interpret_datestring(date_string) + + def test_cql_typename(self): + """ + Smoke test cql_typename + """ + + self.assertEqual(cql_typename('DateType'), 'timestamp') + self.assertEqual(cql_typename('org.apache.cassandra.db.marshal.ListType(IntegerType)'), 'list') + + def test_named_tuple_colname_substitution(self): + colnames = ("func(abc)", "[applied]", "func(func(abc))", "foo_bar", "foo_bar_") + rows = [(1, 2, 3, 4, 5)] + result = named_tuple_factory(colnames, rows)[0] + self.assertEqual(result[0], result.func_abc) + self.assertEqual(result[1], result.applied) + self.assertEqual(result[2], result.func_func_abc) + self.assertEqual(result[3], result.foo_bar) + self.assertEqual(result[4], result.foo_bar_) + + def test_parse_casstype_args(self): + class FooType(CassandraType): + typename = 'org.apache.cassandra.db.marshal.FooType' + + def __init__(self, subtypes, names): + self.subtypes = subtypes + self.names = names + + @classmethod + def apply_parameters(cls, subtypes, names): + return cls(subtypes, [unhexlify(six.b(name)) if name is not None else name for name in names]) + + class BarType(FooType): + typename = 'org.apache.cassandra.db.marshal.BarType' + + ctype = parse_casstype_args(''.join(( + 'org.apache.cassandra.db.marshal.FooType(', + '63697479:org.apache.cassandra.db.marshal.UTF8Type,', + 'BarType(61646472657373:org.apache.cassandra.db.marshal.UTF8Type),', + '7a6970:org.apache.cassandra.db.marshal.UTF8Type', + ')'))) + + self.assertEqual(FooType, ctype.__class__) + + self.assertEqual(UTF8Type, ctype.subtypes[0]) + + # middle subtype should be a BarType instance with its own subtypes and names + self.assertIsInstance(ctype.subtypes[1], BarType) + self.assertEqual([UTF8Type], ctype.subtypes[1].subtypes) + self.assertEqual([b"address"], ctype.subtypes[1].names) + + self.assertEqual(UTF8Type, ctype.subtypes[2]) + self.assertEqual([b'city', None, b'zip'], ctype.names) + + def test_empty_value(self): + self.assertEqual(str(EmptyValue()), 'EMPTY') + + def test_datetype(self): + now_time_seconds = time.time() + now_datetime = datetime.datetime.utcfromtimestamp(now_time_seconds) + + # Cassandra timestamps in millis + now_timestamp = now_time_seconds * 1e3 + + # same results serialized + self.assertEqual(DateType.serialize(now_datetime, 0), DateType.serialize(now_timestamp, 0)) + + # deserialize + # epoc + expected = 0 + self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime.utcfromtimestamp(expected)) + + # beyond 32b + expected = 2 ** 33 + self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime(2242, 3, 16, 12, 56, 32)) + + # less than epoc (PYTHON-119) + expected = -770172256 + self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime(1945, 8, 5, 23, 15, 44)) + + # work around rounding difference among Python versions (PYTHON-230) + expected = 1424817268.274 + self.assertEqual(DateType.deserialize(int64_pack(int(1000 * expected)), 0), datetime.datetime(2015, 2, 24, 22, 34, 28, 274000)) + + # Large date overflow (PYTHON-452) + expected = 2177403010.123 + self.assertEqual(DateType.deserialize(int64_pack(int(1000 * expected)), 0), datetime.datetime(2038, 12, 31, 10, 10, 10, 123000)) + + def test_write_read_string(self): + with tempfile.TemporaryFile() as f: + value = u'test' + write_string(f, value) + f.seek(0) + self.assertEqual(read_string(f), value) + + def test_write_read_longstring(self): + with tempfile.TemporaryFile() as f: + value = u'test' + write_longstring(f, value) + f.seek(0) + self.assertEqual(read_longstring(f), value) + + def test_write_read_stringmap(self): + with tempfile.TemporaryFile() as f: + value = {'key': 'value'} + write_stringmap(f, value) + f.seek(0) + self.assertEqual(read_stringmap(f), value) + + def test_write_read_inet(self): + with tempfile.TemporaryFile() as f: + value = ('192.168.1.1', 9042) + write_inet(f, value) + f.seek(0) + self.assertEqual(read_inet(f), value) + + with tempfile.TemporaryFile() as f: + value = ('2001:db8:0:f101::1', 9042) + write_inet(f, value) + f.seek(0) + self.assertEqual(read_inet(f), value) + + def test_cql_quote(self): + self.assertEqual(cql_quote(u'test'), "'test'") + self.assertEqual(cql_quote('test'), "'test'") + self.assertEqual(cql_quote(0), '0') + + +class TestOrdering(unittest.TestCase): + def _check_order_consistency(self, smaller, bigger, equal=False): + self.assertLessEqual(smaller, bigger) + self.assertGreaterEqual(bigger, smaller) + if equal: + self.assertEqual(smaller, bigger) + else: + self.assertNotEqual(smaller, bigger) + self.assertLess(smaller, bigger) + self.assertGreater(bigger, smaller) + + def _shuffle_lists(self, *args): + return [item for sublist in zip(*args) for item in sublist] + + def _check_sequence_consistency(self, ordered_sequence, equal=False): + for i, el in enumerate(ordered_sequence): + for previous in ordered_sequence[:i]: + self._check_order_consistency(previous, el, equal) + for posterior in ordered_sequence[i + 1:]: + self._check_order_consistency(el, posterior, equal) + + def test_host_order(self): + """ + Test Host class is ordered consistently + + @since 3.9 + @jira_ticket PYTHON-714 + @expected_result the hosts are ordered correctly + + @test_category data_types + """ + hosts = [Host(addr, SimpleConvictionPolicy) for addr in + ("127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4")] + hosts_equal = [Host(addr, SimpleConvictionPolicy) for addr in + ("127.0.0.1", "127.0.0.1")] + hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy), Host("127.0.0.1", ConvictionPolicy)] + self._check_sequence_consistency(hosts) + self._check_sequence_consistency(hosts_equal, equal=True) + self._check_sequence_consistency(hosts_equal_conviction, equal=True) + + def test_date_order(self): + """ + Test Date class is ordered consistently + + @since 3.9 + @jira_ticket PYTHON-714 + @expected_result the dates are ordered correctly + + @test_category data_types + """ + dates_from_string = [Date("2017-01-01"), Date("2017-01-05"), Date("2017-01-09"), Date("2017-01-13")] + dates_from_string_equal = [Date("2017-01-01"), Date("2017-01-01")] + self._check_sequence_consistency(dates_from_string) + self._check_sequence_consistency(dates_from_string_equal, equal=True) + + date_format = "%Y-%m-%d" + + dates_from_value = [ + Date((datetime.datetime.strptime(dtstr, date_format) - + datetime.datetime(1970, 1, 1)).days) + for dtstr in ("2017-01-02", "2017-01-06", "2017-01-10", "2017-01-14") + ] + dates_from_value_equal = [Date(1), Date(1)] + self._check_sequence_consistency(dates_from_value) + self._check_sequence_consistency(dates_from_value_equal, equal=True) + + dates_from_datetime = [Date(datetime.datetime.strptime(dtstr, date_format)) + for dtstr in ("2017-01-03", "2017-01-07", "2017-01-11", "2017-01-15")] + dates_from_datetime_equal = [Date(datetime.datetime.strptime("2017-01-01", date_format)), + Date(datetime.datetime.strptime("2017-01-01", date_format))] + self._check_sequence_consistency(dates_from_datetime) + self._check_sequence_consistency(dates_from_datetime_equal, equal=True) + + dates_from_date = [ + Date(datetime.datetime.strptime(dtstr, date_format).date()) for dtstr in + ("2017-01-04", "2017-01-08", "2017-01-12", "2017-01-16") + ] + dates_from_date_equal = [datetime.datetime.strptime(dtstr, date_format) for dtstr in + ("2017-01-09", "2017-01-9")] + + self._check_sequence_consistency(dates_from_date) + self._check_sequence_consistency(dates_from_date_equal, equal=True) + + self._check_sequence_consistency(self._shuffle_lists(dates_from_string, dates_from_value, + dates_from_datetime, dates_from_date)) + + def test_timer_order(self): + """ + Test Time class is ordered consistently + + @since 3.9 + @jira_ticket PYTHON-714 + @expected_result the times are ordered correctly + + @test_category data_types + """ + time_from_int = [Time(1000), Time(4000), Time(7000), Time(10000)] + time_from_int_equal = [Time(1), Time(1)] + self._check_sequence_consistency(time_from_int) + self._check_sequence_consistency(time_from_int_equal, equal=True) + + time_from_datetime = [Time(datetime.time(hour=0, minute=0, second=0, microsecond=us)) + for us in (2, 5, 8, 11)] + time_from_datetime_equal = [Time(datetime.time(hour=0, minute=0, second=0, microsecond=us)) + for us in (1, 1)] + self._check_sequence_consistency(time_from_datetime) + self._check_sequence_consistency(time_from_datetime_equal, equal=True) + + time_from_string = [Time("00:00:00.000003000"), Time("00:00:00.000006000"), + Time("00:00:00.000009000"), Time("00:00:00.000012000")] + time_from_string_equal = [Time("00:00:00.000004000"), Time("00:00:00.000004000")] + self._check_sequence_consistency(time_from_string) + self._check_sequence_consistency(time_from_string_equal, equal=True) + + self._check_sequence_consistency(self._shuffle_lists(time_from_int, time_from_datetime, time_from_string)) + + def test_token_order(self): + """ + Test Token class is ordered consistently + + @since 3.9 + @jira_ticket PYTHON-714 + @expected_result the tokens are ordered correctly + + @test_category data_types + """ + tokens = [Token(1), Token(2), Token(3), Token(4)] + tokens_equal = [Token(1), Token(1)] + self._check_sequence_consistency(tokens) + self._check_sequence_consistency(tokens_equal, equal=True) diff --git a/tests/unit/test_util_types.py b/tests/unit/test_util_types.py new file mode 100644 index 0000000..8c60bfe --- /dev/null +++ b/tests/unit/test_util_types.py @@ -0,0 +1,296 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import datetime + +from cassandra.util import Date, Time, Duration, Version + + +class DateTests(unittest.TestCase): + + def test_from_datetime(self): + expected_date = datetime.date(1492, 10, 12) + d = Date(expected_date) + self.assertEqual(str(d), str(expected_date)) + + def test_from_string(self): + expected_date = datetime.date(1492, 10, 12) + d = Date(expected_date) + sd = Date('1492-10-12') + self.assertEqual(sd, d) + sd = Date('+1492-10-12') + self.assertEqual(sd, d) + + def test_from_date(self): + expected_date = datetime.date(1492, 10, 12) + d = Date(expected_date) + self.assertEqual(d.date(), expected_date) + + def test_from_days(self): + sd = Date(0) + self.assertEqual(sd, Date(datetime.date(1970, 1, 1))) + sd = Date(-1) + self.assertEqual(sd, Date(datetime.date(1969, 12, 31))) + sd = Date(1) + self.assertEqual(sd, Date(datetime.date(1970, 1, 2))) + + def test_limits(self): + min_builtin = Date(datetime.date(1, 1, 1)) + max_builtin = Date(datetime.date(9999, 12, 31)) + self.assertEqual(Date(min_builtin.days_from_epoch), min_builtin) + self.assertEqual(Date(max_builtin.days_from_epoch), max_builtin) + # just proving we can construct with on offset outside buildin range + self.assertEqual(Date(min_builtin.days_from_epoch - 1).days_from_epoch, + min_builtin.days_from_epoch - 1) + self.assertEqual(Date(max_builtin.days_from_epoch + 1).days_from_epoch, + max_builtin.days_from_epoch + 1) + + def test_invalid_init(self): + self.assertRaises(ValueError, Date, '-1999-10-10') + self.assertRaises(TypeError, Date, 1.234) + + def test_str(self): + date_str = '2015-03-16' + self.assertEqual(str(Date(date_str)), date_str) + + def test_out_of_range(self): + self.assertEqual(str(Date(2932897)), '2932897') + self.assertEqual(repr(Date(1)), 'Date(1)') + + def test_equals(self): + self.assertEqual(Date(1234), 1234) + self.assertEqual(Date(1), datetime.date(1970, 1, 2)) + self.assertFalse(Date(2932897) == datetime.date(9999, 12, 31)) # date can't represent year > 9999 + self.assertEqual(Date(2932897), 2932897) + + +class TimeTests(unittest.TestCase): + + def test_units_from_string(self): + one_micro = 1000 + one_milli = 1000 * one_micro + one_second = 1000 * one_milli + one_minute = 60 * one_second + one_hour = 60 * one_minute + + tt = Time('00:00:00.000000001') + self.assertEqual(tt.nanosecond_time, 1) + tt = Time('00:00:00.000001') + self.assertEqual(tt.nanosecond_time, one_micro) + tt = Time('00:00:00.001') + self.assertEqual(tt.nanosecond_time, one_milli) + tt = Time('00:00:01') + self.assertEqual(tt.nanosecond_time, one_second) + tt = Time('00:01:00') + self.assertEqual(tt.nanosecond_time, one_minute) + tt = Time('01:00:00') + self.assertEqual(tt.nanosecond_time, one_hour) + tt = Time('01:00:00.') + self.assertEqual(tt.nanosecond_time, one_hour) + + tt = Time('23:59:59.123456') + self.assertEqual(tt.nanosecond_time, 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro) + + tt = Time('23:59:59.1234567') + self.assertEqual(tt.nanosecond_time, 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro + 700) + + tt = Time('23:59:59.12345678') + self.assertEqual(tt.nanosecond_time, 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro + 780) + + tt = Time('23:59:59.123456789') + self.assertEqual(tt.nanosecond_time, 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro + 789) + + def test_micro_precision(self): + Time('23:59:59.1') + Time('23:59:59.12') + Time('23:59:59.123') + Time('23:59:59.1234') + Time('23:59:59.12345') + + def test_from_int(self): + tt = Time(12345678) + self.assertEqual(tt.nanosecond_time, 12345678) + + def test_from_time(self): + expected_time = datetime.time(12, 1, 2, 3) + tt = Time(expected_time) + self.assertEqual(tt, expected_time) + + def test_as_time(self): + expected_time = datetime.time(12, 1, 2, 3) + tt = Time(expected_time) + self.assertEqual(tt.time(), expected_time) + + def test_equals(self): + # util.Time self equality + self.assertEqual(Time(1234), Time(1234)) + + def test_str_repr(self): + time_str = '12:13:14.123456789' + self.assertEqual(str(Time(time_str)), time_str) + self.assertEqual(repr(Time(1)), 'Time(1)') + + def test_invalid_init(self): + self.assertRaises(ValueError, Time, '1999-10-10 11:11:11.1234') + self.assertRaises(TypeError, Time, 1.234) + self.assertRaises(ValueError, Time, 123456789000000) + self.assertRaises(TypeError, Time, datetime.datetime(2004, 12, 23, 11, 11, 1)) + + +class DurationTests(unittest.TestCase): + + def test_valid_format(self): + + valid = Duration(1, 1, 1) + self.assertEqual(valid.months, 1) + self.assertEqual(valid.days, 1) + self.assertEqual(valid.nanoseconds, 1) + + valid = Duration(nanoseconds=100000) + self.assertEqual(valid.months, 0) + self.assertEqual(valid.days, 0) + self.assertEqual(valid.nanoseconds, 100000) + + valid = Duration() + self.assertEqual(valid.months, 0) + self.assertEqual(valid.days, 0) + self.assertEqual(valid.nanoseconds, 0) + + valid = Duration(-10, -21, -1000) + self.assertEqual(valid.months, -10) + self.assertEqual(valid.days, -21) + self.assertEqual(valid.nanoseconds, -1000) + + def test_equality(self): + + first = Duration(1, 1, 1) + second = Duration(-1, 1, 1) + self.assertNotEqual(first, second) + + first = Duration(1, 1, 1) + second = Duration(1, 1, 1) + self.assertEqual(first, second) + + first = Duration() + second = Duration(0, 0, 0) + self.assertEqual(first, second) + + first = Duration(1000, 10000, 2345345) + second = Duration(1000, 10000, 2345345) + self.assertEqual(first, second) + + first = Duration(12, 0 , 100) + second = Duration(nanoseconds=100, months=12) + self.assertEqual(first, second) + + def test_str(self): + + self.assertEqual(str(Duration(1, 1, 1)), "1mo1d1ns") + self.assertEqual(str(Duration(1, 1, -1)), "-1mo1d1ns") + self.assertEqual(str(Duration(1, 1, 1000000000000000)), "1mo1d1000000000000000ns") + self.assertEqual(str(Duration(52, 23, 564564)), "52mo23d564564ns") + + +class VersionTests(unittest.TestCase): + + def test_version_parsing(self): + versions = [ + ('2.0.0', (2, 0, 0, 0, 0)), + ('3.1.0', (3, 1, 0, 0, 0)), + ('2.4.54', (2, 4, 54, 0, 0)), + ('3.1.1.12', (3, 1, 1, 12, 0)), + ('3.55.1.build12', (3, 55, 1, 'build12', 0)), + ('3.55.1.20190429-TEST', (3, 55, 1, 20190429, 'TEST')), + ('4.0-SNAPSHOT', (4, 0, 0, 0, 'SNAPSHOT')), + ] + + for str_version, expected_result in versions: + v = Version(str_version) + self.assertEqual(str_version, str(v)) + self.assertEqual(v.major, expected_result[0]) + self.assertEqual(v.minor, expected_result[1]) + self.assertEqual(v.patch, expected_result[2]) + self.assertEqual(v.build, expected_result[3]) + self.assertEqual(v.prerelease, expected_result[4]) + + # not supported version formats + with self.assertRaises(ValueError): + Version('2.1.hello') + + with self.assertRaises(ValueError): + Version('2.test.1') + + with self.assertRaises(ValueError): + Version('test.1.0') + + with self.assertRaises(ValueError): + Version('1.0.0.0.1') + + def test_version_compare(self): + # just tests a bunch of versions + + # major wins + self.assertTrue(Version('3.3.0') > Version('2.5.0')) + self.assertTrue(Version('3.3.0') > Version('2.5.0.66')) + self.assertTrue(Version('3.3.0') > Version('2.5.21')) + + # minor wins + self.assertTrue(Version('2.3.0') > Version('2.2.0')) + self.assertTrue(Version('2.3.0') > Version('2.2.7')) + self.assertTrue(Version('2.3.0') > Version('2.2.7.9')) + + # patch wins + self.assertTrue(Version('2.3.1') > Version('2.3.0')) + self.assertTrue(Version('2.3.1') > Version('2.3.0.4post0')) + self.assertTrue(Version('2.3.1') > Version('2.3.0.44')) + + # various + self.assertTrue(Version('2.3.0.1') > Version('2.3.0.0')) + self.assertTrue(Version('2.3.0.680') > Version('2.3.0.670')) + self.assertTrue(Version('2.3.0.681') > Version('2.3.0.680')) + self.assertTrue(Version('2.3.0.1build0') > Version('2.3.0.1')) # 4th part fallback to str cmp + self.assertTrue(Version('2.3.0.build0') > Version('2.3.0.1')) # 4th part fallback to str cmp + self.assertTrue(Version('2.3.0') < Version('2.3.0.build')) + + self.assertTrue(Version('4-a') <= Version('4.0.0')) + self.assertTrue(Version('4-a') <= Version('4.0-alpha1')) + self.assertTrue(Version('4-a') <= Version('4.0-beta1')) + self.assertTrue(Version('4.0.0') >= Version('4.0.0')) + self.assertTrue(Version('4.0.0.421') >= Version('4.0.0')) + self.assertTrue(Version('4.0.1') >= Version('4.0.0')) + self.assertTrue(Version('2.3.0') == Version('2.3.0')) + self.assertTrue(Version('2.3.32') == Version('2.3.32')) + self.assertTrue(Version('2.3.32') == Version('2.3.32.0')) + self.assertTrue(Version('2.3.0.build') == Version('2.3.0.build')) + + self.assertTrue(Version('4') == Version('4.0.0')) + self.assertTrue(Version('4.0') == Version('4.0.0.0')) + self.assertTrue(Version('4.0') > Version('3.9.3')) + + self.assertTrue(Version('4.0') > Version('4.0-SNAPSHOT')) + self.assertTrue(Version('4.0-SNAPSHOT') == Version('4.0-SNAPSHOT')) + self.assertTrue(Version('4.0.0-SNAPSHOT') == Version('4.0-SNAPSHOT')) + self.assertTrue(Version('4.0.0-SNAPSHOT') == Version('4.0.0-SNAPSHOT')) + self.assertTrue(Version('4.0.0.build5-SNAPSHOT') == Version('4.0.0.build5-SNAPSHOT')) + self.assertTrue(Version('4.1-SNAPSHOT') > Version('4.0-SNAPSHOT')) + self.assertTrue(Version('4.0.1-SNAPSHOT') > Version('4.0.0-SNAPSHOT')) + self.assertTrue(Version('4.0.0.build6-SNAPSHOT') > Version('4.0.0.build5-SNAPSHOT')) + self.assertTrue(Version('4.0-SNAPSHOT2') > Version('4.0-SNAPSHOT1')) + self.assertTrue(Version('4.0-SNAPSHOT2') > Version('4.0.0-SNAPSHOT1')) + + self.assertTrue(Version('4.0.0-alpha1-SNAPSHOT') > Version('4.0.0-SNAPSHOT')) diff --git a/tests/unit/utils.py b/tests/unit/utils.py new file mode 100644 index 0000000..b3ac113 --- /dev/null +++ b/tests/unit/utils.py @@ -0,0 +1,18 @@ +from concurrent.futures import Future +from functools import wraps +from cassandra.cluster import Session +from mock import patch + +def mock_session_pools(f): + """ + Helper decorator that allows tests to initialize :class:.`Session` objects + without actually connecting to a Cassandra cluster. + """ + @wraps(f) + def wrapper(*args, **kwargs): + with patch.object(Session, "add_or_renew_pool") as mocked_add_or_renew_pool: + future = Future() + future.set_result(object()) + mocked_add_or_renew_pool.return_value = future + f(*args, **kwargs) + return wrapper