diff --git a/PKG-INFO b/PKG-INFO index 0c04873..11e02d9 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,113 +1,112 @@ Metadata-Version: 1.1 Name: cassandra-driver -Version: 3.14.0 +Version: 3.16.0 Summary: Python driver for Cassandra Home-page: http://github.com/datastax/python-driver -Author: Tyler Hobbs -Author-email: tyler@datastax.com +Author:: Tyler Hobbs +Author-email:: tyler@datastax.com License: UNKNOWN Description: DataStax Python Driver for Apache Cassandra =========================================== .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master :target: https://travis-ci.org/datastax/python-driver A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. - The driver supports Python 2.7, 3.3, 3.4, 3.5, and 3.6. + The driver supports Python 2.7, 3.4, 3.5, and 3.6. If you require compatibility with DataStax Enterprise, use the `DataStax Enterprise Python Driver `_. **Note:** DataStax products do not support big-endian systems. Feedback Requested ------------------ **Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short). Features -------- * `Synchronous `_ and `Asynchronous `_ APIs * `Simple, Prepared, and Batch statements `_ * Asynchronous IO, parallel execution, request pipelining * `Connection pooling `_ * Automatic node discovery * `Automatic reconnection `_ * Configurable `load balancing `_ and `retry policies `_ * `Concurrent execution utilities `_ * `Object mapper `_ Installation ------------ Installation through pip is recommended:: $ pip install cassandra-driver For more complete installation instructions, see the `installation guide `_. Documentation ------------- The documentation can be found online `here `_. A couple of links for getting up to speed: * `Installation `_ * `Getting started guide `_ * `API docs `_ * `Performance tips `_ Object Mapper ------------- cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the community) is now maintained as an integral part of this package. Refer to `documentation here `_. Contributing ------------ See `CONTRIBUTING.md `_. Reporting Problems ------------------ Please report any bugs and make any feature requests on the `JIRA `_ issue tracker. If you would like to contribute, please feel free to open a pull request. Getting Help ------------ Your best options for getting help with the driver are the `mailing list `_ and the ``#datastax-drivers`` channel in the `DataStax Academy Slack `_. License ------- Copyright 2013-2017 DataStax Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Keywords: cassandra,cql,orm Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Natural Language :: English Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2.7 -Classifier: Programming Language :: Python :: 3.3 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Software Development :: Libraries :: Python Modules diff --git a/README.rst b/README.rst index f30a916..f14cc77 100644 --- a/README.rst +++ b/README.rst @@ -1,88 +1,88 @@ DataStax Python Driver for Apache Cassandra =========================================== .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master :target: https://travis-ci.org/datastax/python-driver A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. -The driver supports Python 2.7, 3.3, 3.4, 3.5, and 3.6. +The driver supports Python 2.7, 3.4, 3.5, and 3.6. If you require compatibility with DataStax Enterprise, use the `DataStax Enterprise Python Driver `_. **Note:** DataStax products do not support big-endian systems. Feedback Requested ------------------ **Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short). Features -------- * `Synchronous `_ and `Asynchronous `_ APIs * `Simple, Prepared, and Batch statements `_ * Asynchronous IO, parallel execution, request pipelining * `Connection pooling `_ * Automatic node discovery * `Automatic reconnection `_ * Configurable `load balancing `_ and `retry policies `_ * `Concurrent execution utilities `_ * `Object mapper `_ Installation ------------ Installation through pip is recommended:: $ pip install cassandra-driver For more complete installation instructions, see the `installation guide `_. Documentation ------------- The documentation can be found online `here `_. A couple of links for getting up to speed: * `Installation `_ * `Getting started guide `_ * `API docs `_ * `Performance tips `_ Object Mapper ------------- cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the community) is now maintained as an integral part of this package. Refer to `documentation here `_. Contributing ------------ See `CONTRIBUTING.md `_. Reporting Problems ------------------ Please report any bugs and make any feature requests on the `JIRA `_ issue tracker. If you would like to contribute, please feel free to open a pull request. Getting Help ------------ Your best options for getting help with the driver are the `mailing list `_ and the ``#datastax-drivers`` channel in the `DataStax Academy Slack `_. License ------- Copyright 2013-2017 DataStax Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. diff --git a/cassandra/__init__.py b/cassandra/__init__.py index a3936b1..94a2bc9 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -1,688 +1,698 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging class NullHandler(logging.Handler): def emit(self, record): pass logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 14, 0) +__version_info__ = (3, 16, 0) __version__ = '.'.join(map(str, __version_info__)) class ConsistencyLevel(object): """ Spcifies how many replicas must respond for an operation to be considered a success. By default, ``ONE`` is used for all operations. """ ANY = 0 """ Only requires that one replica receives the write *or* the coordinator stores a hint to replay later. Valid only for writes. """ ONE = 1 """ Only one replica needs to respond to consider the operation a success """ TWO = 2 """ Two replicas must respond to consider the operation a success """ THREE = 3 """ Three replicas must respond to consider the operation a success """ QUORUM = 4 """ ``ceil(RF/2)`` replicas must respond to consider the operation a success """ ALL = 5 """ All replicas must respond to consider the operation a success """ LOCAL_QUORUM = 6 """ Requires a quorum of replicas in the local datacenter """ EACH_QUORUM = 7 """ Requires a quorum of replicas in each datacenter """ SERIAL = 8 """ For conditional inserts/updates that utilize Cassandra's lightweight transactions, this requires consensus among all replicas for the modified data. """ LOCAL_SERIAL = 9 """ Like :attr:`~ConsistencyLevel.SERIAL`, but only requires consensus among replicas in the local datacenter. """ LOCAL_ONE = 10 """ Sends a request only to replicas in the local datacenter and waits for one response. """ ConsistencyLevel.value_to_name = { ConsistencyLevel.ANY: 'ANY', ConsistencyLevel.ONE: 'ONE', ConsistencyLevel.TWO: 'TWO', ConsistencyLevel.THREE: 'THREE', ConsistencyLevel.QUORUM: 'QUORUM', ConsistencyLevel.ALL: 'ALL', ConsistencyLevel.LOCAL_QUORUM: 'LOCAL_QUORUM', ConsistencyLevel.EACH_QUORUM: 'EACH_QUORUM', ConsistencyLevel.SERIAL: 'SERIAL', ConsistencyLevel.LOCAL_SERIAL: 'LOCAL_SERIAL', ConsistencyLevel.LOCAL_ONE: 'LOCAL_ONE' } ConsistencyLevel.name_to_value = { 'ANY': ConsistencyLevel.ANY, 'ONE': ConsistencyLevel.ONE, 'TWO': ConsistencyLevel.TWO, 'THREE': ConsistencyLevel.THREE, 'QUORUM': ConsistencyLevel.QUORUM, 'ALL': ConsistencyLevel.ALL, 'LOCAL_QUORUM': ConsistencyLevel.LOCAL_QUORUM, 'EACH_QUORUM': ConsistencyLevel.EACH_QUORUM, 'SERIAL': ConsistencyLevel.SERIAL, 'LOCAL_SERIAL': ConsistencyLevel.LOCAL_SERIAL, 'LOCAL_ONE': ConsistencyLevel.LOCAL_ONE } def consistency_value_to_name(value): return ConsistencyLevel.value_to_name[value] if value is not None else "Not Set" class ProtocolVersion(object): """ Defines native protocol versions supported by this driver. """ V1 = 1 """ v1, supported in Cassandra 1.2-->2.2 """ V2 = 2 """ v2, supported in Cassandra 2.0-->2.2; added support for lightweight transactions, batch operations, and automatic query paging. """ V3 = 3 """ v3, supported in Cassandra 2.1-->3.x+; added support for protocol-level client-side timestamps (see :attr:`.Session.use_client_timestamp`), serial consistency levels for :class:`~.BatchStatement`, and an improved connection pool. """ V4 = 4 """ v4, supported in Cassandra 2.2-->3.x+; added a number of new types, server warnings, new failure messages, and custom payloads. Details in the `project docs `_ """ V5 = 5 """ v5, in beta from 3.x+ """ SUPPORTED_VERSIONS = (V5, V4, V3, V2, V1) """ A tuple of all supported protocol versions """ BETA_VERSIONS = (V5,) """ A tuple of all beta protocol versions """ MIN_SUPPORTED = min(SUPPORTED_VERSIONS) """ Minimum protocol version supported by this driver. """ MAX_SUPPORTED = max(SUPPORTED_VERSIONS) """ Maximum protocol versioni supported by this driver. """ @classmethod def get_lower_supported(cls, previous_version): """ Return the lower supported protocol version. Beta versions are omitted. """ try: version = next(v for v in sorted(ProtocolVersion.SUPPORTED_VERSIONS, reverse=True) if v not in ProtocolVersion.BETA_VERSIONS and v < previous_version) except StopIteration: version = 0 return version @classmethod def uses_int_query_flags(cls, version): return version >= cls.V5 @classmethod def uses_prepare_flags(cls, version): return version >= cls.V5 @classmethod def uses_prepared_metadata(cls, version): return version >= cls.V5 @classmethod def uses_error_code_map(cls, version): return version >= cls.V5 @classmethod def uses_keyspace_flag(cls, version): return version >= cls.V5 class WriteType(object): """ For usage with :class:`.RetryPolicy`, this describe a type of write operation. """ SIMPLE = 0 """ A write to a single partition key. Such writes are guaranteed to be atomic and isolated. """ BATCH = 1 """ A write to multiple partition keys that used the distributed batch log to ensure atomicity. """ UNLOGGED_BATCH = 2 """ A write to multiple partition keys that did not use the distributed batch log. Atomicity for such writes is not guaranteed. """ COUNTER = 3 """ A counter write (for one or multiple partition keys). Such writes should not be replayed in order to avoid overcount. """ BATCH_LOG = 4 """ The initial write to the distributed batch log that Cassandra performs internally before a BATCH write. """ CAS = 5 """ A lighweight-transaction write, such as "DELETE ... IF EXISTS". """ VIEW = 6 """ This WriteType is only seen in results for requests that were unable to complete MV operations. """ CDC = 7 """ This WriteType is only seen in results for requests that were unable to complete CDC operations. """ WriteType.name_to_value = { 'SIMPLE': WriteType.SIMPLE, 'BATCH': WriteType.BATCH, 'UNLOGGED_BATCH': WriteType.UNLOGGED_BATCH, 'COUNTER': WriteType.COUNTER, 'BATCH_LOG': WriteType.BATCH_LOG, 'CAS': WriteType.CAS, 'VIEW': WriteType.VIEW, 'CDC': WriteType.CDC } WriteType.value_to_name = {v: k for k, v in WriteType.name_to_value.items()} class SchemaChangeType(object): DROPPED = 'DROPPED' CREATED = 'CREATED' UPDATED = 'UPDATED' class SchemaTargetType(object): KEYSPACE = 'KEYSPACE' TABLE = 'TABLE' TYPE = 'TYPE' FUNCTION = 'FUNCTION' AGGREGATE = 'AGGREGATE' class SignatureDescriptor(object): def __init__(self, name, argument_types): self.name = name self.argument_types = argument_types @property def signature(self): """ function signature string in the form 'name([type0[,type1[...]]])' can be used to uniquely identify overloaded function names within a keyspace """ return self.format_signature(self.name, self.argument_types) @staticmethod def format_signature(name, argument_types): return "%s(%s)" % (name, ','.join(t for t in argument_types)) def __repr__(self): return "%s(%s, %s)" % (self.__class__.__name__, self.name, self.argument_types) class UserFunctionDescriptor(SignatureDescriptor): """ Describes a User function by name and argument signature """ name = None """ name of the function """ argument_types = None """ Ordered list of CQL argument type names comprising the type signature """ class UserAggregateDescriptor(SignatureDescriptor): """ Describes a User aggregate function by name and argument signature """ name = None """ name of the aggregate """ argument_types = None """ Ordered list of CQL argument type names comprising the type signature """ class DriverException(Exception): """ Base for all exceptions explicitly raised by the driver. """ pass class RequestExecutionException(DriverException): """ Base for request execution exceptions returned from the server. """ pass class Unavailable(RequestExecutionException): """ There were not enough live replicas to satisfy the requested consistency level, so the coordinator node immediately failed the request without forwarding it to any replicas. """ consistency = None """ The requested :class:`ConsistencyLevel` """ required_replicas = None """ The number of replicas that needed to be live to complete the operation """ alive_replicas = None """ The number of replicas that were actually alive """ def __init__(self, summary_message, consistency=None, required_replicas=None, alive_replicas=None): self.consistency = consistency self.required_replicas = required_replicas self.alive_replicas = alive_replicas Exception.__init__(self, summary_message + ' info=' + repr({'consistency': consistency_value_to_name(consistency), 'required_replicas': required_replicas, 'alive_replicas': alive_replicas})) class Timeout(RequestExecutionException): """ Replicas failed to respond to the coordinator node before timing out. """ consistency = None """ The requested :class:`ConsistencyLevel` """ required_responses = None """ The number of required replica responses """ received_responses = None """ The number of replicas that responded before the coordinator timed out the operation """ def __init__(self, summary_message, consistency=None, required_responses=None, received_responses=None, **kwargs): self.consistency = consistency self.required_responses = required_responses self.received_responses = received_responses if "write_type" in kwargs: kwargs["write_type"] = WriteType.value_to_name[kwargs["write_type"]] info = {'consistency': consistency_value_to_name(consistency), 'required_responses': required_responses, 'received_responses': received_responses} info.update(kwargs) Exception.__init__(self, summary_message + ' info=' + repr(info)) class ReadTimeout(Timeout): """ A subclass of :exc:`Timeout` for read operations. This indicates that the replicas failed to respond to the coordinator node before the configured timeout. This timeout is configured in ``cassandra.yaml`` with the ``read_request_timeout_in_ms`` and ``range_request_timeout_in_ms`` options. """ data_retrieved = None """ A boolean indicating whether the requested data was retrieved by the coordinator from any replicas before it timed out the operation """ def __init__(self, message, data_retrieved=None, **kwargs): Timeout.__init__(self, message, **kwargs) self.data_retrieved = data_retrieved class WriteTimeout(Timeout): """ A subclass of :exc:`Timeout` for write operations. This indicates that the replicas failed to respond to the coordinator node before the configured timeout. This timeout is configured in ``cassandra.yaml`` with the ``write_request_timeout_in_ms`` option. """ write_type = None """ The type of write operation, enum on :class:`~cassandra.policies.WriteType` """ def __init__(self, message, write_type=None, **kwargs): kwargs["write_type"] = write_type Timeout.__init__(self, message, **kwargs) self.write_type = write_type class CDCWriteFailure(RequestExecutionException): """ Hit limit on data in CDC folder, writes are rejected """ def __init__(self, message): Exception.__init__(self, message) class CoordinationFailure(RequestExecutionException): """ Replicas sent a failure to the coordinator. """ consistency = None """ The requested :class:`ConsistencyLevel` """ required_responses = None """ The number of required replica responses """ received_responses = None """ The number of replicas that responded before the coordinator timed out the operation """ failures = None """ The number of replicas that sent a failure message """ error_code_map = None """ A map of inet addresses to error codes representing replicas that sent a failure message. Only set when `protocol_version` is 5 or higher. """ def __init__(self, summary_message, consistency=None, required_responses=None, received_responses=None, failures=None, error_code_map=None): self.consistency = consistency self.required_responses = required_responses self.received_responses = received_responses self.failures = failures self.error_code_map = error_code_map info_dict = { 'consistency': consistency_value_to_name(consistency), 'required_responses': required_responses, 'received_responses': received_responses, 'failures': failures } if error_code_map is not None: # make error codes look like "0x002a" formatted_map = dict((addr, '0x%04x' % err_code) for (addr, err_code) in error_code_map.items()) info_dict['error_code_map'] = formatted_map Exception.__init__(self, summary_message + ' info=' + repr(info_dict)) class ReadFailure(CoordinationFailure): """ A subclass of :exc:`CoordinationFailure` for read operations. This indicates that the replicas sent a failure message to the coordinator. """ data_retrieved = None """ A boolean indicating whether the requested data was retrieved by the coordinator from any replicas before it timed out the operation """ def __init__(self, message, data_retrieved=None, **kwargs): CoordinationFailure.__init__(self, message, **kwargs) self.data_retrieved = data_retrieved class WriteFailure(CoordinationFailure): """ A subclass of :exc:`CoordinationFailure` for write operations. This indicates that the replicas sent a failure message to the coordinator. """ write_type = None """ The type of write operation, enum on :class:`~cassandra.policies.WriteType` """ def __init__(self, message, write_type=None, **kwargs): CoordinationFailure.__init__(self, message, **kwargs) self.write_type = write_type class FunctionFailure(RequestExecutionException): """ User Defined Function failed during execution """ keyspace = None """ Keyspace of the function """ function = None """ Name of the function """ arg_types = None """ List of argument type names of the function """ def __init__(self, summary_message, keyspace, function, arg_types): self.keyspace = keyspace self.function = function self.arg_types = arg_types Exception.__init__(self, summary_message) class RequestValidationException(DriverException): """ Server request validation failed """ pass class ConfigurationException(RequestValidationException): """ Server indicated request errro due to current configuration """ pass class AlreadyExists(ConfigurationException): """ An attempt was made to create a keyspace or table that already exists. """ keyspace = None """ The name of the keyspace that already exists, or, if an attempt was made to create a new table, the keyspace that the table is in. """ table = None """ The name of the table that already exists, or, if an attempt was make to create a keyspace, :const:`None`. """ def __init__(self, keyspace=None, table=None): if table: message = "Table '%s.%s' already exists" % (keyspace, table) else: message = "Keyspace '%s' already exists" % (keyspace,) Exception.__init__(self, message) self.keyspace = keyspace self.table = table class InvalidRequest(RequestValidationException): """ A query was made that was invalid for some reason, such as trying to set the keyspace for a connection to a nonexistent keyspace. """ pass class Unauthorized(RequestValidationException): """ The current user is not authorized to perform the requested operation. """ pass class AuthenticationFailed(DriverException): """ Failed to authenticate. """ pass class OperationTimedOut(DriverException): """ The operation took longer than the specified (client-side) timeout to complete. This is not an error generated by Cassandra, only the driver. """ errors = None """ A dict of errors keyed by the :class:`~.Host` against which they occurred. """ last_host = None """ The last :class:`~.Host` this operation was attempted against. """ def __init__(self, errors=None, last_host=None): self.errors = errors self.last_host = last_host message = "errors=%s, last_host=%s" % (self.errors, self.last_host) Exception.__init__(self, message) class UnsupportedOperation(DriverException): """ An attempt was made to use a feature that is not supported by the selected protocol version. See :attr:`Cluster.protocol_version` for more details. """ pass + + +class UnresolvableContactPoints(DriverException): + """ + The driver was unable to resolve any provided hostnames. + + Note that this is *not* raised when a :class:`.Cluster` is created with no + contact points, only when lookup fails for all hosts + """ + pass diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 22ea621..e119605 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1,4373 +1,4409 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module houses the main classes you will interact with, :class:`.Cluster` and :class:`.Session`. """ from __future__ import absolute_import import atexit from collections import defaultdict, Mapping from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures from copy import copy from functools import partial, wraps from itertools import groupby, count import logging from warnings import warn from random import random import six from six.moves import filter, range, queue as Queue import socket import sys import time from threading import Lock, RLock, Thread, Event import weakref from weakref import WeakValueDictionary try: from weakref import WeakSet except ImportError: from cassandra.util import WeakSet # NOQA from cassandra import (ConsistencyLevel, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, - SchemaTargetType, DriverException, ProtocolVersion) + SchemaTargetType, DriverException, ProtocolVersion, + UnresolvableContactPoints) from cassandra.connection import (ConnectionException, ConnectionShutdown, ConnectionHeartbeat, ProtocolVersionUnsupported) from cassandra.cqltypes import UserType from cassandra.encoder import Encoder from cassandra.protocol import (QueryMessage, ResultMessage, ErrorMessage, ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, OverloadedErrorMessage, PrepareMessage, ExecuteMessage, PreparedQueryNotFound, IsBootstrappingErrorMessage, BatchMessage, RESULT_KIND_PREPARED, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler) from cassandra.metadata import Metadata, protect_name, murmur3 from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, ExponentialReconnectionPolicy, HostDistance, RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan, NoSpeculativeExecutionPolicy) from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, HostConnectionPool, HostConnection, NoConnectionsAvailable) from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, BatchStatement, bind_params, QueryTrace, TraceUnavailable, named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET) from cassandra.timestamps import MonotonicTimestampGenerator def _is_eventlet_monkey_patched(): if 'eventlet.patcher' not in sys.modules: return False import eventlet.patcher return eventlet.patcher.is_monkey_patched('socket') def _is_gevent_monkey_patched(): if 'gevent.monkey' not in sys.modules: return False import gevent.socket return socket.socket is gevent.socket.socket + # default to gevent when we are monkey patched with gevent, eventlet when # monkey patched with eventlet, otherwise if libev is available, use that as # the default because it's fastest. Otherwise, use asyncore. if _is_gevent_monkey_patched(): from cassandra.io.geventreactor import GeventConnection as DefaultConnection elif _is_eventlet_monkey_patched(): from cassandra.io.eventletreactor import EventletConnection as DefaultConnection else: try: from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA except ImportError: from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA # Forces load of utf8 encoding module to avoid deadlock that occurs # if code that is being imported tries to import the module in a seperate # thread. # See http://bugs.python.org/issue10923 "".encode('utf8') log = logging.getLogger(__name__) DEFAULT_MIN_REQUESTS = 5 DEFAULT_MAX_REQUESTS = 100 DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST = 2 DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST = 8 DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1 DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2 _NOT_SET = object() class NoHostAvailable(Exception): """ Raised when an operation is attempted but all connections are busy, defunct, closed, or resulted in errors when used. """ errors = None """ A map of the form ``{ip: exception}`` which details the particular Exception that was caught for each host the operation was attempted against. """ def __init__(self, message, errors): Exception.__init__(self, message, errors) self.errors = errors def _future_completed(future): """ Helper for run_in_executor() """ exc = future.exception() if exc: log.debug("Failed to run task on executor", exc_info=exc) def run_in_executor(f): """ A decorator to run the given method in the ThreadPoolExecutor. """ @wraps(f) def new_f(self, *args, **kwargs): if self.is_shutdown: return try: future = self.executor.submit(f, self, *args, **kwargs) future.add_done_callback(_future_completed) except Exception: log.exception("Failed to submit task to executor") return new_f _clusters_for_shutdown = set() def _register_cluster_shutdown(cluster): _clusters_for_shutdown.add(cluster) def _discard_cluster_shutdown(cluster): _clusters_for_shutdown.discard(cluster) def _shutdown_clusters(): clusters = _clusters_for_shutdown.copy() # copy because shutdown modifies the global set "discard" for cluster in clusters: cluster.shutdown() + atexit.register(_shutdown_clusters) def default_lbp_factory(): if murmur3 is not None: return TokenAwarePolicy(DCAwareRoundRobinPolicy()) return DCAwareRoundRobinPolicy() +def _addrinfo_or_none(contact_point, port): + """ + A helper function that wraps socket.getaddrinfo and returns None + when it fails to, e.g. resolve one of the hostnames. Used to address + PYTHON-895. + """ + try: + return socket.getaddrinfo(contact_point, port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + except socket.gaierror: + log.debug('Could not resolve hostname "{}" ' + 'with port {}'.format(contact_point, port)) + return None + + +def _resolve_contact_points(contact_points, port): + resolved = tuple(_addrinfo_or_none(p, port) + for p in contact_points) + + if resolved and all((x is None for x in resolved)): + raise UnresolvableContactPoints(contact_points, port) + + resolved = tuple(r for r in resolved if r is not None) + + return [endpoint[4][0] + for addrinfo in resolved + for endpoint in addrinfo] + + class ExecutionProfile(object): load_balancing_policy = None """ An instance of :class:`.policies.LoadBalancingPolicy` or one of its subclasses. Used in determining host distance for establishing connections, and routing requests. Defaults to ``TokenAwarePolicy(DCAwareRoundRobinPolicy())`` if not specified """ retry_policy = None """ An instance of :class:`.policies.RetryPolicy` instance used when :class:`.Statement` objects do not have a :attr:`~.Statement.retry_policy` explicitly set. Defaults to :class:`.RetryPolicy` if not specified """ consistency_level = ConsistencyLevel.LOCAL_ONE """ :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement`. """ serial_consistency_level = None """ Serial :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement` (for LWT conditional statements). """ request_timeout = 10.0 """ Request timeout used when not overridden in :meth:`.Session.execute` """ row_factory = staticmethod(tuple_factory) """ A callable to format results, accepting ``(colnames, rows)`` where ``colnames`` is a list of column names, and ``rows`` is a list of tuples, with each tuple representing a row of parsed values. Some example implementations: - :func:`cassandra.query.tuple_factory` - return a result row as a tuple - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple - :func:`cassandra.query.dict_factory` - return a result row as a dict - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict """ speculative_execution_policy = None """ An instance of :class:`.policies.SpeculativeExecutionPolicy` Defaults to :class:`.NoSpeculativeExecutionPolicy` if not specified """ # indicates if lbp was set explicitly or uses default values _load_balancing_policy_explicit = False def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=10.0, row_factory=named_tuple_factory, speculative_execution_policy=None): if load_balancing_policy is _NOT_SET: self._load_balancing_policy_explicit = False self.load_balancing_policy = default_lbp_factory() else: self._load_balancing_policy_explicit = True self.load_balancing_policy = load_balancing_policy self.retry_policy = retry_policy or RetryPolicy() self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level self.request_timeout = request_timeout self.row_factory = row_factory self.speculative_execution_policy = speculative_execution_policy or NoSpeculativeExecutionPolicy() class ProfileManager(object): def __init__(self): self.profiles = dict() def _profiles_without_explicit_lbps(self): names = (profile_name for profile_name, profile in self.profiles.items() if not profile._load_balancing_policy_explicit) return tuple( 'EXEC_PROFILE_DEFAULT' if n is EXEC_PROFILE_DEFAULT else n for n in names ) def distance(self, host): distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) return HostDistance.LOCAL if HostDistance.LOCAL in distances else \ HostDistance.REMOTE if HostDistance.REMOTE in distances else \ HostDistance.IGNORED def populate(self, cluster, hosts): for p in self.profiles.values(): p.load_balancing_policy.populate(cluster, hosts) def check_supported(self): for p in self.profiles.values(): p.load_balancing_policy.check_supported() def on_up(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_up(host) def on_down(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_down(host) def on_add(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_add(host) def on_remove(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_remove(host) @property def default(self): """ internal-only; no checks are done because this entry is populated on cluster init """ return self.profiles[EXEC_PROFILE_DEFAULT] EXEC_PROFILE_DEFAULT = object() """ Key for the ``Cluster`` default execution profile, used when no other profile is selected in ``Session.execute(execution_profile)``. Use this as the key in ``Cluster(execution_profiles)`` to override the default profile. """ class _ConfigMode(object): UNCOMMITTED = 0 LEGACY = 1 PROFILES = 2 class Cluster(object): """ The main class to use when interacting with a Cassandra cluster. Typically, one instance of this class will be created for each separate Cassandra cluster that your application interacts with. Example usage:: >>> from cassandra.cluster import Cluster >>> cluster = Cluster(['192.168.1.1', '192.168.1.2']) >>> session = cluster.connect() >>> session.execute("CREATE KEYSPACE ...") >>> ... >>> cluster.shutdown() ``Cluster`` and ``Session`` also provide context management functions which implicitly handle shutdown when leaving scope. """ contact_points = ['127.0.0.1'] """ The list of contact points to try connecting for cluster discovery. Defaults to loopback interface. Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit local_dc set (as is the default), the DC is chosen from an arbitrary host in contact_points. In this case, contact_points should contain only nodes from a single, local DC. Note: In the next major version, if you specify contact points, you will also be required to also explicitly specify a load-balancing policy. This change will help prevent cases where users had hard-to-debug issues surrounding unintuitive default load-balancing policy behavior. """ # tracks if contact_points was set explicitly or with default values _contact_points_explicit = None port = 9042 """ The server-side port to open connections to. Defaults to 9042. """ cql_version = None """ If a specific version of CQL should be used, this may be set to that string version. Otherwise, the highest CQL version supported by the server will be automatically used. """ protocol_version = ProtocolVersion.V4 """ The maximum version of the native protocol to use. See :class:`.ProtocolVersion` for more information about versions. If not set in the constructor, the driver will automatically downgrade version based on a negotiation with the server, but it is most efficient to set this to the maximum supported by your version of Cassandra. Setting this will also prevent conflicting versions negotiated if your cluster is upgraded. """ allow_beta_protocol_version = False no_compact = False """ Setting true injects a flag in all messages that makes the server accept and use "beta" protocol version. Used for testing new protocol features incrementally before the new version is complete. """ compression = True """ Controls compression for communications between the driver and Cassandra. If left as the default of :const:`True`, either lz4 or snappy compression may be used, depending on what is supported by both the driver and Cassandra. If both are fully supported, lz4 will be preferred. You may also set this to 'snappy' or 'lz4' to request that specific compression type. Setting this to :const:`False` disables compression. """ _auth_provider = None _auth_provider_callable = None @property def auth_provider(self): """ When :attr:`~.Cluster.protocol_version` is 2 or higher, this should be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`, such as :class:`~.PlainTextAuthProvider`. When :attr:`~.Cluster.protocol_version` is 1, this should be a function that accepts one argument, the IP address of a node, and returns a dict of credentials for that node. When not using authentication, this should be left as :const:`None`. """ return self._auth_provider @auth_provider.setter # noqa def auth_provider(self, value): if not value: self._auth_provider = value return try: self._auth_provider_callable = value.new_authenticator except AttributeError: if self.protocol_version > 1: raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider " "interface when protocol_version >= 2") elif not callable(value): raise TypeError("auth_provider must be callable when protocol_version == 1") self._auth_provider_callable = value self._auth_provider = value _load_balancing_policy = None @property def load_balancing_policy(self): """ An instance of :class:`.policies.LoadBalancingPolicy` or one of its subclasses. .. versionchanged:: 2.6.0 Defaults to :class:`~.TokenAwarePolicy` (:class:`~.DCAwareRoundRobinPolicy`). when using CPython (where the murmur3 extension is available). :class:`~.DCAwareRoundRobinPolicy` otherwise. Default local DC will be chosen from contact points. **Please see** :class:`~.DCAwareRoundRobinPolicy` **for a discussion on default behavior with respect to DC locality and remote nodes.** """ return self._load_balancing_policy @load_balancing_policy.setter def load_balancing_policy(self, lbp): if self._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Cluster.load_balancing_policy while using Configuration Profiles. Set this in a profile instead.") self._load_balancing_policy = lbp self._config_mode = _ConfigMode.LEGACY @property def _default_load_balancing_policy(self): return self.profile_manager.default.load_balancing_policy reconnection_policy = ExponentialReconnectionPolicy(1.0, 600.0) """ An instance of :class:`.policies.ReconnectionPolicy`. Defaults to an instance of :class:`.ExponentialReconnectionPolicy` with a base delay of one second and a max delay of ten minutes. """ _default_retry_policy = RetryPolicy() @property def default_retry_policy(self): """ A default :class:`.policies.RetryPolicy` instance to use for all :class:`.Statement` objects which do not have a :attr:`~.Statement.retry_policy` explicitly set. """ return self._default_retry_policy @default_retry_policy.setter def default_retry_policy(self, policy): if self._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Cluster.default_retry_policy while using Configuration Profiles. Set this in a profile instead.") self._default_retry_policy = policy self._config_mode = _ConfigMode.LEGACY conviction_policy_factory = SimpleConvictionPolicy """ A factory function which creates instances of :class:`.policies.ConvictionPolicy`. Defaults to :class:`.policies.SimpleConvictionPolicy`. """ address_translator = IdentityTranslator() """ :class:`.policies.AddressTranslator` instance to be used in translating server node addresses to driver connection addresses. """ connect_to_remote_hosts = True """ If left as :const:`True`, hosts that are considered :attr:`~.HostDistance.REMOTE` by the :attr:`~.Cluster.load_balancing_policy` will have a connection opened to them. Otherwise, they will not have a connection opened to them. Note that the default load balancing policy ignores remote hosts by default. .. versionadded:: 2.1.0 """ metrics_enabled = False """ Whether or not metric collection is enabled. If enabled, :attr:`.metrics` will be an instance of :class:`~cassandra.metrics.Metrics`. """ metrics = None """ An instance of :class:`cassandra.metrics.Metrics` if :attr:`.metrics_enabled` is :const:`True`, else :const:`None`. """ ssl_options = None """ A optional dict which will be used as kwargs for ``ssl.wrap_socket()`` when new sockets are created. This should be used when client encryption is enabled in Cassandra. By default, a ``ca_certs`` value should be supplied (the value should be a string pointing to the location of the CA certs file), and you probably want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLSv1`` to match Cassandra's default protocol. .. versionchanged:: 3.3.0 In addition to ``wrap_socket`` kwargs, clients may also specify ``'check_hostname': True`` to verify the cert hostname as outlined in RFC 2818 and RFC 6125. Note that this requires the certificate to be transferred, so should almost always require the option ``'cert_reqs': ssl.CERT_REQUIRED``. Note also that this functionality was not built into Python standard library until (2.7.9, 3.2). To enable this mechanism in earlier versions, patch ``ssl.match_hostname`` - with a custom or `back-ported function `_. + with a custom or `back-ported function `_. """ sockopts = None """ An optional list of tuples which will be used as arguments to ``socket.setsockopt()`` for all created sockets. Note: some drivers find setting TCPNODELAY beneficial in the context of their execution model. It was not found generally beneficial for this driver. To try with your own workload, set ``sockopts = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]`` """ max_schema_agreement_wait = 10 """ The maximum duration (in seconds) that the driver will wait for schema agreement across the cluster. Defaults to ten seconds. If set <= 0, the driver will bypass schema agreement waits altogether. """ metadata = None """ An instance of :class:`cassandra.metadata.Metadata`. """ connection_class = DefaultConnection """ This determines what event loop system will be used for managing I/O with Cassandra. These are the current options: * :class:`cassandra.io.asyncorereactor.AsyncoreConnection` * :class:`cassandra.io.libevreactor.LibevConnection` * :class:`cassandra.io.eventletreactor.EventletConnection` (requires monkey-patching - see doc for details) * :class:`cassandra.io.geventreactor.GeventConnection` (requires monkey-patching - see doc for details) * :class:`cassandra.io.twistedreactor.TwistedConnection` * EXPERIMENTAL: :class:`cassandra.io.asyncioreactor.AsyncioConnection` By default, ``AsyncoreConnection`` will be used, which uses the ``asyncore`` module in the Python standard library. If ``libev`` is installed, ``LibevConnection`` will be used instead. If ``gevent`` or ``eventlet`` monkey-patching is detected, the corresponding connection class will be used automatically. ``AsyncioConnection``, which uses the ``asyncio`` module in the Python standard library, is also available, but currently experimental. Note that it requires ``asyncio`` features that were only introduced in the 3.4 line in 3.4.6, and in the 3.5 line in 3.5.1. """ control_connection_timeout = 2.0 """ A timeout, in seconds, for queries made by the control connection, such as querying the current schema and information about nodes in the cluster. If set to :const:`None`, there will be no timeout for these queries. """ idle_heartbeat_interval = 30 """ Interval, in seconds, on which to heartbeat idle connections. This helps keep connections open through network devices that expire idle connections. It also helps discover bad connections early in low-traffic scenarios. Setting to zero disables heartbeats. """ idle_heartbeat_timeout = 30 """ Timeout, in seconds, on which the heartbeat wait for idle connection responses. Lowering this value can help to discover bad connections earlier. """ schema_event_refresh_window = 2 """ Window, in seconds, within which a schema component will be refreshed after receiving a schema_change event. The driver delays a random amount of time in the range [0.0, window) before executing the refresh. This serves two purposes: 1.) Spread the refresh for deployments with large fanout from C* to client tier, preventing a 'thundering herd' problem with many clients refreshing simultaneously. 2.) Remove redundant refreshes. Redundant events arriving within the delay period are discarded, and only one refresh is executed. Setting this to zero will execute refreshes immediately. Setting this negative will disable schema refreshes in response to push events (refreshes will still occur in response to schema change responses to DDL statements executed by Sessions of this Cluster). """ topology_event_refresh_window = 10 """ Window, in seconds, within which the node and token list will be refreshed after receiving a topology_change event. Setting this to zero will execute refreshes immediately. Setting this negative will disable node refreshes in response to push events. See :attr:`.schema_event_refresh_window` for discussion of rationale """ status_event_refresh_window = 2 """ Window, in seconds, within which the driver will start the reconnect after receiving a status_change event. Setting this to zero will connect immediately. This is primarily used to avoid 'thundering herd' in deployments with large fanout from cluster to clients. When nodes come up, clients attempt to reprepare prepared statements (depending on :attr:`.reprepare_on_up`), and establish connection pools. This can cause a rush of connections and queries if not mitigated with this factor. """ prepare_on_all_hosts = True """ Specifies whether statements should be prepared on all hosts, or just one. This can reasonably be disabled on long-running applications with numerous clients preparing statements on startup, where a randomized initial condition of the load balancing policy can be expected to distribute prepares from different clients across the cluster. """ reprepare_on_up = True """ Specifies whether all known prepared statements should be prepared on a node when it comes up. May be used to avoid overwhelming a node on return, or if it is supposed that the node was only marked down due to network. If statements are not reprepared, they are prepared on the first execution, causing an extra roundtrip for one or more client requests. """ connect_timeout = 5 """ Timeout, in seconds, for creating new connections. This timeout covers the entire connection negotiation, including TCP establishment, options passing, and authentication. """ timestamp_generator = None """ An object, shared between all sessions created by this cluster instance, that generates timestamps when client-side timestamp generation is enabled. By default, each :class:`Cluster` uses a new :class:`~.MonotonicTimestampGenerator`. Applications can set this value for custom timestamp behavior. See the documentation for :meth:`Session.timestamp_generator`. """ @property def schema_metadata_enabled(self): """ Flag indicating whether internal schema metadata is updated. When disabled, the driver does not populate Cluster.metadata.keyspaces on connect, or on schema change events. This can be used to speed initial connection, and reduce load on client and server during operation. Turning this off gives away token aware request routing, and programmatic inspection of the metadata model. """ return self.control_connection._schema_meta_enabled @schema_metadata_enabled.setter def schema_metadata_enabled(self, enabled): self.control_connection._schema_meta_enabled = bool(enabled) @property def token_metadata_enabled(self): """ Flag indicating whether internal token metadata is updated. When disabled, the driver does not query node token information on connect, or on topology change events. This can be used to speed initial connection, and reduce load on client and server during operation. It is most useful in large clusters using vnodes, where the token map can be expensive to compute. Turning this off gives away token aware request routing, and programmatic inspection of the token ring. """ return self.control_connection._token_meta_enabled @token_metadata_enabled.setter def token_metadata_enabled(self, enabled): self.control_connection._token_meta_enabled = bool(enabled) profile_manager = None _config_mode = _ConfigMode.UNCOMMITTED sessions = None control_connection = None scheduler = None executor = None is_shutdown = False _is_setup = False _prepared_statements = None _prepared_statement_lock = None _idle_heartbeat = None _protocol_version_explicit = False _discount_down_events = True _user_types = None """ A map of {keyspace: {type_name: UserType}} """ _listeners = None _listener_lock = None def __init__(self, contact_points=_NOT_SET, port=9042, compression=True, auth_provider=None, load_balancing_policy=None, reconnection_policy=None, default_retry_policy=None, conviction_policy_factory=None, metrics_enabled=False, connection_class=None, ssl_options=None, sockopts=None, cql_version=None, protocol_version=_NOT_SET, executor_threads=2, max_schema_agreement_wait=10, control_connection_timeout=2.0, idle_heartbeat_interval=30, schema_event_refresh_window=2, topology_event_refresh_window=10, connect_timeout=5, schema_metadata_enabled=True, token_metadata_enabled=True, address_translator=None, status_event_refresh_window=2, prepare_on_all_hosts=True, reprepare_on_up=True, execution_profiles=None, allow_beta_protocol_version=False, timestamp_generator=None, idle_heartbeat_timeout=30, no_compact=False): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as extablishing connection pools or refreshing metadata. Any of the mutable Cluster attributes may be set as keyword arguments to the constructor. """ if contact_points is not None: if contact_points is _NOT_SET: self._contact_points_explicit = False contact_points = ['127.0.0.1'] else: self._contact_points_explicit = True if isinstance(contact_points, six.string_types): raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings") if None in contact_points: raise ValueError("contact_points should not contain None (it can resolve to localhost)") self.contact_points = contact_points self.port = port - self.contact_points_resolved = [endpoint[4][0] for a in self.contact_points - for endpoint in socket.getaddrinfo(a, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM)] + self.contact_points_resolved = _resolve_contact_points(self.contact_points, + self.port) self.compression = compression if protocol_version is not _NOT_SET: self.protocol_version = protocol_version self._protocol_version_explicit = True self.allow_beta_protocol_version = allow_beta_protocol_version self.no_compact = no_compact self.auth_provider = auth_provider if load_balancing_policy is not None: if isinstance(load_balancing_policy, type): raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class") self.load_balancing_policy = load_balancing_policy else: self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode if reconnection_policy is not None: if isinstance(reconnection_policy, type): raise TypeError("reconnection_policy should not be a class, it should be an instance of that class") self.reconnection_policy = reconnection_policy if default_retry_policy is not None: if isinstance(default_retry_policy, type): raise TypeError("default_retry_policy should not be a class, it should be an instance of that class") self.default_retry_policy = default_retry_policy if conviction_policy_factory is not None: if not callable(conviction_policy_factory): raise ValueError("conviction_policy_factory must be callable") self.conviction_policy_factory = conviction_policy_factory if address_translator is not None: if isinstance(address_translator, type): raise TypeError("address_translator should not be a class, it should be an instance of that class") self.address_translator = address_translator if connection_class is not None: self.connection_class = connection_class if timestamp_generator is not None: if not callable(timestamp_generator): raise ValueError("timestamp_generator must be callable") self.timestamp_generator = timestamp_generator else: self.timestamp_generator = MonotonicTimestampGenerator() self.profile_manager = ProfileManager() self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(self.load_balancing_policy, self.default_retry_policy, Session._default_consistency_level, Session._default_serial_consistency_level, Session._default_timeout, Session._row_factory) # legacy mode if either of these is not default if load_balancing_policy or default_retry_policy: if execution_profiles: raise ValueError("Clusters constructed with execution_profiles should not specify legacy parameters " "load_balancing_policy or default_retry_policy. Configure this in a profile instead.") self._config_mode = _ConfigMode.LEGACY warn("Legacy execution parameters will be removed in 4.0. Consider using " "execution profiles.", DeprecationWarning) else: if execution_profiles: self.profile_manager.profiles.update(execution_profiles) self._config_mode = _ConfigMode.PROFILES if self._contact_points_explicit: if self._config_mode is _ConfigMode.PROFILES: default_lbp_profiles = self.profile_manager._profiles_without_explicit_lbps() if default_lbp_profiles: log.warning( 'Cluster.__init__ called with contact_points ' 'specified, but load-balancing policies are not ' 'specified in some ExecutionProfiles. In the next ' 'major version, this will raise an error; please ' 'specify a load-balancing policy. ' '(contact_points = {cp}, ' 'EPs without explicit LBPs = {eps})' ''.format(cp=contact_points, eps=default_lbp_profiles)) else: if load_balancing_policy is None: log.warning( 'Cluster.__init__ called with contact_points ' 'specified, but no load_balancing_policy. In the next ' 'major version, this will raise an error; please ' 'specify a load-balancing policy. ' '(contact_points = {cp}, lbp = {lbp})' ''.format(cp=contact_points, lbp=load_balancing_policy)) self.metrics_enabled = metrics_enabled self.ssl_options = ssl_options self.sockopts = sockopts self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait self.control_connection_timeout = control_connection_timeout self.idle_heartbeat_interval = idle_heartbeat_interval self.idle_heartbeat_timeout = idle_heartbeat_timeout self.schema_event_refresh_window = schema_event_refresh_window self.topology_event_refresh_window = topology_event_refresh_window self.status_event_refresh_window = status_event_refresh_window self.connect_timeout = connect_timeout self.prepare_on_all_hosts = prepare_on_all_hosts self.reprepare_on_up = reprepare_on_up self._listeners = set() self._listener_lock = Lock() # let Session objects be GC'ed (and shutdown) when the user no longer # holds a reference. self.sessions = WeakSet() self.metadata = Metadata() self.control_connection = None self._prepared_statements = WeakValueDictionary() self._prepared_statement_lock = Lock() self._user_types = defaultdict(dict) self._min_requests_per_connection = { HostDistance.LOCAL: DEFAULT_MIN_REQUESTS, HostDistance.REMOTE: DEFAULT_MIN_REQUESTS } self._max_requests_per_connection = { HostDistance.LOCAL: DEFAULT_MAX_REQUESTS, HostDistance.REMOTE: DEFAULT_MAX_REQUESTS } self._core_connections_per_host = { HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST } self._max_connections_per_host = { HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST } self.executor = ThreadPoolExecutor(max_workers=executor_threads) self.scheduler = _Scheduler(self.executor) self._lock = RLock() if self.metrics_enabled: from cassandra.metrics import Metrics self.metrics = Metrics(weakref.proxy(self)) self.control_connection = ControlConnection( self, self.control_connection_timeout, self.schema_event_refresh_window, self.topology_event_refresh_window, self.status_event_refresh_window, schema_metadata_enabled, token_metadata_enabled) def register_user_type(self, keyspace, user_type, klass): """ Registers a class to use to represent a particular user-defined type. Query parameters for this user-defined type will be assumed to be instances of `klass`. Result sets for this user-defined type will be instances of `klass`. If no class is registered for a user-defined type, a namedtuple will be used for result sets, and non-prepared statements may not encode parameters for this type correctly. `keyspace` is the name of the keyspace that the UDT is defined in. `user_type` is the string name of the UDT to register the mapping for. `klass` should be a class with attributes whose names match the fields of the user-defined type. The constructor must accepts kwargs for each of the fields in the UDT. This method should only be called after the type has been created within Cassandra. Example:: cluster = Cluster(protocol_version=3) session = cluster.connect() session.set_keyspace('mykeyspace') session.execute("CREATE TYPE address (street text, zipcode int)") session.execute("CREATE TABLE users (id int PRIMARY KEY, location address)") # create a class to map to the "address" UDT class Address(object): def __init__(self, street, zipcode): self.street = street self.zipcode = zipcode cluster.register_user_type('mykeyspace', 'address', Address) # insert a row using an instance of Address session.execute("INSERT INTO users (id, location) VALUES (%s, %s)", (0, Address("123 Main St.", 78723))) # results will include Address instances results = session.execute("SELECT * FROM users") row = results[0] print row.id, row.location.street, row.location.zipcode """ if self.protocol_version < 3: log.warning("User Type serialization is only supported in native protocol version 3+ (%d in use). " "CQL encoding for simple statements will still work, but named tuples will " "be returned when reading type %s.%s.", self.protocol_version, keyspace, user_type) self._user_types[keyspace][user_type] = klass for session in tuple(self.sessions): session.user_type_registered(keyspace, user_type, klass) UserType.evict_udt_class(keyspace, user_type) def add_execution_profile(self, name, profile, pool_wait_timeout=5): """ Adds an :class:`.ExecutionProfile` to the cluster. This makes it available for use by ``name`` in :meth:`.Session.execute` and :meth:`.Session.execute_async`. This method will raise if the profile already exists. Normally profiles will be injected at cluster initialization via ``Cluster(execution_profiles)``. This method provides a way of adding them dynamically. Adding a new profile updates the connection pools according to the specified ``load_balancing_policy``. By default, this method will wait up to five seconds for the pool creation to complete, so the profile can be used immediately upon return. This behavior can be controlled using ``pool_wait_timeout`` (see `concurrent.futures.wait `_ for timeout semantics). """ if not isinstance(profile, ExecutionProfile): raise TypeError("profile must be an instance of ExecutionProfile") if self._config_mode == _ConfigMode.LEGACY: raise ValueError("Cannot add execution profiles when legacy parameters are set explicitly.") if name in self.profile_manager.profiles: - raise ValueError("Profile %s already exists") + raise ValueError("Profile {} already exists".format(name)) contact_points_but_no_lbp = ( self._contact_points_explicit and not profile._load_balancing_policy_explicit) if contact_points_but_no_lbp: log.warning( 'Tried to add an ExecutionProfile with name {name}. ' '{self} was explicitly configured with contact_points, but ' '{ep} was not explicitly configured with a ' 'load_balancing_policy. In the next major version, trying to ' 'add an ExecutionProfile without an explicitly configured LBP ' 'to a cluster with explicitly configured contact_points will ' 'raise an exception; please specify a load-balancing policy ' 'in the ExecutionProfile.' ''.format(name=repr(name), self=self, ep=profile)) self.profile_manager.profiles[name] = profile profile.load_balancing_policy.populate(self, self.metadata.all_hosts()) # on_up after populate allows things like DCA LBP to choose default local dc for host in filter(lambda h: h.is_up, self.metadata.all_hosts()): profile.load_balancing_policy.on_up(host) futures = set() for session in tuple(self.sessions): futures.update(session.update_created_pools()) _, not_done = wait_futures(futures, pool_wait_timeout) if not_done: raise OperationTimedOut("Failed to create all new connection pools in the %ss timeout.") - def get_min_requests_per_connection(self, host_distance): return self._min_requests_per_connection[host_distance] def set_min_requests_per_connection(self, host_distance, min_requests): """ Sets a threshold for concurrent requests per connection, below which connections will be considered for disposal (down to core connections; see :meth:`~Cluster.set_core_connections_per_host`). Pertains to connection pool management in protocol versions {1,2}. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_min_requests_per_connection() only has an effect " "when using protocol_version 1 or 2.") if min_requests < 0 or min_requests > 126 or \ min_requests >= self._max_requests_per_connection[host_distance]: raise ValueError("min_requests must be 0-126 and less than the max_requests for this host_distance (%d)" % (self._min_requests_per_connection[host_distance],)) self._min_requests_per_connection[host_distance] = min_requests def get_max_requests_per_connection(self, host_distance): return self._max_requests_per_connection[host_distance] def set_max_requests_per_connection(self, host_distance, max_requests): """ Sets a threshold for concurrent requests per connection, above which new connections will be created to a host (up to max connections; see :meth:`~Cluster.set_max_connections_per_host`). Pertains to connection pool management in protocol versions {1,2}. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_max_requests_per_connection() only has an effect " "when using protocol_version 1 or 2.") if max_requests < 1 or max_requests > 127 or \ max_requests <= self._min_requests_per_connection[host_distance]: raise ValueError("max_requests must be 1-127 and greater than the min_requests for this host_distance (%d)" % (self._min_requests_per_connection[host_distance],)) self._max_requests_per_connection[host_distance] = max_requests def get_core_connections_per_host(self, host_distance): """ Gets the minimum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for :attr:`~HostDistance.REMOTE`. This property is ignored if :attr:`~.Cluster.protocol_version` is 3 or higher. """ return self._core_connections_per_host[host_distance] def set_core_connections_per_host(self, host_distance, core_connections): """ Sets the minimum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for :attr:`~HostDistance.REMOTE`. Protocol version 1 and 2 are limited in the number of concurrent requests they can send per connection. The driver implements connection pooling to support higher levels of concurrency. If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) and using this will result in an :exc:`~.UnsupporteOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_core_connections_per_host() only has an effect " "when using protocol_version 1 or 2.") old = self._core_connections_per_host[host_distance] self._core_connections_per_host[host_distance] = core_connections if old < core_connections: self._ensure_core_connections() def get_max_connections_per_host(self, host_distance): """ Gets the maximum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 8 for :attr:`~HostDistance.LOCAL` and 2 for :attr:`~HostDistance.REMOTE`. This property is ignored if :attr:`~.Cluster.protocol_version` is 3 or higher. """ return self._max_connections_per_host[host_distance] def set_max_connections_per_host(self, host_distance, max_connections): """ Sets the maximum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for :attr:`~HostDistance.REMOTE`. If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) and using this will result in an :exc:`~.UnsupporteOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_max_connections_per_host() only has an effect " "when using protocol_version 1 or 2.") self._max_connections_per_host[host_distance] = max_connections def connection_factory(self, address, *args, **kwargs): """ Called to create a new connection with proper configuration. Intended for internal use only. """ kwargs = self._make_connection_kwargs(address, kwargs) return self.connection_class.factory(address, self.connect_timeout, *args, **kwargs) def _make_connection_factory(self, host, *args, **kwargs): kwargs = self._make_connection_kwargs(host.address, kwargs) return partial(self.connection_class.factory, host.address, self.connect_timeout, *args, **kwargs) def _make_connection_kwargs(self, address, kwargs_dict): if self._auth_provider_callable: kwargs_dict.setdefault('authenticator', self._auth_provider_callable(address)) kwargs_dict.setdefault('port', self.port) kwargs_dict.setdefault('compression', self.compression) kwargs_dict.setdefault('sockopts', self.sockopts) kwargs_dict.setdefault('ssl_options', self.ssl_options) kwargs_dict.setdefault('cql_version', self.cql_version) kwargs_dict.setdefault('protocol_version', self.protocol_version) kwargs_dict.setdefault('user_type_map', self._user_types) kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version) kwargs_dict.setdefault('no_compact', self.no_compact) return kwargs_dict def protocol_downgrade(self, host_addr, previous_version): if self._protocol_version_explicit: raise DriverException("ProtocolError returned from server while using explicitly set client protocol_version %d" % (previous_version,)) new_version = ProtocolVersion.get_lower_supported(previous_version) if new_version < ProtocolVersion.MIN_SUPPORTED: raise DriverException( "Cannot downgrade protocol version below minimum supported version: %d" % (ProtocolVersion.MIN_SUPPORTED,)) log.warning("Downgrading core protocol version from %d to %d for %s. " "To avoid this, it is best practice to explicitly set Cluster(protocol_version) to the version supported by your cluster. " "http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_addr) self.protocol_version = new_version def connect(self, keyspace=None, wait_for_all_pools=False): """ Creates and returns a new :class:`~.Session` object. If `keyspace` is specified, that keyspace will be the default keyspace for operations on the ``Session``. """ with self._lock: if self.is_shutdown: raise DriverException("Cluster is already shut down") if not self._is_setup: log.debug("Connecting to cluster, contact points: %s; protocol version: %s", self.contact_points, self.protocol_version) self.connection_class.initialize_reactor() _register_cluster_shutdown(self) for address in self.contact_points_resolved: host, new = self.add_host(address, signal=False) if new: host.set_up() for listener in self.listeners: listener.on_add(host) self.profile_manager.populate( weakref.proxy(self), self.metadata.all_hosts()) self.load_balancing_policy.populate( weakref.proxy(self), self.metadata.all_hosts() ) try: self.control_connection.connect() # we set all contact points up for connecting, but we won't infer state after this for address in self.contact_points_resolved: h = self.metadata.get_host(address) if h and self.profile_manager.distance(h) == HostDistance.IGNORED: h.is_up = None log.debug("Control connection created") except Exception: log.exception("Control connection failed to connect, " "shutting down Cluster:") self.shutdown() raise self.profile_manager.check_supported() # todo: rename this method if self.idle_heartbeat_interval: self._idle_heartbeat = ConnectionHeartbeat( self.idle_heartbeat_interval, self.get_connection_holders, timeout=self.idle_heartbeat_timeout ) self._is_setup = True session = self._new_session(keyspace) if wait_for_all_pools: wait_futures(session._initial_connect_futures) return session def get_connection_holders(self): holders = [] for s in tuple(self.sessions): holders.extend(s.get_pools()) holders.append(self.control_connection) return holders def shutdown(self): """ Closes all sessions and connection associated with this Cluster. To ensure all connections are properly closed, **you should always call shutdown() on a Cluster instance when you are done with it**. Once shutdown, a Cluster should not be used for any purpose. """ with self._lock: if self.is_shutdown: return else: self.is_shutdown = True if self._idle_heartbeat: self._idle_heartbeat.stop() self.scheduler.shutdown() self.control_connection.shutdown() for session in tuple(self.sessions): session.shutdown() self.executor.shutdown() _discard_cluster_shutdown(self) def __enter__(self): return self def __exit__(self, *args): self.shutdown() def _new_session(self, keyspace): session = Session(self, self.metadata.all_hosts(), keyspace) self._session_register_user_types(session) self.sessions.add(session) return session def _session_register_user_types(self, session): for keyspace, type_map in six.iteritems(self._user_types): for udt_name, klass in six.iteritems(type_map): session.user_type_registered(keyspace, udt_name, klass) def _cleanup_failed_on_up_handling(self, host): self.profile_manager.on_down(host) self.control_connection.on_down(host) for session in tuple(self.sessions): session.remove_pool(host) self._start_reconnector(host, is_host_addition=False) def _on_up_future_completed(self, host, futures, results, lock, finished_future): with lock: futures.discard(finished_future) try: results.append(finished_future.result()) except Exception as exc: results.append(exc) if futures: return try: # all futures have completed at this point for exc in [f for f in results if isinstance(f, Exception)]: log.error("Unexpected failure while marking node %s up:", host, exc_info=exc) self._cleanup_failed_on_up_handling(host) return if not all(results): log.debug("Connection pool could not be created, not marking node %s up", host) self._cleanup_failed_on_up_handling(host) return log.info("Connection pools established for node %s", host) # mark the host as up and notify all listeners host.set_up() for listener in self.listeners: listener.on_up(host) finally: with host.lock: host._currently_handling_node_up = False # see if there are any pools to add or remove now that the host is marked up for session in tuple(self.sessions): session.update_created_pools() def on_up(self, host): """ Intended for internal use only. """ if self.is_shutdown: return log.debug("Waiting to acquire lock for handling up status of node %s", host) with host.lock: if host._currently_handling_node_up: log.debug("Another thread is already handling up status of node %s", host) return if host.is_up: log.debug("Host %s was already marked up", host) return host._currently_handling_node_up = True log.debug("Starting to handle up status of node %s", host) have_future = False futures = set() try: log.info("Host %s may be up; will prepare queries and open connection pool", host) reconnector = host.get_and_set_reconnection_handler(None) if reconnector: log.debug("Now that host %s is up, cancelling the reconnection handler", host) reconnector.cancel() if self.profile_manager.distance(host) != HostDistance.IGNORED: self._prepare_all_queries(host) log.debug("Done preparing all queries for host %s, ", host) for session in tuple(self.sessions): session.remove_pool(host) log.debug("Signalling to load balancing policies that host %s is up", host) self.profile_manager.on_up(host) log.debug("Signalling to control connection that host %s is up", host) self.control_connection.on_up(host) log.debug("Attempting to open new connection pools for host %s", host) futures_lock = Lock() futures_results = [] callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: have_future = True future.add_done_callback(callback) futures.add(future) except Exception: log.exception("Unexpected failure handling node %s being marked up:", host) for future in futures: future.cancel() self._cleanup_failed_on_up_handling(host) with host.lock: host._currently_handling_node_up = False raise else: if not have_future: with host.lock: host.set_up() host._currently_handling_node_up = False # for testing purposes return futures def _start_reconnector(self, host, is_host_addition): if self.profile_manager.distance(host) == HostDistance.IGNORED: return schedule = self.reconnection_policy.new_schedule() # in order to not hold references to this Cluster open and prevent # proper shutdown when the program ends, we'll just make a closure # of the current Cluster attributes to create new Connections with conn_factory = self._make_connection_factory(host) reconnector = _HostReconnectionHandler( host, conn_factory, is_host_addition, self.on_add, self.on_up, self.scheduler, schedule, host.get_and_set_reconnection_handler, new_handler=None) old_reconnector = host.get_and_set_reconnection_handler(reconnector) if old_reconnector: log.debug("Old host reconnector found for %s, cancelling", host) old_reconnector.cancel() log.debug("Starting reconnector for host %s", host) reconnector.start() @run_in_executor def on_down(self, host, is_host_addition, expect_host_to_be_down=False): """ Intended for internal use only. """ if self.is_shutdown: return with host.lock: was_up = host.is_up # ignore down signals if we have open pools to the host # this is to avoid closing pools when a control connection host became isolated if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: connected = False for session in tuple(self.sessions): pool_states = session.get_pool_state() pool_state = pool_states.get(host) if pool_state: connected |= pool_state['open_count'] > 0 if connected: return host.set_down() if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): return log.warning("Host %s has been marked down", host) self.profile_manager.on_down(host) self.control_connection.on_down(host) for session in tuple(self.sessions): session.on_down(host) for listener in self.listeners: listener.on_down(host) self._start_reconnector(host, is_host_addition) def on_add(self, host, refresh_nodes=True): if self.is_shutdown: return log.debug("Handling new host %r and notifying listeners", host) distance = self.profile_manager.distance(host) if distance != HostDistance.IGNORED: self._prepare_all_queries(host) log.debug("Done preparing queries for new host %r", host) self.profile_manager.on_add(host) self.control_connection.on_add(host, refresh_nodes) if distance == HostDistance.IGNORED: log.debug("Not adding connection pool for new host %r because the " "load balancing policy has marked it as IGNORED", host) self._finalize_add(host, set_up=False) return futures_lock = Lock() futures_results = [] futures = set() def future_completed(future): with futures_lock: futures.discard(future) try: futures_results.append(future.result()) except Exception as exc: futures_results.append(exc) if futures: return log.debug('All futures have completed for added host %s', host) for exc in [f for f in futures_results if isinstance(f, Exception)]: log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) return if not all(futures_results): log.warning("Connection pool could not be created, not marking node %s up", host) return self._finalize_add(host) have_future = False for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=True) if future is not None: have_future = True futures.add(future) future.add_done_callback(future_completed) if not have_future: self._finalize_add(host) def _finalize_add(self, host, set_up=True): if set_up: host.set_up() for listener in self.listeners: listener.on_add(host) # see if there are any pools to add or remove now that the host is marked up for session in tuple(self.sessions): session.update_created_pools() def on_remove(self, host): if self.is_shutdown: return log.debug("Removing host %s", host) host.set_down() self.profile_manager.on_remove(host) for session in tuple(self.sessions): session.on_remove(host) for listener in self.listeners: listener.on_remove(host) self.control_connection.on_remove(host) def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): is_down = host.signal_connection_failure(connection_exc) if is_down: self.on_down(host, is_host_addition, expect_host_to_be_down) return is_down def add_host(self, address, datacenter=None, rack=None, signal=True, refresh_nodes=True): """ Called when adding initial contact points and when the control connection subsequently discovers a new node. Returns a Host instance, and a flag indicating whether it was new in the metadata. Intended for internal use only. """ host, new = self.metadata.add_or_return_host(Host(address, self.conviction_policy_factory, datacenter, rack)) if new and signal: log.info("New Cassandra host %r discovered", host) self.on_add(host, refresh_nodes) return host, new def remove_host(self, host): """ Called when the control connection observes that a node has left the ring. Intended for internal use only. """ if host and self.metadata.remove_host(host): log.info("Cassandra host %s removed", host) self.on_remove(host) def register_listener(self, listener): """ Adds a :class:`cassandra.policies.HostStateListener` subclass instance to the list of listeners to be notified when a host is added, removed, marked up, or marked down. """ with self._listener_lock: self._listeners.add(listener) def unregister_listener(self, listener): """ Removes a registered listener. """ with self._listener_lock: self._listeners.remove(listener) @property def listeners(self): with self._listener_lock: return self._listeners.copy() def _ensure_core_connections(self): """ If any host has fewer than the configured number of core connections open, attempt to open connections until that number is met. """ for session in tuple(self.sessions): for pool in tuple(session._pools.values()): pool.ensure_core_connections() @staticmethod def _validate_refresh_schema(keyspace, table, usertype, function, aggregate): if any((table, usertype, function, aggregate)): if not keyspace: raise ValueError("keyspace is required to refresh specific sub-entity {table, usertype, function, aggregate}") if sum(1 for e in (table, usertype, function) if e) > 1: raise ValueError("{table, usertype, function, aggregate} are mutually exclusive") @staticmethod def _target_type_from_refresh_args(keyspace, table, usertype, function, aggregate): if aggregate: return SchemaTargetType.AGGREGATE elif function: return SchemaTargetType.FUNCTION elif usertype: return SchemaTargetType.TYPE elif table: return SchemaTargetType.TABLE elif keyspace: return SchemaTargetType.KEYSPACE return None def get_control_connection_host(self): """ Returns the control connection host metadata. """ connection = self.control_connection._connection host = connection.host if connection else None return self.metadata.get_host(host) if host else None def refresh_schema_metadata(self, max_schema_agreement_wait=None): """ Synchronously refresh all schema metadata. By default, the timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait` and :attr:`~.Cluster.control_connection_timeout`. Passing max_schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`. Setting max_schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately. An Exception is raised if schema refresh fails for any reason. """ if not self.control_connection.refresh_schema(schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("Schema metadata was not refreshed. See log for details.") def refresh_keyspace_metadata(self, keyspace, max_schema_agreement_wait=None): """ Synchronously refresh keyspace metadata. This applies to keyspace-level information such as replication and durability settings. It does not refresh tables, types, etc. contained in the keyspace. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.KEYSPACE, keyspace=keyspace, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("Keyspace metadata was not refreshed. See log for details.") def refresh_table_metadata(self, keyspace, table, max_schema_agreement_wait=None): """ Synchronously refresh table metadata. This applies to a table, and any triggers or indexes attached to the table. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=table, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("Table metadata was not refreshed. See log for details.") def refresh_materialized_view_metadata(self, keyspace, view, max_schema_agreement_wait=None): """ Synchronously refresh materialized view metadata. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=view, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("View metadata was not refreshed. See log for details.") def refresh_user_type_metadata(self, keyspace, user_type, max_schema_agreement_wait=None): """ Synchronously refresh user defined type metadata. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TYPE, keyspace=keyspace, type=user_type, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("User Type metadata was not refreshed. See log for details.") def refresh_user_function_metadata(self, keyspace, function, max_schema_agreement_wait=None): """ Synchronously refresh user defined function metadata. ``function`` is a :class:`cassandra.UserFunctionDescriptor`. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.FUNCTION, keyspace=keyspace, function=function, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("User Function metadata was not refreshed. See log for details.") def refresh_user_aggregate_metadata(self, keyspace, aggregate, max_schema_agreement_wait=None): """ Synchronously refresh user defined aggregate metadata. ``aggregate`` is a :class:`cassandra.UserAggregateDescriptor`. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.AGGREGATE, keyspace=keyspace, aggregate=aggregate, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("User Aggregate metadata was not refreshed. See log for details.") def refresh_nodes(self, force_token_rebuild=False): """ Synchronously refresh the node list and token metadata `force_token_rebuild` can be used to rebuild the token map metadata, even if no new nodes are discovered. An Exception is raised if node refresh fails for any reason. """ if not self.control_connection.refresh_node_list_and_token_map(force_token_rebuild): raise DriverException("Node list was not refreshed. See log for details.") def set_meta_refresh_enabled(self, enabled): """ *Deprecated:* set :attr:`~.Cluster.schema_metadata_enabled` :attr:`~.Cluster.token_metadata_enabled` instead Sets a flag to enable (True) or disable (False) all metadata refresh queries. This applies to both schema and node topology. Disabling this is useful to minimize refreshes during multiple changes. Meta refresh must be enabled for the driver to become aware of any cluster topology changes or schema updates. """ warn("Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0. Set " "Cluster.schema_metadata_enabled and Cluster.token_metadata_enabled instead.", DeprecationWarning) self.schema_metadata_enabled = enabled self.token_metadata_enabled = enabled @classmethod def _send_chunks(cls, connection, host, chunks, set_keyspace=False): for ks_chunk in chunks: messages = [PrepareMessage(query=s.query_string, keyspace=s.keyspace if set_keyspace else None) for s in ks_chunk] # TODO: make this timeout configurable somehow? responses = connection.wait_for_responses(*messages, timeout=5.0, fail_on_error=False) for success, response in responses: if not success: log.debug("Got unexpected response when preparing " "statement on host %s: %r", host, response) def _prepare_all_queries(self, host): if not self._prepared_statements or not self.reprepare_on_up: return log.debug("Preparing all known prepared statements against host %s", host) connection = None try: connection = self.connection_factory(host.address) statements = self._prepared_statements.values() if ProtocolVersion.uses_keyspace_flag(self.protocol_version): # V5 protocol and higher, no need to set the keyspace chunks = [] for i in range(0, len(statements), 10): chunks.append(statements[i:i + 10]) self._send_chunks(connection, host, chunks, True) else: for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): if keyspace is not None: connection.set_keyspace_blocking(keyspace) # prepare 10 statements at a time ks_statements = list(ks_statements) chunks = [] for i in range(0, len(ks_statements), 10): chunks.append(ks_statements[i:i + 10]) self._send_chunks(connection, host, chunks) log.debug("Done preparing all known prepared statements against host %s", host) except OperationTimedOut as timeout: log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout) except (ConnectionException, socket.error) as exc: log.warning("Error trying to prepare all statements on host %s: %r", host, exc) except Exception: log.exception("Error trying to prepare all statements on host %s", host) finally: if connection: connection.close() def add_prepared(self, query_id, prepared_statement): with self._prepared_statement_lock: self._prepared_statements[query_id] = prepared_statement class Session(object): """ A collection of connection pools for each host in the cluster. Instances of this class should not be created directly, only using :meth:`.Cluster.connect()`. Queries and statements can be executed through ``Session`` instances using the :meth:`~.Session.execute()` and :meth:`~.Session.execute_async()` methods. Example usage:: >>> session = cluster.connect() >>> session.set_keyspace("mykeyspace") >>> session.execute("SELECT * FROM mycf") """ cluster = None hosts = None keyspace = None is_shutdown = False _row_factory = staticmethod(named_tuple_factory) @property def row_factory(self): """ The format to return row results in. By default, each returned row will be a named tuple. You can alternatively use any of the following: - :func:`cassandra.query.tuple_factory` - return a result row as a tuple - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple - :func:`cassandra.query.dict_factory` - return a result row as a dict - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict """ return self._row_factory @row_factory.setter def row_factory(self, rf): self._validate_set_legacy_config('row_factory', rf) _default_timeout = 10.0 @property def default_timeout(self): """ A default timeout, measured in seconds, for queries executed through :meth:`.execute()` or :meth:`.execute_async()`. This default may be overridden with the `timeout` parameter for either of those methods. Setting this to :const:`None` will cause no timeouts to be set by default. Please see :meth:`.ResponseFuture.result` for details on the scope and effect of this timeout. .. versionadded:: 2.0.0 """ return self._default_timeout @default_timeout.setter def default_timeout(self, timeout): self._validate_set_legacy_config('default_timeout', timeout) _default_consistency_level = ConsistencyLevel.LOCAL_ONE @property def default_consistency_level(self): """ *Deprecated:* use execution profiles instead The default :class:`~ConsistencyLevel` for operations executed through this session. This default may be overridden by setting the :attr:`~.Statement.consistency_level` on individual statements. .. versionadded:: 1.2.0 .. versionchanged:: 3.0.0 default changed from ONE to LOCAL_ONE """ return self._default_consistency_level @default_consistency_level.setter def default_consistency_level(self, cl): """ *Deprecated:* use execution profiles instead """ warn("Setting the consistency level at the session level will be removed in 4.0. Consider using " "execution profiles and setting the desired consitency level to the EXEC_PROFILE_DEFAULT profile." , DeprecationWarning) self._validate_set_legacy_config('default_consistency_level', cl) _default_serial_consistency_level = None @property def default_serial_consistency_level(self): """ The default :class:`~ConsistencyLevel` for serial phase of conditional updates executed through this session. This default may be overridden by setting the :attr:`~.Statement.serial_consistency_level` on individual statements. Only valid for ``protocol_version >= 2``. """ return self._default_serial_consistency_level @default_serial_consistency_level.setter def default_serial_consistency_level(self, cl): self._validate_set_legacy_config('default_serial_consistency_level', cl) max_trace_wait = 2.0 """ The maximum amount of time (in seconds) the driver will wait for trace details to be populated server-side for a query before giving up. If the `trace` parameter for :meth:`~.execute()` or :meth:`~.execute_async()` is :const:`True`, the driver will repeatedly attempt to fetch trace details for the query (using exponential backoff) until this limit is hit. If the limit is passed, an error will be logged and the :attr:`.Statement.trace` will be left as :const:`None`. """ default_fetch_size = 5000 """ By default, this many rows will be fetched at a time. Setting this to :const:`None` will disable automatic paging for large query results. The fetch size can be also specified per-query through :attr:`.Statement.fetch_size`. This only takes effect when protocol version 2 or higher is used. See :attr:`.Cluster.protocol_version` for details. .. versionadded:: 2.0.0 """ use_client_timestamp = True """ When using protocol version 3 or higher, write timestamps may be supplied client-side at the protocol level. (Normally they are generated server-side by the coordinator node.) Note that timestamps specified within a CQL query will override this timestamp. .. versionadded:: 2.1.0 """ timestamp_generator = None """ When :attr:`use_client_timestamp` is set, sessions call this object and use the result as the timestamp. (Note that timestamps specified within a CQL query will override this timestamp.) By default, a new :class:`~.MonotonicTimestampGenerator` is created for each :class:`Cluster` instance. Applications can set this value for custom timestamp behavior. For example, an application could share a timestamp generator across :class:`Cluster` objects to guarantee that the application will use unique, increasing timestamps across clusters, or set it to to ``lambda: int(time.time() * 1e6)`` if losing records over clock inconsistencies is acceptable for the application. Custom :attr:`timestamp_generator` s should be callable, and calling them should return an integer representing microseconds since some point in time, typically UNIX epoch. .. versionadded:: 3.8.0 """ - encoder = None """ A :class:`~cassandra.encoder.Encoder` instance that will be used when formatting query parameters for non-prepared statements. This is not used for prepared statements (because prepared statements give the driver more information about what CQL types are expected, allowing it to accept a wider range of python types). The encoder uses a mapping from python types to encoder methods (for specific CQL types). This mapping can be be modified by users as they see fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping values if possible, because they take precautions to avoid injections and properly sanitize data. Example:: cluster = Cluster() session = cluster.connect("mykeyspace") session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple)") session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')]) .. versionadded:: 2.1.0 """ client_protocol_handler = ProtocolHandler """ Specifies a protocol handler that will be used for client-initiated requests (i.e. no internal driver requests). This can be used to override or extend features such as message or type ser/des. The default pure python implementation is :class:`cassandra.protocol.ProtocolHandler`. When compiled with Cython, there are also built-in faster alternatives. See :ref:`faster_deser` """ _lock = None _pools = None _profile_manager = None _metrics = None _request_init_callbacks = None def __init__(self, cluster, hosts, keyspace=None): self.cluster = cluster self.hosts = hosts self.keyspace = keyspace self._lock = RLock() self._pools = {} self._profile_manager = cluster.profile_manager self._metrics = cluster.metrics self._request_init_callbacks = [] self._protocol_version = self.cluster.protocol_version self.encoder = Encoder() # create connection pools in parallel self._initial_connect_futures = set() for host in hosts: future = self.add_or_renew_pool(host, is_host_addition=False) if future: self._initial_connect_futures.add(future) futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) while futures.not_done and not any(f.result() for f in futures.done): futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) if not any(f.result() for f in self._initial_connect_futures): msg = "Unable to connect to any servers" if self.keyspace: msg += " using keyspace '%s'" % self.keyspace raise NoHostAvailable(msg, [h.address for h in hosts]) def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None): """ Execute the given query and synchronously wait for the response. If an error is encountered while executing the query, an Exception will be raised. `query` may be a query string or an instance of :class:`cassandra.query.Statement`. `parameters` may be a sequence or dict of parameters to bind. If a sequence is used, ``%s`` should be used the placeholder for each argument. If a dict is used, ``%(name)s`` style placeholders must be used. `timeout` should specify a floating-point timeout (in seconds) after which an :exc:`.OperationTimedOut` exception will be raised if the query has not completed. If not set, the timeout defaults to :attr:`~.Session.default_timeout`. If set to :const:`None`, there is no timeout. Please see :meth:`.ResponseFuture.result` for details on the scope and effect of this timeout. If `trace` is set to :const:`True`, the query will be sent with tracing enabled. The trace details can be obtained using the returned :class:`.ResultSet` object. `custom_payload` is a :ref:`custom_payload` dict to be passed to the server. If `query` is a Statement with its own custom_payload. The message payload will be a union of the two, with the values specified here taking precedence. `execution_profile` is the execution profile to use for this request. It can be a key to a profile configured via :meth:`Cluster.add_execution_profile` or an instance (from :meth:`Session.execution_profile_clone_update`, for example `paging_state` is an optional paging state, reused from a previous :class:`ResultSet`. """ return self.execute_async(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state).result() def execute_async(self, query, parameters=None, trace=False, custom_payload=None, timeout=_NOT_SET, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None): """ Execute the given query and return a :class:`~.ResponseFuture` object which callbacks may be attached to for asynchronous response delivery. You may also call :meth:`~.ResponseFuture.result()` on the :class:`.ResponseFuture` to synchronously block for results at any time. See :meth:`Session.execute` for parameter definitions. Example usage:: >>> session = cluster.connect() >>> future = session.execute_async("SELECT * FROM mycf") >>> def log_results(results): ... for row in results: ... log.info("Results: %s", row) >>> def log_error(exc): >>> log.error("Operation failed: %s", exc) >>> future.add_callbacks(log_results, log_error) Async execution with blocking wait for results:: >>> future = session.execute_async("SELECT * FROM mycf") >>> # do other stuff... >>> try: ... results = future.result() ... except Exception: ... log.exception("Operation failed:") """ future = self._create_response_future(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state) future._protocol_handler = self.client_protocol_handler self._on_request(future) future.send_request() return future def _create_response_future(self, query, parameters, trace, custom_payload, timeout, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None): """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None if isinstance(query, six.string_types): query = SimpleStatement(query) elif isinstance(query, PreparedStatement): query = query.bind(parameters) if self.cluster._config_mode == _ConfigMode.LEGACY: if execution_profile is not EXEC_PROFILE_DEFAULT: raise ValueError("Cannot specify execution_profile while using legacy parameters.") if timeout is _NOT_SET: timeout = self.default_timeout cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else self.default_serial_consistency_level retry_policy = query.retry_policy or self.cluster.default_retry_policy row_factory = self.row_factory load_balancing_policy = self.cluster.load_balancing_policy spec_exec_policy = None else: execution_profile = self._get_execution_profile(execution_profile) if timeout is _NOT_SET: timeout = execution_profile.request_timeout cl = query.consistency_level if query.consistency_level is not None else execution_profile.consistency_level serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else execution_profile.serial_consistency_level retry_policy = query.retry_policy or execution_profile.retry_policy row_factory = execution_profile.row_factory load_balancing_policy = execution_profile.load_balancing_policy spec_exec_policy = execution_profile.speculative_execution_policy - fetch_size = query.fetch_size if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2: fetch_size = self.default_fetch_size elif self._protocol_version == 1: fetch_size = None start_time = time.time() if self._protocol_version >= 3 and self.use_client_timestamp: timestamp = self.cluster.timestamp_generator() else: timestamp = None if isinstance(query, SimpleStatement): query_string = query.query_string statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None if parameters: query_string = bind_params(query_string, parameters, self.encoder) message = QueryMessage( query_string, cl, serial_cl, fetch_size, timestamp=timestamp, keyspace=statement_keyspace) elif isinstance(query, BoundStatement): prepared_statement = query.prepared_statement message = ExecuteMessage( prepared_statement.query_id, query.values, cl, serial_cl, fetch_size, timestamp=timestamp, skip_meta=bool(prepared_statement.result_metadata), result_metadata_id=prepared_statement.result_metadata_id) elif isinstance(query, BatchStatement): if self._protocol_version < 2: raise UnsupportedOperation( "BatchStatement execution is only supported with protocol version " "2 or higher (supported in Cassandra 2.0 and higher). Consider " "setting Cluster.protocol_version to 2 to support this operation.") statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None message = BatchMessage( query.batch_type, query._statements_and_parameters, cl, serial_cl, timestamp, statement_keyspace) message.tracing = trace message.update_custom_payload(query.custom_payload) message.update_custom_payload(custom_payload) message.allow_beta_protocol_version = self.cluster.allow_beta_protocol_version message.paging_state = paging_state spec_exec_plan = spec_exec_policy.new_plan(query.keyspace or self.keyspace, query) if query.is_idempotent and spec_exec_policy else None return ResponseFuture( self, message, query, timeout, metrics=self._metrics, prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory, load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan) def _get_execution_profile(self, ep): profiles = self.cluster.profile_manager.profiles try: return ep if isinstance(ep, ExecutionProfile) else profiles[ep] except KeyError: raise ValueError("Invalid execution_profile: '%s'; valid profiles are %s" % (ep, profiles.keys())) def execution_profile_clone_update(self, ep, **kwargs): """ Returns a clone of the ``ep`` profile. ``kwargs`` can be specified to update attributes of the returned profile. This is a shallow clone, so any objects referenced by the profile are shared. This means Load Balancing Policy is maintained by inclusion in the active profiles. It also means updating any other rich objects will be seen by the active profile. In cases where this is not desirable, be sure to replace the instance instead of manipulating the shared object. """ clone = copy(self._get_execution_profile(ep)) for attr, value in kwargs.items(): setattr(clone, attr, value) return clone def add_request_init_listener(self, fn, *args, **kwargs): """ Adds a callback with arguments to be called when any request is created. It will be invoked as `fn(response_future, *args, **kwargs)` after each client request is created, and before the request is sent\*. This can be used to create extensions by adding result callbacks to the response future. \* where `response_future` is the :class:`.ResponseFuture` for the request. Note that the init callback is done on the client thread creating the request, so you may need to consider synchronization if you have multiple threads. Any callbacks added to the response future will be executed on the event loop thread, so the normal advice about minimizing cycles and avoiding blocking apply (see Note in :meth:`.ResponseFuture.add_callbacks`. See `this example `_ in the source tree for an example. """ self._request_init_callbacks.append((fn, args, kwargs)) def remove_request_init_listener(self, fn, *args, **kwargs): """ Removes a callback and arguments from the list. See :meth:`.Session.add_request_init_listener`. """ self._request_init_callbacks.remove((fn, args, kwargs)) def _on_request(self, response_future): for fn, args, kwargs in self._request_init_callbacks: fn(response_future, *args, **kwargs) def prepare(self, query, custom_payload=None, keyspace=None): """ Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement` instance which can be used as follows:: >>> session = cluster.connect("mykeyspace") >>> query = "INSERT INTO users (id, name, age) VALUES (?, ?, ?)" >>> prepared = session.prepare(query) >>> session.execute(prepared, (user.id, user.name, user.age)) Or you may bind values to the prepared statement ahead of time:: >>> prepared = session.prepare(query) >>> bound_stmt = prepared.bind((user.id, user.name, user.age)) >>> session.execute(bound_stmt) Of course, prepared statements may (and should) be reused:: >>> prepared = session.prepare(query) >>> for user in users: ... bound = prepared.bind((user.id, user.name, user.age)) ... session.execute(bound) Alternatively, if :attr:`~.Cluster.protocol_version` is 5 or higher (requires Cassandra 4.0+), the keyspace can be specified as a parameter. This will allow you to avoid specifying the keyspace in the query without specifying a keyspace in :meth:`~.Cluster.connect`. It even will let you prepare and use statements against a keyspace other than the one originally specified on connection: >>> analyticskeyspace_prepared = session.prepare( ... "INSERT INTO user_activity id, last_activity VALUES (?, ?)", ... keyspace="analyticskeyspace") # note the different keyspace **Important**: PreparedStatements should be prepared only once. Preparing the same query more than once will likely affect performance. `custom_payload` is a key value map to be passed along with the prepare message. See :ref:`custom_payload`. """ message = PrepareMessage(query=query, keyspace=keyspace) future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) try: future.send_request() query_id, bind_metadata, pk_indexes, result_metadata, result_metadata_id = future.result() except Exception: log.exception("Error preparing query:") raise prepared_keyspace = keyspace if keyspace else None prepared_statement = PreparedStatement.from_message( query_id, bind_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace, self._protocol_version, result_metadata, result_metadata_id) prepared_statement.custom_payload = future.custom_payload self.cluster.add_prepared(query_id, prepared_statement) if self.cluster.prepare_on_all_hosts: host = future._current_host try: self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace) except Exception: log.exception("Error preparing query on all hosts:") return prepared_statement def prepare_on_all_hosts(self, query, excluded_host, keyspace=None): """ Prepare the given query on all hosts, excluding ``excluded_host``. Intended for internal use only. """ futures = [] for host in tuple(self._pools.keys()): if host != excluded_host and host.is_up: future = ResponseFuture(self, PrepareMessage(query=query, keyspace=keyspace), None, self.default_timeout) # we don't care about errors preparing against specific hosts, # since we can always prepare them as needed when the prepared # statement is used. Just log errors and continue on. try: request_id = future._query(host) except Exception: log.exception("Error preparing query for host %s:", host) continue if request_id is None: # the error has already been logged by ResponsFuture log.debug("Failed to prepare query for host %s: %r", host, future._errors.get(host)) continue futures.append((host, future)) for host, future in futures: try: future.result() except Exception: log.exception("Error preparing query for host %s:", host) def shutdown(self): """ Close all connections. ``Session`` instances should not be used for any purpose after being shutdown. """ with self._lock: if self.is_shutdown: return else: self.is_shutdown = True # PYTHON-673. If shutdown was called shortly after session init, avoid # a race by cancelling any initial connection attempts haven't started, # then blocking on any that have. for future in self._initial_connect_futures: future.cancel() wait_futures(self._initial_connect_futures) for pool in tuple(self._pools.values()): pool.shutdown() def __enter__(self): return self def __exit__(self, *args): self.shutdown() def __del__(self): try: # Ensure all connections are closed, in case the Session object is deleted by the GC self.shutdown() except: # Ignore all errors. Shutdown errors can be caught by the user # when cluster.shutdown() is called explicitly. pass def add_or_renew_pool(self, host, is_host_addition): """ For internal use only. """ distance = self._profile_manager.distance(host) if distance == HostDistance.IGNORED: return None def run_add_or_renew_pool(): try: if self._protocol_version >= 3: new_pool = HostConnection(host, distance, self) else: new_pool = HostConnectionPool(host, distance, self) except AuthenticationFailed as auth_exc: conn_exc = ConnectionException(str(auth_exc), host=host) self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) return False except Exception as conn_exc: log.warning("Failed to create connection pool for new host %s:", host, exc_info=conn_exc) # the host itself will still be marked down, so we need to pass # a special flag to make sure the reconnector is created self.cluster.signal_connection_failure( host, conn_exc, is_host_addition, expect_host_to_be_down=True) return False previous = self._pools.get(host) with self._lock: while new_pool._keyspace != self.keyspace: self._lock.release() set_keyspace_event = Event() errors_returned = [] def callback(pool, errors): errors_returned.extend(errors) set_keyspace_event.set() new_pool._set_keyspace_for_all_conns(self.keyspace, callback) set_keyspace_event.wait(self.cluster.connect_timeout) if not set_keyspace_event.is_set() or errors_returned: log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) self.cluster.on_down(host, is_host_addition) new_pool.shutdown() self._lock.acquire() return False self._lock.acquire() self._pools[host] = new_pool log.debug("Added pool for host %s to session", host) if previous: previous.shutdown() return True return self.submit(run_add_or_renew_pool) def remove_pool(self, host): pool = self._pools.pop(host, None) if pool: log.debug("Removed connection pool for %r", host) return self.submit(pool.shutdown) else: return None def update_created_pools(self): """ When the set of live nodes change, the loadbalancer will change its mind on host distances. It might change it on the node that came/left but also on other nodes (for instance, if a node dies, another previously ignored node may be now considered). This method ensures that all hosts for which a pool should exist have one, and hosts that shouldn't don't. For internal use only. """ futures = set() for host in self.cluster.metadata.all_hosts(): distance = self._profile_manager.distance(host) pool = self._pools.get(host) future = None if not pool or pool.is_shutdown: # we don't eagerly set is_up on previously ignored hosts. None is included here # to allow us to attempt connections to hosts that have gone from ignored to something # else. if distance != HostDistance.IGNORED and host.is_up in (True, None): future = self.add_or_renew_pool(host, False) elif distance != pool.host_distance: # the distance has changed if distance == HostDistance.IGNORED: future = self.remove_pool(host) else: pool.host_distance = distance if future: futures.add(future) return futures def on_down(self, host): """ Called by the parent Cluster instance when a node is marked down. Only intended for internal use. """ future = self.remove_pool(host) if future: future.add_done_callback(lambda f: self.update_created_pools()) def on_remove(self, host): """ Internal """ self.on_down(host) def set_keyspace(self, keyspace): """ Set the default keyspace for all queries made through this Session. This operation blocks until complete. """ self.execute('USE %s' % (protect_name(keyspace),)) def _set_keyspace_for_all_pools(self, keyspace, callback): """ Asynchronously sets the keyspace on all pools. When all pools have set all of their connections, `callback` will be called with a dictionary of all errors that occurred, keyed by the `Host` that they occurred against. """ with self._lock: self.keyspace = keyspace remaining_callbacks = set(self._pools.values()) errors = {} if not remaining_callbacks: callback(errors) return def pool_finished_setting_keyspace(pool, host_errors): remaining_callbacks.remove(pool) if host_errors: errors[pool.host] = host_errors if not remaining_callbacks: callback(host_errors) for pool in tuple(self._pools.values()): pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace) def user_type_registered(self, keyspace, user_type, klass): """ Called by the parent Cluster instance when the user registers a new mapping from a user-defined type to a class. Intended for internal use only. """ try: ks_meta = self.cluster.metadata.keyspaces[keyspace] except KeyError: raise UserTypeDoesNotExist( 'Keyspace %s does not exist or has not been discovered by the driver' % (keyspace,)) try: type_meta = ks_meta.user_types[user_type] except KeyError: raise UserTypeDoesNotExist( 'User type %s does not exist in keyspace %s' % (user_type, keyspace)) field_names = type_meta.field_names if six.PY2: # go from unicode to string to avoid decode errors from implicit # decode when formatting non-ascii values field_names = [fn.encode('utf-8') for fn in field_names] def encode(val): return '{ %s }' % ' , '.join('%s : %s' % ( field_name, self.encoder.cql_encode_all_types(getattr(val, field_name, None)) ) for field_name in field_names) self.encoder.mapping[klass] = encode def submit(self, fn, *args, **kwargs): """ Internal """ if not self.is_shutdown: return self.cluster.executor.submit(fn, *args, **kwargs) def get_pool_state(self): return dict((host, pool.get_state()) for host, pool in tuple(self._pools.items())) def get_pools(self): return self._pools.values() def _validate_set_legacy_config(self, attr_name, value): if self.cluster._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." % (attr_name,)) setattr(self, '_' + attr_name, value) self.cluster._config_mode = _ConfigMode.LEGACY class UserTypeDoesNotExist(Exception): """ An attempt was made to use a user-defined type that does not exist. .. versionadded:: 2.1.0 """ pass class _ControlReconnectionHandler(_ReconnectionHandler): """ Internal """ def __init__(self, control_connection, *args, **kwargs): _ReconnectionHandler.__init__(self, *args, **kwargs) self.control_connection = weakref.proxy(control_connection) def try_reconnect(self): return self.control_connection._reconnect_internal() def on_reconnection(self, connection): self.control_connection._set_new_connection(connection) def on_exception(self, exc, next_delay): # TODO only overridden to add logging, so add logging if isinstance(exc, AuthenticationFailed): return False else: log.debug("Error trying to reconnect control connection: %r", exc) return True def _watch_callback(obj_weakref, method_name, *args, **kwargs): """ A callback handler for the ControlConnection that tolerates weak references. """ obj = obj_weakref() if obj is None: return getattr(obj, method_name)(*args, **kwargs) def _clear_watcher(conn, expiring_weakref): """ Called when the ControlConnection object is about to be finalized. This clears watchers on the underlying Connection object. """ try: conn.control_conn_disposed() except ReferenceError: pass class ControlConnection(object): """ Internal """ _SELECT_PEERS = "SELECT * FROM system.peers" _SELECT_PEERS_NO_TOKENS = "SELECT peer, data_center, rack, rpc_address, release_version, schema_version FROM system.peers" _SELECT_LOCAL = "SELECT * FROM system.local WHERE key='local'" _SELECT_LOCAL_NO_TOKENS = "SELECT cluster_name, data_center, rack, partitioner, release_version, schema_version FROM system.local WHERE key='local'" _SELECT_SCHEMA_PEERS = "SELECT peer, rpc_address, schema_version FROM system.peers" _SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'" _is_shutdown = False _timeout = None _protocol_version = None _schema_event_refresh_window = None _topology_event_refresh_window = None _status_event_refresh_window = None _schema_meta_enabled = True _token_meta_enabled = True # for testing purposes _time = time def __init__(self, cluster, timeout, schema_event_refresh_window, topology_event_refresh_window, status_event_refresh_window, schema_meta_enabled=True, token_meta_enabled=True): # use a weak reference to allow the Cluster instance to be GC'ed (and # shutdown) since implementing __del__ disables the cycle detector self._cluster = weakref.proxy(cluster) self._connection = None self._timeout = timeout self._schema_event_refresh_window = schema_event_refresh_window self._topology_event_refresh_window = topology_event_refresh_window self._status_event_refresh_window = status_event_refresh_window self._schema_meta_enabled = schema_meta_enabled self._token_meta_enabled = token_meta_enabled self._lock = RLock() self._schema_agreement_lock = Lock() self._reconnection_handler = None self._reconnection_lock = RLock() self._event_schedule_times = {} def connect(self): if self._is_shutdown: return self._protocol_version = self._cluster.protocol_version self._set_new_connection(self._reconnect_internal()) def _set_new_connection(self, conn): """ Replace existing connection (if there is one) and close it. """ with self._lock: old = self._connection self._connection = conn if old: log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() def _reconnect_internal(self): """ Tries to connect to each host in the query plan until one succeeds or every attempt fails. If successful, a new Connection will be returned. Otherwise, :exc:`NoHostAvailable` will be raised with an "errors" arg that is a dict mapping host addresses to the exception that was raised when an attempt was made to open a connection to that host. """ errors = {} lbp = ( self._cluster.load_balancing_policy if self._cluster._config_mode == _ConfigMode.LEGACY else self._cluster._default_load_balancing_policy ) for host in lbp.make_query_plan(): try: return self._try_connect(host) except ConnectionException as exc: errors[host.address] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) self._cluster.signal_connection_failure(host, exc, is_host_addition=False) except Exception as exc: errors[host.address] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) if self._is_shutdown: raise DriverException("[control connection] Reconnection in progress during shutdown") raise NoHostAvailable("Unable to connect to any servers", errors) def _try_connect(self, host): """ Creates a new Connection, registers for pushed events, and refreshes node/token and schema metadata. """ log.debug("[control connection] Opening new connection to %s", host) while True: try: connection = self._cluster.connection_factory(host.address, is_control_connection=True) if self._is_shutdown: connection.close() raise DriverException("Reconnecting during shutdown") break except ProtocolVersionUnsupported as e: self._cluster.protocol_downgrade(host.address, e.startup_version) log.debug("[control connection] Established new connection %r, " "registering watchers and refreshing schema and topology", connection) # use weak references in both directions # _clear_watcher will be called when this ControlConnection is about to be finalized # _watch_callback will get the actual callback from the Connection and relay it to # this object (after a dereferencing a weakref) self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) try: connection.register_watchers({ "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') }, register_timeout=self._timeout) sel_peers = self._SELECT_PEERS if self._token_meta_enabled else self._SELECT_PEERS_NO_TOKENS sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE) local_query = QueryMessage(query=sel_local, consistency_level=ConsistencyLevel.ONE) shared_results = connection.wait_for_responses( peers_query, local_query, timeout=self._timeout) self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results) self._refresh_schema(connection, preloaded_results=shared_results, schema_agreement_wait=-1) except Exception: connection.close() raise return connection def reconnect(self): if self._is_shutdown: return self._submit(self._reconnect) def _reconnect(self): log.debug("[control connection] Attempting to reconnect") try: self._set_new_connection(self._reconnect_internal()) except NoHostAvailable: # make a retry schedule (which includes backoff) schedule = self._cluster.reconnection_policy.new_schedule() with self._reconnection_lock: # cancel existing reconnection attempts if self._reconnection_handler: self._reconnection_handler.cancel() # when a connection is successfully made, _set_new_connection # will be called with the new connection and then our # _reconnection_handler will be cleared out self._reconnection_handler = _ControlReconnectionHandler( self, self._cluster.scheduler, schedule, self._get_and_set_reconnection_handler, new_handler=None) self._reconnection_handler.start() except Exception: log.debug("[control connection] error reconnecting", exc_info=True) raise def _get_and_set_reconnection_handler(self, new_handler): """ Called by the _ControlReconnectionHandler when a new connection is successfully created. Clears out the _reconnection_handler on this ControlConnection. """ with self._reconnection_lock: old = self._reconnection_handler self._reconnection_handler = new_handler return old def _submit(self, *args, **kwargs): try: if not self._cluster.is_shutdown: return self._cluster.executor.submit(*args, **kwargs) except ReferenceError: pass return None def shutdown(self): # stop trying to reconnect (if we are) with self._reconnection_lock: if self._reconnection_handler: self._reconnection_handler.cancel() with self._lock: if self._is_shutdown: return else: self._is_shutdown = True log.debug("Shutting down control connection") if self._connection: self._connection.close() self._connection = None def refresh_schema(self, force=False, **kwargs): try: if self._connection: return self._refresh_schema(self._connection, force=force, **kwargs) except ReferenceError: pass # our weak reference to the Cluster is no good except Exception: log.debug("[control connection] Error refreshing schema", exc_info=True) self._signal_error() return False def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_wait=None, force=False, **kwargs): if self._cluster.is_shutdown: return False agreed = self.wait_for_schema_agreement(connection, preloaded_results=preloaded_results, wait_time=schema_agreement_wait) if not self._schema_meta_enabled and not force: log.debug("[control connection] Skipping schema refresh because schema metadata is disabled") return False if not agreed: log.debug("Skipping schema refresh due to lack of schema agreement") return False self._cluster.metadata.refresh(connection, self._timeout, **kwargs) return True def refresh_node_list_and_token_map(self, force_token_rebuild=False): try: if self._connection: self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild) return True except ReferenceError: pass # our weak reference to the Cluster is no good except Exception: log.debug("[control connection] Error refreshing node list and token map", exc_info=True) self._signal_error() return False def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, force_token_rebuild=False): if preloaded_results: log.debug("[control connection] Refreshing node list and token map using preloaded results") peers_result = preloaded_results[0] local_result = preloaded_results[1] else: cl = ConsistencyLevel.ONE if not self._token_meta_enabled: log.debug("[control connection] Refreshing node list without token map") sel_peers = self._SELECT_PEERS_NO_TOKENS sel_local = self._SELECT_LOCAL_NO_TOKENS else: log.debug("[control connection] Refreshing node list and token map") sel_peers = self._SELECT_PEERS sel_local = self._SELECT_LOCAL peers_query = QueryMessage(query=sel_peers, consistency_level=cl) local_query = QueryMessage(query=sel_local, consistency_level=cl) peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=self._timeout) peers_result = dict_factory(*peers_result.results) partitioner = None token_map = {} found_hosts = set() if local_result.results: found_hosts.add(connection.host) local_rows = dict_factory(*(local_result.results)) local_row = local_rows[0] cluster_name = local_row["cluster_name"] self._cluster.metadata.cluster_name = cluster_name partitioner = local_row.get("partitioner") tokens = local_row.get("tokens") host = self._cluster.metadata.get_host(connection.host) if host: datacenter = local_row.get("data_center") rack = local_row.get("rack") self._update_location_info(host, datacenter, rack) host.listen_address = local_row.get("listen_address") host.broadcast_address = local_row.get("broadcast_address") host.release_version = local_row.get("release_version") host.dse_version = local_row.get("dse_version") host.dse_workload = local_row.get("workload") if partitioner and tokens: token_map[host] = tokens # Check metadata.partitioner to see if we haven't built anything yet. If # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None for row in peers_result: addr = self._rpc_from_peer_row(row) tokens = row.get("tokens", None) if 'tokens' in row and not tokens: # it was selected, but empty log.warning("Excluding host (%s) with no tokens in system.peers table of %s." % (addr, connection.host)) continue if addr in found_hosts: log.warning("Found multiple hosts with the same rpc_address (%s). Excluding peer %s", addr, row.get("peer")) continue found_hosts.add(addr) host = self._cluster.metadata.get_host(addr) datacenter = row.get("data_center") rack = row.get("rack") if host is None: log.debug("[control connection] Found new host to connect to: %s", addr) host, _ = self._cluster.add_host(addr, datacenter, rack, signal=True, refresh_nodes=False) should_rebuild_token_map = True else: should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) host.broadcast_address = row.get("peer") host.release_version = row.get("release_version") host.dse_version = row.get("dse_version") host.dse_workload = row.get("workload") if partitioner and tokens: token_map[host] = tokens for old_host in self._cluster.metadata.all_hosts(): if old_host.address != connection.host and old_host.address not in found_hosts: should_rebuild_token_map = True log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) self._cluster.remove_host(old_host) log.debug("[control connection] Finished fetching ring info") if partitioner and should_rebuild_token_map: log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) def _update_location_info(self, host, datacenter, rack): if host.datacenter == datacenter and host.rack == rack: return False # If the dc/rack information changes, we need to update the load balancing policy. # For that, we remove and re-add the node against the policy. Not the most elegant, and assumes # that the policy will update correctly, but in practice this should work. self._cluster.profile_manager.on_down(host) host.set_location_info(datacenter, rack) self._cluster.profile_manager.on_up(host) return True def _delay_for_event_type(self, event_type, delay_window): # this serves to order processing correlated events (received within the window) # the window and randomization still have the desired effect of skew across client instances next_time = self._event_schedule_times.get(event_type, 0) now = self._time.time() if now <= next_time: this_time = next_time + 0.01 delay = this_time - now else: delay = random() * delay_window this_time = now + delay self._event_schedule_times[event_type] = this_time return delay def _refresh_nodes_if_not_up(self, addr): """ Used to mitigate refreshes for nodes that are already known. Some versions of the server send superfluous NEW_NODE messages in addition to UP events. """ host = self._cluster.metadata.get_host(addr) if not host or not host.is_up: self.refresh_node_list_and_token_map() def _handle_topology_change(self, event): change_type = event["change_type"] addr = self._translate_address(event["address"][0]) if change_type == "NEW_NODE" or change_type == "MOVED_NODE": if self._topology_event_refresh_window >= 0: delay = self._delay_for_event_type('topology_change', self._topology_event_refresh_window) self._cluster.scheduler.schedule_unique(delay, self._refresh_nodes_if_not_up, addr) elif change_type == "REMOVED_NODE": host = self._cluster.metadata.get_host(addr) self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host) def _handle_status_change(self, event): change_type = event["change_type"] addr = self._translate_address(event["address"][0]) host = self._cluster.metadata.get_host(addr) if change_type == "UP": delay = self._delay_for_event_type('status_change', self._status_event_refresh_window) if host is None: # this is the first time we've seen the node self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) else: self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host) elif change_type == "DOWN": # Note that there is a slight risk we can receive the event late and thus # mark the host down even though we already had reconnected successfully. # But it is unlikely, and don't have too much consequence since we'll try reconnecting # right away, so we favor the detection to make the Host.is_up more accurate. if host is not None: # this will be run by the scheduler self._cluster.on_down(host, is_host_addition=False) def _translate_address(self, addr): return self._cluster.address_translator.translate(addr) def _handle_schema_change(self, event): if self._schema_event_refresh_window < 0: return delay = self._delay_for_event_type('schema_change', self._schema_event_refresh_window) self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event) def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait if total_timeout <= 0: return True # Each schema change typically generates two schema refreshes, one # from the response type and one from the pushed notification. Holding # a lock is just a simple way to cut down on the number of schema queries # we'll make. with self._schema_agreement_lock: if self._is_shutdown: return if not connection: connection = self._connection if preloaded_results: log.debug("[control connection] Attempting to use preloaded results for schema agreement") peers_result = preloaded_results[0] local_result = preloaded_results[1] schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host) if schema_mismatches is None: return True log.debug("[control connection] Waiting for schema agreement") start = self._time.time() elapsed = 0 cl = ConsistencyLevel.ONE schema_mismatches = None while elapsed < total_timeout: peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl) local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl) try: timeout = min(self._timeout, total_timeout - elapsed) peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=timeout) except OperationTimedOut as timeout: log.debug("[control connection] Timed out waiting for " "response during schema agreement check: %s", timeout) elapsed = self._time.time() - start continue except ConnectionShutdown: if self._is_shutdown: log.debug("[control connection] Aborting wait for schema match due to shutdown") return None else: raise schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host) if schema_mismatches is None: return True log.debug("[control connection] Schemas mismatched, trying again") self._time.sleep(0.2) elapsed = self._time.time() - start log.warning("Node %s is reporting a schema disagreement: %s", connection.host, schema_mismatches) return False def _get_schema_mismatches(self, peers_result, local_result, local_address): peers_result = dict_factory(*peers_result.results) versions = defaultdict(set) if local_result.results: local_row = dict_factory(*local_result.results)[0] if local_row.get("schema_version"): versions[local_row.get("schema_version")].add(local_address) for row in peers_result: schema_ver = row.get('schema_version') if not schema_ver: continue addr = self._rpc_from_peer_row(row) peer = self._cluster.metadata.get_host(addr) if peer and peer.is_up is not False: versions[schema_ver].add(addr) if len(versions) == 1: log.debug("[control connection] Schemas match") return None return dict((version, list(nodes)) for version, nodes in six.iteritems(versions)) def _rpc_from_peer_row(self, row): addr = row.get("rpc_address") if not addr or addr in ["0.0.0.0", "::"]: addr = row.get("peer") return self._translate_address(addr) def _signal_error(self): with self._lock: if self._is_shutdown: return # try just signaling the cluster, as this will trigger a reconnect # as part of marking the host down if self._connection and self._connection.is_defunct: host = self._cluster.metadata.get_host(self._connection.host) # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: self._cluster.signal_connection_failure( host, self._connection.last_error, is_host_addition=False) return # if the connection is not defunct or the host already left, reconnect # manually self.reconnect() def on_up(self, host): pass def on_down(self, host): conn = self._connection if conn and conn.host == host.address and \ self._reconnection_handler is None: log.debug("[control connection] Control connection host (%s) is " "considered down, starting reconnection", host) # this will result in a task being submitted to the executor to reconnect self.reconnect() def on_add(self, host, refresh_nodes=True): if refresh_nodes: self.refresh_node_list_and_token_map(force_token_rebuild=True) def on_remove(self, host): c = self._connection if c and c.host == host.address: log.debug("[control connection] Control connection host (%s) is being removed. Reconnecting", host) # refresh will be done on reconnect self.reconnect() else: self.refresh_node_list_and_token_map(force_token_rebuild=True) def get_connections(self): c = getattr(self, '_connection', None) return [c] if c else [] def return_connection(self, connection): if connection is self._connection and (connection.is_defunct or connection.is_closed): self.reconnect() def _stop_scheduler(scheduler, thread): try: if not scheduler.is_shutdown: scheduler.shutdown() except ReferenceError: pass thread.join() class _Scheduler(Thread): _queue = None _scheduled_tasks = None _executor = None is_shutdown = False def __init__(self, executor): self._queue = Queue.PriorityQueue() self._scheduled_tasks = set() self._count = count() self._executor = executor Thread.__init__(self, name="Task Scheduler") self.daemon = True self.start() def shutdown(self): try: log.debug("Shutting down Cluster Scheduler") except AttributeError: # this can happen on interpreter shutdown pass self.is_shutdown = True self._queue.put_nowait((0, 0, None)) self.join() def schedule(self, delay, fn, *args, **kwargs): self._insert_task(delay, (fn, args, tuple(kwargs.items()))) def schedule_unique(self, delay, fn, *args, **kwargs): task = (fn, args, tuple(kwargs.items())) if task not in self._scheduled_tasks: self._insert_task(delay, task) else: log.debug("Ignoring schedule_unique for already-scheduled task: %r", task) def _insert_task(self, delay, task): if not self.is_shutdown: run_at = time.time() + delay self._scheduled_tasks.add(task) self._queue.put_nowait((run_at, next(self._count), task)) else: log.debug("Ignoring scheduled task after shutdown: %r", task) def run(self): while True: if self.is_shutdown: return try: while True: run_at, i, task = self._queue.get(block=True, timeout=None) if self.is_shutdown: if task: log.debug("Not executing scheduled task due to Scheduler shutdown") return if run_at <= time.time(): self._scheduled_tasks.discard(task) fn, args, kwargs = task kwargs = dict(kwargs) future = self._executor.submit(fn, *args, **kwargs) future.add_done_callback(self._log_if_failed) else: self._queue.put_nowait((run_at, i, task)) break except Queue.Empty: pass time.sleep(0.1) def _log_if_failed(self, future): exc = future.exception() if exc: log.warning( "An internally scheduled tasked failed with an unhandled exception:", exc_info=exc) def refresh_schema_and_set_result(control_conn, response_future, connection, **kwargs): try: log.debug("Refreshing schema in response to schema change. " "%s", kwargs) response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs) except Exception: log.exception("Exception refreshing schema in response to schema change:") response_future.session.submit(control_conn.refresh_schema, **kwargs) finally: response_future._set_final_result(None) class ResponseFuture(object): """ An asynchronous response delivery mechanism that is returned from calls to :meth:`.Session.execute_async()`. There are two ways for results to be delivered: - Synchronously, by calling :meth:`.result()` - Asynchronously, by attaching callback and errback functions via :meth:`.add_callback()`, :meth:`.add_errback()`, and :meth:`.add_callbacks()`. """ query = None """ The :class:`~.Statement` instance that is being executed through this :class:`.ResponseFuture`. """ is_schema_agreed = True """ For DDL requests, this may be set ``False`` if the schema agreement poll after the response fails. Always ``True`` for non-DDL requests. """ request_encoded_size = None """ Size of the request message sent """ coordinator_host = None """ The host from which we recieved a response """ attempted_hosts = None """ A list of hosts tried, including all speculative executions, retries, and pages """ session = None row_factory = None message = None default_timeout = None _retry_policy = None _profile_manager = None _req_id = None _final_result = _NOT_SET _col_names = None _col_types = None _final_exception = None _query_traces = None _callbacks = None _errbacks = None _current_host = None _connection = None _query_retries = 0 _start_time = None _metrics = None _paging_state = None _custom_payload = None _warnings = None _timer = None _protocol_handler = ProtocolHandler _spec_execution_plan = NoSpeculativeExecutionPlan() _warned_timeout = False def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None, retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, speculative_execution_plan=None): self.session = session # TODO: normalize handling of retry policy and row factory self.row_factory = row_factory or session.row_factory self._load_balancer = load_balancer or session.cluster._default_load_balancing_policy self.message = message self.query = query self.timeout = timeout self._retry_policy = retry_policy self._metrics = metrics self.prepared_statement = prepared_statement self._callback_lock = Lock() self._start_time = start_time or time.time() self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan self._make_query_plan() self._event = Event() self._errors = {} self._callbacks = [] self._errbacks = [] self.attempted_hosts = [] self._start_timer() @property def _time_remaining(self): if self.timeout is None: return None return (self._start_time + self.timeout) - time.time() def _start_timer(self): if self._timer is None: spec_delay = self._spec_execution_plan.next_execution(self._current_host) if spec_delay >= 0: if self._time_remaining is None or self._time_remaining > spec_delay: self._timer = self.session.cluster.connection_class.create_timer(spec_delay, self._on_speculative_execute) return if self._time_remaining is not None: self._timer = self.session.cluster.connection_class.create_timer(self._time_remaining, self._on_timeout) def _cancel_timer(self): if self._timer: self._timer.cancel() def _on_timeout(self, _attempts=0): """ Called when the request associated with this ResponseFuture times out. This function may reschedule itself. The ``_attempts`` parameter tracks the number of times this has happened. This parameter should only be set in those cases, where ``_on_timeout`` reschedules itself. """ # PYTHON-853: for short timeouts, we sometimes race with our __init__ if self._connection is None and _attempts < 3: self._timer = self.session.cluster.connection_class.create_timer( 0.01, partial(self._on_timeout, _attempts=_attempts + 1) ) return if self._connection is not None: try: self._connection._requests.pop(self._req_id) # This prevents the race condition of the # event loop thread just receiving the waited message # If it arrives after this, it will be ignored except KeyError: return pool = self.session._pools.get(self._current_host) if pool and not pool.is_shutdown: with self._connection.lock: self._connection.request_ids.append(self._req_id) pool.return_connection(self._connection) errors = self._errors if not errors: if self.is_schema_agreed: key = self._current_host.address if self._current_host else 'no host queried before timeout' errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} else: connection = self.session.cluster.control_connection._connection host = connection.host if connection else 'unknown' errors = {host: "Request timed out while waiting for schema agreement. See Session.execute[_async](timeout) and Cluster.max_schema_agreement_wait."} self._set_final_exception(OperationTimedOut(errors, self._current_host)) def _on_speculative_execute(self): self._timer = None if not self._event.is_set(): # PYTHON-836, the speculative queries must be after # the query is sent from the main thread, otherwise the # query from the main thread may raise NoHostAvailable # if the _query_plan has been exhausted by the specualtive queries. # This also prevents a race condition accessing the iterator. # We reschedule this call until the main thread has succeeded # making a query if not self.attempted_hosts: self._timer = self.session.cluster.connection_class.create_timer(0.01, self._on_speculative_execute) return if self._time_remaining is not None: if self._time_remaining <= 0: self._on_timeout() return self.send_request(error_no_hosts=False) self._start_timer() def _make_query_plan(self): # convert the list/generator/etc to an iterator so that subsequent # calls to send_request (which retries may do) will resume where # they last left off self.query_plan = iter(self._load_balancer.make_query_plan(self.session.keyspace, self.query)) def send_request(self, error_no_hosts=True): """ Internal """ # query_plan is an iterator, so this will resume where we last left # off if send_request() is called multiple times for host in self.query_plan: req_id = self._query(host) if req_id is not None: self._req_id = req_id return True if self.timeout is not None and time.time() - self._start_time > self.timeout: self._on_timeout() return True if error_no_hosts: self._set_final_exception(NoHostAvailable( "Unable to complete the operation against any hosts", self._errors)) return False def _query(self, host, message=None, cb=None): if message is None: message = self.message pool = self.session._pools.get(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") return None elif pool.is_shutdown: self._errors[host] = ConnectionException("Pool is shutdown") return None self._current_host = host connection = None try: # TODO get connectTimeout from cluster settings connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] if cb is None: cb = partial(self._set_result, host, connection, pool) self.request_encoded_size = connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message, result_metadata=result_meta) self.attempted_hosts.append(host) return request_id except NoConnectionsAvailable as exc: log.debug("All connections for host %s are at capacity, moving to the next host", host) self._errors[host] = exc return None except Exception as exc: log.debug("Error querying host %s", host, exc_info=True) self._errors[host] = exc if self._metrics is not None: self._metrics.on_connection_error() if connection: pool.return_connection(connection) return None @property def has_more_pages(self): """ Returns :const:`True` if there are more pages left in the query results, :const:`False` otherwise. This should only be checked after the first page has been returned. .. versionadded:: 2.0.0 """ return self._paging_state is not None @property def warnings(self): """ Warnings returned from the server, if any. This will only be set for protocol_version 4+. Warnings may be returned for such things as oversized batches, or too many tombstones in slice queries. Ensure the future is complete before trying to access this property (call :meth:`.result()`, or after callback is invoked). Otherwise it may throw if the response has not been received. """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): raise DriverException("warnings cannot be retrieved before ResponseFuture is finalized") return self._warnings @property def custom_payload(self): """ The custom payload returned from the server, if any. This will only be set by Cassandra servers implementing a custom QueryHandler, and only for protocol_version 4+. Ensure the future is complete before trying to access this property (call :meth:`.result()`, or after callback is invoked). Otherwise it may throw if the response has not been received. :return: :ref:`custom_payload`. """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): raise DriverException("custom_payload cannot be retrieved before ResponseFuture is finalized") return self._custom_payload def start_fetching_next_page(self): """ If there are more pages left in the query result, this asynchronously starts fetching the next page. If there are no pages left, :exc:`.QueryExhausted` is raised. Also see :attr:`.has_more_pages`. This should only be called after the first page has been returned. .. versionadded:: 2.0.0 """ if not self._paging_state: raise QueryExhausted() self._make_query_plan() self.message.paging_state = self._paging_state self._event.clear() self._final_result = _NOT_SET self._final_exception = None self._start_timer() self.send_request() def _reprepare(self, prepare_message, host, connection, pool): cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool) request_id = self._query(host, prepare_message, cb=cb) if request_id is None: # try to submit the original prepared statement on some other host self.send_request() def _set_result(self, host, connection, pool, response): try: self.coordinator_host = host if pool: pool.return_connection(connection) trace_id = getattr(response, 'trace_id', None) if trace_id: if not self._query_traces: self._query_traces = [] self._query_traces.append(QueryTrace(trace_id, self.session)) self._warnings = getattr(response, 'warnings', None) self._custom_payload = getattr(response, 'custom_payload', None) if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_SET_KEYSPACE: session = getattr(self, 'session', None) # since we're running on the event loop thread, we need to # use a non-blocking method for setting the keyspace on # all connections in this session, otherwise the event # loop thread will deadlock waiting for keyspaces to be # set. This uses a callback chain which ends with # self._set_keyspace_completed() being called in the # event loop thread. if session: session._set_keyspace_for_all_pools( response.results, self._set_keyspace_completed) elif response.kind == RESULT_KIND_SCHEMA_CHANGE: # refresh the schema before responding, but do it in another # thread instead of the event loop thread self.is_schema_agreed = False self.session.submit( refresh_schema_and_set_result, self.session.cluster.control_connection, self, connection, **response.results) else: results = getattr(response, 'results', None) if results is not None and response.kind == RESULT_KIND_ROWS: self._paging_state = response.paging_state self._col_types = response.col_types self._col_names = results[0] results = self.row_factory(*results) self._set_final_result(results) elif isinstance(response, ErrorMessage): retry_policy = self._retry_policy if isinstance(response, ReadTimeoutErrorMessage): if self._metrics is not None: self._metrics.on_read_timeout() retry = retry_policy.on_read_timeout( self.query, retry_num=self._query_retries, **response.info) elif isinstance(response, WriteTimeoutErrorMessage): if self._metrics is not None: self._metrics.on_write_timeout() retry = retry_policy.on_write_timeout( self.query, retry_num=self._query_retries, **response.info) elif isinstance(response, UnavailableErrorMessage): if self._metrics is not None: self._metrics.on_unavailable() retry = retry_policy.on_unavailable( self.query, retry_num=self._query_retries, **response.info) elif isinstance(response, OverloadedErrorMessage): if self._metrics is not None: self._metrics.on_other_error() # need to retry against a different host here log.warning("Host %s is overloaded, retrying against a different " "host", host) self._retry(reuse_connection=False, consistency_level=None, host=host) return elif isinstance(response, IsBootstrappingErrorMessage): if self._metrics is not None: self._metrics.on_other_error() # need to retry against a different host here self._retry(reuse_connection=False, consistency_level=None, host=host) return elif isinstance(response, PreparedQueryNotFound): if self.prepared_statement: query_id = self.prepared_statement.query_id assert query_id == response.info, \ "Got different query ID in server response (%s) than we " \ "had before (%s)" % (response.info, query_id) else: query_id = response.info try: prepared_statement = self.session.cluster._prepared_statements[query_id] except KeyError: if not self.prepared_statement: log.error("Tried to execute unknown prepared statement: id=%s", query_id.encode('hex')) self._set_final_exception(response) return else: prepared_statement = self.prepared_statement self.session.cluster._prepared_statements[query_id] = prepared_statement current_keyspace = self._connection.keyspace prepared_keyspace = prepared_statement.keyspace if not ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) \ and prepared_keyspace and current_keyspace != prepared_keyspace: self._set_final_exception( ValueError("The Session's current keyspace (%s) does " "not match the keyspace the statement was " "prepared with (%s)" % (current_keyspace, prepared_keyspace))) return log.debug("Re-preparing unrecognized prepared statement against host %s: %s", host, prepared_statement.query_string) prepared_keyspace = prepared_statement.keyspace \ if ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) else None prepare_message = PrepareMessage(query=prepared_statement.query_string, keyspace=prepared_keyspace) # since this might block, run on the executor to avoid hanging # the event loop thread self.session.submit(self._reprepare, prepare_message, host, connection, pool) return else: if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) return retry_type, consistency = retry if retry_type in (RetryPolicy.RETRY, RetryPolicy.RETRY_NEXT_HOST): self._query_retries += 1 reuse = retry_type == RetryPolicy.RETRY self._retry(reuse, consistency, host) elif retry_type is RetryPolicy.RETHROW: self._set_final_exception(response.to_exception()) else: # IGNORE if self._metrics is not None: self._metrics.on_ignore() self._set_final_result(None) self._errors[host] = response.to_exception() elif isinstance(response, ConnectionException): if self._metrics is not None: self._metrics.on_connection_error() if not isinstance(response, ConnectionShutdown): self._connection.defunct(response) self._retry(reuse_connection=False, consistency_level=None, host=host) elif isinstance(response, Exception): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) else: # we got some other kind of response message msg = "Got unexpected message: %r" % (response,) exc = ConnectionException(msg, host) self._cancel_timer() self._connection.defunct(exc) self._set_final_exception(exc) except Exception as exc: # almost certainly caused by a bug, but we need to set something here log.exception("Unexpected exception while handling result in ResponseFuture:") self._set_final_exception(exc) def _set_keyspace_completed(self, errors): if not errors: self._set_final_result(None) else: self._set_final_exception(ConnectionException( "Failed to set keyspace on all hosts: %s" % (errors,))) def _execute_after_prepare(self, host, connection, pool, response): """ Handle the response to our attempt to prepare a statement. If it succeeded, run the original query again against the same host. """ if pool: pool.return_connection(connection) if self._final_exception: return if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_PREPARED: if self.prepared_statement: # result metadata is the only thing that could have # changed from an alter (_, _, _, self.prepared_statement.result_metadata, new_metadata_id) = response.results if new_metadata_id is not None: self.prepared_statement.result_metadata_id = new_metadata_id # use self._query to re-use the same host and # at the same time properly borrow the connection request_id = self._query(host) if request_id is None: # this host errored out, move on to the next self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response when preparing statement " "on host %s: %s" % (host, response))) elif isinstance(response, ErrorMessage): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) elif isinstance(response, ConnectionException): log.debug("Connection error when preparing statement on host %s: %s", host, response) # try again on a different host, preparing again if necessary self._errors[host] = response self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response type when preparing " "statement on host %s: %s" % (host, response))) def _set_final_result(self, response): self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) with self._callback_lock: self._final_result = response # save off current callbacks inside lock for execution outside it # -- prevents case where _final_result is set, then a callback is # added and executed on the spot, then executed again as a # registered callback to_call = tuple( partial(fn, response, *args, **kwargs) for (fn, args, kwargs) in self._callbacks ) self._event.set() # apply each callback for callback_partial in to_call: callback_partial() def _set_final_exception(self, response): self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) with self._callback_lock: self._final_exception = response # save off current errbacks inside lock for execution outside it -- # prevents case where _final_exception is set, then an errback is # added and executed on the spot, then executed again as a # registered errback to_call = tuple( partial(fn, response, *args, **kwargs) for (fn, args, kwargs) in self._errbacks ) self._event.set() # apply each callback for callback_partial in to_call: callback_partial() def _retry(self, reuse_connection, consistency_level, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation return if self._metrics is not None: self._metrics.on_retry() if consistency_level is not None: self.message.consistency_level = consistency_level # don't retry on the event loop thread self.session.submit(self._retry_task, reuse_connection, host) def _retry_task(self, reuse_connection, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation return if reuse_connection and self._query(host) is not None: return # otherwise, move onto another host self.send_request() def result(self): """ Return the final result or raise an Exception if errors were encountered. If the final result or error has not been set yet, this method will block until it is set, or the timeout set for the request expires. Timeout is specified in the Session request execution functions. If the timeout is exceeded, an :exc:`cassandra.OperationTimedOut` will be raised. This is a client-side timeout. For more information about server-side coordinator timeouts, see :class:`.policies.RetryPolicy`. Example usage:: >>> future = session.execute_async("SELECT * FROM mycf") >>> # do other stuff... >>> try: ... rows = future.result() ... for row in rows: ... ... # process results ... except Exception: ... log.exception("Operation failed:") """ self._event.wait() if self._final_result is not _NOT_SET: return ResultSet(self, self._final_result) else: raise self._final_exception def get_query_trace_ids(self): """ Returns the trace session ids for this future, if tracing was enabled (does not fetch trace data). """ return [trace.trace_id for trace in self._query_traces] def get_query_trace(self, max_wait=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ Fetches and returns the query trace of the last response, or `None` if tracing was not enabled. Note that this may raise an exception if there are problems retrieving the trace details from Cassandra. If the trace is not available after `max_wait`, :exc:`cassandra.query.TraceUnavailable` will be raised. If the ResponseFuture is not done (async execution) and you try to retrieve the trace, :exc:`cassandra.query.TraceUnavailable` will be raised. `query_cl` is the consistency level used to poll the trace tables. """ if self._final_result is _NOT_SET and self._final_exception is None: raise TraceUnavailable( "Trace information was not available. The ResponseFuture is not done.") if self._query_traces: return self._get_query_trace(len(self._query_traces) - 1, max_wait, query_cl) def get_all_query_traces(self, max_wait_per=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ Fetches and returns the query traces for all query pages, if tracing was enabled. See note in :meth:`~.get_query_trace` regarding possible exceptions. """ if self._query_traces: return [self._get_query_trace(i, max_wait_per, query_cl) for i in range(len(self._query_traces))] return [] def _get_query_trace(self, i, max_wait, query_cl): trace = self._query_traces[i] if not trace.events: trace.populate(max_wait=max_wait, query_cl=query_cl) return trace def add_callback(self, fn, *args, **kwargs): """ Attaches a callback function to be called when the final results arrive. By default, `fn` will be called with the results as the first and only argument. If `*args` or `**kwargs` are supplied, they will be passed through as additional positional or keyword arguments to `fn`. If an error is hit while executing the operation, a callback attached here will not be called. Use :meth:`.add_errback()` or :meth:`add_callbacks()` if you wish to handle that case. If the final result has already been seen when this method is called, the callback will be called immediately (before this method returns). Note: in the case that the result is not available when the callback is added, the callback is executed by IO event thread. This means that the callback should not block or attempt further synchronous requests, because no further IO will be processed until the callback returns. **Important**: if the callback you attach results in an exception being raised, **the exception will be ignored**, so please ensure your callback handles all error cases that you care about. Usage example:: >>> session = cluster.connect("mykeyspace") >>> def handle_results(rows, start_time, should_log=False): ... if should_log: ... log.info("Total time: %f", time.time() - start_time) ... ... >>> future = session.execute_async("SELECT * FROM users") >>> future.add_callback(handle_results, time.time(), should_log=True) """ run_now = False with self._callback_lock: # Always add fn to self._callbacks, even when we're about to # execute it, to prevent races with functions like # start_fetching_next_page that reset _final_result self._callbacks.append((fn, args, kwargs)) if self._final_result is not _NOT_SET: run_now = True if run_now: fn(self._final_result, *args, **kwargs) return self def add_errback(self, fn, *args, **kwargs): """ Like :meth:`.add_callback()`, but handles error cases. An Exception instance will be passed as the first positional argument to `fn`. """ run_now = False with self._callback_lock: # Always add fn to self._errbacks, even when we're about to execute # it, to prevent races with functions like start_fetching_next_page # that reset _final_exception self._errbacks.append((fn, args, kwargs)) if self._final_exception: run_now = True if run_now: fn(self._final_exception, *args, **kwargs) return self def add_callbacks(self, callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_kwargs=None): """ A convenient combination of :meth:`.add_callback()` and :meth:`.add_errback()`. Example usage:: >>> session = cluster.connect() >>> query = "SELECT * FROM mycf" >>> future = session.execute_async(query) >>> def log_results(results, level='debug'): ... for row in results: ... log.log(level, "Result: %s", row) >>> def log_error(exc, query): ... log.error("Query '%s' failed: %s", query, exc) >>> future.add_callbacks( ... callback=log_results, callback_kwargs={'level': 'info'}, ... errback=log_error, errback_args=(query,)) """ self.add_callback(callback, *callback_args, **(callback_kwargs or {})) self.add_errback(errback, *errback_args, **(errback_kwargs or {})) def clear_callbacks(self): with self._callback_lock: self._callbacks = [] self._errbacks = [] def __str__(self): result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result return "" \ % (self.query, self._req_id, result, self._final_exception, self.coordinator_host) __repr__ = __str__ class QueryExhausted(Exception): """ Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and there are no more pages. You can check :attr:`.ResponseFuture.has_more_pages` before calling to avoid this. .. versionadded:: 2.0.0 """ pass class ResultSet(object): """ An iterator over the rows from a query result. Also supplies basic equality and indexing methods for backward-compatability. These methods materialize the entire result set (loading all pages), and should only be used if the total result size is understood. Warnings are emitted when paged results are materialized in this fashion. You can treat this as a normal iterator over rows:: >>> from cassandra.query import SimpleStatement >>> statement = SimpleStatement("SELECT * FROM users", fetch_size=10) >>> for user_row in session.execute(statement): ... process_user(user_row) Whenever there are no more rows in the current page, the next page will be fetched transparently. However, note that it *is* possible for an :class:`Exception` to be raised while fetching the next page, just like you might see on a normal call to ``session.execute()``. """ def __init__(self, response_future, initial_response): self.response_future = response_future self.column_names = response_future._col_names self.column_types = response_future._col_types self._set_current_rows(initial_response) self._page_iter = None self._list_mode = False @property def has_more_pages(self): """ True if the last response indicated more pages; False otherwise """ return self.response_future.has_more_pages @property def current_rows(self): """ The list of current page rows. May be empty if the result was empty, or this is the last page. """ return self._current_rows or [] def one(self): """ Return a single row of the results or None if empty. This is basically a shortcut to `result_set.current_rows[0]` and should only be used when you know a query returns a single row. Consider using an iterator if the ResultSet contains more than one row. """ - return self._current_rows[0] if self._current_rows else None + row = None + if self._current_rows: + try: + row = self._current_rows[0] + except TypeError: # generator object is not subscriptable, PYTHON-1026 + row = next(iter(self._current_rows)) + + return row def __iter__(self): if self._list_mode: return iter(self._current_rows) self._page_iter = iter(self._current_rows) return self def next(self): try: return next(self._page_iter) except StopIteration: if not self.response_future.has_more_pages: if not self._list_mode: self._current_rows = [] raise self.fetch_next_page() self._page_iter = iter(self._current_rows) return next(self._page_iter) __next__ = next def fetch_next_page(self): """ Manually, synchronously fetch the next page. Supplied for manually retrieving pages and inspecting :meth:`~.current_page`. It is not necessary to call this when iterating through results; paging happens implicitly in iteration. """ if self.response_future.has_more_pages: self.response_future.start_fetching_next_page() result = self.response_future.result() self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form else: self._current_rows = [] def _set_current_rows(self, result): if isinstance(result, Mapping): self._current_rows = [result] if result else [] return try: iter(result) # can't check directly for generator types because cython generators are different self._current_rows = result except TypeError: self._current_rows = [result] if result else [] def _fetch_all(self): self._current_rows = list(self) self._page_iter = None def _enter_list_mode(self, operator): if self._list_mode: return if self._page_iter: raise RuntimeError("Cannot use %s when results have been iterated." % operator) if self.response_future.has_more_pages: log.warning("Using %s on paged results causes entire result set to be materialized.", operator) self._fetch_all() # done regardless of paging status in case the row factory produces a generator self._list_mode = True def __eq__(self, other): self._enter_list_mode("equality operator") return self._current_rows == other def __getitem__(self, i): if i == 0: warn("ResultSet indexing support will be removed in 4.0. Consider using " "ResultSet.one() to get a single row.", DeprecationWarning) self._enter_list_mode("index operator") return self._current_rows[i] def __nonzero__(self): return bool(self._current_rows) __bool__ = __nonzero__ def get_query_trace(self, max_wait_sec=None): """ Gets the last query trace from the associated future. See :meth:`.ResponseFuture.get_query_trace` for details. """ return self.response_future.get_query_trace(max_wait_sec) def get_all_query_traces(self, max_wait_sec_per=None): """ Gets all query traces from the associated future. See :meth:`.ResponseFuture.get_all_query_traces` for details. """ return self.response_future.get_all_query_traces(max_wait_sec_per) @property def was_applied(self): """ For LWT results, returns whether the transaction was applied. Result is indeterminate if called on a result that was not an LWT request or on a :class:`.query.BatchStatement` containing LWT. In the latter case either all the batch succeeds or fails. Only valid when one of the of the internal row factories is in use. """ if self.response_future.row_factory not in (named_tuple_factory, dict_factory, tuple_factory): raise RuntimeError("Cannot determine LWT result with row factory %s" % (self.response_future.row_factory,)) is_batch_statement = isinstance(self.response_future.query, BatchStatement) if is_batch_statement and (not self.column_names or self.column_names[0] != "[applied]"): raise RuntimeError("No LWT were present in the BatchStatement") if not is_batch_statement and len(self.current_rows) != 1: raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows))) row = self.current_rows[0] if isinstance(row, tuple): return row[0] else: return row['[applied]'] @property def paging_state(self): """ Server paging state of the query. Can be `None` if the query was not paged. The driver treats paging state as opaque, but it may contain primary key data, so applications may want to avoid sending this to untrusted parties. """ return self.response_future._paging_state diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index 41a8e2f..11f664e 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -1,1531 +1,1531 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from datetime import datetime, timedelta from functools import partial import time import six from warnings import warn from cassandra.query import SimpleStatement, BatchType as CBatchType, BatchStatement from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin from cassandra.cqlengine import connection as conn from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator, LessThanOperator, LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator) from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement, UpdateStatement, InsertStatement, BaseCQLStatement, MapDeleteClause, ConditionalClause) class QueryException(CQLEngineException): pass class IfNotExistsWithCounterColumn(CQLEngineException): pass class IfExistsWithCounterColumn(CQLEngineException): pass class LWTException(CQLEngineException): """Lightweight conditional exception. This exception will be raised when a write using an `IF` clause could not be applied due to existing data violating the condition. The existing data is available through the ``existing`` attribute. :param existing: The current state of the data which prevented the write. """ def __init__(self, existing): super(LWTException, self).__init__("LWT Query was not applied") self.existing = existing class DoesNotExist(QueryException): pass class MultipleObjectsReturned(QueryException): pass def check_applied(result): """ Raises LWTException if it looks like a failed LWT request. A LWTException won't be raised in the special case in which there are several failed LWT in a :class:`~cqlengine.query.BatchQuery`. """ try: applied = result.was_applied except Exception: applied = True # result was not LWT form if not applied: - raise LWTException(result[0]) + raise LWTException(result.one()) class AbstractQueryableColumn(UnicodeMixin): """ exposes cql query operators through pythons builtin comparator symbols """ def _get_column(self): raise NotImplementedError def __unicode__(self): raise NotImplementedError def _to_database(self, val): if isinstance(val, QueryValue): return val else: return self._get_column().to_database(val) def in_(self, item): """ Returns an in operator used where you'd typically want to use python's `in` operator """ return WhereClause(six.text_type(self), InOperator(), item) def contains_(self, item): """ Returns a CONTAINS operator """ return WhereClause(six.text_type(self), ContainsOperator(), item) def __eq__(self, other): return WhereClause(six.text_type(self), EqualsOperator(), self._to_database(other)) def __gt__(self, other): return WhereClause(six.text_type(self), GreaterThanOperator(), self._to_database(other)) def __ge__(self, other): return WhereClause(six.text_type(self), GreaterThanOrEqualOperator(), self._to_database(other)) def __lt__(self, other): return WhereClause(six.text_type(self), LessThanOperator(), self._to_database(other)) def __le__(self, other): return WhereClause(six.text_type(self), LessThanOrEqualOperator(), self._to_database(other)) class BatchType(object): Unlogged = 'UNLOGGED' Counter = 'COUNTER' class BatchQuery(object): """ Handles the batching of queries http://docs.datastax.com/en/cql/3.0/cql/cql_reference/batch_r.html See :doc:`/cqlengine/batches` for more details. """ warn_multiple_exec = True _consistency = None _connection = None _connection_explicit = False def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on_exception=False, timeout=conn.NOT_SET, connection=None): """ :param batch_type: (optional) One of batch type values available through BatchType enum :type batch_type: BatchType, str or None :param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied to the batch conditional. :type timestamp: datetime or timedelta or None :param consistency: (optional) One of consistency values ("ANY", "ONE", "QUORUM" etc) :type consistency: The :class:`.ConsistencyLevel` to be used for the batch query, or None. :param execute_on_exception: (Defaults to False) Indicates that when the BatchQuery instance is used as a context manager the queries accumulated within the context must be executed despite encountering an error within the context. By default, any exception raised from within the context scope will cause the batched queries not to be executed. :type execute_on_exception: bool :param timeout: (optional) Timeout for the entire batch (in seconds), if not specified fallback to default session timeout :type timeout: float or None :param str connection: Connection name to use for the batch execution """ self.queries = [] self.batch_type = batch_type if timestamp is not None and not isinstance(timestamp, (datetime, timedelta)): raise CQLEngineException('timestamp object must be an instance of datetime') self.timestamp = timestamp self._consistency = consistency self._execute_on_exception = execute_on_exception self._timeout = timeout self._callbacks = [] self._executed = False self._context_entered = False self._connection = connection if connection: self._connection_explicit = True def add_query(self, query): if not isinstance(query, BaseCQLStatement): raise CQLEngineException('only BaseCQLStatements can be added to a batch query') self.queries.append(query) def consistency(self, consistency): self._consistency = consistency def _execute_callbacks(self): for callback, args, kwargs in self._callbacks: callback(*args, **kwargs) def add_callback(self, fn, *args, **kwargs): """Add a function and arguments to be passed to it to be executed after the batch executes. A batch can support multiple callbacks. Note, that if the batch does not execute, the callbacks are not executed. A callback, thus, is an "on batch success" handler. :param fn: Callable object :type fn: callable :param \*args: Positional arguments to be passed to the callback at the time of execution :param \*\*kwargs: Named arguments to be passed to the callback at the time of execution """ if not callable(fn): raise ValueError("Value for argument 'fn' is {0} and is not a callable object.".format(type(fn))) self._callbacks.append((fn, args, kwargs)) def execute(self): if self._executed and self.warn_multiple_exec: msg = "Batch executed multiple times." if self._context_entered: msg += " If using the batch as a context manager, there is no need to call execute directly." warn(msg) self._executed = True if len(self.queries) == 0: # Empty batch is a no-op # except for callbacks self._execute_callbacks() return batch_type = None if self.batch_type is CBatchType.LOGGED else self.batch_type opener = 'BEGIN ' + (str(batch_type) + ' ' if batch_type else '') + ' BATCH' if self.timestamp: if isinstance(self.timestamp, six.integer_types): ts = self.timestamp elif isinstance(self.timestamp, (datetime, timedelta)): ts = self.timestamp if isinstance(self.timestamp, timedelta): ts += datetime.now() # Apply timedelta ts = int(time.mktime(ts.timetuple()) * 1e+6 + ts.microsecond) else: raise ValueError("Batch expects a long, a timedelta, or a datetime") opener += ' USING TIMESTAMP {0}'.format(ts) query_list = [opener] parameters = {} ctx_counter = 0 for query in self.queries: query.update_context_id(ctx_counter) ctx = query.get_context() ctx_counter += len(ctx) query_list.append(' ' + str(query)) parameters.update(ctx) query_list.append('APPLY BATCH;') tmp = conn.execute('\n'.join(query_list), parameters, self._consistency, self._timeout, connection=self._connection) check_applied(tmp) self.queries = [] self._execute_callbacks() def __enter__(self): self._context_entered = True return self def __exit__(self, exc_type, exc_val, exc_tb): # don't execute if there was an exception by default if exc_type is not None and not self._execute_on_exception: return self.execute() class ContextQuery(object): """ A Context manager to allow a Model to switch context easily. Presently, the context only specifies a keyspace for model IO. :param \*args: One or more models. A model should be a class type, not an instance. :param \*\*kwargs: (optional) Context parameters: can be *keyspace* or *connection* For example: .. code-block:: python with ContextQuery(Automobile, keyspace='test2') as A: A.objects.create(manufacturer='honda', year=2008, model='civic') print len(A.objects.all()) # 1 result with ContextQuery(Automobile, keyspace='test4') as A: print len(A.objects.all()) # 0 result # Multiple models with ContextQuery(Automobile, Automobile2, connection='cluster2') as (A, A2): print len(A.objects.all()) print len(A2.objects.all()) """ def __init__(self, *args, **kwargs): from cassandra.cqlengine import models self.models = [] if len(args) < 1: raise ValueError("No model provided.") keyspace = kwargs.pop('keyspace', None) connection = kwargs.pop('connection', None) if kwargs: raise ValueError("Unknown keyword argument(s): {0}".format( ','.join(kwargs.keys()))) for model in args: try: issubclass(model, models.Model) except TypeError: raise ValueError("Models must be derived from base Model.") m = models._clone_model_class(model, {}) if keyspace: m.__keyspace__ = keyspace if connection: m.__connection__ = connection self.models.append(m) def __enter__(self): if len(self.models) > 1: return tuple(self.models) return self.models[0] def __exit__(self, exc_type, exc_val, exc_tb): return class AbstractQuerySet(object): def __init__(self, model): super(AbstractQuerySet, self).__init__() self.model = model # Where clause filters self._where = [] # Conditional clause filters self._conditional = [] # ordering arguments self._order = [] self._allow_filtering = False # CQL has a default limit of 10000, it's defined here # because explicit is better than implicit self._limit = 10000 # We store the fields for which we use the Equal operator # in a query, so we don't select it from the DB. _defer_fields # will contain the names of the fields in the DB, not the names # of the variables used by the mapper self._defer_fields = set() self._deferred_values = {} # This variable will hold the names in the database of the fields # for which we want to query self._only_fields = [] self._values_list = False self._flat_values_list = False # results cache self._result_cache = None self._result_idx = None self._result_generator = None self._materialize_results = True self._distinct_fields = None self._count = None self._batch = None self._ttl = None self._consistency = None self._timestamp = None self._if_not_exists = False self._timeout = conn.NOT_SET self._if_exists = False self._fetch_size = None self._connection = None @property def column_family_name(self): return self.model.column_family_name() def _execute(self, statement): if self._batch: return self._batch.add_query(statement) else: connection = self._connection or self.model._get_connection() result = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) if self._if_not_exists or self._if_exists or self._conditional: check_applied(result) return result def __unicode__(self): return six.text_type(self._select_query()) def __str__(self): return str(self.__unicode__()) def __call__(self, *args, **kwargs): return self.filter(*args, **kwargs) def __deepcopy__(self, memo): clone = self.__class__(self.model) for k, v in self.__dict__.items(): if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator', '_construct_result']: # don't clone these, which are per-request-execution clone.__dict__[k] = None elif k == '_batch': # we need to keep the same batch instance across # all queryset clones, otherwise the batched queries # fly off into other batch instances which are never # executed, thx @dokai clone.__dict__[k] = self._batch elif k == '_timeout': clone.__dict__[k] = self._timeout else: clone.__dict__[k] = copy.deepcopy(v, memo) return clone def __len__(self): self._execute_query() return self.count() # ----query generation / execution---- def _select_fields(self): """ returns the fields to select """ return [] def _validate_select_where(self): """ put select query validation here """ def _select_query(self): """ Returns a select clause based on the given filter args """ if self._where: self._validate_select_where() return SelectStatement( self.column_family_name, fields=self._select_fields(), where=self._where, order_by=self._order, limit=self._limit, allow_filtering=self._allow_filtering, distinct_fields=self._distinct_fields, fetch_size=self._fetch_size ) # ----Reads------ def _execute_query(self): if self._batch: raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") if self._result_cache is None: self._result_generator = (i for i in self._execute(self._select_query())) self._result_cache = [] self._construct_result = self._maybe_inject_deferred(self._get_result_constructor()) # "DISTINCT COUNT()" is not supported in C* < 2.2, so we need to materialize all results to get # len() and count() working with DISTINCT queries if self._materialize_results or self._distinct_fields: self._fill_result_cache() def _fill_result_cache(self): """ Fill the result cache with all results. """ idx = 0 try: while True: idx += 1000 self._fill_result_cache_to_idx(idx) except StopIteration: pass self._count = len(self._result_cache) def _fill_result_cache_to_idx(self, idx): self._execute_query() if self._result_idx is None: self._result_idx = -1 qty = idx - self._result_idx if qty < 1: return else: for idx in range(qty): self._result_idx += 1 while True: try: self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx]) break except IndexError: self._result_cache.append(next(self._result_generator)) def __iter__(self): self._execute_query() idx = 0 while True: if len(self._result_cache) <= idx: try: self._result_cache.append(next(self._result_generator)) except StopIteration: break instance = self._result_cache[idx] if isinstance(instance, dict): self._fill_result_cache_to_idx(idx) yield self._result_cache[idx] idx += 1 def __getitem__(self, s): self._execute_query() if isinstance(s, slice): start = s.start if s.start else 0 if start < 0 or (s.stop is not None and s.stop < 0): warn("ModelQuerySet slicing with negative indices support will be removed in 4.0.", DeprecationWarning) # calculate the amount of results that need to be loaded end = s.stop if start < 0 or s.stop is None or s.stop < 0: end = self.count() try: self._fill_result_cache_to_idx(end) except StopIteration: pass return self._result_cache[start:s.stop:s.step] else: try: s = int(s) except (ValueError, TypeError): raise TypeError('QuerySet indices must be integers') if s < 0: warn("ModelQuerySet indexing with negative indices support will be removed in 4.0.", DeprecationWarning) # Using negative indexing is costly since we have to execute a count() if s < 0: num_results = self.count() s += num_results try: self._fill_result_cache_to_idx(s) except StopIteration: raise IndexError return self._result_cache[s] def _get_result_constructor(self): """ Returns a function that will be used to instantiate query results """ raise NotImplementedError @staticmethod def _construct_with_deferred(f, deferred, row): row.update(deferred) return f(row) def _maybe_inject_deferred(self, constructor): return partial(self._construct_with_deferred, constructor, self._deferred_values)\ if self._deferred_values else constructor def batch(self, batch_obj): """ Set a batch object to run the query on. Note: running a select query with a batch object will raise an exception """ if self._connection: raise CQLEngineException("Cannot specify the connection on model in batch mode.") if batch_obj is not None and not isinstance(batch_obj, BatchQuery): raise CQLEngineException('batch_obj must be a BatchQuery instance or None') clone = copy.deepcopy(self) clone._batch = batch_obj return clone def first(self): try: return six.next(iter(self)) except StopIteration: return None def all(self): """ Returns a queryset matching all rows .. code-block:: python for user in User.objects().all(): print(user) """ return copy.deepcopy(self) def consistency(self, consistency): """ Sets the consistency level for the operation. See :class:`.ConsistencyLevel`. .. code-block:: python for user in User.objects(id=3).consistency(CL.ONE): print(user) """ clone = copy.deepcopy(self) clone._consistency = consistency return clone def _parse_filter_arg(self, arg): """ Parses a filter arg in the format: __ :returns: colname, op tuple """ statement = arg.rsplit('__', 1) if len(statement) == 1: return arg, None elif len(statement) == 2: return (statement[0], statement[1]) if arg != 'pk__token' else (arg, None) else: raise QueryException("Can't parse '{0}'".format(arg)) def iff(self, *args, **kwargs): """Adds IF statements to queryset""" if len([x for x in kwargs.values() if x is None]): raise CQLEngineException("None values on iff are not allowed") clone = copy.deepcopy(self) for operator in args: if not isinstance(operator, ConditionalClause): raise QueryException('{0} is not a valid query operator'.format(operator)) clone._conditional.append(operator) for arg, val in kwargs.items(): if isinstance(val, Token): raise QueryException("Token() values are not valid in conditionals") col_name, col_op = self._parse_filter_arg(arg) try: column = self.model._get_column(col_name) except KeyError: raise QueryException("Can't resolve column name: '{0}'".format(col_name)) if isinstance(val, BaseQueryFunction): query_val = val else: query_val = column.to_database(val) operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') operator = operator_class() clone._conditional.append(WhereClause(column.db_field_name, operator, query_val)) return clone def filter(self, *args, **kwargs): """ Adds WHERE arguments to the queryset, returning a new queryset See :ref:`retrieving-objects-with-filters` Returns a QuerySet filtered on the keyword arguments """ # add arguments to the where clause filters if len([x for x in kwargs.values() if x is None]): raise CQLEngineException("None values on filter are not allowed") clone = copy.deepcopy(self) for operator in args: if not isinstance(operator, WhereClause): raise QueryException('{0} is not a valid query operator'.format(operator)) clone._where.append(operator) for arg, val in kwargs.items(): col_name, col_op = self._parse_filter_arg(arg) quote_field = True if not isinstance(val, Token): try: column = self.model._get_column(col_name) except KeyError: raise QueryException("Can't resolve column name: '{0}'".format(col_name)) else: if col_name != 'pk__token': raise QueryException("Token() values may only be compared to the 'pk__token' virtual column") column = columns._PartitionKeysToken(self.model) quote_field = False partition_columns = column.partition_columns if len(partition_columns) != len(val.value): raise QueryException( 'Token() received {0} arguments but model has {1} partition keys'.format( len(val.value), len(partition_columns))) val.set_columns(partition_columns) # get query operator, or use equals if not supplied operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') operator = operator_class() if isinstance(operator, InOperator): if not isinstance(val, (list, tuple)): raise QueryException('IN queries must use a list/tuple value') query_val = [column.to_database(v) for v in val] elif isinstance(val, BaseQueryFunction): query_val = val elif (isinstance(operator, ContainsOperator) and isinstance(column, (columns.List, columns.Set, columns.Map))): # For ContainsOperator and collections, we query using the value, not the container query_val = val else: query_val = column.to_database(val) if not col_op: # only equal values should be deferred clone._defer_fields.add(column.db_field_name) clone._deferred_values[column.db_field_name] = val # map by db field name for substitution in results clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field)) return clone def get(self, *args, **kwargs): """ Returns a single instance matching this query, optionally with additional filter kwargs. See :ref:`retrieving-objects-with-filters` Returns a single object matching the QuerySet. .. code-block:: python user = User.get(id=1) If no objects are matched, a :class:`~.DoesNotExist` exception is raised. If more than one object is found, a :class:`~.MultipleObjectsReturned` exception is raised. """ if args or kwargs: return self.filter(*args, **kwargs).get() self._execute_query() # Check that the resultset only contains one element, avoiding sending a COUNT query try: self[1] raise self.model.MultipleObjectsReturned('Multiple objects found') except IndexError: pass try: obj = self[0] except IndexError: raise self.model.DoesNotExist return obj def _get_ordering_condition(self, colname): order_type = 'DESC' if colname.startswith('-') else 'ASC' colname = colname.replace('-', '') return colname, order_type def order_by(self, *colnames): """ Sets the column(s) to be used for ordering Default order is ascending, prepend a '-' to any column name for descending *Note: column names must be a clustering key* .. code-block:: python from uuid import uuid1,uuid4 class Comment(Model): photo_id = UUID(primary_key=True) comment_id = TimeUUID(primary_key=True, default=uuid1) # second primary key component is a clustering key comment = Text() sync_table(Comment) u = uuid4() for x in range(5): Comment.create(photo_id=u, comment="test %d" % x) print("Normal") for comment in Comment.objects(photo_id=u): print comment.comment_id print("Reversed") for comment in Comment.objects(photo_id=u).order_by("-comment_id"): print comment.comment_id """ if len(colnames) == 0: clone = copy.deepcopy(self) clone._order = [] return clone conditions = [] for colname in colnames: conditions.append('"{0}" {1}'.format(*self._get_ordering_condition(colname))) clone = copy.deepcopy(self) clone._order.extend(conditions) return clone def count(self): """ Returns the number of rows matched by this query. *Note: This function executes a SELECT COUNT() and has a performance cost on large datasets* """ if self._batch: raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") if self._count is None: query = self._select_query() query.count = True result = self._execute(query) - count_row = result[0].popitem() + count_row = result.one().popitem() self._count = count_row[1] return self._count def distinct(self, distinct_fields=None): """ Returns the DISTINCT rows matched by this query. distinct_fields default to the partition key fields if not specified. *Note: distinct_fields must be a partition key or a static column* .. code-block:: python class Automobile(Model): manufacturer = columns.Text(partition_key=True) year = columns.Integer(primary_key=True) model = columns.Text(primary_key=True) price = columns.Decimal() sync_table(Automobile) # create rows Automobile.objects.distinct() # or Automobile.objects.distinct(['manufacturer']) """ clone = copy.deepcopy(self) if distinct_fields: clone._distinct_fields = distinct_fields else: clone._distinct_fields = [x.column_name for x in self.model._partition_keys.values()] return clone def limit(self, v): """ Limits the number of results returned by Cassandra. Use *0* or *None* to disable. *Note that CQL's default limit is 10,000, so all queries without a limit set explicitly will have an implicit limit of 10,000* .. code-block:: python # Fetch 100 users for user in User.objects().limit(100): print(user) # Fetch all users for user in User.objects().limit(None): print(user) """ if v is None: v = 0 if not isinstance(v, six.integer_types): raise TypeError if v == self._limit: return self if v < 0: raise QueryException("Negative limit is not allowed") clone = copy.deepcopy(self) clone._limit = v return clone def fetch_size(self, v): """ Sets the number of rows that are fetched at a time. *Note that driver's default fetch size is 5000.* .. code-block:: python for user in User.objects().fetch_size(500): print(user) """ if not isinstance(v, six.integer_types): raise TypeError if v == self._fetch_size: return self if v < 1: raise QueryException("fetch size less than 1 is not allowed") clone = copy.deepcopy(self) clone._fetch_size = v return clone def allow_filtering(self): """ Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key """ clone = copy.deepcopy(self) clone._allow_filtering = True return clone def _only_or_defer(self, action, fields): if action == 'only' and self._only_fields: raise QueryException("QuerySet already has 'only' fields defined") clone = copy.deepcopy(self) # check for strange fields missing_fields = [f for f in fields if f not in self.model._columns.keys()] if missing_fields: raise QueryException( "Can't resolve fields {0} in {1}".format( ', '.join(missing_fields), self.model.__name__)) fields = [self.model._columns[field].db_field_name for field in fields] if action == 'defer': clone._defer_fields.update(fields) elif action == 'only': clone._only_fields = fields else: raise ValueError return clone def only(self, fields): """ Load only these fields for the returned query """ return self._only_or_defer('only', fields) def defer(self, fields): """ Don't load these fields for the returned query """ return self._only_or_defer('defer', fields) def create(self, **kwargs): return self.model(**kwargs) \ .batch(self._batch) \ .ttl(self._ttl) \ .consistency(self._consistency) \ .if_not_exists(self._if_not_exists) \ .timestamp(self._timestamp) \ .if_exists(self._if_exists) \ .using(connection=self._connection) \ .save() def delete(self): """ Deletes the contents of a query """ # validate where clause partition_keys = set(x.db_field_name for x in self.model._partition_keys.values()) if partition_keys - set(c.field for c in self._where): raise QueryException("The partition key must be defined on delete queries") dq = DeleteStatement( self.column_family_name, where=self._where, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists ) self._execute(dq) def __eq__(self, q): if len(self._where) == len(q._where): return all([w in q._where for w in self._where]) return False def __ne__(self, q): return not (self != q) def timeout(self, timeout): """ :param timeout: Timeout for the query (in seconds) :type timeout: float or None """ clone = copy.deepcopy(self) clone._timeout = timeout return clone def using(self, keyspace=None, connection=None): """ Change the context on-the-fly of the Model class (keyspace, connection) """ if connection and self._batch: raise CQLEngineException("Cannot specify a connection on model in batch mode.") clone = copy.deepcopy(self) if keyspace: from cassandra.cqlengine.models import _clone_model_class clone.model = _clone_model_class(self.model, {'__keyspace__': keyspace}) if connection: clone._connection = connection return clone class ResultObject(dict): """ adds attribute access to a dictionary """ def __getattr__(self, item): try: return self[item] except KeyError: raise AttributeError class SimpleQuerySet(AbstractQuerySet): """ Overrides _get_result_constructor for querysets that do not define a model (e.g. NamedTable queries) """ def _get_result_constructor(self): """ Returns a function that will be used to instantiate query results """ return ResultObject class ModelQuerySet(AbstractQuerySet): """ """ def _validate_select_where(self): """ Checks that a filterset will not create invalid select statement """ # check that there's either a =, a IN or a CONTAINS (collection) # relationship with a primary key or indexed field. We also allow # custom indexes to be queried with any operator (a difference # between a secondary index) equal_ops = [self.model._get_column_by_db_name(w.field) \ for w in self._where if not isinstance(w.value, Token) and (isinstance(w.operator, EqualsOperator) or self.model._get_column_by_db_name(w.field).custom_index)] token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) if not any(w.primary_key or w.has_index for w in equal_ops) and not token_comparison and not self._allow_filtering: raise QueryException( ('Where clauses require either =, a IN or a CONTAINS ' '(collection) comparison with either a primary key or ' 'indexed field. You might want to consider setting ' 'custom_index on fields that you manage index outside ' 'cqlengine.')) if not self._allow_filtering: # if the query is not on an indexed field if not any(w.has_index for w in equal_ops): if not any([w.partition_key for w in equal_ops]) and not token_comparison: raise QueryException( ('Filtering on a clustering key without a partition ' 'key is not allowed unless allow_filtering() is ' 'called on the queryset. You might want to consider ' 'setting custom_index on fields that you manage ' 'index outside cqlengine.')) def _select_fields(self): if self._defer_fields or self._only_fields: fields = [columns.db_field_name for columns in self.model._columns.values()] if self._defer_fields: fields = [f for f in fields if f not in self._defer_fields] # select the partition keys if all model fields are set defer if not fields: fields = [columns.db_field_name for columns in self.model._partition_keys.values()] if self._only_fields: fields = [f for f in fields if f in self._only_fields] if not fields: raise QueryException('No fields in select query. Only fields: "{0}", defer fields: "{1}"'.format( ','.join(self._only_fields), ','.join(self._defer_fields))) return fields return super(ModelQuerySet, self)._select_fields() def _get_result_constructor(self): """ Returns a function that will be used to instantiate query results """ if not self._values_list: # we want models return self.model._construct_instance elif self._flat_values_list: # the user has requested flattened list (1 value per row) key = self._only_fields[0] return lambda row: row[key] else: return lambda row: [row[f] for f in self._only_fields] def _get_ordering_condition(self, colname): colname, order_type = super(ModelQuerySet, self)._get_ordering_condition(colname) column = self.model._columns.get(colname) if column is None: raise QueryException("Can't resolve the column name: '{0}'".format(colname)) # validate the column selection if not column.primary_key: raise QueryException( "Can't order on '{0}', can only order on (clustered) primary keys".format(colname)) pks = [v for k, v in self.model._columns.items() if v.primary_key] if column == pks[0]: raise QueryException( "Can't order by the first primary key (partition key), clustering (secondary) keys only") return column.db_field_name, order_type def values_list(self, *fields, **kwargs): """ Instructs the query set to return tuples, not model instance """ flat = kwargs.pop('flat', False) if kwargs: raise TypeError('Unexpected keyword arguments to values_list: %s' % (kwargs.keys(),)) if flat and len(fields) > 1: raise TypeError("'flat' is not valid when values_list is called with more than one field.") clone = self.only(fields) clone._values_list = True clone._flat_values_list = flat return clone def ttl(self, ttl): """ Sets the ttl (in seconds) for modified data. *Note that running a select query with a ttl value will raise an exception* """ clone = copy.deepcopy(self) clone._ttl = ttl return clone def timestamp(self, timestamp): """ Allows for custom timestamps to be saved with the record. """ clone = copy.deepcopy(self) clone._timestamp = timestamp return clone def if_not_exists(self): """ Check the existence of an object before insertion. If the insertion isn't applied, a LWTException is raised. """ if self.model._has_counter: raise IfNotExistsWithCounterColumn('if_not_exists cannot be used with tables containing counter columns') clone = copy.deepcopy(self) clone._if_not_exists = True return clone def if_exists(self): """ Check the existence of an object before an update or delete. If the update or delete isn't applied, a LWTException is raised. """ if self.model._has_counter: raise IfExistsWithCounterColumn('if_exists cannot be used with tables containing counter columns') clone = copy.deepcopy(self) clone._if_exists = True return clone def update(self, **values): """ Performs an update on the row selected by the queryset. Include values to update in the update like so: .. code-block:: python Model.objects(key=n).update(value='x') Passing in updates for columns which are not part of the model will raise a ValidationError. Per column validation will be performed, but instance level validation will not (i.e., `Model.validate` is not called). This is sometimes referred to as a blind update. For example: .. code-block:: python class User(Model): id = Integer(primary_key=True) name = Text() setup(["localhost"], "test") sync_table(User) u = User.create(id=1, name="jon") User.objects(id=1).update(name="Steve") # sets name to null User.objects(id=1).update(name=None) Also supported is blindly adding and removing elements from container columns, without loading a model instance from Cassandra. Using the syntax `.update(column_name={x, y, z})` will overwrite the contents of the container, like updating a non container column. However, adding `__` to the end of the keyword arg, makes the update call add or remove items from the collection, without overwriting then entire column. Given the model below, here are the operations that can be performed on the different container columns: .. code-block:: python class Row(Model): row_id = columns.Integer(primary_key=True) set_column = columns.Set(Integer) list_column = columns.List(Integer) map_column = columns.Map(Integer, Integer) :class:`~cqlengine.columns.Set` - `add`: adds the elements of the given set to the column - `remove`: removes the elements of the given set to the column .. code-block:: python # add elements to a set Row.objects(row_id=5).update(set_column__add={6}) # remove elements to a set Row.objects(row_id=5).update(set_column__remove={4}) :class:`~cqlengine.columns.List` - `append`: appends the elements of the given list to the end of the column - `prepend`: prepends the elements of the given list to the beginning of the column .. code-block:: python # append items to a list Row.objects(row_id=5).update(list_column__append=[6, 7]) # prepend items to a list Row.objects(row_id=5).update(list_column__prepend=[1, 2]) :class:`~cqlengine.columns.Map` - `update`: adds the given keys/values to the columns, creating new entries if they didn't exist, and overwriting old ones if they did .. code-block:: python # add items to a map Row.objects(row_id=5).update(map_column__update={1: 2, 3: 4}) # remove items from a map Row.objects(row_id=5).update(map_column__remove={1, 2}) """ if not values: return nulled_columns = set() updated_columns = set() us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, val in values.items(): col_name, col_op = self._parse_filter_arg(name) col = self.model._columns.get(col_name) # check for nonexistant columns if col is None: raise ValidationError("{0}.{1} has no column named: {2}".format(self.__module__, self.model.__name__, col_name)) # check for primary key update attempts if col.is_primary_key: raise ValidationError("Cannot apply update to primary key '{0}' for {1}.{2}".format(col_name, self.__module__, self.model.__name__)) if col_op == 'remove' and isinstance(col, columns.Map): if not isinstance(val, set): raise ValidationError( "Cannot apply update operation '{0}' on column '{1}' with value '{2}'. A set is required.".format(col_op, col_name, val)) val = {v: None for v in val} else: # we should not provide default values in this use case. val = col.validate(val) if val is None: nulled_columns.add(col_name) continue us.add_update(col, val, operation=col_op) updated_columns.add(col_name) if us.assignments: self._execute(us) if nulled_columns: delete_conditional = [condition for condition in self._conditional if condition.field not in updated_columns] if self._conditional else None ds = DeleteStatement(self.column_family_name, fields=nulled_columns, where=self._where, conditionals=delete_conditional, if_exists=self._if_exists) self._execute(ds) class DMLQuery(object): """ A query object used for queries performing inserts, updates, or deletes this is usually instantiated by the model instance to be modified unlike the read query object, this is mutable """ _ttl = None _consistency = None _timestamp = None _if_not_exists = False _if_exists = False def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, conditional=None, timeout=conn.NOT_SET, if_exists=False): self.model = model self.column_family_name = self.model.column_family_name() self.instance = instance self._batch = batch self._ttl = ttl self._consistency = consistency self._timestamp = timestamp self._if_not_exists = if_not_exists self._if_exists = if_exists self._conditional = conditional self._timeout = timeout def _execute(self, statement): connection = self.instance._get_connection() if self.instance else self.model._get_connection() if self._batch: if self._batch._connection: if not self._batch._connection_explicit and connection and \ connection != self._batch._connection: raise CQLEngineException('BatchQuery queries must be executed on the same connection') else: # set the BatchQuery connection from the model self._batch._connection = connection return self._batch.add_query(statement) else: results = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) if self._if_not_exists or self._if_exists or self._conditional: check_applied(results) return results def batch(self, batch_obj): if batch_obj is not None and not isinstance(batch_obj, BatchQuery): raise CQLEngineException('batch_obj must be a BatchQuery instance or None') self._batch = batch_obj return self def _delete_null_columns(self, conditionals=None): """ executes a delete query to remove columns that have changed to null """ ds = DeleteStatement(self.column_family_name, conditionals=conditionals, if_exists=self._if_exists) deleted_fields = False static_only = True for _, v in self.instance._values.items(): col = v.column if v.deleted: ds.add_field(col.db_field_name) deleted_fields = True static_only &= col.static elif isinstance(col, columns.Map): uc = MapDeleteClause(col.db_field_name, v.value, v.previous_value) if uc.get_context_size() > 0: ds.add_field(uc) deleted_fields = True static_only |= col.static if deleted_fields: keys = self.model._partition_keys if static_only else self.model._primary_keys for name, col in keys.items(): ds.add_where(col, EqualsOperator(), getattr(self.instance, name)) self._execute(ds) def update(self): """ updates a row. This is a blind update call. All validation and cleaning needs to happen prior to calling this. """ if self.instance is None: raise CQLEngineException("DML Query intance attribute is None") assert type(self.instance) == self.model null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True static_changed_only = True statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.instance._clustering_keys.items(): null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) updated_columns = set() # get defined fields and their column names for name, col in self.model._columns.items(): # if clustering key is null, don't include non static columns if null_clustering_key and not col.static and not col.partition_key: continue if not col.is_primary_key: val = getattr(self.instance, name, None) val_mgr = self.instance._values[name] if val is None: continue if not val_mgr.changed and not isinstance(col, columns.Counter): continue static_changed_only = static_changed_only and col.static statement.add_update(col, val, previous=val_mgr.previous_value) updated_columns.add(col.db_field_name) if statement.assignments: for name, col in self.model._primary_keys.items(): # only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error if (null_clustering_key or static_changed_only) and (not col.partition_key): continue statement.add_where(col, EqualsOperator(), getattr(self.instance, name)) self._execute(statement) if not null_clustering_key: # remove conditions on fields that have been updated delete_conditionals = [condition for condition in self._conditional if condition.field not in updated_columns] if self._conditional else None self._delete_null_columns(delete_conditionals) def save(self): """ Creates / updates a row. This is a blind insert call. All validation and cleaning needs to happen prior to calling this. """ if self.instance is None: raise CQLEngineException("DML Query intance attribute is None") assert type(self.instance) == self.model nulled_fields = set() if self.instance._has_counter or self.instance._can_update(): if self.instance._has_counter: warn("'create' and 'save' actions on Counters are deprecated. It will be disallowed in 4.0. " "Use the 'update' mechanism instead.", DeprecationWarning) return self.update() else: insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists) static_save_only = False if len(self.instance._clustering_keys) == 0 else True for name, col in self.instance._clustering_keys.items(): static_save_only = static_save_only and col._val_is_null(getattr(self.instance, name, None)) for name, col in self.instance._columns.items(): if static_save_only and not col.static and not col.partition_key: continue val = getattr(self.instance, name, None) if col._val_is_null(val): if self.instance._values[name].changed: nulled_fields.add(col.db_field_name) continue if col.has_default and not self.instance._values[name].changed: # Ensure default columns included in a save() are marked as explicit, to get them *persisted* properly self.instance._values[name].explicit = True insert.add_assignment(col, getattr(self.instance, name, None)) # skip query execution if it's empty # caused by pointless update queries if not insert.is_empty: self._execute(insert) # delete any nulled columns if not static_save_only: self._delete_null_columns() def delete(self): """ Deletes one instance """ if self.instance is None: raise CQLEngineException("DML Query instance attribute is None") ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.model._primary_keys.items(): val = getattr(self.instance, name) if val is None and not col.partition_key: continue ds.add_where(col, EqualsOperator(), val) self._execute(ds) def _execute_statement(model, statement, consistency_level, timeout, connection=None): params = statement.get_context() s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=statement.fetch_size) if model._partition_key_index: key_values = statement.partition_key_values(model._partition_key_index) if not any(v is None for v in key_values): parts = model._routing_key_from_values(key_values, conn.get_cluster(connection).protocol_version) s.routing_key = parts s.keyspace = model._get_keyspace() connection = connection or model._get_connection() return conn.execute(s, params, timeout=timeout, connection=connection) diff --git a/cassandra/encoder.py b/cassandra/encoder.py index 1f02348..00f7bf1 100644 --- a/cassandra/encoder.py +++ b/cassandra/encoder.py @@ -1,240 +1,243 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ These functions are used to convert Python objects into CQL strings. When non-prepared statements are executed, these encoder functions are called on each query parameter. """ import logging log = logging.getLogger(__name__) from binascii import hexlify import calendar import datetime import math import sys import types from uuid import UUID import six if six.PY3: import ipaddress from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey, sortedset, Time, Date) if six.PY3: long = int def cql_quote(term): # The ordering of this method is important for the result of this method to # be a native str type (for both Python 2 and 3) if isinstance(term, str): return "'%s'" % str(term).replace("'", "''") # This branch of the if statement will only be used by Python 2 to catch # unicode strings, text_type is used to prevent type errors with Python 3. elif isinstance(term, six.text_type): return "'%s'" % term.encode('utf8').replace("'", "''") else: return str(term) class ValueSequence(list): pass class Encoder(object): """ A container for mapping python types to CQL string literals when working with non-prepared statements. The type :attr:`~.Encoder.mapping` can be directly customized by users. """ mapping = None """ A map of python types to encoder functions. """ def __init__(self): self.mapping = { float: self.cql_encode_float, bytearray: self.cql_encode_bytes, str: self.cql_encode_str, int: self.cql_encode_object, UUID: self.cql_encode_object, datetime.datetime: self.cql_encode_datetime, datetime.date: self.cql_encode_date, datetime.time: self.cql_encode_time, Date: self.cql_encode_date_ext, Time: self.cql_encode_time, dict: self.cql_encode_map_collection, OrderedDict: self.cql_encode_map_collection, OrderedMap: self.cql_encode_map_collection, OrderedMapSerializedKey: self.cql_encode_map_collection, list: self.cql_encode_list_collection, tuple: self.cql_encode_list_collection, # TODO: change to tuple in next major set: self.cql_encode_set_collection, sortedset: self.cql_encode_set_collection, frozenset: self.cql_encode_set_collection, types.GeneratorType: self.cql_encode_list_collection, ValueSequence: self.cql_encode_sequence } if six.PY2: self.mapping.update({ unicode: self.cql_encode_unicode, buffer: self.cql_encode_bytes, long: self.cql_encode_object, types.NoneType: self.cql_encode_none, }) else: self.mapping.update({ memoryview: self.cql_encode_bytes, bytes: self.cql_encode_bytes, type(None): self.cql_encode_none, ipaddress.IPv4Address: self.cql_encode_ipaddress, ipaddress.IPv6Address: self.cql_encode_ipaddress }) def cql_encode_none(self, val): """ Converts :const:`None` to the string 'NULL'. """ return 'NULL' def cql_encode_unicode(self, val): """ Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping. """ return cql_quote(val.encode('utf-8')) def cql_encode_str(self, val): """ Escapes quotes in :class:`str` objects. """ return cql_quote(val) if six.PY3: def cql_encode_bytes(self, val): return (b'0x' + hexlify(val)).decode('utf-8') elif sys.version_info >= (2, 7): def cql_encode_bytes(self, val): # noqa return b'0x' + hexlify(val) else: # python 2.6 requires string or read-only buffer for hexlify def cql_encode_bytes(self, val): # noqa return b'0x' + hexlify(buffer(val)) def cql_encode_object(self, val): """ Default encoder for all objects that do not have a specific encoder function registered. This function simply calls :meth:`str()` on the object. """ return str(val) def cql_encode_float(self, val): """ Encode floats using repr to preserve precision """ if math.isinf(val): return 'Infinity' if val > 0 else '-Infinity' elif math.isnan(val): return 'NaN' else: return repr(val) def cql_encode_datetime(self, val): """ Converts a :class:`datetime.datetime` object to a (string) integer timestamp with millisecond precision. """ timestamp = calendar.timegm(val.utctimetuple()) return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3)) def cql_encode_date(self, val): """ Converts a :class:`datetime.date` object to a string with format ``YYYY-MM-DD``. """ return "'%s'" % val.strftime('%Y-%m-%d') def cql_encode_time(self, val): """ Converts a :class:`cassandra.util.Time` object to a string with format ``HH:MM:SS.mmmuuunnn``. """ return "'%s'" % val def cql_encode_date_ext(self, val): """ Encodes a :class:`cassandra.util.Date` object as an integer """ # using the int form in case the Date exceeds datetime.[MIN|MAX]YEAR return str(val.days_from_epoch + 2 ** 31) def cql_encode_sequence(self, val): """ Converts a sequence to a string of the form ``(item1, item2, ...)``. This is suitable for ``IN`` value lists. """ return '(%s)' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) cql_encode_tuple = cql_encode_sequence """ Converts a sequence to a string of the form ``(item1, item2, ...)``. This is suitable for ``tuple`` type columns. """ def cql_encode_map_collection(self, val): """ Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``. This is suitable for ``map`` type columns. """ return '{%s}' % ', '.join('%s: %s' % ( self.mapping.get(type(k), self.cql_encode_object)(k), self.mapping.get(type(v), self.cql_encode_object)(v) ) for k, v in six.iteritems(val)) def cql_encode_list_collection(self, val): """ Converts a sequence to a string of the form ``[item1, item2, ...]``. This is suitable for ``list`` type columns. """ return '[%s]' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) def cql_encode_set_collection(self, val): """ Converts a sequence to a string of the form ``{item1, item2, ...}``. This is suitable for ``set`` type columns. """ return '{%s}' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) - def cql_encode_all_types(self, val): + def cql_encode_all_types(self, val, as_text_type=False): """ Converts any type into a CQL string, defaulting to ``cql_encode_object`` if :attr:`~Encoder.mapping` does not contain an entry for the type. """ - return self.mapping.get(type(val), self.cql_encode_object)(val) + encoded = self.mapping.get(type(val), self.cql_encode_object)(val) + if as_text_type and not isinstance(encoded, six.text_type): + return encoded.decode('utf-8') + return encoded if six.PY3: def cql_encode_ipaddress(self, val): """ Converts an ipaddress (IPV4Address, IPV6Address) to a CQL string. This is suitable for ``inet`` type columns. """ return "'%s'" % val.compressed diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index f7b7cac..91431aa 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -1,460 +1,461 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import atexit from collections import deque from functools import partial import logging import os import socket import sys from threading import Lock, Thread, Event import time import weakref import sys from six.moves import range try: from weakref import WeakSet except ImportError: from cassandra.util import WeakSet # noqa import asyncore try: import ssl except ImportError: ssl = None # NOQA from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager log = logging.getLogger(__name__) _dispatcher_map = {} -def _cleanup(loop_weakref): - try: - loop = loop_weakref() - except ReferenceError: - return - - loop._cleanup() +def _cleanup(loop): + if loop: + loop._cleanup() class WaitableTimer(Timer): def __init__(self, timeout, callback): Timer.__init__(self, timeout, callback) self.callback = callback self.event = Event() self.final_exception = None def finish(self, time_now): try: finished = Timer.finish(self, time_now) if finished: self.event.set() return True return False except Exception as e: self.final_exception = e self.event.set() return True def wait(self, timeout=None): self.event.wait(timeout) if self.final_exception: raise self.final_exception class _PipeWrapper(object): def __init__(self, fd): self.fd = fd def fileno(self): return self.fd def close(self): os.close(self.fd) def getsockopt(self, level, optname, buflen=None): # act like an unerrored socket for the asyncore error handling if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: return 0 raise NotImplementedError() class _AsyncoreDispatcher(asyncore.dispatcher): def __init__(self, socket): asyncore.dispatcher.__init__(self, map=_dispatcher_map) # inject after to avoid base class validation self.set_socket(socket) self._notified = False def writable(self): return False def validate(self): assert not self._notified self.notify_loop() assert self._notified self.loop(0.1) assert not self._notified def loop(self, timeout): asyncore.loop(timeout=timeout, use_poll=True, map=_dispatcher_map, count=1) class _AsyncorePipeDispatcher(_AsyncoreDispatcher): def __init__(self): self.read_fd, self.write_fd = os.pipe() _AsyncoreDispatcher.__init__(self, _PipeWrapper(self.read_fd)) def writable(self): return False def handle_read(self): while len(os.read(self.read_fd, 4096)) == 4096: pass self._notified = False def notify_loop(self): if not self._notified: self._notified = True os.write(self.write_fd, b'x') class _AsyncoreUDPDispatcher(_AsyncoreDispatcher): """ Experimental alternate dispatcher for avoiding busy wait in the asyncore loop. It is not used by default because it relies on local port binding. Port scanning is not implemented, so multiple clients on one host will collide. This address would need to be set per instance, or this could be specialized to scan until an address is found. To use:: from cassandra.io.asyncorereactor import _AsyncoreUDPDispatcher, AsyncoreLoop AsyncoreLoop._loop_dispatch_class = _AsyncoreUDPDispatcher """ bind_address = ('localhost', 10000) def __init__(self): self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._socket.bind(self.bind_address) self._socket.setblocking(0) _AsyncoreDispatcher.__init__(self, self._socket) def handle_read(self): try: d = self._socket.recvfrom(1) while d and d[1]: d = self._socket.recvfrom(1) except socket.error as e: pass self._notified = False def notify_loop(self): if not self._notified: self._notified = True self._socket.sendto(b'', self.bind_address) def loop(self, timeout): asyncore.loop(timeout=timeout, use_poll=False, map=_dispatcher_map, count=1) class _BusyWaitDispatcher(object): max_write_latency = 0.001 """ Timeout pushed down to asyncore select/poll. Dictates the amount of time it will sleep before coming back to check if anything is writable. """ def notify_loop(self): pass def loop(self, timeout): if not _dispatcher_map: time.sleep(0.005) count = timeout // self.max_write_latency asyncore.loop(timeout=self.max_write_latency, use_poll=True, map=_dispatcher_map, count=count) def validate(self): pass def close(self): pass class AsyncoreLoop(object): timer_resolution = 0.1 # used as the max interval to be in the io loop before returning to service timeouts _loop_dispatch_class = _AsyncorePipeDispatcher if os.name != 'nt' else _BusyWaitDispatcher def __init__(self): self._pid = os.getpid() self._loop_lock = Lock() self._started = False self._shutdown = False self._thread = None self._timers = TimerManager() try: dispatcher = self._loop_dispatch_class() dispatcher.validate() log.debug("Validated loop dispatch with %s", self._loop_dispatch_class) except Exception: log.exception("Failed validating loop dispatch with %s. Using busy wait execution instead.", self._loop_dispatch_class) dispatcher.close() dispatcher = _BusyWaitDispatcher() self._loop_dispatcher = dispatcher - atexit.register(partial(_cleanup, weakref.ref(self))) - def maybe_start(self): should_start = False did_acquire = False try: did_acquire = self._loop_lock.acquire(False) if did_acquire and not self._started: self._started = True should_start = True finally: if did_acquire: self._loop_lock.release() if should_start: - self._thread = Thread(target=self._run_loop, name="cassandra_driver_event_loop") + self._thread = Thread(target=self._run_loop, name="asyncore_cassandra_driver_event_loop") self._thread.daemon = True self._thread.start() def wake_loop(self): self._loop_dispatcher.notify_loop() def _run_loop(self): log.debug("Starting asyncore event loop") with self._loop_lock: while not self._shutdown: try: self._loop_dispatcher.loop(self.timer_resolution) self._timers.service_timeouts() except Exception: log.debug("Asyncore event loop stopped unexepectedly", exc_info=True) break self._started = False log.debug("Asyncore event loop ended") def add_timer(self, timer): self._timers.add_timer(timer) # This function is called from a different thread than the event loop # thread, so for this call to be thread safe, we must wake up the loop # in case it's stuck at a select self.wake_loop() def _cleanup(self): global _dispatcher_map self._shutdown = True if not self._thread: return log.debug("Waiting for event loop thread to join...") self._thread.join(timeout=1.0) if self._thread.is_alive(): log.warning( "Event loop thread could not be joined, so shutdown may not be clean. " "Please call Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") # Ensure all connections are closed and in-flight requests cancelled for conn in tuple(_dispatcher_map.values()): if conn is not self._loop_dispatcher: conn.close() self._timers.service_timeouts() # Once all the connections are closed, close the dispatcher self._loop_dispatcher.close() log.debug("Dispatchers were closed") +_global_loop = None +atexit.register(partial(_cleanup, _global_loop)) + + class AsyncoreConnection(Connection, asyncore.dispatcher): """ An implementation of :class:`.Connection` that uses the ``asyncore`` module in the Python standard library for its event loop. """ - _loop = None - _writable = False _readable = False @classmethod def initialize_reactor(cls): - if not cls._loop: - cls._loop = AsyncoreLoop() + global _global_loop + if not _global_loop: + _global_loop = AsyncoreLoop() else: current_pid = os.getpid() - if cls._loop._pid != current_pid: + if _global_loop._pid != current_pid: log.debug("Detected fork, clearing and reinitializing reactor state") cls.handle_fork() - cls._loop = AsyncoreLoop() + _global_loop = AsyncoreLoop() @classmethod def handle_fork(cls): - global _dispatcher_map + global _dispatcher_map, _global_loop _dispatcher_map = {} - if cls._loop: - cls._loop._cleanup() - cls._loop = None + if _global_loop: + _global_loop._cleanup() + _global_loop = None @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) - cls._loop.add_timer(timer) + _global_loop.add_timer(timer) return timer def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self.deque = deque() self.deque_lock = Lock() self._connect_socket() # start the event loop if needed - self._loop.maybe_start() + _global_loop.maybe_start() init_handler = WaitableTimer( timeout=0, callback=partial(asyncore.dispatcher.__init__, self, self._socket, _dispatcher_map) ) - self._loop.add_timer(init_handler) + _global_loop.add_timer(init_handler) init_handler.wait(kwargs["connect_timeout"]) self._writable = True self._readable = True self._send_options_message() def close(self): with self.lock: if self.is_closed: return self.is_closed = True log.debug("Closing connection (%s) to %s", id(self), self.host) self._writable = False self._readable = False # We don't have to wait for this to be closed, we can just schedule it self.create_timer(0, partial(asyncore.dispatcher.close, self)) log.debug("Closed socket to %s", self.host) if not self.is_defunct: self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) #This happens when the connection is shutdown while waiting for the ReadyMessage if not self.connected_event.is_set(): self.last_error = ConnectionShutdown("Connection to %s was closed" % self.host) # don't leave in-progress operations hanging self.connected_event.set() def handle_error(self): self.defunct(sys.exc_info()[1]) def handle_close(self): log.debug("Connection %s closed by server", self) self.close() def handle_write(self): while True: with self.deque_lock: try: next_msg = self.deque.popleft() except IndexError: self._writable = False return try: sent = self.send(next_msg) self._readable = True except socket.error as err: if (err.args[0] in NONBLOCKING): with self.deque_lock: self.deque.appendleft(next_msg) else: self.defunct(err) return else: if sent < len(next_msg): with self.deque_lock: self.deque.appendleft(next_msg[sent:]) if sent == 0: return def handle_read(self): try: while True: buf = self.recv(self.in_buffer_size) self._iobuf.write(buf) if len(buf) < self.in_buffer_size: break except socket.error as err: if ssl and isinstance(err, ssl.SSLError): - if err.args[0] not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + return + else: self.defunct(err) return - elif err.args[0] not in NONBLOCKING: + elif err.args[0] in NONBLOCKING: + return + else: self.defunct(err) return if self._iobuf.tell(): self.process_io_buffer() if not self._requests and not self.is_control_connection: self._readable = False def push(self, data): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] for i in range(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] with self.deque_lock: self.deque.extend(chunks) self._writable = True - self._loop.wake_loop() + _global_loop.wake_loop() def writable(self): return self._writable def readable(self): return self._readable or (self.is_control_connection and not (self.is_defunct or self.is_closed)) diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index 90bd761..21111b0 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -1,370 +1,375 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import atexit from collections import deque from functools import partial import logging import os import socket import ssl from threading import Lock, Thread import time import weakref from six.moves import range from cassandra.connection import (Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager) try: import cassandra.io.libevwrapper as libev except ImportError: raise ImportError( "The C extension needed to use libev was not found. This " "probably means that you didn't have the required build dependencies " "when installing the driver. See " "http://datastax.github.io/python-driver/installation.html#c-extensions " "for instructions on installing build dependencies and building " "the C extension.") log = logging.getLogger(__name__) -def _cleanup(loop_weakref): - try: - loop = loop_weakref() - except ReferenceError: - return - loop._cleanup() +def _cleanup(loop): + if loop: + loop._cleanup() class LibevLoop(object): def __init__(self): self._pid = os.getpid() self._loop = libev.Loop() self._notifier = libev.Async(self._loop) self._notifier.start() # prevent _notifier from keeping the loop from returning self._loop.unref() self._started = False self._shutdown = False self._lock = Lock() self._lock_thread = Lock() self._thread = None # set of all connections; only replaced with a new copy # while holding _conn_set_lock, never modified in place self._live_conns = set() # newly created connections that need their write/read watcher started self._new_conns = set() # recently closed connections that need their write/read watcher stopped self._closed_conns = set() self._conn_set_lock = Lock() self._preparer = libev.Prepare(self._loop, self._loop_will_run) # prevent _preparer from keeping the loop from returning self._loop.unref() self._preparer.start() self._timers = TimerManager() self._loop_timer = libev.Timer(self._loop, self._on_loop_timer) - atexit.register(partial(_cleanup, weakref.ref(self))) - def maybe_start(self): should_start = False with self._lock: if not self._started: log.debug("Starting libev event loop") self._started = True should_start = True if should_start: with self._lock_thread: if not self._shutdown: self._thread = Thread(target=self._run_loop, name="event_loop") self._thread.daemon = True self._thread.start() self._notifier.send() def _run_loop(self): while True: self._loop.start() # there are still active watchers, no deadlock with self._lock: if not self._shutdown and self._live_conns: log.debug("Restarting event loop") continue else: # all Connections have been closed, no active watchers log.debug("All Connections currently closed, event loop ended") self._started = False break def _cleanup(self): self._shutdown = True if not self._thread: return for conn in self._live_conns | self._new_conns | self._closed_conns: conn.close() for watcher in (conn._write_watcher, conn._read_watcher): if watcher: watcher.stop() self.notify() # wake the timer watcher # PYTHON-752 Thread might have just been created and not started with self._lock_thread: self._thread.join(timeout=1.0) if self._thread.is_alive(): log.warning( "Event loop thread could not be joined, so shutdown may not be clean. " "Please call Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") def add_timer(self, timer): self._timers.add_timer(timer) self._notifier.send() # wake up in case this timer is earlier def _update_timer(self): if not self._shutdown: next_end = self._timers.service_timeouts() if next_end: self._loop_timer.start(next_end - time.time()) # timer handles negative values else: self._loop_timer.stop() def _on_loop_timer(self): self._timers.service_timeouts() def notify(self): self._notifier.send() def connection_created(self, conn): with self._conn_set_lock: new_live_conns = self._live_conns.copy() new_live_conns.add(conn) self._live_conns = new_live_conns new_new_conns = self._new_conns.copy() new_new_conns.add(conn) self._new_conns = new_new_conns def connection_destroyed(self, conn): with self._conn_set_lock: new_live_conns = self._live_conns.copy() new_live_conns.discard(conn) self._live_conns = new_live_conns new_closed_conns = self._closed_conns.copy() new_closed_conns.add(conn) self._closed_conns = new_closed_conns self._notifier.send() def _loop_will_run(self, prepare): changed = False for conn in self._live_conns: if not conn.deque and conn._write_watcher_is_active: if conn._write_watcher: conn._write_watcher.stop() conn._write_watcher_is_active = False changed = True elif conn.deque and not conn._write_watcher_is_active: conn._write_watcher.start() conn._write_watcher_is_active = True changed = True if self._new_conns: with self._conn_set_lock: to_start = self._new_conns self._new_conns = set() for conn in to_start: conn._read_watcher.start() changed = True if self._closed_conns: with self._conn_set_lock: to_stop = self._closed_conns self._closed_conns = set() for conn in to_stop: if conn._write_watcher: conn._write_watcher.stop() # clear reference cycles from IO callback del conn._write_watcher if conn._read_watcher: conn._read_watcher.stop() # clear reference cycles from IO callback del conn._read_watcher changed = True # TODO: update to do connection management, timer updates through dedicated async 'notifier' callbacks self._update_timer() if changed: self._notifier.send() +_global_loop = None +atexit.register(partial(_cleanup, _global_loop)) + + class LibevConnection(Connection): """ An implementation of :class:`.Connection` that uses libev for its event loop. """ - _libevloop = None _write_watcher_is_active = False _read_watcher = None _write_watcher = None _socket = None @classmethod def initialize_reactor(cls): - if not cls._libevloop: - cls._libevloop = LibevLoop() + global _global_loop + if not _global_loop: + _global_loop = LibevLoop() else: - if cls._libevloop._pid != os.getpid(): + if _global_loop._pid != os.getpid(): log.debug("Detected fork, clearing and reinitializing reactor state") cls.handle_fork() - cls._libevloop = LibevLoop() + _global_loop = LibevLoop() @classmethod def handle_fork(cls): - if cls._libevloop: - cls._libevloop._cleanup() - cls._libevloop = None + global _global_loop + if _global_loop: + _global_loop._cleanup() + _global_loop = None @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) - cls._libevloop.add_timer(timer) + _global_loop.add_timer(timer) return timer def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self.deque = deque() self._deque_lock = Lock() self._connect_socket() self._socket.setblocking(0) - with self._libevloop._lock: - self._read_watcher = libev.IO(self._socket.fileno(), libev.EV_READ, self._libevloop._loop, self.handle_read) - self._write_watcher = libev.IO(self._socket.fileno(), libev.EV_WRITE, self._libevloop._loop, self.handle_write) + with _global_loop._lock: + self._read_watcher = libev.IO(self._socket.fileno(), libev.EV_READ, _global_loop._loop, self.handle_read) + self._write_watcher = libev.IO(self._socket.fileno(), libev.EV_WRITE, _global_loop._loop, self.handle_write) self._send_options_message() - self._libevloop.connection_created(self) + _global_loop.connection_created(self) # start the global event loop if needed - self._libevloop.maybe_start() + _global_loop.maybe_start() def close(self): with self.lock: if self.is_closed: return self.is_closed = True log.debug("Closing connection (%s) to %s", id(self), self.host) - self._libevloop.connection_destroyed(self) + + _global_loop.connection_destroyed(self) self._socket.close() log.debug("Closed socket to %s", self.host) # don't leave in-progress operations hanging if not self.is_defunct: self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) def handle_write(self, watcher, revents, errno=None): if revents & libev.EV_ERROR: if errno: exc = IOError(errno, os.strerror(errno)) else: exc = Exception("libev reported an error") self.defunct(exc) return while True: try: with self._deque_lock: next_msg = self.deque.popleft() except IndexError: return try: sent = self._socket.send(next_msg) except socket.error as err: if (err.args[0] in NONBLOCKING): with self._deque_lock: self.deque.appendleft(next_msg) else: self.defunct(err) return else: if sent < len(next_msg): with self._deque_lock: self.deque.appendleft(next_msg[sent:]) def handle_read(self, watcher, revents, errno=None): if revents & libev.EV_ERROR: if errno: exc = IOError(errno, os.strerror(errno)) else: exc = Exception("libev reported an error") self.defunct(exc) return try: while True: buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) if len(buf) < self.in_buffer_size: break except socket.error as err: if ssl and isinstance(err, ssl.SSLError): - if err.args[0] not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + return + else: self.defunct(err) return - elif err.args[0] not in NONBLOCKING: + elif err.args[0] in NONBLOCKING: + return + else: self.defunct(err) return if self._iobuf.tell(): self.process_io_buffer() else: log.debug("Connection %s closed by server", self) self.close() def push(self, data): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] for i in range(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] with self._deque_lock: self.deque.extend(chunks) - self._libevloop.notify() + _global_loop.notify() diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py index 8121c78..3611cdf 100644 --- a/cassandra/io/twistedreactor.py +++ b/cassandra/io/twistedreactor.py @@ -1,317 +1,317 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Module that implements an event loop based on twisted ( https://twistedmatrix.com ). """ import atexit from functools import partial import logging from threading import Thread, Lock import time from twisted.internet import reactor, protocol import weakref from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) def _cleanup(cleanup_weakref): try: cleanup_weakref()._cleanup() except ReferenceError: return class TwistedConnectionProtocol(protocol.Protocol): """ Twisted Protocol class for handling data received and connection made events. """ def __init__(self): self.connection = None def dataReceived(self, data): """ Callback function that is called when data has been received on the connection. Reaches back to the Connection object and queues the data for processing. """ self.connection._iobuf.write(data) self.connection.handle_read() def connectionMade(self): """ Callback function that is called when a connection has succeeded. Reaches back to the Connection object and confirms that the connection is ready. """ try: # Non SSL connection self.connection = self.transport.connector.factory.conn except AttributeError: # SSL connection self.connection = self.transport.connector.factory.wrappedFactory.conn self.connection.client_connection_made(self.transport) def connectionLost(self, reason): # reason is a Failure instance self.connection.defunct(reason.value) class TwistedConnectionClientFactory(protocol.ClientFactory): def __init__(self, connection): # ClientFactory does not define __init__() in parent classes # and does not inherit from object. self.conn = connection def buildProtocol(self, addr): """ Twisted function that defines which kind of protocol to use in the ClientFactory. """ return TwistedConnectionProtocol() def clientConnectionFailed(self, connector, reason): """ Overridden twisted callback which is called when the connection attempt fails. """ log.debug("Connect failed: %s", reason) self.conn.defunct(reason.value) def clientConnectionLost(self, connector, reason): """ Overridden twisted callback which is called when the connection goes away (cleanly or otherwise). It should be safe to call defunct() here instead of just close, because we can assume that if the connection was closed cleanly, there are no requests to error out. If this assumption turns out to be false, we can call close() instead of defunct() when "reason" is an appropriate type. """ log.debug("Connect lost: %s", reason) self.conn.defunct(reason.value) class TwistedLoop(object): _lock = None _thread = None _timeout_task = None _timeout = None def __init__(self): self._lock = Lock() self._timers = TimerManager() def maybe_start(self): with self._lock: if not reactor.running: self._thread = Thread(target=reactor.run, - name="cassandra_driver_event_loop", + name="cassandra_driver_twisted_event_loop", kwargs={'installSignalHandlers': False}) self._thread.daemon = True self._thread.start() atexit.register(partial(_cleanup, weakref.ref(self))) def _cleanup(self): if self._thread: reactor.callFromThread(reactor.stop) self._thread.join(timeout=1.0) if self._thread.is_alive(): log.warning("Event loop thread could not be joined, so " "shutdown may not be clean. Please call " "Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") def add_timer(self, timer): self._timers.add_timer(timer) # callFromThread to schedule from the loop thread, where # the timeout task can safely be modified reactor.callFromThread(self._schedule_timeout, timer.end) def _schedule_timeout(self, next_timeout): if next_timeout: delay = max(next_timeout - time.time(), 0) if self._timeout_task and self._timeout_task.active(): if next_timeout < self._timeout: self._timeout_task.reset(delay) self._timeout = next_timeout else: self._timeout_task = reactor.callLater(delay, self._on_loop_timer) self._timeout = next_timeout def _on_loop_timer(self): self._timers.service_timeouts() self._schedule_timeout(self._timers.next_timeout) try: from twisted.internet import ssl import OpenSSL.crypto from OpenSSL.crypto import load_certificate, FILETYPE_PEM class _SSLContextFactory(ssl.ClientContextFactory): def __init__(self, ssl_options, check_hostname, host): self.ssl_options = ssl_options self.check_hostname = check_hostname self.host = host def getContext(self): # This version has to be OpenSSL.SSL.DESIRED_VERSION # instead of ssl.DESIRED_VERSION as in other loops self.method = self.ssl_options["ssl_version"] context = ssl.ClientContextFactory.getContext(self) if "certfile" in self.ssl_options: context.use_certificate_file(self.ssl_options["certfile"]) if "keyfile" in self.ssl_options: context.use_privatekey_file(self.ssl_options["keyfile"]) if "ca_certs" in self.ssl_options: x509 = load_certificate(FILETYPE_PEM, open(self.ssl_options["ca_certs"]).read()) store = context.get_cert_store() store.add_cert(x509) if "cert_reqs" in self.ssl_options: # This expects OpenSSL.SSL.VERIFY_NONE/OpenSSL.SSL.VERIFY_PEER # or OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT context.set_verify(self.ssl_options["cert_reqs"], callback=self.verify_callback) return context def verify_callback(self, connection, x509, errnum, errdepth, ok): if ok: if self.check_hostname and self.host != x509.get_subject().commonName: return False return ok _HAS_SSL = True except ImportError as e: _HAS_SSL = False class TwistedConnection(Connection): """ An implementation of :class:`.Connection` that utilizes the Twisted event loop. """ _loop = None @classmethod def initialize_reactor(cls): if not cls._loop: cls._loop = TwistedLoop() @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) cls._loop.add_timer(timer) return timer def __init__(self, *args, **kwargs): """ Initialization method. Note that we can't call reactor methods directly here because it's not thread-safe, so we schedule the reactor/connection stuff to be run from the event loop thread when it gets the chance. """ Connection.__init__(self, *args, **kwargs) self.is_closed = True self.connector = None self.transport = None reactor.callFromThread(self.add_connection) self._loop.maybe_start() def add_connection(self): """ Convenience function to connect and store the resulting connector. """ if self.ssl_options: if not _HAS_SSL: raise ImportError( str(e) + ', pyOpenSSL must be installed to enable SSL support with the Twisted event loop' ) self.connector = reactor.connectSSL( host=self.host, port=self.port, factory=TwistedConnectionClientFactory(self), contextFactory=_SSLContextFactory(self.ssl_options, self._check_hostname, self.host), timeout=self.connect_timeout) else: self.connector = reactor.connectTCP( host=self.host, port=self.port, factory=TwistedConnectionClientFactory(self), timeout=self.connect_timeout) def client_connection_made(self, transport): """ Called by twisted protocol when a connection attempt has succeeded. """ with self.lock: self.is_closed = False self.transport = transport self._send_options_message() def close(self): """ Disconnect and error-out all requests. """ with self.lock: if self.is_closed: return self.is_closed = True log.debug("Closing connection (%s) to %s", id(self), self.host) reactor.callFromThread(self.connector.disconnect) log.debug("Closed socket to %s", self.host) if not self.is_defunct: self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) # don't leave in-progress operations hanging self.connected_event.set() def handle_read(self): """ Process the incoming data buffer. """ self.process_io_buffer() def push(self, data): """ This function is called when outgoing data should be queued for sending. Note that we can't call transport.write() directly because it is not thread-safe, so we schedule it to run from within the event loop when it gets the chance. """ reactor.callFromThread(self.transport.write, data) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index b8f5f11..377ea4d 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -1,2620 +1,2821 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from binascii import unhexlify -from bisect import bisect_right +from bisect import bisect_left from collections import defaultdict, Mapping from functools import total_ordering from hashlib import md5 from itertools import islice, cycle import json import logging import re import six from six.moves import zip import sys from threading import RLock import struct import random murmur3 = None try: from cassandra.murmur3 import murmur3 except ImportError as e: pass from cassandra import SignatureDescriptor, ConsistencyLevel, InvalidRequest, Unauthorized import cassandra.cqltypes as types from cassandra.encoder import Encoder from cassandra.marshal import varint_unpack from cassandra.protocol import QueryMessage -from cassandra.query import dict_factory, bind_params, Statement +from cassandra.query import dict_factory, bind_params from cassandra.util import OrderedDict from cassandra.pool import HostDistance log = logging.getLogger(__name__) cql_keywords = set(( 'add', 'aggregate', 'all', 'allow', 'alter', 'and', 'apply', 'as', 'asc', 'ascii', 'authorize', 'batch', 'begin', 'bigint', 'blob', 'boolean', 'by', 'called', 'clustering', 'columnfamily', 'compact', 'contains', 'count', 'counter', 'create', 'custom', 'date', 'decimal', 'delete', 'desc', 'describe', 'distinct', 'double', 'drop', 'entries', 'execute', 'exists', 'filtering', 'finalfunc', 'float', 'from', 'frozen', 'full', 'function', 'functions', 'grant', 'if', 'in', 'index', 'inet', 'infinity', 'initcond', 'input', 'insert', 'int', 'into', 'is', 'json', 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'modify', 'nan', 'nologin', 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission', 'permissions', 'primary', 'rename', 'replace', 'returns', 'revoke', 'role', 'roles', 'schema', 'select', 'set', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'table', 'text', 'time', 'timestamp', 'timeuuid', 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', 'update', 'use', 'user', 'users', 'using', 'uuid', 'values', 'varchar', 'varint', 'view', 'where', 'with', 'writetime' )) """ Set of keywords in CQL. Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g """ cql_keywords_unreserved = set(( 'aggregate', 'all', 'as', 'ascii', 'bigint', 'blob', 'boolean', 'called', 'clustering', 'compact', 'contains', 'count', 'counter', 'custom', 'date', 'decimal', 'distinct', 'double', 'exists', 'filtering', 'finalfunc', 'float', 'frozen', 'function', 'functions', 'inet', 'initcond', 'input', 'int', 'json', 'key', 'keys', 'keyspaces', 'language', 'list', 'login', 'map', 'nologin', 'nosuperuser', 'options', 'password', 'permission', 'permissions', 'returns', 'role', 'roles', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'text', 'time', 'timestamp', 'timeuuid', 'tinyint', 'trigger', 'ttl', 'tuple', 'type', 'user', 'users', 'uuid', 'values', 'varchar', 'varint', 'writetime' )) """ Set of unreserved keywords in CQL. Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g """ cql_keywords_reserved = cql_keywords - cql_keywords_unreserved """ Set of reserved keywords in CQL. """ _encoder = Encoder() class Metadata(object): """ Holds a representation of the cluster schema and topology. """ cluster_name = None """ The string name of the cluster. """ keyspaces = None """ A map from keyspace names to matching :class:`~.KeyspaceMetadata` instances. """ partitioner = None """ The string name of the partitioner for the cluster. """ token_map = None """ A :class:`~.TokenMap` instance describing the ring topology. """ def __init__(self): self.keyspaces = {} self._hosts = {} self._hosts_lock = RLock() def export_schema_as_string(self): """ Returns a string that can be executed as a query in order to recreate the entire schema. The string is formatted to be human readable. """ return "\n\n".join(ks.export_as_string() for ks in self.keyspaces.values()) def refresh(self, connection, timeout, target_type=None, change_type=None, **kwargs): server_version = self.get_host(connection.host).release_version parser = get_schema_parser(connection, server_version, timeout) if not target_type: self._rebuild_all(parser) return tt_lower = target_type.lower() try: parse_method = getattr(parser, 'get_' + tt_lower) meta = parse_method(self.keyspaces, **kwargs) if meta: update_method = getattr(self, '_update_' + tt_lower) if tt_lower == 'keyspace' and connection.protocol_version < 3: # we didn't have 'type' target in legacy protocol versions, so we need to query those too user_types = parser.get_types_map(self.keyspaces, **kwargs) self._update_keyspace(meta, user_types) else: update_method(meta) else: drop_method = getattr(self, '_drop_' + tt_lower) drop_method(**kwargs) except AttributeError: raise ValueError("Unknown schema target_type: '%s'" % target_type) def _rebuild_all(self, parser): current_keyspaces = set() for keyspace_meta in parser.get_all_keyspaces(): current_keyspaces.add(keyspace_meta.name) old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None) self.keyspaces[keyspace_meta.name] = keyspace_meta if old_keyspace_meta: self._keyspace_updated(keyspace_meta.name) else: self._keyspace_added(keyspace_meta.name) # remove not-just-added keyspaces removed_keyspaces = [name for name in self.keyspaces.keys() if name not in current_keyspaces] self.keyspaces = dict((name, meta) for name, meta in self.keyspaces.items() if name in current_keyspaces) for ksname in removed_keyspaces: self._keyspace_removed(ksname) def _update_keyspace(self, keyspace_meta, new_user_types=None): ks_name = keyspace_meta.name old_keyspace_meta = self.keyspaces.get(ks_name, None) self.keyspaces[ks_name] = keyspace_meta if old_keyspace_meta: keyspace_meta.tables = old_keyspace_meta.tables keyspace_meta.user_types = new_user_types if new_user_types is not None else old_keyspace_meta.user_types keyspace_meta.indexes = old_keyspace_meta.indexes keyspace_meta.functions = old_keyspace_meta.functions keyspace_meta.aggregates = old_keyspace_meta.aggregates keyspace_meta.views = old_keyspace_meta.views if (keyspace_meta.replication_strategy != old_keyspace_meta.replication_strategy): self._keyspace_updated(ks_name) else: self._keyspace_added(ks_name) def _drop_keyspace(self, keyspace): if self.keyspaces.pop(keyspace, None): self._keyspace_removed(keyspace) def _update_table(self, meta): try: keyspace_meta = self.keyspaces[meta.keyspace_name] # this is unfortunate, but protocol v4 does not differentiate # between events for tables and views. .get_table will # return one or the other based on the query results. # Here we deal with that. if isinstance(meta, TableMetadata): keyspace_meta._add_table_metadata(meta) else: keyspace_meta._add_view_metadata(meta) except KeyError: # can happen if keyspace disappears while processing async event pass def _drop_table(self, keyspace, table): try: keyspace_meta = self.keyspaces[keyspace] keyspace_meta._drop_table_metadata(table) # handles either table or view except KeyError: # can happen if keyspace disappears while processing async event pass def _update_type(self, type_meta): try: self.keyspaces[type_meta.keyspace].user_types[type_meta.name] = type_meta except KeyError: # can happen if keyspace disappears while processing async event pass def _drop_type(self, keyspace, type): try: self.keyspaces[keyspace].user_types.pop(type, None) except KeyError: # can happen if keyspace disappears while processing async event pass def _update_function(self, function_meta): try: self.keyspaces[function_meta.keyspace].functions[function_meta.signature] = function_meta except KeyError: # can happen if keyspace disappears while processing async event pass def _drop_function(self, keyspace, function): try: self.keyspaces[keyspace].functions.pop(function.signature, None) except KeyError: pass def _update_aggregate(self, aggregate_meta): try: self.keyspaces[aggregate_meta.keyspace].aggregates[aggregate_meta.signature] = aggregate_meta except KeyError: pass def _drop_aggregate(self, keyspace, aggregate): try: self.keyspaces[keyspace].aggregates.pop(aggregate.signature, None) except KeyError: pass def _keyspace_added(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) def _keyspace_updated(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) def _keyspace_removed(self, ksname): if self.token_map: self.token_map.remove_keyspace(ksname) def rebuild_token_map(self, partitioner, token_map): """ Rebuild our view of the topology from fresh rows from the system topology tables. For internal use only. """ self.partitioner = partitioner if partitioner.endswith('RandomPartitioner'): token_class = MD5Token elif partitioner.endswith('Murmur3Partitioner'): token_class = Murmur3Token elif partitioner.endswith('ByteOrderedPartitioner'): token_class = BytesToken else: self.token_map = None return token_to_host_owner = {} ring = [] for host, token_strings in six.iteritems(token_map): for token_string in token_strings: token = token_class.from_string(token_string) ring.append(token) token_to_host_owner[token] = host all_tokens = sorted(ring) self.token_map = TokenMap( token_class, token_to_host_owner, all_tokens, self) def get_replicas(self, keyspace, key): """ Returns a list of :class:`.Host` instances that are replicas for a given partition key. """ t = self.token_map if not t: return [] try: return t.get_replicas(keyspace, t.token_class.from_key(key)) except NoMurmur3: return [] def can_support_partitioner(self): if self.partitioner.endswith('Murmur3Partitioner') and murmur3 is None: return False else: return True def add_or_return_host(self, host): """ Returns a tuple (host, new), where ``host`` is a Host instance, and ``new`` is a bool indicating whether the host was newly added. """ with self._hosts_lock: try: return self._hosts[host.address], False except KeyError: self._hosts[host.address] = host return host, True def remove_host(self, host): with self._hosts_lock: return bool(self._hosts.pop(host.address, False)) def get_host(self, address): return self._hosts.get(address) def all_hosts(self): """ Returns a list of all known :class:`.Host` instances in the cluster. """ with self._hosts_lock: return list(self._hosts.values()) REPLICATION_STRATEGY_CLASS_PREFIX = "org.apache.cassandra.locator." def trim_if_startswith(s, prefix): if s.startswith(prefix): return s[len(prefix):] return s _replication_strategies = {} class ReplicationStrategyTypeType(type): def __new__(metacls, name, bases, dct): dct.setdefault('name', name) cls = type.__new__(metacls, name, bases, dct) if not name.startswith('_'): _replication_strategies[name] = cls return cls @six.add_metaclass(ReplicationStrategyTypeType) class _ReplicationStrategy(object): options_map = None @classmethod def create(cls, strategy_class, options_map): if not strategy_class: return None strategy_name = trim_if_startswith(strategy_class, REPLICATION_STRATEGY_CLASS_PREFIX) rs_class = _replication_strategies.get(strategy_name, None) if rs_class is None: rs_class = _UnknownStrategyBuilder(strategy_name) _replication_strategies[strategy_name] = rs_class try: rs_instance = rs_class(options_map) except Exception as exc: log.warning("Failed creating %s with options %s: %s", strategy_name, options_map, exc) return None return rs_instance def make_token_replica_map(self, token_to_host_owner, ring): raise NotImplementedError() def export_for_schema(self): raise NotImplementedError() ReplicationStrategy = _ReplicationStrategy class _UnknownStrategyBuilder(object): def __init__(self, name): self.name = name def __call__(self, options_map): strategy_instance = _UnknownStrategy(self.name, options_map) return strategy_instance class _UnknownStrategy(ReplicationStrategy): def __init__(self, name, options_map): self.name = name self.options_map = options_map.copy() if options_map is not None else dict() self.options_map['class'] = self.name def __eq__(self, other): - return (isinstance(other, _UnknownStrategy) - and self.name == other.name - and self.options_map == other.options_map) + return (isinstance(other, _UnknownStrategy) and + self.name == other.name and + self.options_map == other.options_map) def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ if self.options_map: return dict((str(key), str(value)) for key, value in self.options_map.items()) return "{'class': '%s'}" % (self.name, ) def make_token_replica_map(self, token_to_host_owner, ring): return {} class SimpleStrategy(ReplicationStrategy): replication_factor = None """ The replication factor for this keyspace. """ def __init__(self, options_map): try: self.replication_factor = int(options_map['replication_factor']) except Exception: raise ValueError("SimpleStrategy requires an integer 'replication_factor' option") def make_token_replica_map(self, token_to_host_owner, ring): replica_map = {} for i in range(len(ring)): j, hosts = 0, list() while len(hosts) < self.replication_factor and j < len(ring): token = ring[(i + j) % len(ring)] host = token_to_host_owner[token] if host not in hosts: hosts.append(host) j += 1 replica_map[ring[i]] = hosts return replica_map def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ return "{'class': 'SimpleStrategy', 'replication_factor': '%d'}" \ % (self.replication_factor,) def __eq__(self, other): if not isinstance(other, SimpleStrategy): return False return self.replication_factor == other.replication_factor class NetworkTopologyStrategy(ReplicationStrategy): dc_replication_factors = None """ A map of datacenter names to the replication factor for that DC. """ def __init__(self, dc_replication_factors): self.dc_replication_factors = dict( (str(k), int(v)) for k, v in dc_replication_factors.items()) def make_token_replica_map(self, token_to_host_owner, ring): dc_rf_map = dict((dc, int(rf)) for dc, rf in self.dc_replication_factors.items() if rf > 0) # build a map of DCs to lists of indexes into `ring` for tokens that # belong to that DC dc_to_token_offset = defaultdict(list) dc_racks = defaultdict(set) hosts_per_dc = defaultdict(set) for i, token in enumerate(ring): host = token_to_host_owner[token] dc_to_token_offset[host.datacenter].append(i) if host.datacenter and host.rack: dc_racks[host.datacenter].add(host.rack) hosts_per_dc[host.datacenter].add(host) # A map of DCs to an index into the dc_to_token_offset value for that dc. # This is how we keep track of advancing around the ring for each DC. dc_to_current_index = defaultdict(int) replica_map = defaultdict(list) for i in range(len(ring)): replicas = replica_map[ring[i]] # go through each DC and find the replicas in that DC for dc in dc_to_token_offset.keys(): if dc not in dc_rf_map: continue # advance our per-DC index until we're up to at least the # current token in the ring token_offsets = dc_to_token_offset[dc] index = dc_to_current_index[dc] num_tokens = len(token_offsets) while index < num_tokens and token_offsets[index] < i: index += 1 dc_to_current_index[dc] = index replicas_remaining = dc_rf_map[dc] replicas_this_dc = 0 skipped_hosts = [] racks_placed = set() racks_this_dc = dc_racks[dc] hosts_this_dc = len(hosts_per_dc[dc]) for token_offset in islice(cycle(token_offsets), index, index + num_tokens): host = token_to_host_owner[ring[token_offset]] if replicas_remaining == 0 or replicas_this_dc == hosts_this_dc: break if host in replicas: continue if host.rack in racks_placed and len(racks_placed) < len(racks_this_dc): skipped_hosts.append(host) continue replicas.append(host) replicas_this_dc += 1 replicas_remaining -= 1 racks_placed.add(host.rack) if len(racks_placed) == len(racks_this_dc): for host in skipped_hosts: if replicas_remaining == 0: break replicas.append(host) replicas_remaining -= 1 del skipped_hosts[:] return replica_map def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ ret = "{'class': 'NetworkTopologyStrategy'" for dc, repl_factor in sorted(self.dc_replication_factors.items()): ret += ", '%s': '%d'" % (dc, repl_factor) return ret + "}" def __eq__(self, other): if not isinstance(other, NetworkTopologyStrategy): return False return self.dc_replication_factors == other.dc_replication_factors class LocalStrategy(ReplicationStrategy): def __init__(self, options_map): pass def make_token_replica_map(self, token_to_host_owner, ring): return {} def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ return "{'class': 'LocalStrategy'}" def __eq__(self, other): return isinstance(other, LocalStrategy) class KeyspaceMetadata(object): """ A representation of the schema for a single keyspace. """ name = None """ The string name of the keyspace. """ durable_writes = True """ A boolean indicating whether durable writes are enabled for this keyspace or not. """ replication_strategy = None """ A :class:`.ReplicationStrategy` subclass object. """ tables = None """ A map from table names to instances of :class:`~.TableMetadata`. """ indexes = None """ A dict mapping index names to :class:`.IndexMetadata` instances. """ user_types = None """ A map from user-defined type names to instances of :class:`~cassandra.metadata.UserType`. .. versionadded:: 2.1.0 """ functions = None """ A map from user-defined function signatures to instances of :class:`~cassandra.metadata.Function`. .. versionadded:: 2.6.0 """ aggregates = None """ A map from user-defined aggregate signatures to instances of :class:`~cassandra.metadata.Aggregate`. .. versionadded:: 2.6.0 """ views = None """ A dict mapping view names to :class:`.MaterializedViewMetadata` instances. """ + virtual = False + """ + A boolean indicating if this is a virtual keyspace or not. Always ``False`` + for clusters running pre-4.0 versions of Cassandra. + + .. versionadded:: 3.15 + """ + _exc_info = None """ set if metadata parsing failed """ def __init__(self, name, durable_writes, strategy_class, strategy_options): self.name = name self.durable_writes = durable_writes self.replication_strategy = ReplicationStrategy.create(strategy_class, strategy_options) self.tables = {} self.indexes = {} self.user_types = {} self.functions = {} self.aggregates = {} self.views = {} def export_as_string(self): """ Returns a CQL query string that can be used to recreate the entire keyspace, including user-defined types and tables. """ - cql = "\n\n".join([self.as_cql_query() + ';'] - + self.user_type_strings() - + [f.export_as_string() for f in self.functions.values()] - + [a.export_as_string() for a in self.aggregates.values()] - + [t.export_as_string() for t in self.tables.values()]) + cql = "\n\n".join([self.as_cql_query() + ';'] + + self.user_type_strings() + + [f.export_as_string() for f in self.functions.values()] + + [a.export_as_string() for a in self.aggregates.values()] + + [t.export_as_string() for t in self.tables.values()]) if self._exc_info: import traceback ret = "/*\nWarning: Keyspace %s is incomplete because of an error processing metadata.\n" % \ (self.name) for line in traceback.format_exception(*self._exc_info): ret += line ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % cql return ret + if self.virtual: + return ("/*\nWarning: Keyspace {ks} is a virtual keyspace and cannot be recreated with CQL.\n" + "Structure, for reference:*/\n" + "{cql}\n" + "").format(ks=self.name, cql=cql) return cql def as_cql_query(self): """ Returns a CQL query string that can be used to recreate just this keyspace, not including user-defined types and tables. """ + if self.virtual: + return "// VIRTUAL KEYSPACE {}".format(protect_name(self.name)) ret = "CREATE KEYSPACE %s WITH replication = %s " % ( protect_name(self.name), self.replication_strategy.export_for_schema()) return ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false")) def user_type_strings(self): user_type_strings = [] user_types = self.user_types.copy() keys = sorted(user_types.keys()) for k in keys: if k in user_types: self.resolve_user_types(k, user_types, user_type_strings) return user_type_strings def resolve_user_types(self, key, user_types, user_type_strings): user_type = user_types.pop(key) for type_name in user_type.field_types: for sub_type in types.cql_types_from_string(type_name): if sub_type in user_types: self.resolve_user_types(sub_type, user_types, user_type_strings) user_type_strings.append(user_type.export_as_string()) def _add_table_metadata(self, table_metadata): old_indexes = {} old_meta = self.tables.get(table_metadata.name, None) if old_meta: # views are not queried with table, so they must be transferred to new table_metadata.views = old_meta.views # indexes will be updated with what is on the new metadata old_indexes = old_meta.indexes # note the intentional order of add before remove # this makes sure the maps are never absent something that existed before this update for index_name, index_metadata in six.iteritems(table_metadata.indexes): self.indexes[index_name] = index_metadata for index_name in (n for n in old_indexes if n not in table_metadata.indexes): self.indexes.pop(index_name, None) self.tables[table_metadata.name] = table_metadata def _drop_table_metadata(self, table_name): table_meta = self.tables.pop(table_name, None) if table_meta: for index_name in table_meta.indexes: self.indexes.pop(index_name, None) for view_name in table_meta.views: self.views.pop(view_name, None) return # we can't tell table drops from views, so drop both # (name is unique among them, within a keyspace) view_meta = self.views.pop(table_name, None) if view_meta: try: self.tables[view_meta.base_table_name].views.pop(table_name, None) except KeyError: pass def _add_view_metadata(self, view_metadata): try: self.tables[view_metadata.base_table_name].views[view_metadata.name] = view_metadata self.views[view_metadata.name] = view_metadata except KeyError: pass class UserType(object): """ A user defined type, as created by ``CREATE TYPE`` statements. User-defined types were introduced in Cassandra 2.1. .. versionadded:: 2.1.0 """ keyspace = None """ The string name of the keyspace in which this type is defined. """ name = None """ The name of this type. """ field_names = None """ An ordered list of the names for each field in this user-defined type. """ field_types = None """ An ordered list of the types for each field in this user-defined type. """ def __init__(self, keyspace, name, field_names, field_types): self.keyspace = keyspace self.name = name # non-frozen collections can return None self.field_names = field_names or [] self.field_types = field_types or [] def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this type. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ ret = "CREATE TYPE %s.%s (%s" % ( protect_name(self.keyspace), protect_name(self.name), "\n" if formatted else "") if formatted: field_join = ",\n" padding = " " else: field_join = ", " padding = "" fields = [] for field_name, field_type in zip(self.field_names, self.field_types): fields.append("%s %s" % (protect_name(field_name), field_type)) ret += field_join.join("%s%s" % (padding, field) for field in fields) ret += "\n)" if formatted else ")" return ret def export_as_string(self): return self.as_cql_query(formatted=True) + ';' class Aggregate(object): """ A user defined aggregate function, as created by ``CREATE AGGREGATE`` statements. Aggregate functions were introduced in Cassandra 2.2 .. versionadded:: 2.6.0 """ keyspace = None """ The string name of the keyspace in which this aggregate is defined """ name = None """ The name of this aggregate """ argument_types = None """ An ordered list of the types for each argument to the aggregate """ final_func = None """ Name of a final function """ initial_condition = None """ Initial condition of the aggregate """ return_type = None """ Return type of the aggregate """ state_func = None """ Name of a state function """ state_type = None """ Type of the aggregate state """ def __init__(self, keyspace, name, argument_types, state_func, state_type, final_func, initial_condition, return_type): self.keyspace = keyspace self.name = name self.argument_types = argument_types self.state_func = state_func self.state_type = state_type self.final_func = final_func self.initial_condition = initial_condition self.return_type = return_type def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this aggregate. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace) name = protect_name(self.name) type_list = ', '.join(self.argument_types) state_func = protect_name(self.state_func) state_type = self.state_type ret = "CREATE AGGREGATE %(keyspace)s.%(name)s(%(type_list)s)%(sep)s" \ "SFUNC %(state_func)s%(sep)s" \ "STYPE %(state_type)s" % locals() ret += ''.join((sep, 'FINALFUNC ', protect_name(self.final_func))) if self.final_func else '' ret += ''.join((sep, 'INITCOND ', self.initial_condition)) if self.initial_condition is not None else '' return ret def export_as_string(self): return self.as_cql_query(formatted=True) + ';' @property def signature(self): return SignatureDescriptor.format_signature(self.name, self.argument_types) class Function(object): """ A user defined function, as created by ``CREATE FUNCTION`` statements. User-defined functions were introduced in Cassandra 2.2 .. versionadded:: 2.6.0 """ keyspace = None """ The string name of the keyspace in which this function is defined """ name = None """ The name of this function """ argument_types = None """ An ordered list of the types for each argument to the function """ argument_names = None """ An ordered list of the names of each argument to the function """ return_type = None """ Return type of the function """ language = None """ Language of the function body """ body = None """ Function body string """ called_on_null_input = None """ Flag indicating whether this function should be called for rows with null values (convenience function to avoid handling nulls explicitly if the result will just be null) """ def __init__(self, keyspace, name, argument_types, argument_names, return_type, language, body, called_on_null_input): self.keyspace = keyspace self.name = name self.argument_types = argument_types # argument_types (frozen>) will always be a list # argument_name is not frozen in C* < 3.0 and may return None self.argument_names = argument_names or [] self.return_type = return_type self.language = language self.body = body self.called_on_null_input = called_on_null_input def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this function. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace) name = protect_name(self.name) arg_list = ', '.join(["%s %s" % (protect_name(n), t) for n, t in zip(self.argument_names, self.argument_types)]) typ = self.return_type lang = self.language body = self.body on_null = "CALLED" if self.called_on_null_input else "RETURNS NULL" return "CREATE FUNCTION %(keyspace)s.%(name)s(%(arg_list)s)%(sep)s" \ "%(on_null)s ON NULL INPUT%(sep)s" \ "RETURNS %(typ)s%(sep)s" \ "LANGUAGE %(lang)s%(sep)s" \ "AS $$%(body)s$$" % locals() def export_as_string(self): return self.as_cql_query(formatted=True) + ';' @property def signature(self): return SignatureDescriptor.format_signature(self.name, self.argument_types) class TableMetadata(object): """ A representation of the schema for a single table. """ keyspace_name = None """ String name of this Table's keyspace """ name = None """ The string name of the table. """ partition_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the partition key for this table. This will always hold at least one column. """ clustering_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the clustering key for this table. These are all of the :attr:`.primary_key` columns that are not in the :attr:`.partition_key`. Note that a table may have no clustering keys, in which case this will be an empty list. """ @property def primary_key(self): """ A list of :class:`.ColumnMetadata` representing the components of the primary key for this table. """ return self.partition_key + self.clustering_key columns = None """ A dict mapping column names to :class:`.ColumnMetadata` instances. """ indexes = None """ A dict mapping index names to :class:`.IndexMetadata` instances. """ is_compact_storage = False options = None """ A dict mapping table option names to their specific settings for this table. """ compaction_options = { "min_compaction_threshold": "min_threshold", "max_compaction_threshold": "max_threshold", "compaction_strategy_class": "class"} triggers = None """ A dict mapping trigger names to :class:`.TriggerMetadata` instances. """ views = None """ A dict mapping view names to :class:`.MaterializedViewMetadata` instances. """ _exc_info = None """ set if metadata parsing failed """ + virtual = False + """ + A boolean indicating if this is a virtual table or not. Always ``False`` + for clusters running pre-4.0 versions of Cassandra. + + .. versionadded:: 3.15 + """ + @property def is_cql_compatible(self): """ A boolean indicating if this table can be represented as CQL in export """ + if self.virtual: + return False comparator = getattr(self, 'comparator', None) if comparator: # no compact storage with more than one column beyond PK if there # are clustering columns incompatible = (self.is_compact_storage and len(self.columns) > len(self.primary_key) + 1 and len(self.clustering_key) >= 1) return not incompatible return True extensions = None """ Metadata describing configuration for table extensions """ - def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None): + def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None, virtual=False): self.keyspace_name = keyspace_name self.name = name self.partition_key = [] if partition_key is None else partition_key self.clustering_key = [] if clustering_key is None else clustering_key self.columns = OrderedDict() if columns is None else columns self.indexes = {} self.options = {} if options is None else options self.comparator = None self.triggers = OrderedDict() if triggers is None else triggers self.views = {} + self.virtual = virtual def export_as_string(self): """ Returns a string of CQL queries that can be used to recreate this table along with all indexes on it. The returned string is formatted to be human readable. """ if self._exc_info: import traceback ret = "/*\nWarning: Table %s.%s is incomplete because of an error processing metadata.\n" % \ (self.keyspace_name, self.name) for line in traceback.format_exception(*self._exc_info): ret += line ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() elif not self.is_cql_compatible: # If we can't produce this table with CQL, comment inline ret = "/*\nWarning: Table %s.%s omitted because it has constructs not compatible with CQL (was created via legacy API).\n" % \ (self.keyspace_name, self.name) ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() + elif self.virtual: + ret = ('/*\nWarning: Table {ks}.{tab} is a virtual table and cannot be recreated with CQL.\n' + 'Structure, for reference:\n' + '{cql}\n*/').format(ks=self.keyspace_name, tab=self.name, cql=self._all_as_cql()) + else: ret = self._all_as_cql() return ret def _all_as_cql(self): ret = self.as_cql_query(formatted=True) ret += ";" for index in self.indexes.values(): ret += "\n%s;" % index.as_cql_query() for trigger_meta in self.triggers.values(): ret += "\n%s;" % (trigger_meta.as_cql_query(),) for view_meta in self.views.values(): ret += "\n\n%s;" % (view_meta.as_cql_query(formatted=True),) if self.extensions: registry = _RegisteredExtensionType._extension_registry for k in six.viewkeys(registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey ext = registry[k] cql = ext.after_table_cql(self, k, self.extensions[k]) if cql: ret += "\n\n%s" % (cql,) return ret def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this table (index creations are not included). If `formatted` is set to :const:`True`, extra whitespace will be added to make the query human readable. """ - ret = "CREATE TABLE %s.%s (%s" % ( + ret = "%s TABLE %s.%s (%s" % ( + ('VIRTUAL' if self.virtual else 'CREATE'), protect_name(self.keyspace_name), protect_name(self.name), "\n" if formatted else "") if formatted: column_join = ",\n" padding = " " else: column_join = ", " padding = "" columns = [] for col in self.columns.values(): columns.append("%s %s%s" % (protect_name(col.name), col.cql_type, ' static' if col.is_static else '')) if len(self.partition_key) == 1 and not self.clustering_key: columns[0] += " PRIMARY KEY" ret += column_join.join("%s%s" % (padding, col) for col in columns) # primary key if len(self.partition_key) > 1 or self.clustering_key: ret += "%s%sPRIMARY KEY (" % (column_join, padding) if len(self.partition_key) > 1: ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key) else: ret += protect_name(self.partition_key[0].name) if self.clustering_key: ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key) ret += ")" # properties ret += "%s) WITH " % ("\n" if formatted else "") ret += self._property_string(formatted, self.clustering_key, self.options, self.is_compact_storage) return ret @classmethod def _property_string(cls, formatted, clustering_key, options_map, is_compact_storage=False): properties = [] if is_compact_storage: properties.append("COMPACT STORAGE") if clustering_key: cluster_str = "CLUSTERING ORDER BY " inner = [] for col in clustering_key: ordering = "DESC" if col.is_reversed else "ASC" inner.append("%s %s" % (protect_name(col.name), ordering)) cluster_str += "(%s)" % ", ".join(inner) properties.append(cluster_str) properties.extend(cls._make_option_strings(options_map)) join_str = "\n AND " if formatted else " AND " return join_str.join(properties) @classmethod def _make_option_strings(cls, options_map): ret = [] options_copy = dict(options_map.items()) actual_options = json.loads(options_copy.pop('compaction_strategy_options', '{}')) value = options_copy.pop("compaction_strategy_class", None) actual_options.setdefault("class", value) compaction_option_strings = ["'%s': '%s'" % (k, v) for k, v in actual_options.items()] ret.append('compaction = {%s}' % ', '.join(compaction_option_strings)) for system_table_name in cls.compaction_options.keys(): options_copy.pop(system_table_name, None) # delete if present options_copy.pop('compaction_strategy_option', None) if not options_copy.get('compression'): params = json.loads(options_copy.pop('compression_parameters', '{}')) param_strings = ["'%s': '%s'" % (k, v) for k, v in params.items()] ret.append('compression = {%s}' % ', '.join(param_strings)) for name, value in options_copy.items(): if value is not None: if name == "comment": value = value or "" ret.append("%s = %s" % (name, protect_value(value))) return list(sorted(ret)) class TableExtensionInterface(object): """ Defines CQL/DDL for Cassandra table extensions. """ # limited API for now. Could be expanded as new extension types materialize -- "extend_option_strings", for example @classmethod def after_table_cql(cls, ext_key, ext_blob): """ Called to produce CQL/DDL to follow the table definition. Should contain requisite terminating semicolon(s). """ pass class _RegisteredExtensionType(type): _extension_registry = {} def __new__(mcs, name, bases, dct): cls = super(_RegisteredExtensionType, mcs).__new__(mcs, name, bases, dct) if name != 'RegisteredTableExtension': mcs._extension_registry[cls.name] = cls return cls @six.add_metaclass(_RegisteredExtensionType) class RegisteredTableExtension(TableExtensionInterface): """ Extending this class registers it by name (associated by key in the `system_schema.tables.extensions` map). """ name = None """ Name of the extension (key in the map) """ def protect_name(name): return maybe_escape_name(name) def protect_names(names): return [protect_name(n) for n in names] def protect_value(value): if value is None: return 'NULL' if isinstance(value, (int, float, bool)): return str(value).lower() return "'%s'" % value.replace("'", "''") valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$') def is_valid_name(name): if name is None: return False if name.lower() in cql_keywords_reserved: return False return valid_cql3_word_re.match(name) is not None def maybe_escape_name(name): if is_valid_name(name): return name return escape_name(name) def escape_name(name): return '"%s"' % (name.replace('"', '""'),) class ColumnMetadata(object): """ A representation of a single column in a table. """ table = None """ The :class:`.TableMetadata` this column belongs to. """ name = None """ The string name of this column. """ cql_type = None """ The CQL type for the column. """ is_static = False """ If this column is static (available in Cassandra 2.1+), this will be :const:`True`, otherwise :const:`False`. """ is_reversed = False """ If this column is reversed (DESC) as in clustering order """ _cass_type = None def __init__(self, table_metadata, column_name, cql_type, is_static=False, is_reversed=False): self.table = table_metadata self.name = column_name self.cql_type = cql_type self.is_static = is_static self.is_reversed = is_reversed def __str__(self): return "%s %s" % (self.name, self.cql_type) class IndexMetadata(object): """ A representation of a secondary index on a column. """ keyspace_name = None """ A string name of the keyspace. """ table_name = None """ A string name of the table this index is on. """ name = None """ A string name for the index. """ kind = None """ A string representing the kind of index (COMPOSITE, CUSTOM,...). """ index_options = {} """ A dict of index options. """ def __init__(self, keyspace_name, table_name, index_name, kind, index_options): self.keyspace_name = keyspace_name self.table_name = table_name self.name = index_name self.kind = kind self.index_options = index_options def as_cql_query(self): """ Returns a CQL query that can be used to recreate this index. """ options = dict(self.index_options) index_target = options.pop("target") if self.kind != "CUSTOM": return "CREATE INDEX %s ON %s.%s (%s)" % ( protect_name(self.name), protect_name(self.keyspace_name), protect_name(self.table_name), index_target) else: class_name = options.pop("class_name") ret = "CREATE CUSTOM INDEX %s ON %s.%s (%s) USING '%s'" % ( protect_name(self.name), protect_name(self.keyspace_name), protect_name(self.table_name), index_target, class_name) if options: - ret += " WITH OPTIONS = %s" % Encoder().cql_encode_all_types(options) + # PYTHON-1008: `ret` will always be a unicode + opts_cql_encoded = _encoder.cql_encode_all_types(options, as_text_type=True) + ret += " WITH OPTIONS = %s" % opts_cql_encoded return ret def export_as_string(self): """ Returns a CQL query string that can be used to recreate this index. """ return self.as_cql_query() + ';' class TokenMap(object): """ Information about the layout of the ring. """ token_class = None """ A subclass of :class:`.Token`, depending on what partitioner the cluster uses. """ token_to_host_owner = None """ A map of :class:`.Token` objects to the :class:`.Host` that owns that token. """ tokens_to_hosts_by_ks = None """ A map of keyspace names to a nested map of :class:`.Token` objects to sets of :class:`.Host` objects. """ ring = None """ An ordered list of :class:`.Token` instances in the ring. """ _metadata = None def __init__(self, token_class, token_to_host_owner, all_tokens, metadata): self.token_class = token_class self.ring = all_tokens self.token_to_host_owner = token_to_host_owner self.tokens_to_hosts_by_ks = {} self._metadata = metadata self._rebuild_lock = RLock() def rebuild_keyspace(self, keyspace, build_if_absent=False): with self._rebuild_lock: try: current = self.tokens_to_hosts_by_ks.get(keyspace, None) if (build_if_absent and current is None) or (not build_if_absent and current is not None): ks_meta = self._metadata.keyspaces.get(keyspace) if ks_meta: replica_map = self.replica_map_for_keyspace(self._metadata.keyspaces[keyspace]) self.tokens_to_hosts_by_ks[keyspace] = replica_map except Exception: # should not happen normally, but we don't want to blow up queries because of unexpected meta state # bypass until new map is generated self.tokens_to_hosts_by_ks[keyspace] = {} log.exception("Failed creating a token map for keyspace '%s' with %s. PLEASE REPORT THIS: https://datastax-oss.atlassian.net/projects/PYTHON", keyspace, self.token_to_host_owner) def replica_map_for_keyspace(self, ks_metadata): strategy = ks_metadata.replication_strategy if strategy: return strategy.make_token_replica_map(self.token_to_host_owner, self.ring) else: return None def remove_keyspace(self, keyspace): self.tokens_to_hosts_by_ks.pop(keyspace, None) def get_replicas(self, keyspace, token): """ Get a set of :class:`.Host` instances representing all of the replica nodes for a given :class:`.Token`. """ tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None) if tokens_to_hosts is None: self.rebuild_keyspace(keyspace, build_if_absent=True) tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None) if tokens_to_hosts: - # token range ownership is exclusive on the LHS (the start token), so - # we use bisect_right, which, in the case of a tie/exact match, - # picks an insertion point to the right of the existing match - point = bisect_right(self.ring, token) + # The values in self.ring correspond to the end of the + # token range up to and including the value listed. + point = bisect_left(self.ring, token) if point == len(self.ring): return tokens_to_hosts[self.ring[0]] else: return tokens_to_hosts[self.ring[point]] return [] @total_ordering class Token(object): """ Abstract class representing a token. """ def __init__(self, token): self.value = token @classmethod def hash_fn(cls, key): return key @classmethod def from_key(cls, key): return cls(cls.hash_fn(key)) @classmethod def from_string(cls, token_string): raise NotImplementedError() def __eq__(self, other): return self.value == other.value def __lt__(self, other): return self.value < other.value def __hash__(self): return hash(self.value) def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self.value) __str__ = __repr__ + MIN_LONG = -(2 ** 63) MAX_LONG = (2 ** 63) - 1 class NoMurmur3(Exception): pass class HashToken(Token): @classmethod def from_string(cls, token_string): """ `token_string` should be the string representation from the server. """ # The hash partitioners just store the deciman value return cls(int(token_string)) class Murmur3Token(HashToken): """ A token for ``Murmur3Partitioner``. """ @classmethod def hash_fn(cls, key): if murmur3 is not None: h = int(murmur3(key)) return h if h != MIN_LONG else MAX_LONG else: raise NoMurmur3() def __init__(self, token): """ `token` is an int or string representing the token. """ self.value = int(token) class MD5Token(HashToken): """ A token for ``RandomPartitioner``. """ @classmethod def hash_fn(cls, key): if isinstance(key, six.text_type): key = key.encode('UTF-8') return abs(varint_unpack(md5(key).digest())) class BytesToken(Token): """ A token for ``ByteOrderedPartitioner``. """ @classmethod def from_string(cls, token_string): """ `token_string` should be the string representation from the server. """ # unhexlify works fine with unicode input in everythin but pypy3, where it Raises "TypeError: 'str' does not support the buffer interface" if isinstance(token_string, six.text_type): token_string = token_string.encode('ascii') # The BOP stores a hex string return cls(unhexlify(token_string)) class TriggerMetadata(object): """ A representation of a trigger for a table. """ table = None """ The :class:`.TableMetadata` this trigger belongs to. """ name = None """ The string name of this trigger. """ options = None """ A dict mapping trigger option names to their specific settings for this table. """ def __init__(self, table_metadata, trigger_name, options=None): self.table = table_metadata self.name = trigger_name self.options = options def as_cql_query(self): ret = "CREATE TRIGGER %s ON %s.%s USING %s" % ( protect_name(self.name), protect_name(self.table.keyspace_name), protect_name(self.table.name), protect_value(self.options['class']) ) return ret def export_as_string(self): return self.as_cql_query() + ';' class _SchemaParser(object): def __init__(self, connection, timeout): self.connection = connection self.timeout = timeout - def _handle_results(self, success, result): - if success: + def _handle_results(self, success, result, expected_failures=tuple()): + """ + Given a bool and a ResultSet (the form returned per result from + Connection.wait_for_responses), return a dictionary containing the + results. Used to process results from asynchronous queries to system + tables. + + ``expected_failures`` will usually be used to allow callers to ignore + ``InvalidRequest`` errors caused by a missing system keyspace. For + example, some DSE versions report a 4.X server version, but do not have + virtual tables. Thus, running against 4.X servers, SchemaParserV4 uses + expected_failures to make a best-effort attempt to read those + keyspaces, but treat them as empty if they're not found. + + :param success: A boolean representing whether or not the query + succeeded + :param result: The resultset in question. + :expected_failures: An Exception class or an iterable thereof. If the + query failed, but raised an instance of an expected failure class, this + will ignore the failure and return an empty list. + """ + if not success and isinstance(result, expected_failures): + return [] + elif success: return dict_factory(*result.results) if result else [] else: raise result def _query_build_row(self, query_string, build_func): result = self._query_build_rows(query_string, build_func) return result[0] if result else None def _query_build_rows(self, query_string, build_func): query = QueryMessage(query=query_string, consistency_level=ConsistencyLevel.ONE) responses = self.connection.wait_for_responses((query), timeout=self.timeout, fail_on_error=False) (success, response) = responses[0] if success: result = dict_factory(*response.results) return [build_func(row) for row in result] elif isinstance(response, InvalidRequest): log.debug("user types table not found") return [] else: raise response class SchemaParserV22(_SchemaParser): _SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces" _SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies" _SELECT_COLUMNS = "SELECT * FROM system.schema_columns" _SELECT_TRIGGERS = "SELECT * FROM system.schema_triggers" _SELECT_TYPES = "SELECT * FROM system.schema_usertypes" _SELECT_FUNCTIONS = "SELECT * FROM system.schema_functions" _SELECT_AGGREGATES = "SELECT * FROM system.schema_aggregates" _table_name_col = 'columnfamily_name' _function_agg_arument_type_col = 'signature' recognized_table_options = ( "comment", "read_repair_chance", "dclocal_read_repair_chance", # kept to be safe, but see _build_table_options() "local_read_repair_chance", "replicate_on_write", "gc_grace_seconds", "bloom_filter_fp_chance", "caching", "compaction_strategy_class", "compaction_strategy_options", "min_compaction_threshold", "max_compaction_threshold", "compression_parameters", "min_index_interval", "max_index_interval", "index_interval", "speculative_retry", "rows_per_partition_to_cache", "memtable_flush_period_in_ms", "populate_io_cache_on_flush", "compression", "default_time_to_live") def __init__(self, connection, timeout): super(SchemaParserV22, self).__init__(connection, timeout) self.keyspaces_result = [] self.tables_result = [] self.columns_result = [] self.triggers_result = [] self.types_result = [] self.functions_result = [] self.aggregates_result = [] self.keyspace_table_rows = defaultdict(list) self.keyspace_table_col_rows = defaultdict(lambda: defaultdict(list)) self.keyspace_type_rows = defaultdict(list) self.keyspace_func_rows = defaultdict(list) self.keyspace_agg_rows = defaultdict(list) self.keyspace_table_trigger_rows = defaultdict(lambda: defaultdict(list)) def get_all_keyspaces(self): self._query_all() for row in self.keyspaces_result: keyspace_meta = self._build_keyspace_metadata(row) try: for table_row in self.keyspace_table_rows.get(keyspace_meta.name, []): table_meta = self._build_table_metadata(table_row) keyspace_meta._add_table_metadata(table_meta) for usertype_row in self.keyspace_type_rows.get(keyspace_meta.name, []): usertype = self._build_user_type(usertype_row) keyspace_meta.user_types[usertype.name] = usertype for fn_row in self.keyspace_func_rows.get(keyspace_meta.name, []): fn = self._build_function(fn_row) keyspace_meta.functions[fn.signature] = fn for agg_row in self.keyspace_agg_rows.get(keyspace_meta.name, []): agg = self._build_aggregate(agg_row) keyspace_meta.aggregates[agg.signature] = agg except Exception: log.exception("Error while parsing metadata for keyspace %s. Metadata model will be incomplete.", keyspace_meta.name) keyspace_meta._exc_info = sys.exc_info() yield keyspace_meta def get_table(self, keyspaces, keyspace, table): cl = ConsistencyLevel.ONE where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col,), (keyspace, table), _encoder) cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl) col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl) (cf_success, cf_result), (col_success, col_result), (triggers_success, triggers_result) \ = self.connection.wait_for_responses(cf_query, col_query, triggers_query, timeout=self.timeout, fail_on_error=False) table_result = self._handle_results(cf_success, cf_result) col_result = self._handle_results(col_success, col_result) - # handle the triggers table not existing in Cassandra 1.2 - if not triggers_success and isinstance(triggers_result, InvalidRequest): - triggers_result = [] - else: - triggers_result = self._handle_results(triggers_success, triggers_result) + # the triggers table doesn't exist in C* 1.2 + triggers_result = self._handle_results(triggers_success, triggers_result, + expected_failures=InvalidRequest) if table_result: return self._build_table_metadata(table_result[0], col_result, triggers_result) def get_type(self, keyspaces, keyspace, type): where_clause = bind_params(" WHERE keyspace_name = %s AND type_name = %s", (keyspace, type), _encoder) return self._query_build_row(self._SELECT_TYPES + where_clause, self._build_user_type) def get_types_map(self, keyspaces, keyspace): where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) types = self._query_build_rows(self._SELECT_TYPES + where_clause, self._build_user_type) return dict((t.name, t) for t in types) def get_function(self, keyspaces, keyspace, function): where_clause = bind_params(" WHERE keyspace_name = %%s AND function_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), (keyspace, function.name, function.argument_types), _encoder) return self._query_build_row(self._SELECT_FUNCTIONS + where_clause, self._build_function) def get_aggregate(self, keyspaces, keyspace, aggregate): where_clause = bind_params(" WHERE keyspace_name = %%s AND aggregate_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), (keyspace, aggregate.name, aggregate.argument_types), _encoder) return self._query_build_row(self._SELECT_AGGREGATES + where_clause, self._build_aggregate) def get_keyspace(self, keyspaces, keyspace): where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) return self._query_build_row(self._SELECT_KEYSPACES + where_clause, self._build_keyspace_metadata) @classmethod def _build_keyspace_metadata(cls, row): try: ksm = cls._build_keyspace_metadata_internal(row) except Exception: name = row["keyspace_name"] ksm = KeyspaceMetadata(name, False, 'UNKNOWN', {}) ksm._exc_info = sys.exc_info() # capture exc_info before log because nose (test) logging clears it in certain circumstances log.exception("Error while parsing metadata for keyspace %s row(%s)", name, row) return ksm @staticmethod def _build_keyspace_metadata_internal(row): name = row["keyspace_name"] durable_writes = row["durable_writes"] strategy_class = row["strategy_class"] strategy_options = json.loads(row["strategy_options"]) return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) @classmethod def _build_user_type(cls, usertype_row): field_types = list(map(cls._schema_type_to_cql, usertype_row['field_types'])) return UserType(usertype_row['keyspace_name'], usertype_row['type_name'], usertype_row['field_names'], field_types) @classmethod def _build_function(cls, function_row): return_type = cls._schema_type_to_cql(function_row['return_type']) return Function(function_row['keyspace_name'], function_row['function_name'], function_row[cls._function_agg_arument_type_col], function_row['argument_names'], return_type, function_row['language'], function_row['body'], function_row['called_on_null_input']) @classmethod def _build_aggregate(cls, aggregate_row): cass_state_type = types.lookup_casstype(aggregate_row['state_type']) initial_condition = aggregate_row['initcond'] if initial_condition is not None: initial_condition = _encoder.cql_encode_all_types(cass_state_type.deserialize(initial_condition, 3)) state_type = _cql_from_cass_type(cass_state_type) return_type = cls._schema_type_to_cql(aggregate_row['return_type']) return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], aggregate_row['signature'], aggregate_row['state_func'], state_type, aggregate_row['final_func'], initial_condition, return_type) def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): keyspace_name = row["keyspace_name"] cfname = row[self._table_name_col] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][cfname] trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][cfname] if not col_rows: # CASSANDRA-8487 log.warning("Building table metadata with no column meta for %s.%s", keyspace_name, cfname) table_meta = TableMetadata(keyspace_name, cfname) try: comparator = types.lookup_casstype(row["comparator"]) table_meta.comparator = comparator is_dct_comparator = issubclass(comparator, types.DynamicCompositeType) is_composite_comparator = issubclass(comparator, types.CompositeType) column_name_types = comparator.subtypes if is_composite_comparator else (comparator,) num_column_name_components = len(column_name_types) last_col = column_name_types[-1] column_aliases = row.get("column_aliases", None) clustering_rows = [r for r in col_rows if r.get('type', None) == "clustering_key"] if len(clustering_rows) > 1: clustering_rows = sorted(clustering_rows, key=lambda row: row.get('component_index')) if column_aliases is not None: column_aliases = json.loads(column_aliases) if not column_aliases: # json load failed or column_aliases empty PYTHON-562 column_aliases = [r.get('column_name') for r in clustering_rows] if is_composite_comparator: if issubclass(last_col, types.ColumnToCollectionType): # collections is_compact = False has_value = False clustering_size = num_column_name_components - 2 - elif (len(column_aliases) == num_column_name_components - 1 - and issubclass(last_col, types.UTF8Type)): + elif (len(column_aliases) == num_column_name_components - 1 and + issubclass(last_col, types.UTF8Type)): # aliases? is_compact = False has_value = False clustering_size = num_column_name_components - 1 else: # compact table is_compact = True has_value = column_aliases or not col_rows clustering_size = num_column_name_components # Some thrift tables define names in composite types (see PYTHON-192) if not column_aliases and hasattr(comparator, 'fieldnames'): column_aliases = filter(None, comparator.fieldnames) else: is_compact = True if column_aliases or not col_rows or is_dct_comparator: has_value = True clustering_size = num_column_name_components else: has_value = False clustering_size = 0 # partition key partition_rows = [r for r in col_rows if r.get('type', None) == "partition_key"] if len(partition_rows) > 1: partition_rows = sorted(partition_rows, key=lambda row: row.get('component_index')) key_aliases = row.get("key_aliases") if key_aliases is not None: key_aliases = json.loads(key_aliases) if key_aliases else [] else: # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. key_aliases = [r.get('column_name') for r in partition_rows] key_validator = row.get("key_validator") if key_validator is not None: key_type = types.lookup_casstype(key_validator) key_types = key_type.subtypes if issubclass(key_type, types.CompositeType) else [key_type] else: key_types = [types.lookup_casstype(r.get('validator')) for r in partition_rows] for i, col_type in enumerate(key_types): if len(key_aliases) > i: column_name = key_aliases[i] elif i == 0: column_name = "key" else: column_name = "key%d" % i col = ColumnMetadata(table_meta, column_name, col_type.cql_parameterized_type()) table_meta.columns[column_name] = col table_meta.partition_key.append(col) # clustering key for i in range(clustering_size): if len(column_aliases) > i: column_name = column_aliases[i] else: column_name = "column%d" % (i + 1) data_type = column_name_types[i] cql_type = _cql_from_cass_type(data_type) is_reversed = types.is_reversed_casstype(data_type) col = ColumnMetadata(table_meta, column_name, cql_type, is_reversed=is_reversed) table_meta.columns[column_name] = col table_meta.clustering_key.append(col) # value alias (if present) if has_value: value_alias_rows = [r for r in col_rows if r.get('type', None) == "compact_value"] if not key_aliases: # TODO are we checking the right thing here? value_alias = "value" else: value_alias = row.get("value_alias", None) if value_alias is None and value_alias_rows: # CASSANDRA-8487 # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. value_alias = value_alias_rows[0].get('column_name') default_validator = row.get("default_validator") if default_validator: validator = types.lookup_casstype(default_validator) else: if value_alias_rows: # CASSANDRA-8487 validator = types.lookup_casstype(value_alias_rows[0].get('validator')) cql_type = _cql_from_cass_type(validator) col = ColumnMetadata(table_meta, value_alias, cql_type) if value_alias: # CASSANDRA-8487 table_meta.columns[value_alias] = col # other normal columns for col_row in col_rows: column_meta = self._build_column_metadata(table_meta, col_row) if column_meta.name: table_meta.columns[column_meta.name] = column_meta index_meta = self._build_index_metadata(column_meta, col_row) if index_meta: table_meta.indexes[index_meta.name] = index_meta for trigger_row in trigger_rows: trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) table_meta.triggers[trigger_meta.name] = trigger_meta table_meta.options = self._build_table_options(row) table_meta.is_compact_storage = is_compact except Exception: table_meta._exc_info = sys.exc_info() log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, cfname, row, col_rows) return table_meta def _build_table_options(self, row): """ Setup the mostly-non-schema table options, like caching settings """ options = dict((o, row.get(o)) for o in self.recognized_table_options if o in row) # the option name when creating tables is "dclocal_read_repair_chance", # but the column name in system.schema_columnfamilies is # "local_read_repair_chance". We'll store this as dclocal_read_repair_chance, # since that's probably what users are expecting (and we need it for the # CREATE TABLE statement anyway). if "local_read_repair_chance" in options: val = options.pop("local_read_repair_chance") options["dclocal_read_repair_chance"] = val return options @classmethod def _build_column_metadata(cls, table_metadata, row): name = row["column_name"] type_string = row["validator"] data_type = types.lookup_casstype(type_string) cql_type = _cql_from_cass_type(data_type) is_static = row.get("type", None) == "static" is_reversed = types.is_reversed_casstype(data_type) column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed) column_meta._cass_type = data_type return column_meta @staticmethod def _build_index_metadata(column_metadata, row): index_name = row.get("index_name") kind = row.get("index_type") if index_name or kind: options = row.get("index_options") options = json.loads(options) if options else {} options = options or {} # if the json parsed to None, init empty dict # generate a CQL index identity string target = protect_name(column_metadata.name) if kind != "CUSTOM": if "index_keys" in options: target = 'keys(%s)' % (target,) elif "index_values" in options: # don't use any "function" for collection values pass else: # it might be a "full" index on a frozen collection, but # we need to check the data type to verify that, because # there is no special index option for full-collection # indexes. data_type = column_metadata._cass_type collection_types = ('map', 'set', 'list') if data_type.typename == "frozen" and data_type.subtypes[0].typename in collection_types: # no index option for full-collection index target = 'full(%s)' % (target,) options['target'] = target return IndexMetadata(column_metadata.table.keyspace_name, column_metadata.table.name, index_name, kind, options) @staticmethod def _build_trigger_metadata(table_metadata, row): name = row["trigger_name"] options = row["trigger_options"] trigger_meta = TriggerMetadata(table_metadata, name, options) return trigger_meta def _query_all(self): cl = ConsistencyLevel.ONE queries = [ QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), QueryMessage(query=self._SELECT_COLUMN_FAMILIES, consistency_level=cl), QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl) ] - responses = self.connection.wait_for_responses(*queries, timeout=self.timeout, fail_on_error=False) - (ks_success, ks_result), (table_success, table_result), \ - (col_success, col_result), (types_success, types_result), \ - (functions_success, functions_result), \ - (aggregates_success, aggregates_result), \ - (triggers_success, triggers_result) = responses + ((ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result)) = ( + self.connection.wait_for_responses(*queries, timeout=self.timeout, + fail_on_error=False) + ) self.keyspaces_result = self._handle_results(ks_success, ks_result) self.tables_result = self._handle_results(table_success, table_result) self.columns_result = self._handle_results(col_success, col_result) # if we're connected to Cassandra < 2.0, the triggers table will not exist if triggers_success: self.triggers_result = dict_factory(*triggers_result.results) else: if isinstance(triggers_result, InvalidRequest): log.debug("triggers table not found") elif isinstance(triggers_result, Unauthorized): log.warning("this version of Cassandra does not allow access to schema_triggers metadata with authorization enabled (CASSANDRA-7967); " "The driver will operate normally, but will not reflect triggers in the local metadata model, or schema strings.") else: raise triggers_result # if we're connected to Cassandra < 2.1, the usertypes table will not exist if types_success: self.types_result = dict_factory(*types_result.results) else: if isinstance(types_result, InvalidRequest): log.debug("user types table not found") self.types_result = {} else: raise types_result # functions were introduced in Cassandra 2.2 if functions_success: self.functions_result = dict_factory(*functions_result.results) else: if isinstance(functions_result, InvalidRequest): log.debug("user functions table not found") else: raise functions_result # aggregates were introduced in Cassandra 2.2 if aggregates_success: self.aggregates_result = dict_factory(*aggregates_result.results) else: if isinstance(aggregates_result, InvalidRequest): log.debug("user aggregates table not found") else: raise aggregates_result self._aggregate_results() def _aggregate_results(self): m = self.keyspace_table_rows for row in self.tables_result: m[row["keyspace_name"]].append(row) m = self.keyspace_table_col_rows for row in self.columns_result: ksname = row["keyspace_name"] cfname = row[self._table_name_col] m[ksname][cfname].append(row) m = self.keyspace_type_rows for row in self.types_result: m[row["keyspace_name"]].append(row) m = self.keyspace_func_rows for row in self.functions_result: m[row["keyspace_name"]].append(row) m = self.keyspace_agg_rows for row in self.aggregates_result: m[row["keyspace_name"]].append(row) m = self.keyspace_table_trigger_rows for row in self.triggers_result: ksname = row["keyspace_name"] cfname = row[self._table_name_col] m[ksname][cfname].append(row) @staticmethod def _schema_type_to_cql(type_string): cass_type = types.lookup_casstype(type_string) return _cql_from_cass_type(cass_type) class SchemaParserV3(SchemaParserV22): _SELECT_KEYSPACES = "SELECT * FROM system_schema.keyspaces" _SELECT_TABLES = "SELECT * FROM system_schema.tables" _SELECT_COLUMNS = "SELECT * FROM system_schema.columns" _SELECT_INDEXES = "SELECT * FROM system_schema.indexes" _SELECT_TRIGGERS = "SELECT * FROM system_schema.triggers" _SELECT_TYPES = "SELECT * FROM system_schema.types" _SELECT_FUNCTIONS = "SELECT * FROM system_schema.functions" _SELECT_AGGREGATES = "SELECT * FROM system_schema.aggregates" _SELECT_VIEWS = "SELECT * FROM system_schema.views" _table_name_col = 'table_name' _function_agg_arument_type_col = 'argument_types' recognized_table_options = ( 'bloom_filter_fp_chance', 'caching', 'cdc', 'comment', 'compaction', 'compression', 'crc_check_chance', 'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds', 'max_index_interval', 'memtable_flush_period_in_ms', 'min_index_interval', 'read_repair_chance', 'speculative_retry') def __init__(self, connection, timeout): super(SchemaParserV3, self).__init__(connection, timeout) self.indexes_result = [] self.keyspace_table_index_rows = defaultdict(lambda: defaultdict(list)) self.keyspace_view_rows = defaultdict(list) def get_all_keyspaces(self): for keyspace_meta in super(SchemaParserV3, self).get_all_keyspaces(): for row in self.keyspace_view_rows[keyspace_meta.name]: view_meta = self._build_view_metadata(row) keyspace_meta._add_view_metadata(view_meta) yield keyspace_meta def get_table(self, keyspaces, keyspace, table): cl = ConsistencyLevel.ONE where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder) cf_query = QueryMessage(query=self._SELECT_TABLES + where_clause, consistency_level=cl) col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) indexes_query = QueryMessage(query=self._SELECT_INDEXES + where_clause, consistency_level=cl) triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl) # in protocol v4 we don't know if this event is a view or a table, so we look for both where_clause = bind_params(" WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder) view_query = QueryMessage(query=self._SELECT_VIEWS + where_clause, consistency_level=cl) - (cf_success, cf_result), (col_success, col_result), (indexes_sucess, indexes_result), \ - (triggers_success, triggers_result), (view_success, view_result) \ - = self.connection.wait_for_responses(cf_query, col_query, indexes_query, triggers_query, view_query, - timeout=self.timeout, fail_on_error=False) + ((cf_success, cf_result), (col_success, col_result), + (indexes_sucess, indexes_result), (triggers_success, triggers_result), + (view_success, view_result)) = ( + self.connection.wait_for_responses( + cf_query, col_query, indexes_query, triggers_query, + view_query, timeout=self.timeout, fail_on_error=False) + ) table_result = self._handle_results(cf_success, cf_result) col_result = self._handle_results(col_success, col_result) if table_result: indexes_result = self._handle_results(indexes_sucess, indexes_result) triggers_result = self._handle_results(triggers_success, triggers_result) return self._build_table_metadata(table_result[0], col_result, triggers_result, indexes_result) view_result = self._handle_results(view_success, view_result) if view_result: return self._build_view_metadata(view_result[0], col_result) @staticmethod def _build_keyspace_metadata_internal(row): name = row["keyspace_name"] durable_writes = row["durable_writes"] strategy_options = dict(row["replication"]) strategy_class = strategy_options.pop("class") return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) @staticmethod def _build_aggregate(aggregate_row): return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], aggregate_row['argument_types'], aggregate_row['state_func'], aggregate_row['state_type'], aggregate_row['final_func'], aggregate_row['initcond'], aggregate_row['return_type']) - def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_rows=None): + def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_rows=None, virtual=False): keyspace_name = row["keyspace_name"] table_name = row[self._table_name_col] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][table_name] trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][table_name] index_rows = index_rows or self.keyspace_table_index_rows[keyspace_name][table_name] - table_meta = TableMetadataV3(keyspace_name, table_name) + table_meta = TableMetadataV3(keyspace_name, table_name, virtual=virtual) try: table_meta.options = self._build_table_options(row) flags = row.get('flags', set()) if flags: compact_static = False table_meta.is_compact_storage = 'dense' in flags or 'super' in flags or 'compound' not in flags is_dense = 'dense' in flags + elif virtual: + compact_static = False + table_meta.is_compact_storage = False + is_dense = False else: compact_static = True table_meta.is_compact_storage = True is_dense = False - self._build_table_columns(table_meta, col_rows, compact_static, is_dense) + self._build_table_columns(table_meta, col_rows, compact_static, is_dense, virtual) for trigger_row in trigger_rows: trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) table_meta.triggers[trigger_meta.name] = trigger_meta for index_row in index_rows: index_meta = self._build_index_metadata(table_meta, index_row) if index_meta: table_meta.indexes[index_meta.name] = index_meta table_meta.extensions = row.get('extensions', {}) except Exception: table_meta._exc_info = sys.exc_info() log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, table_name, row, col_rows) return table_meta def _build_table_options(self, row): """ Setup the mostly-non-schema table options, like caching settings """ return dict((o, row.get(o)) for o in self.recognized_table_options if o in row) - def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False): + def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False, virtual=False): # partition key partition_rows = [r for r in col_rows if r.get('kind', None) == "partition_key"] if len(partition_rows) > 1: partition_rows = sorted(partition_rows, key=lambda row: row.get('position')) for r in partition_rows: # we have to add meta here (and not in the later loop) because TableMetadata.columns is an # OrderedDict, and it assumes keys are inserted first, in order, when exporting CQL column_meta = self._build_column_metadata(meta, r) meta.columns[column_meta.name] = column_meta meta.partition_key.append(meta.columns[r.get('column_name')]) # clustering key if not compact_static: clustering_rows = [r for r in col_rows if r.get('kind', None) == "clustering"] if len(clustering_rows) > 1: clustering_rows = sorted(clustering_rows, key=lambda row: row.get('position')) for r in clustering_rows: column_meta = self._build_column_metadata(meta, r) meta.columns[column_meta.name] = column_meta meta.clustering_key.append(meta.columns[r.get('column_name')]) for col_row in (r for r in col_rows if r.get('kind', None) not in ('partition_key', 'clustering_key')): column_meta = self._build_column_metadata(meta, col_row) if is_dense and column_meta.cql_type == types.cql_empty_type: continue if compact_static and not column_meta.is_static: # for compact static tables, we omit the clustering key and value, and only add the logical columns. # They are marked not static so that it generates appropriate CQL continue if compact_static: column_meta.is_static = False meta.columns[column_meta.name] = column_meta def _build_view_metadata(self, row, col_rows=None): keyspace_name = row["keyspace_name"] view_name = row["view_name"] base_table_name = row["base_table_name"] include_all_columns = row["include_all_columns"] where_clause = row["where_clause"] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][view_name] view_meta = MaterializedViewMetadata(keyspace_name, view_name, base_table_name, include_all_columns, where_clause, self._build_table_options(row)) self._build_table_columns(view_meta, col_rows) view_meta.extensions = row.get('extensions', {}) return view_meta @staticmethod def _build_column_metadata(table_metadata, row): name = row["column_name"] cql_type = row["type"] is_static = row.get("kind", None) == "static" is_reversed = row["clustering_order"].upper() == "DESC" column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed) return column_meta @staticmethod def _build_index_metadata(table_metadata, row): index_name = row.get("index_name") kind = row.get("kind") if index_name or kind: index_options = row.get("options") return IndexMetadata(table_metadata.keyspace_name, table_metadata.name, index_name, kind, index_options) else: return None @staticmethod def _build_trigger_metadata(table_metadata, row): name = row["trigger_name"] options = row["options"] trigger_meta = TriggerMetadata(table_metadata, name, options) return trigger_meta def _query_all(self): cl = ConsistencyLevel.ONE queries = [ QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), QueryMessage(query=self._SELECT_TABLES, consistency_level=cl), QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl), QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl), QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl) ] - responses = self.connection.wait_for_responses(*queries, timeout=self.timeout, fail_on_error=False) - (ks_success, ks_result), (table_success, table_result), \ - (col_success, col_result), (types_success, types_result), \ - (functions_success, functions_result), \ - (aggregates_success, aggregates_result), \ - (triggers_success, triggers_result), \ - (indexes_success, indexes_result), \ - (views_success, views_result) = responses + ((ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result), + (indexes_success, indexes_result), + (views_success, views_result)) = self.connection.wait_for_responses( + *queries, timeout=self.timeout, fail_on_error=False + ) self.keyspaces_result = self._handle_results(ks_success, ks_result) self.tables_result = self._handle_results(table_success, table_result) self.columns_result = self._handle_results(col_success, col_result) self.triggers_result = self._handle_results(triggers_success, triggers_result) self.types_result = self._handle_results(types_success, types_result) self.functions_result = self._handle_results(functions_success, functions_result) self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) self.indexes_result = self._handle_results(indexes_success, indexes_result) self.views_result = self._handle_results(views_success, views_result) self._aggregate_results() def _aggregate_results(self): super(SchemaParserV3, self)._aggregate_results() m = self.keyspace_table_index_rows for row in self.indexes_result: ksname = row["keyspace_name"] cfname = row[self._table_name_col] m[ksname][cfname].append(row) m = self.keyspace_view_rows for row in self.views_result: m[row["keyspace_name"]].append(row) @staticmethod def _schema_type_to_cql(type_string): return type_string +class SchemaParserV4(SchemaParserV3): + + recognized_table_options = tuple( + opt for opt in + SchemaParserV3.recognized_table_options + if opt not in ( + # removed in V4: CASSANDRA-13910 + 'dclocal_read_repair_chance', 'read_repair_chance' + ) + ) + + _SELECT_VIRTUAL_KEYSPACES = 'SELECT * from system_virtual_schema.keyspaces' + _SELECT_VIRTUAL_TABLES = 'SELECT * from system_virtual_schema.tables' + _SELECT_VIRTUAL_COLUMNS = 'SELECT * from system_virtual_schema.columns' + + def __init__(self, connection, timeout): + super(SchemaParserV4, self).__init__(connection, timeout) + self.virtual_keyspaces_rows = defaultdict(list) + self.virtual_tables_rows = defaultdict(list) + self.virtual_columns_rows = defaultdict(lambda: defaultdict(list)) + + def _query_all(self): + cl = ConsistencyLevel.ONE + # todo: this duplicates V3; we should find a way for _query_all methods + # to extend each other. + queries = [ + # copied from V3 + QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), + QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), + QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), + QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), + QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl), + QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl), + # V4-only queries + QueryMessage(query=self._SELECT_VIRTUAL_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_COLUMNS, consistency_level=cl) + ] + + responses = self.connection.wait_for_responses( + *queries, timeout=self.timeout, fail_on_error=False) + ( + # copied from V3 + (ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result), + (indexes_success, indexes_result), + (views_success, views_result), + # V4-only responses + (virtual_ks_success, virtual_ks_result), + (virtual_table_success, virtual_table_result), + (virtual_column_success, virtual_column_result) + ) = responses + + # copied from V3 + self.keyspaces_result = self._handle_results(ks_success, ks_result) + self.tables_result = self._handle_results(table_success, table_result) + self.columns_result = self._handle_results(col_success, col_result) + self.triggers_result = self._handle_results(triggers_success, triggers_result) + self.types_result = self._handle_results(types_success, types_result) + self.functions_result = self._handle_results(functions_success, functions_result) + self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) + self.indexes_result = self._handle_results(indexes_success, indexes_result) + self.views_result = self._handle_results(views_success, views_result) + # V4-only results + # These tables don't exist in some DSE versions reporting 4.X so we can + # ignore them if we got an error + self.virtual_keyspaces_result = self._handle_results( + virtual_ks_success, virtual_ks_result, + expected_failures=InvalidRequest + ) + self.virtual_tables_result = self._handle_results( + virtual_table_success, virtual_table_result, + expected_failures=InvalidRequest + ) + self.virtual_columns_result = self._handle_results( + virtual_column_success, virtual_column_result, + expected_failures=InvalidRequest + ) + + self._aggregate_results() + + def _aggregate_results(self): + super(SchemaParserV4, self)._aggregate_results() + + m = self.virtual_tables_rows + for row in self.virtual_tables_result: + m[row["keyspace_name"]].append(row) + + m = self.virtual_columns_rows + for row in self.virtual_columns_result: + ks_name = row['keyspace_name'] + tab_name = row[self._table_name_col] + m[ks_name][tab_name].append(row) + + def get_all_keyspaces(self): + for x in super(SchemaParserV4, self).get_all_keyspaces(): + yield x + + for row in self.virtual_keyspaces_result: + ks_name = row['keyspace_name'] + keyspace_meta = self._build_keyspace_metadata(row) + keyspace_meta.virtual = True + + for table_row in self.virtual_tables_rows.get(ks_name, []): + table_name = table_row[self._table_name_col] + + col_rows = self.virtual_columns_rows[ks_name][table_name] + keyspace_meta._add_table_metadata( + self._build_table_metadata(table_row, + col_rows=col_rows, + virtual=True) + ) + yield keyspace_meta + + @staticmethod + def _build_keyspace_metadata_internal(row): + # necessary fields that aren't int virtual ks + row["durable_writes"] = row.get("durable_writes", None) + row["replication"] = row.get("replication", {}) + row["replication"]["class"] = row["replication"].get("class", None) + return super(SchemaParserV4, SchemaParserV4)._build_keyspace_metadata_internal(row) + + class TableMetadataV3(TableMetadata): compaction_options = {} option_maps = ['compaction', 'compression', 'caching'] @property def is_cql_compatible(self): return True @classmethod def _make_option_strings(cls, options_map): ret = [] options_copy = dict(options_map.items()) for option in cls.option_maps: value = options_copy.get(option) if isinstance(value, Mapping): del options_copy[option] params = ("'%s': '%s'" % (k, v) for k, v in value.items()) ret.append("%s = {%s}" % (option, ', '.join(params))) for name, value in options_copy.items(): if value is not None: if name == "comment": value = value or "" ret.append("%s = %s" % (name, protect_value(value))) return list(sorted(ret)) class MaterializedViewMetadata(object): """ A representation of a materialized view on a table """ keyspace_name = None """ A string name of the view.""" name = None """ A string name of the view.""" base_table_name = None """ A string name of the base table for this view.""" partition_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the partition key for this view. This will always hold at least one column. """ clustering_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the clustering key for this view. Note that a table may have no clustering keys, in which case this will be an empty list. """ columns = None """ A dict mapping column names to :class:`.ColumnMetadata` instances. """ include_all_columns = None """ A flag indicating whether the view was created AS SELECT * """ where_clause = None """ String WHERE clause for the view select statement. From server metadata """ options = None """ A dict mapping table option names to their specific settings for this view. """ extensions = None """ Metadata describing configuration for table extensions """ def __init__(self, keyspace_name, view_name, base_table_name, include_all_columns, where_clause, options): self.keyspace_name = keyspace_name self.name = view_name self.base_table_name = base_table_name self.partition_key = [] self.clustering_key = [] self.columns = OrderedDict() self.include_all_columns = include_all_columns self.where_clause = where_clause self.options = options or {} def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this function. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace_name) name = protect_name(self.name) selected_cols = '*' if self.include_all_columns else ', '.join(protect_name(col.name) for col in self.columns.values()) base_table = protect_name(self.base_table_name) where_clause = self.where_clause part_key = ', '.join(protect_name(col.name) for col in self.partition_key) if len(self.partition_key) > 1: pk = "((%s)" % part_key else: pk = "(%s" % part_key if self.clustering_key: pk += ", %s" % ', '.join(protect_name(col.name) for col in self.clustering_key) pk += ")" properties = TableMetadataV3._property_string(formatted, self.clustering_key, self.options) - ret = "CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" \ - "SELECT %(selected_cols)s%(sep)s" \ - "FROM %(keyspace)s.%(base_table)s%(sep)s" \ - "WHERE %(where_clause)s%(sep)s" \ - "PRIMARY KEY %(pk)s%(sep)s" \ - "WITH %(properties)s" % locals() + ret = ("CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" + "SELECT %(selected_cols)s%(sep)s" + "FROM %(keyspace)s.%(base_table)s%(sep)s" + "WHERE %(where_clause)s%(sep)s" + "PRIMARY KEY %(pk)s%(sep)s" + "WITH %(properties)s") % locals() if self.extensions: registry = _RegisteredExtensionType._extension_registry for k in six.viewkeys(registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey ext = registry[k] cql = ext.after_table_cql(self, k, self.extensions[k]) if cql: ret += "\n\n%s" % (cql,) return ret def export_as_string(self): return self.as_cql_query(formatted=True) + ";" def get_schema_parser(connection, server_version, timeout): server_major_version = int(server_version.split('.')[0]) + if server_major_version >= 4: + return SchemaParserV4(connection, timeout) if server_major_version >= 3: return SchemaParserV3(connection, timeout) else: # we could further specialize by version. Right now just refactoring the # multi-version parser we have as of C* 2.2.0rc1. return SchemaParserV22(connection, timeout) def _cql_from_cass_type(cass_type): """ A string representation of the type for this column, such as "varchar" or "map". """ if issubclass(cass_type, types.ReversedType): return cass_type.subtypes[0].cql_parameterized_type() else: return cass_type.cql_parameterized_type() NO_VALID_REPLICA = object() def group_keys_by_replica(session, keyspace, table, keys): """ Returns a :class:`dict` with the keys grouped per host. This can be used to more accurately group by IN clause or to batch the keys per host. If a valid replica is not found for a particular key it will be grouped under :class:`~.NO_VALID_REPLICA` Example usage:: result = group_keys_by_replica( session, "system", "peers", (("127.0.0.1", ), ("127.0.0.2", )) ) """ cluster = session.cluster partition_keys = cluster.metadata.keyspaces[keyspace].tables[table].partition_key serializers = list(types._cqltypes[partition_key.cql_type] for partition_key in partition_keys) keys_per_host = defaultdict(list) distance = cluster._default_load_balancing_policy.distance for key in keys: serialized_key = [serializer.serialize(pk, cluster.protocol_version) for serializer, pk in zip(serializers, key)] if len(serialized_key) == 1: routing_key = serialized_key[0] else: routing_key = b"".join(struct.pack(">H%dsB" % len(p), len(p), p, 0) for p in serialized_key) all_replicas = cluster.metadata.get_replicas(keyspace, routing_key) # First check if there are local replicas valid_replicas = [host for host in all_replicas if host.is_up and distance(host) == HostDistance.LOCAL] if not valid_replicas: valid_replicas = [host for host in all_replicas if host.is_up] if valid_replicas: keys_per_host[random.choice(valid_replicas)].append(key) else: # We will group under this statement all the keys for which # we haven't found a valid replica keys_per_host[NO_VALID_REPLICA].append(key) return dict(keys_per_host) diff --git a/cassandra/metrics.py b/cassandra/metrics.py index 473e527..223b0c7 100644 --- a/cassandra/metrics.py +++ b/cassandra/metrics.py @@ -1,201 +1,201 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from itertools import chain import logging try: from greplin import scales except ImportError: raise ImportError( "The scales library is required for metrics support: " - "https://pypi.python.org/pypi/scales") + "https://pypi.org/project/scales/") log = logging.getLogger(__name__) class Metrics(object): """ A collection of timers and counters for various performance metrics. Timer metrics are represented as floating point seconds. """ request_timer = None """ A :class:`greplin.scales.PmfStat` timer for requests. This is a dict-like object with the following keys: * count - number of requests that have been timed * min - min latency * max - max latency * mean - mean latency * stddev - standard deviation for latencies * median - median latency * 75percentile - 75th percentile latencies * 95percentile - 95th percentile latencies * 98percentile - 98th percentile latencies * 99percentile - 99th percentile latencies * 999percentile - 99.9th percentile latencies """ connection_errors = None """ A :class:`greplin.scales.IntStat` count of the number of times that a request to a Cassandra node has failed due to a connection problem. """ write_timeouts = None """ A :class:`greplin.scales.IntStat` count of write requests that resulted in a timeout. """ read_timeouts = None """ A :class:`greplin.scales.IntStat` count of read requests that resulted in a timeout. """ unavailables = None """ A :class:`greplin.scales.IntStat` count of write or read requests that failed due to an insufficient number of replicas being alive to meet the requested :class:`.ConsistencyLevel`. """ other_errors = None """ A :class:`greplin.scales.IntStat` count of all other request failures, including failures caused by invalid requests, bootstrapping nodes, overloaded nodes, etc. """ retries = None """ A :class:`greplin.scales.IntStat` count of the number of times a request was retried based on the :class:`.RetryPolicy` decision. """ ignores = None """ A :class:`greplin.scales.IntStat` count of the number of times a failed request was ignored based on the :class:`.RetryPolicy` decision. """ known_hosts = None """ A :class:`greplin.scales.IntStat` count of the number of nodes in the cluster that the driver is aware of, regardless of whether any connections are opened to those nodes. """ connected_to = None """ A :class:`greplin.scales.IntStat` count of the number of nodes that the driver currently has at least one connection open to. """ open_connections = None """ A :class:`greplin.scales.IntStat` count of the number connections the driver currently has open. """ _stats_counter = 0 def __init__(self, cluster_proxy): log.debug("Starting metric capture") self.stats_name = 'cassandra-{0}'.format(str(self._stats_counter)) Metrics._stats_counter += 1 self.stats = scales.collection(self.stats_name, scales.PmfStat('request_timer'), scales.IntStat('connection_errors'), scales.IntStat('write_timeouts'), scales.IntStat('read_timeouts'), scales.IntStat('unavailables'), scales.IntStat('other_errors'), scales.IntStat('retries'), scales.IntStat('ignores'), # gauges scales.Stat('known_hosts', lambda: len(cluster_proxy.metadata.all_hosts())), scales.Stat('connected_to', lambda: len(set(chain.from_iterable(s._pools.keys() for s in cluster_proxy.sessions)))), scales.Stat('open_connections', lambda: sum(sum(p.open_count for p in s._pools.values()) for s in cluster_proxy.sessions))) # TODO, to be removed in 4.0 # /cassandra contains the metrics of the first cluster registered if 'cassandra' not in scales._Stats.stats: scales._Stats.stats['cassandra'] = scales._Stats.stats[self.stats_name] self.request_timer = self.stats.request_timer self.connection_errors = self.stats.connection_errors self.write_timeouts = self.stats.write_timeouts self.read_timeouts = self.stats.read_timeouts self.unavailables = self.stats.unavailables self.other_errors = self.stats.other_errors self.retries = self.stats.retries self.ignores = self.stats.ignores self.known_hosts = self.stats.known_hosts self.connected_to = self.stats.connected_to self.open_connections = self.stats.open_connections def on_connection_error(self): self.stats.connection_errors += 1 def on_write_timeout(self): self.stats.write_timeouts += 1 def on_read_timeout(self): self.stats.read_timeouts += 1 def on_unavailable(self): self.stats.unavailables += 1 def on_other_error(self): self.stats.other_errors += 1 def on_ignore(self): self.stats.ignores += 1 def on_retry(self): self.stats.retries += 1 def get_stats(self): """ Returns the metrics for the registered cluster instance. """ return scales.getStats()[self.stats_name] def set_stats_name(self, stats_name): """ Set the metrics stats name. The stats_name is a string used to access the metris through scales: scales.getStats()[] Default is 'cassandra-'. """ if self.stats_name == stats_name: return if stats_name in scales._Stats.stats: raise ValueError('"{0}" already exists in stats.'.format(stats_name)) stats = scales._Stats.stats[self.stats_name] del scales._Stats.stats[self.stats_name] self.stats_name = stats_name scales._Stats.stats[self.stats_name] = stats diff --git a/cassandra/query.py b/cassandra/query.py index 56b470d..b2193d6 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -1,1041 +1,1089 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module holds classes for working with prepared statements and specifying consistency levels and retry policies for individual queries. """ from collections import namedtuple from datetime import datetime, timedelta import re import struct import time import six from six.moves import range, zip +import warnings from cassandra import ConsistencyLevel, OperationTimedOut from cassandra.util import unix_time_from_uuid1 from cassandra.encoder import Encoder import cassandra.encoder from cassandra.protocol import _UNSET_VALUE from cassandra.util import OrderedDict, _sanitize_identifiers import logging log = logging.getLogger(__name__) UNSET_VALUE = _UNSET_VALUE """ Specifies an unset value when binding a prepared statement. Unset values are ignored, allowing prepared statements to be used without specify See https://issues.apache.org/jira/browse/CASSANDRA-7304 for further details on semantics. .. versionadded:: 2.6.0 Only valid when using native protocol v4+ """ NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') _clean_name_cache = {} def _clean_column_name(name): try: return _clean_name_cache[name] except KeyError: clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))) _clean_name_cache[name] = clean return clean def tuple_factory(colnames, rows): """ Returns each row as a tuple Example:: >>> from cassandra.query import tuple_factory >>> session = cluster.connect('mykeyspace') >>> session.row_factory = tuple_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") >>> print rows[0] ('Bob', 42) .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ return rows +class PseudoNamedTupleRow(object): + """ + Helper class for pseudo_named_tuple_factory. These objects provide an + __iter__ interface, as well as index- and attribute-based access to values, + but otherwise do not attempt to implement the full namedtuple or iterable + interface. + """ + def __init__(self, ordered_dict): + self._dict = ordered_dict + self._tuple = tuple(ordered_dict.values()) + + def __getattr__(self, name): + return self._dict[name] + + def __getitem__(self, idx): + return self._tuple[idx] + + def __iter__(self): + return iter(self._tuple) + + def __repr__(self): + return '{t}({od})'.format(t=self.__class__.__name__, + od=self._dict) + + +def pseudo_namedtuple_factory(colnames, rows): + """ + Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback + factory for cases where :meth:`.named_tuple_factory` fails to create rows. + """ + return [PseudoNamedTupleRow(od) + for od in ordered_dict_factory(colnames, rows)] + def named_tuple_factory(colnames, rows): """ Returns each row as a `namedtuple `_. This is the default row factory. Example:: >>> from cassandra.query import named_tuple_factory >>> session = cluster.connect('mykeyspace') >>> session.row_factory = named_tuple_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") >>> user = rows[0] >>> # you can access field by their name: >>> print "name: %s, age: %d" % (user.name, user.age) name: Bob, age: 42 >>> # or you can access fields by their position (like a tuple) >>> name, age = user >>> print "name: %s, age: %d" % (name, age) name: Bob, age: 42 >>> name = user[0] >>> age = user[1] >>> print "name: %s, age: %d" % (name, age) name: Bob, age: 42 .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ clean_column_names = map(_clean_column_name, colnames) try: Row = namedtuple('Row', clean_column_names) + except SyntaxError: + warnings.warn( + "Failed creating namedtuple for a result because there were too " + "many columns. This is due to a Python limitation that affects " + "namedtuple in Python 3.0-3.6 (see issue18896). The row will be " + "created with {substitute_factory_name}, which lacks some namedtuple " + "features and is slower. To avoid slower performance accessing " + "values on row objects, Upgrade to Python 3.7, or use a different " + "row factory. (column names: {colnames})".format( + substitute_factory_name=pseudo_namedtuple_factory.__name__, + colnames=colnames + ) + ) + return pseudo_namedtuple_factory(colnames, rows) except Exception: clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " "(see Python 'namedtuple' documentation for details on name rules). " "Results will be returned with positional names. " "Avoid this by choosing different names, using SELECT \"\" AS aliases, " "or specifying a different row_factory on your Session" % (colnames, clean_column_names)) Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) return [Row(*row) for row in rows] def dict_factory(colnames, rows): """ Returns each row as a dict. Example:: >>> from cassandra.query import dict_factory >>> session = cluster.connect('mykeyspace') >>> session.row_factory = dict_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") >>> print rows[0] {u'age': 42, u'name': u'Bob'} .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ return [dict(zip(colnames, row)) for row in rows] def ordered_dict_factory(colnames, rows): """ Like :meth:`~cassandra.query.dict_factory`, but returns each row as an OrderedDict, so the order of the columns is preserved. .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ return [OrderedDict(zip(colnames, row)) for row in rows] FETCH_SIZE_UNSET = object() class Statement(object): """ An abstract class representing a single query. There are three subclasses: :class:`.SimpleStatement`, :class:`.BoundStatement`, and :class:`.BatchStatement`. These can be passed to :meth:`.Session.execute()`. """ retry_policy = None """ An instance of a :class:`cassandra.policies.RetryPolicy` or one of its subclasses. This controls when a query will be retried and how it will be retried. """ consistency_level = None """ The :class:`.ConsistencyLevel` to be used for this operation. Defaults to :const:`None`, which means that the default consistency level for the Session this is executed in will be used. """ fetch_size = FETCH_SIZE_UNSET """ How many rows will be fetched at a time. This overrides the default of :attr:`.Session.default_fetch_size` This only takes effect when protocol version 2 or higher is used. See :attr:`.Cluster.protocol_version` for details. .. versionadded:: 2.0.0 """ keyspace = None """ The string name of the keyspace this query acts on. This is used when :class:`~.TokenAwarePolicy` is configured for :attr:`.Cluster.load_balancing_policy` It is set implicitly on :class:`.BoundStatement`, and :class:`.BatchStatement`, but must be set explicitly on :class:`.SimpleStatement`. .. versionadded:: 2.1.3 """ custom_payload = None """ :ref:`custom_payload` to be passed to the server. These are only allowed when using protocol version 4 or higher. .. versionadded:: 2.6.0 """ is_idempotent = False """ Flag indicating whether this statement is safe to run multiple times in speculative execution. """ _serial_consistency_level = None _routing_key = None def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, is_idempotent=False): if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') if retry_policy is not None: self.retry_policy = retry_policy if consistency_level is not None: self.consistency_level = consistency_level self._routing_key = routing_key if serial_consistency_level is not None: self.serial_consistency_level = serial_consistency_level if fetch_size is not FETCH_SIZE_UNSET: self.fetch_size = fetch_size if keyspace is not None: self.keyspace = keyspace if custom_payload is not None: self.custom_payload = custom_payload self.is_idempotent = is_idempotent def _key_parts_packed(self, parts): for p in parts: l = len(p) yield struct.pack(">H%dsB" % l, l, p, 0) def _get_routing_key(self): return self._routing_key def _set_routing_key(self, key): if isinstance(key, (list, tuple)): if len(key) == 1: self._routing_key = key[0] else: self._routing_key = b"".join(self._key_parts_packed(key)) else: self._routing_key = key def _del_routing_key(self): self._routing_key = None routing_key = property( _get_routing_key, _set_routing_key, _del_routing_key, """ The :attr:`~.TableMetadata.partition_key` portion of the primary key, which can be used to determine which nodes are replicas for the query. If the partition key is a composite, a list or tuple must be passed in. Each key component should be in its packed (binary) format, so all components should be strings. """) def _get_serial_consistency_level(self): return self._serial_consistency_level def _set_serial_consistency_level(self, serial_consistency_level): acceptable = (None, ConsistencyLevel.SERIAL, ConsistencyLevel.LOCAL_SERIAL) if serial_consistency_level not in acceptable: raise ValueError( "serial_consistency_level must be either ConsistencyLevel.SERIAL " "or ConsistencyLevel.LOCAL_SERIAL") self._serial_consistency_level = serial_consistency_level def _del_serial_consistency_level(self): self._serial_consistency_level = None serial_consistency_level = property( _get_serial_consistency_level, _set_serial_consistency_level, _del_serial_consistency_level, """ The serial consistency level is only used by conditional updates (``INSERT``, ``UPDATE`` and ``DELETE`` with an ``IF`` condition). For those, the ``serial_consistency_level`` defines the consistency level of the serial phase (or "paxos" phase) while the normal :attr:`~.consistency_level` defines the consistency for the "learn" phase, i.e. what type of reads will be guaranteed to see the update right away. For example, if a conditional write has a :attr:`~.consistency_level` of :attr:`~.ConsistencyLevel.QUORUM` (and is successful), then a :attr:`~.ConsistencyLevel.QUORUM` read is guaranteed to see that write. But if the regular :attr:`~.consistency_level` of that write is :attr:`~.ConsistencyLevel.ANY`, then only a read with a :attr:`~.consistency_level` of :attr:`~.ConsistencyLevel.SERIAL` is guaranteed to see it (even a read with consistency :attr:`~.ConsistencyLevel.ALL` is not guaranteed to be enough). The serial consistency can only be one of :attr:`~.ConsistencyLevel.SERIAL` or :attr:`~.ConsistencyLevel.LOCAL_SERIAL`. While ``SERIAL`` guarantees full linearizability (with other ``SERIAL`` updates), ``LOCAL_SERIAL`` only guarantees it in the local data center. The serial consistency level is ignored for any query that is not a conditional update. Serial reads should use the regular :attr:`consistency_level`. Serial consistency levels may only be used against Cassandra 2.0+ and the :attr:`~.Cluster.protocol_version` must be set to 2 or higher. See :doc:`/lwt` for a discussion on how to work with results returned from conditional statements. .. versionadded:: 2.0.0 """) class SimpleStatement(Statement): """ A simple, un-prepared query. """ def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, is_idempotent=False): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the `parameters` argument of :meth:`.Session.execute()`. See :class:`Statement` attributes for a description of the other parameters. """ Statement.__init__(self, retry_policy, consistency_level, routing_key, serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) self._query_string = query_string @property def query_string(self): return self._query_string def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.query_string, consistency)) __repr__ = __str__ class PreparedStatement(object): """ A statement that has been prepared against at least one Cassandra node. Instances of this class should not be created directly, but through :meth:`.Session.prepare()`. A :class:`.PreparedStatement` should be prepared only once. Re-preparing a statement may affect performance (as the operation requires a network roundtrip). |prepared_stmt_head|: Do not use ``*`` in prepared statements if you might change the schema of the table being queried. The driver and server each maintain a map between metadata for a schema and statements that were prepared against that schema. When a user changes a schema, e.g. by adding or removing a column, the server invalidates its mappings involving that schema. However, there is currently no way to propagate that invalidation to drivers. Thus, after a schema change, the driver will incorrectly interpret the results of ``SELECT *`` queries prepared before the schema change. This is currently being addressed in `CASSANDRA-10786 `_. .. |prepared_stmt_head| raw:: html A note about * in prepared statements """ column_metadata = None #TODO: make this bind_metadata in next major retry_policy = None consistency_level = None custom_payload = None fetch_size = FETCH_SIZE_UNSET keyspace = None # change to prepared_keyspace in major release protocol_version = None query_id = None query_string = None result_metadata = None result_metadata_id = None routing_key_indexes = None _routing_key_index_set = None serial_consistency_level = None def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace, protocol_version, result_metadata, result_metadata_id): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes self.query_string = query self.keyspace = keyspace self.protocol_version = protocol_version self.result_metadata = result_metadata self.result_metadata_id = result_metadata_id self.is_idempotent = False @classmethod def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version, result_metadata, result_metadata_id): if not column_metadata: return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version, result_metadata, result_metadata_id) if pk_indexes: routing_key_indexes = pk_indexes else: routing_key_indexes = None first_col = column_metadata[0] ks_meta = cluster_metadata.keyspaces.get(first_col.keyspace_name) if ks_meta: table_meta = ks_meta.tables.get(first_col.table_name) if table_meta: partition_key_columns = table_meta.partition_key # make a map of {column_name: index} for each column in the statement statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata)) # a list of which indexes in the statement correspond to partition key items try: routing_key_indexes = [statement_indexes[c.name] for c in partition_key_columns] except KeyError: # we're missing a partition key component in the prepared pass # statement; just leave routing_key_indexes as None return PreparedStatement(column_metadata, query_id, routing_key_indexes, query, prepared_keyspace, protocol_version, result_metadata, result_metadata_id) def bind(self, values): """ Creates and returns a :class:`BoundStatement` instance using `values`. See :meth:`BoundStatement.bind` for rules on input ``values``. """ return BoundStatement(self).bind(values) def is_routing_key_index(self, i): if self._routing_key_index_set is None: self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set() return i in self._routing_key_index_set def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.query_string, consistency)) __repr__ = __str__ class BoundStatement(Statement): """ A prepared statement that has been bound to a particular set of values. These may be created directly or through :meth:`.PreparedStatement.bind()`. """ prepared_statement = None """ The :class:`PreparedStatement` instance that this was created from. """ values = None """ The sequence of values that were bound to the prepared statement. """ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. See :class:`Statement` attributes for a description of the other parameters. """ self.prepared_statement = prepared_statement self.retry_policy = prepared_statement.retry_policy self.consistency_level = prepared_statement.consistency_level self.serial_consistency_level = prepared_statement.serial_consistency_level self.fetch_size = prepared_statement.fetch_size self.custom_payload = prepared_statement.custom_payload self.is_idempotent = prepared_statement.is_idempotent self.values = [] meta = prepared_statement.column_metadata if meta: self.keyspace = meta[0].keyspace_name Statement.__init__(self, retry_policy, consistency_level, routing_key, serial_consistency_level, fetch_size, keyspace, custom_payload, prepared_statement.is_idempotent) def bind(self, values): """ Binds a sequence of values for the prepared statement parameters and returns this instance. Note that `values` *must* be: * a sequence, even if you are only binding one value, or * a dict that relates 1-to-1 between dict keys and columns .. versionchanged:: 2.6.0 :data:`~.UNSET_VALUE` was introduced. These can be bound as positional parameters in a sequence, or by name in a dict. Additionally, when using protocol v4+: * short sequences will be extended to match bind parameters with UNSET_VALUE * names may be omitted from a dict with UNSET_VALUE implied. .. versionchanged:: 3.0.0 method will not throw if extra keys are present in bound dict (PYTHON-178) """ if values is None: values = () proto_version = self.prepared_statement.protocol_version col_meta = self.prepared_statement.column_metadata # special case for binding dicts if isinstance(values, dict): values_dict = values values = [] # sort values accordingly for col in col_meta: try: values.append(values_dict[col.name]) except KeyError: if proto_version >= 4: values.append(UNSET_VALUE) else: raise KeyError( 'Column name `%s` not found in bound dict.' % (col.name)) value_len = len(values) col_meta_len = len(col_meta) if value_len > col_meta_len: raise ValueError( "Too many arguments provided to bind() (got %d, expected %d)" % (len(values), len(col_meta))) # this is fail-fast for clarity pre-v4. When v4 can be assumed, # the error will be better reported when UNSET_VALUE is implicitly added. if proto_version < 4 and self.prepared_statement.routing_key_indexes and \ value_len < len(self.prepared_statement.routing_key_indexes): raise ValueError( "Too few arguments provided to bind() (got %d, required %d for routing key)" % (value_len, len(self.prepared_statement.routing_key_indexes))) self.raw_values = values self.values = [] for value, col_spec in zip(values, col_meta): if value is None: self.values.append(None) elif value is UNSET_VALUE: if proto_version >= 4: self._append_unset_value() else: raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) else: try: self.values.append(col_spec.type.serialize(value, proto_version)) except (TypeError, struct.error) as exc: actual_type = type(value) message = ('Received an argument of invalid type for column "%s". ' 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) raise TypeError(message) if proto_version >= 4: diff = col_meta_len - len(self.values) if diff: for _ in range(diff): self._append_unset_value() return self def _append_unset_value(self): next_index = len(self.values) if self.prepared_statement.is_routing_key_index(next_index): col_meta = self.prepared_statement.column_metadata[next_index] raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name) self.values.append(UNSET_VALUE) @property def routing_key(self): if not self.prepared_statement.routing_key_indexes: return None if self._routing_key is not None: return self._routing_key routing_indexes = self.prepared_statement.routing_key_indexes if len(routing_indexes) == 1: self._routing_key = self.values[routing_indexes[0]] else: self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes)) return self._routing_key def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.prepared_statement.query_string, self.raw_values, consistency)) __repr__ = __str__ class BatchType(object): """ A BatchType is used with :class:`.BatchStatement` instances to control the atomicity of the batch operation. .. versionadded:: 2.0.0 """ LOGGED = None """ Atomic batch operation. """ UNLOGGED = None """ Non-atomic batch operation. """ COUNTER = None """ Batches of counter operations. """ def __init__(self, name, value): self.name = name self.value = value def __str__(self): return self.name def __repr__(self): return "BatchType.%s" % (self.name, ) BatchType.LOGGED = BatchType("LOGGED", 0) BatchType.UNLOGGED = BatchType("UNLOGGED", 1) BatchType.COUNTER = BatchType("COUNTER", 2) class BatchStatement(Statement): """ A protocol-level batch of operations which are applied atomically by default. .. versionadded:: 2.0.0 """ batch_type = None """ The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. """ serial_consistency_level = None """ The same as :attr:`.Statement.serial_consistency_level`, but is only supported when using protocol version 3 or higher. """ _statements_and_parameters = None _session = None def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None, serial_consistency_level=None, session=None, custom_payload=None): """ `batch_type` specifies The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. `retry_policy` should be a :class:`~.RetryPolicy` instance for controlling retries on the operation. `consistency_level` should be a :class:`~.ConsistencyLevel` value to be used for all operations in the batch. `custom_payload` is a :ref:`custom_payload` passed to the server. Note: as Statement objects are added to the batch, this map is updated with any values found in their custom payloads. These are only allowed when using protocol version 4 or higher. Example usage: .. code-block:: python insert_user = session.prepare("INSERT INTO users (name, age) VALUES (?, ?)") batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM) for (name, age) in users_to_insert: batch.add(insert_user, (name, age)) session.execute(batch) You can also mix different types of operations within a batch: .. code-block:: python batch = BatchStatement() batch.add(SimpleStatement("INSERT INTO users (name, age) VALUES (%s, %s)"), (name, age)) batch.add(SimpleStatement("DELETE FROM pending_users WHERE name=%s"), (name,)) session.execute(batch) .. versionadded:: 2.0.0 .. versionchanged:: 2.1.0 Added `serial_consistency_level` as a parameter .. versionchanged:: 2.6.0 Added `custom_payload` as a parameter """ self.batch_type = batch_type self._statements_and_parameters = [] self._session = session Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) def clear(self): """ This is a convenience method to clear a batch statement for reuse. *Note:* it should not be used concurrently with uncompleted execution futures executing the same ``BatchStatement``. """ del self._statements_and_parameters[:] self.keyspace = None self.routing_key = None if self.custom_payload: self.custom_payload.clear() def add(self, statement, parameters=None): """ Adds a :class:`.Statement` and optional sequence of parameters to be used with the statement to the batch. Like with other statements, parameters must be a sequence, even if there is only one item. """ if isinstance(statement, six.string_types): if parameters: encoder = Encoder() if self._session is None else self._session.encoder statement = bind_params(statement, parameters, encoder) self._add_statement_and_params(False, statement, ()) elif isinstance(statement, PreparedStatement): query_id = statement.query_id bound_statement = statement.bind(() if parameters is None else parameters) self._update_state(bound_statement) self._add_statement_and_params(True, query_id, bound_statement.values) elif isinstance(statement, BoundStatement): if parameters: raise ValueError( "Parameters cannot be passed with a BoundStatement " "to BatchStatement.add()") self._update_state(statement) self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) else: # it must be a SimpleStatement query_string = statement.query_string if parameters: encoder = Encoder() if self._session is None else self._session.encoder query_string = bind_params(query_string, parameters, encoder) self._update_state(statement) self._add_statement_and_params(False, query_string, ()) return self def add_all(self, statements, parameters): """ Adds a sequence of :class:`.Statement` objects and a matching sequence of parameters to the batch. Statement and parameter sequences must be of equal length or one will be truncated. :const:`None` can be used in the parameters position where are needed. """ for statement, value in zip(statements, parameters): self.add(statement, value) def _add_statement_and_params(self, is_prepared, statement, parameters): if len(self._statements_and_parameters) >= 0xFFFF: raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) self._statements_and_parameters.append((is_prepared, statement, parameters)) def _maybe_set_routing_attributes(self, statement): if self.routing_key is None: if statement.keyspace and statement.routing_key: self.routing_key = statement.routing_key self.keyspace = statement.keyspace def _update_custom_payload(self, statement): if statement.custom_payload: if self.custom_payload is None: self.custom_payload = {} self.custom_payload.update(statement.custom_payload) def _update_state(self, statement): self._maybe_set_routing_attributes(statement) self._update_custom_payload(statement) def __len__(self): return len(self._statements_and_parameters) def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.batch_type, len(self), consistency)) __repr__ = __str__ ValueSequence = cassandra.encoder.ValueSequence """ A wrapper class that is used to specify that a sequence of values should be treated as a CQL list of values instead of a single column collection when used as part of the `parameters` argument for :meth:`.Session.execute()`. This is typically needed when supplying a list of keys to select. For example:: >>> my_user_ids = ('alice', 'bob', 'charles') >>> query = "SELECT * FROM users WHERE user_id IN %s" >>> session.execute(query, parameters=[ValueSequence(my_user_ids)]) """ def bind_params(query, params, encoder): if six.PY2 and isinstance(query, six.text_type): query = query.encode('utf-8') if isinstance(params, dict): return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params)) else: return query % tuple(encoder.cql_encode_all_types(v) for v in params) class TraceUnavailable(Exception): """ Raised when complete trace details cannot be fetched from Cassandra. """ pass class QueryTrace(object): """ A trace of the duration and events that occurred when executing an operation. """ trace_id = None """ :class:`uuid.UUID` unique identifier for this tracing session. Matches the ``session_id`` column in ``system_traces.sessions`` and ``system_traces.events``. """ request_type = None """ A string that very generally describes the traced operation. """ duration = None """ A :class:`datetime.timedelta` measure of the duration of the query. """ client = None """ The IP address of the client that issued this request This is only available when using Cassandra 2.2+ """ coordinator = None """ The IP address of the host that acted as coordinator for this request. """ parameters = None """ A :class:`dict` of parameters for the traced operation, such as the specific query string. """ started_at = None """ A UTC :class:`datetime.datetime` object describing when the operation was started. """ events = None """ A chronologically sorted list of :class:`.TraceEvent` instances representing the steps the traced operation went through. This corresponds to the rows in ``system_traces.events`` for this tracing session. """ _session = None _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s" _SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s" _BASE_RETRY_SLEEP = 0.003 def __init__(self, trace_id, session): self.trace_id = trace_id self._session = session def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): """ Retrieves the actual tracing details from Cassandra and populates the attributes of this instance. Because tracing details are stored asynchronously by Cassandra, this may need to retry the session detail fetch. If the trace is still not available after `max_wait` seconds, :exc:`.TraceUnavailable` will be raised; if `max_wait` is :const:`None`, this will retry forever. `wait_for_complete=False` bypasses the wait for duration to be populated. This can be used to query events from partial sessions. `query_cl` specifies a consistency level to use for polling the trace tables, if it should be different than the session default. """ attempt = 0 start = time.time() while True: time_spent = time.time() - start if max_wait is not None and time_spent >= max_wait: raise TraceUnavailable( "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,)) log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) session_results = self._execute( SimpleStatement(self._SELECT_SESSIONS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) # PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries is_complete = session_results and session_results[0].duration is not None and session_results[0].started_at is not None if not session_results or (wait_for_complete and not is_complete): time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) attempt += 1 continue if is_complete: log.debug("Fetched trace info for trace ID: %s", self.trace_id) else: log.debug("Fetching parital trace info for trace ID: %s", self.trace_id) session_row = session_results[0] self.request_type = session_row.request self.duration = timedelta(microseconds=session_row.duration) if is_complete else None self.started_at = session_row.started_at self.coordinator = session_row.coordinator self.parameters = session_row.parameters # since C* 2.2 self.client = getattr(session_row, 'client', None) log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) time_spent = time.time() - start event_results = self._execute( SimpleStatement(self._SELECT_EVENTS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) log.debug("Fetched trace events for trace ID: %s", self.trace_id) self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) for r in event_results) break def _execute(self, query, parameters, time_spent, max_wait): timeout = (max_wait - time_spent) if max_wait is not None else None future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) # in case the user switched the row factory, set it to namedtuple for this query future.row_factory = named_tuple_factory future.send_request() try: return future.result() except OperationTimedOut: raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) def __str__(self): return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ % (self.request_type, self.trace_id, self.coordinator, self.started_at, self.duration, self.parameters) class TraceEvent(object): """ Representation of a single event within a query trace. """ description = None """ A brief description of the event. """ datetime = None """ A UTC :class:`datetime.datetime` marking when the event occurred. """ source = None """ The IP address of the node this event occurred on. """ source_elapsed = None """ A :class:`datetime.timedelta` measuring the amount of time until this event occurred starting from when :attr:`.source` first received the query. """ thread_name = None """ The name of the thread that this event occurred on. """ def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.description = description self.datetime = datetime.utcfromtimestamp(unix_time_from_uuid1(timeuuid)) self.source = source if source_elapsed is not None: self.source_elapsed = timedelta(microseconds=source_elapsed) else: self.source_elapsed = None self.thread_name = thread_name def __str__(self): return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) diff --git a/cassandra_driver.egg-info/PKG-INFO b/cassandra_driver.egg-info/PKG-INFO index 0c04873..11e02d9 100644 --- a/cassandra_driver.egg-info/PKG-INFO +++ b/cassandra_driver.egg-info/PKG-INFO @@ -1,113 +1,112 @@ Metadata-Version: 1.1 Name: cassandra-driver -Version: 3.14.0 +Version: 3.16.0 Summary: Python driver for Cassandra Home-page: http://github.com/datastax/python-driver -Author: Tyler Hobbs -Author-email: tyler@datastax.com +Author:: Tyler Hobbs +Author-email:: tyler@datastax.com License: UNKNOWN Description: DataStax Python Driver for Apache Cassandra =========================================== .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master :target: https://travis-ci.org/datastax/python-driver A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. - The driver supports Python 2.7, 3.3, 3.4, 3.5, and 3.6. + The driver supports Python 2.7, 3.4, 3.5, and 3.6. If you require compatibility with DataStax Enterprise, use the `DataStax Enterprise Python Driver `_. **Note:** DataStax products do not support big-endian systems. Feedback Requested ------------------ **Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short). Features -------- * `Synchronous `_ and `Asynchronous `_ APIs * `Simple, Prepared, and Batch statements `_ * Asynchronous IO, parallel execution, request pipelining * `Connection pooling `_ * Automatic node discovery * `Automatic reconnection `_ * Configurable `load balancing `_ and `retry policies `_ * `Concurrent execution utilities `_ * `Object mapper `_ Installation ------------ Installation through pip is recommended:: $ pip install cassandra-driver For more complete installation instructions, see the `installation guide `_. Documentation ------------- The documentation can be found online `here `_. A couple of links for getting up to speed: * `Installation `_ * `Getting started guide `_ * `API docs `_ * `Performance tips `_ Object Mapper ------------- cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the community) is now maintained as an integral part of this package. Refer to `documentation here `_. Contributing ------------ See `CONTRIBUTING.md `_. Reporting Problems ------------------ Please report any bugs and make any feature requests on the `JIRA `_ issue tracker. If you would like to contribute, please feel free to open a pull request. Getting Help ------------ Your best options for getting help with the driver are the `mailing list `_ and the ``#datastax-drivers`` channel in the `DataStax Academy Slack `_. License ------- Copyright 2013-2017 DataStax Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Keywords: cassandra,cql,orm Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Natural Language :: English Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2.7 -Classifier: Programming Language :: Python :: 3.3 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Software Development :: Libraries :: Python Modules diff --git a/setup.py b/setup.py index a49bb1e..1b0ebf6 100644 --- a/setup.py +++ b/setup.py @@ -1,447 +1,446 @@ # Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import os import sys import warnings if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests": print("Running gevent tests") from gevent.monkey import patch_all patch_all() if __name__ == '__main__' and sys.argv[1] == "eventlet_nosetests": print("Running eventlet tests") from eventlet import monkey_patch monkey_patch() import ez_setup ez_setup.use_setuptools() from setuptools import setup from distutils.command.build_ext import build_ext from distutils.core import Extension from distutils.errors import (CCompilerError, DistutilsPlatformError, DistutilsExecError) from distutils.cmd import Command PY3 = sys.version_info[0] == 3 try: import subprocess has_subprocess = True except ImportError: has_subprocess = False from cassandra import __version__ long_description = "" with open("README.rst") as f: long_description = f.read() try: from nose.commands import nosetests except ImportError: gevent_nosetests = None eventlet_nosetests = None else: class gevent_nosetests(nosetests): description = "run nosetests with gevent monkey patching" class eventlet_nosetests(nosetests): description = "run nosetests with eventlet monkey patching" has_cqlengine = False if __name__ == '__main__' and sys.argv[1] == "install": try: import cqlengine has_cqlengine = True except ImportError: pass PROFILING = False class DocCommand(Command): description = "generate or test documentation" user_options = [("test", "t", "run doctests instead of generating documentation")] boolean_options = ["test"] def initialize_options(self): self.test = False def finalize_options(self): pass def run(self): if self.test: path = "docs/_build/doctest" mode = "doctest" else: path = "docs/_build/%s" % __version__ mode = "html" try: os.makedirs(path) except: pass if has_subprocess: # Prevent run with in-place extensions because cython-generated objects do not carry docstrings # http://docs.cython.org/src/userguide/special_methods.html#docstrings import glob for f in glob.glob("cassandra/*.so"): print("Removing '%s' to allow docs to run on pure python modules." %(f,)) os.unlink(f) # Build io extension to make import and docstrings work try: output = subprocess.check_output( ["python", "setup.py", "build_ext", "--inplace", "--force", "--no-murmur3", "--no-cython"], stderr=subprocess.STDOUT) except subprocess.CalledProcessError as exc: raise RuntimeError("Documentation step '%s' failed: %s: %s" % ("build_ext", exc, exc.output)) else: print(output) try: output = subprocess.check_output( ["sphinx-build", "-b", mode, "docs", path], stderr=subprocess.STDOUT) except subprocess.CalledProcessError as exc: raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output)) else: print(output) print("") print("Documentation step '%s' performed, results here:" % mode) print(" file://%s/%s/index.html" % (os.path.dirname(os.path.realpath(__file__)), path)) class BuildFailed(Exception): def __init__(self, ext): self.ext = ext murmur3_ext = Extension('cassandra.cmurmur3', sources=['cassandra/cmurmur3.c']) libev_ext = Extension('cassandra.io.libevwrapper', sources=['cassandra/io/libevwrapper.c'], include_dirs=['/usr/include/libev', '/usr/local/include', '/opt/local/include'], libraries=['ev'], library_dirs=['/usr/local/lib', '/opt/local/lib']) platform_unsupported_msg = \ """ =============================================================================== The optional C extensions are not supported on this platform. =============================================================================== """ arch_unsupported_msg = \ """ =============================================================================== The optional C extensions are not supported on big-endian systems. =============================================================================== """ pypy_unsupported_msg = \ """ ================================================================================= Some optional C extensions are not supported in PyPy. Only murmur3 will be built. ================================================================================= """ is_windows = os.name == 'nt' is_pypy = "PyPy" in sys.version if is_pypy: sys.stderr.write(pypy_unsupported_msg) is_supported_platform = sys.platform != "cli" and not sys.platform.startswith("java") is_supported_arch = sys.byteorder != "big" if not is_supported_platform: sys.stderr.write(platform_unsupported_msg) elif not is_supported_arch: sys.stderr.write(arch_unsupported_msg) try_extensions = "--no-extensions" not in sys.argv and is_supported_platform and is_supported_arch and not os.environ.get('CASS_DRIVER_NO_EXTENSIONS') try_murmur3 = try_extensions and "--no-murmur3" not in sys.argv try_libev = try_extensions and "--no-libev" not in sys.argv and not is_pypy and not is_windows try_cython = try_extensions and "--no-cython" not in sys.argv and not is_pypy and not os.environ.get('CASS_DRIVER_NO_CYTHON') try_cython &= 'egg_info' not in sys.argv # bypass setup_requires for pip egg_info calls, which will never have --install-option"--no-cython" coming fomr pip sys.argv = [a for a in sys.argv if a not in ("--no-murmur3", "--no-libev", "--no-cython", "--no-extensions")] build_concurrency = int(os.environ.get('CASS_DRIVER_BUILD_CONCURRENCY', '0')) class NoPatchExtension(Extension): # Older versions of setuptools.extension has a static flag which is set False before our # setup_requires lands Cython. It causes our *.pyx sources to be renamed to *.c in # the initializer. # The other workaround would be to manually generate sources, but that bypasses a lot # of the niceness cythonize embodies (setup build dir, conditional build, etc). # Newer setuptools does not have this problem because it checks for cython dynamically. # https://bitbucket.org/pypa/setuptools/commits/714c3144e08fd01a9f61d1c88411e76d2538b2e4 def __init__(self, *args, **kwargs): # bypass the patched init if possible if Extension.__bases__: base, = Extension.__bases__ base.__init__(self, *args, **kwargs) else: Extension.__init__(self, *args, **kwargs) class build_extensions(build_ext): error_message = """ =============================================================================== WARNING: could not compile %s. The C extensions are not required for the driver to run, but they add support for token-aware routing with the Murmur3Partitioner. On Windows, make sure Visual Studio or an SDK is installed, and your environment is configured to build for the appropriate architecture (matching your Python runtime). This is often a matter of using vcvarsall.bat from your install directory, or running from a command prompt in the Visual Studio Tools Start Menu. =============================================================================== """ if is_windows else """ =============================================================================== WARNING: could not compile %s. The C extensions are not required for the driver to run, but they add support for libev and token-aware routing with the Murmur3Partitioner. Linux users should ensure that GCC and the Python headers are available. On Ubuntu and Debian, this can be accomplished by running: $ sudo apt-get install build-essential python-dev On RedHat and RedHat-based systems like CentOS and Fedora: $ sudo yum install gcc python-devel On OSX, homebrew installations of Python should provide the necessary headers. libev Support ------------- For libev support, you will also need to install libev and its headers. On Debian/Ubuntu: $ sudo apt-get install libev4 libev-dev On RHEL/CentOS/Fedora: $ sudo yum install libev libev-devel On OSX, via homebrew: $ brew install libev =============================================================================== """ def run(self): try: self._setup_extensions() build_ext.run(self) except DistutilsPlatformError as exc: sys.stderr.write('%s\n' % str(exc)) warnings.warn(self.error_message % "C extensions.") def build_extensions(self): if build_concurrency > 1: self.check_extensions_list(self.extensions) import multiprocessing.pool multiprocessing.pool.ThreadPool(processes=build_concurrency).map(self.build_extension, self.extensions) else: build_ext.build_extensions(self) def build_extension(self, ext): try: build_ext.build_extension(self, ext) except (CCompilerError, DistutilsExecError, DistutilsPlatformError, IOError) as exc: sys.stderr.write('%s\n' % str(exc)) name = "The %s extension" % (ext.name,) warnings.warn(self.error_message % (name,)) def _setup_extensions(self): # We defer extension setup until this command to leveraage 'setup_requires' pulling in Cython before we # attempt to import anything self.extensions = [] if try_murmur3: self.extensions.append(murmur3_ext) if try_libev: self.extensions.append(libev_ext) if try_cython: try: from Cython.Build import cythonize cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'metadata', 'pool', 'protocol', 'query', 'util'] compile_args = [] if is_windows else ['-Wno-unused-function'] self.extensions.extend(cythonize( [Extension('cassandra.%s' % m, ['cassandra/%s.py' % m], extra_compile_args=compile_args) for m in cython_candidates], nthreads=build_concurrency, exclude_failures=True)) self.extensions.extend(cythonize(NoPatchExtension("*", ["cassandra/*.pyx"], extra_compile_args=compile_args), nthreads=build_concurrency)) except Exception: sys.stderr.write("Failed to cythonize one or more modules. These will not be compiled as extensions (optional).\n") def pre_build_check(): """ Try to verify build tools """ if os.environ.get('CASS_DRIVER_NO_PRE_BUILD_CHECK'): return True try: from distutils.ccompiler import new_compiler from distutils.sysconfig import customize_compiler from distutils.dist import Distribution # base build_ext just to emulate compiler option setup be = build_ext(Distribution()) be.initialize_options() be.finalize_options() # First, make sure we have a Python include directory have_python_include = any(os.path.isfile(os.path.join(p, 'Python.h')) for p in be.include_dirs) if not have_python_include: sys.stderr.write("Did not find 'Python.h' in %s.\n" % (be.include_dirs,)) return False compiler = new_compiler(compiler=be.compiler) customize_compiler(compiler) try: # We must be able to initialize the compiler if it has that method if hasattr(compiler, "initialize"): compiler.initialize() except: return False executables = [] if compiler.compiler_type in ('unix', 'cygwin'): executables = [compiler.executables[exe][0] for exe in ('compiler_so', 'linker_so')] elif compiler.compiler_type == 'nt': executables = [getattr(compiler, exe) for exe in ('cc', 'linker')] if executables: from distutils.spawn import find_executable for exe in executables: if not find_executable(exe): sys.stderr.write("Failed to find %s for compiler type %s.\n" % (exe, compiler.compiler_type)) return False except Exception as exc: sys.stderr.write('%s\n' % str(exc)) sys.stderr.write("Failed pre-build check. Attempting anyway.\n") # if we are unable to positively id the compiler type, or one of these assumptions fails, # just proceed as we would have without the check return True def run_setup(extensions): kw = {'cmdclass': {'doc': DocCommand}} if gevent_nosetests is not None: kw['cmdclass']['gevent_nosetests'] = gevent_nosetests if eventlet_nosetests is not None: kw['cmdclass']['eventlet_nosetests'] = eventlet_nosetests kw['cmdclass']['build_ext'] = build_extensions kw['ext_modules'] = [Extension('DUMMY', [])] # dummy extension makes sure build_ext is called for install if try_cython: # precheck compiler before adding to setup_requires # we don't actually negate try_cython because: # 1.) build_ext eats errors at compile time, letting the install complete while producing useful feedback # 2.) there could be a case where the python environment has cython installed but the system doesn't have build tools if pre_build_check(): - cython_dep = 'Cython>=0.20,!=0.25,<0.28' + cython_dep = 'Cython>=0.20,!=0.25,<0.29' user_specified_cython_version = os.environ.get('CASS_DRIVER_ALLOWED_CYTHON_VERSION') if user_specified_cython_version is not None: cython_dep = 'Cython==%s' % (user_specified_cython_version,) kw['setup_requires'] = [cython_dep] else: sys.stderr.write("Bypassing Cython setup requirement\n") dependencies = ['six >=1.9'] if not PY3: dependencies.append('futures') setup( name='cassandra-driver', version=__version__, description='Python driver for Cassandra', long_description=long_description, url='http://github.com/datastax/python-driver', author='Tyler Hobbs', author_email='tyler@datastax.com', packages=['cassandra', 'cassandra.io', 'cassandra.cqlengine'], keywords='cassandra,cql,orm', include_package_data=True, install_requires=dependencies, tests_require=['nose', 'mock>=2.0.0', 'PyYAML', 'pytz', 'sure'], classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'License :: OSI Approved :: Apache Software License', 'Natural Language :: English', 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', 'Topic :: Software Development :: Libraries :: Python Modules' ], **kw) run_setup(None) if has_cqlengine: warnings.warn("\n#######\n'cqlengine' package is present on path: %s\n" "cqlengine is now an integrated sub-package of this driver.\n" "It is recommended to remove this package to reduce the chance for conflicting usage" % cqlengine.__file__)