diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f433b1a --- /dev/null +++ b/LICENSE @@ -0,0 +1,177 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..1825f7b --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include setup.py README.rst MANIFEST.in LICENSE ez_setup.py diff --git a/PKG-INFO b/PKG-INFO new file mode 100644 index 0000000..6571faa --- /dev/null +++ b/PKG-INFO @@ -0,0 +1,99 @@ +Metadata-Version: 1.1 +Name: cassandra-driver +Version: 2.5.1 +Summary: Python driver for Cassandra +Home-page: http://github.com/datastax/python-driver +Author: Tyler Hobbs +Author-email: tyler@datastax.com +License: UNKNOWN +Description: DataStax Python Driver for Apache Cassandra + =========================================== + + .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master + :target: https://travis-ci.org/datastax/python-driver + + A Python client driver for Apache Cassandra. This driver works exclusively + with the Cassandra Query Language v3 (CQL3) and Cassandra's native + protocol. Cassandra versions 1.2 through 2.1 are supported. + + The driver supports Python 2.6, 2.7, 3.3, and 3.4*. + + * cqlengine component presently supports Python 2.7+ + + Installation + ------------ + Installation through pip is recommended:: + + $ pip install cassandra-driver + + For more complete installation instructions, see the + `installation guide `_. + + Documentation + ------------- + The documentation can be found online `here `_. + + A couple of links for getting up to speed: + + * `Installation `_ + * `Getting started guide `_ + * `API docs `_ + * `Performance tips `_ + + Object Mapper + ------------- + cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the + community) is now maintained as an integral part of this package. Refer to + `documentation here `_. + + Reporting Problems + ------------------ + Please report any bugs and make any feature requests on the + `JIRA `_ issue tracker. + + If you would like to contribute, please feel free to open a pull request. + + Getting Help + ------------ + Your two best options for getting help with the driver are the + `mailing list `_ + and the IRC channel. + + For IRC, use the #datastax-drivers channel on irc.freenode.net. If you don't have an IRC client, + you can use `freenode's web-based client `_. + + Features to be Added + -------------------- + * C extension for encoding/decoding messages + + License + ------- + Copyright 2013-2015 DataStax + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +Keywords: cassandra,cql,orm +Platform: UNKNOWN +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Natural Language :: English +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 2.6 +Classifier: Programming Language :: Python :: 2.7 +Classifier: Programming Language :: Python :: 3.3 +Classifier: Programming Language :: Python :: 3.4 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Topic :: Software Development :: Libraries :: Python Modules diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..e976f93 --- /dev/null +++ b/README.rst @@ -0,0 +1,75 @@ +DataStax Python Driver for Apache Cassandra +=========================================== + +.. image:: https://travis-ci.org/datastax/python-driver.png?branch=master + :target: https://travis-ci.org/datastax/python-driver + +A Python client driver for Apache Cassandra. This driver works exclusively +with the Cassandra Query Language v3 (CQL3) and Cassandra's native +protocol. Cassandra versions 1.2 through 2.1 are supported. + +The driver supports Python 2.6, 2.7, 3.3, and 3.4*. + +* cqlengine component presently supports Python 2.7+ + +Installation +------------ +Installation through pip is recommended:: + + $ pip install cassandra-driver + +For more complete installation instructions, see the +`installation guide `_. + +Documentation +------------- +The documentation can be found online `here `_. + +A couple of links for getting up to speed: + +* `Installation `_ +* `Getting started guide `_ +* `API docs `_ +* `Performance tips `_ + +Object Mapper +------------- +cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the +community) is now maintained as an integral part of this package. Refer to +`documentation here `_. + +Reporting Problems +------------------ +Please report any bugs and make any feature requests on the +`JIRA `_ issue tracker. + +If you would like to contribute, please feel free to open a pull request. + +Getting Help +------------ +Your two best options for getting help with the driver are the +`mailing list `_ +and the IRC channel. + +For IRC, use the #datastax-drivers channel on irc.freenode.net. If you don't have an IRC client, +you can use `freenode's web-based client `_. + +Features to be Added +-------------------- +* C extension for encoding/decoding messages + +License +------- +Copyright 2013-2015 DataStax + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/cassandra/__init__.py b/cassandra/__init__.py new file mode 100644 index 0000000..2be5700 --- /dev/null +++ b/cassandra/__init__.py @@ -0,0 +1,304 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + + +class NullHandler(logging.Handler): + + def emit(self, record): + pass + +logging.getLogger('cassandra').addHandler(NullHandler()) + + +__version_info__ = (2, 5, 1) +__version__ = '.'.join(map(str, __version_info__)) + + +class ConsistencyLevel(object): + """ + Spcifies how many replicas must respond for an operation to be considered + a success. By default, ``ONE`` is used for all operations. + """ + + ANY = 0 + """ + Only requires that one replica receives the write *or* the coordinator + stores a hint to replay later. Valid only for writes. + """ + + ONE = 1 + """ + Only one replica needs to respond to consider the operation a success + """ + + TWO = 2 + """ + Two replicas must respond to consider the operation a success + """ + + THREE = 3 + """ + Three replicas must respond to consider the operation a success + """ + + QUORUM = 4 + """ + ``ceil(RF/2)`` replicas must respond to consider the operation a success + """ + + ALL = 5 + """ + All replicas must respond to consider the operation a success + """ + + LOCAL_QUORUM = 6 + """ + Requires a quorum of replicas in the local datacenter + """ + + EACH_QUORUM = 7 + """ + Requires a quorum of replicas in each datacenter + """ + + SERIAL = 8 + """ + For conditional inserts/updates that utilize Cassandra's lightweight + transactions, this requires consensus among all replicas for the + modified data. + """ + + LOCAL_SERIAL = 9 + """ + Like :attr:`~ConsistencyLevel.SERIAL`, but only requires consensus + among replicas in the local datacenter. + """ + + LOCAL_ONE = 10 + """ + Sends a request only to replicas in the local datacenter and waits for + one response. + """ + +ConsistencyLevel.value_to_name = { + ConsistencyLevel.ANY: 'ANY', + ConsistencyLevel.ONE: 'ONE', + ConsistencyLevel.TWO: 'TWO', + ConsistencyLevel.THREE: 'THREE', + ConsistencyLevel.QUORUM: 'QUORUM', + ConsistencyLevel.ALL: 'ALL', + ConsistencyLevel.LOCAL_QUORUM: 'LOCAL_QUORUM', + ConsistencyLevel.EACH_QUORUM: 'EACH_QUORUM', + ConsistencyLevel.SERIAL: 'SERIAL', + ConsistencyLevel.LOCAL_SERIAL: 'LOCAL_SERIAL', + ConsistencyLevel.LOCAL_ONE: 'LOCAL_ONE' +} + +ConsistencyLevel.name_to_value = { + 'ANY': ConsistencyLevel.ANY, + 'ONE': ConsistencyLevel.ONE, + 'TWO': ConsistencyLevel.TWO, + 'THREE': ConsistencyLevel.THREE, + 'QUORUM': ConsistencyLevel.QUORUM, + 'ALL': ConsistencyLevel.ALL, + 'LOCAL_QUORUM': ConsistencyLevel.LOCAL_QUORUM, + 'EACH_QUORUM': ConsistencyLevel.EACH_QUORUM, + 'SERIAL': ConsistencyLevel.SERIAL, + 'LOCAL_SERIAL': ConsistencyLevel.LOCAL_SERIAL, + 'LOCAL_ONE': ConsistencyLevel.LOCAL_ONE +} + + +def consistency_value_to_name(value): + return ConsistencyLevel.value_to_name[value] if value is not None else "Not Set" + + +class Unavailable(Exception): + """ + There were not enough live replicas to satisfy the requested consistency + level, so the coordinator node immediately failed the request without + forwarding it to any replicas. + """ + + consistency = None + """ The requested :class:`ConsistencyLevel` """ + + required_replicas = None + """ The number of replicas that needed to be live to complete the operation """ + + alive_replicas = None + """ The number of replicas that were actually alive """ + + def __init__(self, summary_message, consistency=None, required_replicas=None, alive_replicas=None): + self.consistency = consistency + self.required_replicas = required_replicas + self.alive_replicas = alive_replicas + Exception.__init__(self, summary_message + ' info=' + + repr({'consistency': consistency_value_to_name(consistency), + 'required_replicas': required_replicas, + 'alive_replicas': alive_replicas})) + + +class Timeout(Exception): + """ + Replicas failed to respond to the coordinator node before timing out. + """ + + consistency = None + """ The requested :class:`ConsistencyLevel` """ + + required_responses = None + """ The number of required replica responses """ + + received_responses = None + """ + The number of replicas that responded before the coordinator timed out + the operation + """ + + def __init__(self, summary_message, consistency=None, required_responses=None, received_responses=None): + self.consistency = consistency + self.required_responses = required_responses + self.received_responses = received_responses + Exception.__init__(self, summary_message + ' info=' + + repr({'consistency': consistency_value_to_name(consistency), + 'required_responses': required_responses, + 'received_responses': received_responses})) + + +class ReadTimeout(Timeout): + """ + A subclass of :exc:`Timeout` for read operations. + + This indicates that the replicas failed to respond to the coordinator + node before the configured timeout. This timeout is configured in + ``cassandra.yaml`` with the ``read_request_timeout_in_ms`` + and ``range_request_timeout_in_ms`` options. + """ + + data_retrieved = None + """ + A boolean indicating whether the requested data was retrieved + by the coordinator from any replicas before it timed out the + operation + """ + + def __init__(self, message, data_retrieved=None, **kwargs): + Timeout.__init__(self, message, **kwargs) + self.data_retrieved = data_retrieved + + +class WriteTimeout(Timeout): + """ + A subclass of :exc:`Timeout` for write operations. + + This indicates that the replicas failed to respond to the coordinator + node before the configured timeout. This timeout is configured in + ``cassandra.yaml`` with the ``write_request_timeout_in_ms`` + option. + """ + + write_type = None + """ + The type of write operation, enum on :class:`~cassandra.policies.WriteType` + """ + + def __init__(self, message, write_type=None, **kwargs): + Timeout.__init__(self, message, **kwargs) + self.write_type = write_type + + +class AlreadyExists(Exception): + """ + An attempt was made to create a keyspace or table that already exists. + """ + + keyspace = None + """ + The name of the keyspace that already exists, or, if an attempt was + made to create a new table, the keyspace that the table is in. + """ + + table = None + """ + The name of the table that already exists, or, if an attempt was + make to create a keyspace, :const:`None`. + """ + + def __init__(self, keyspace=None, table=None): + if table: + message = "Table '%s.%s' already exists" % (keyspace, table) + else: + message = "Keyspace '%s' already exists" % (keyspace,) + + Exception.__init__(self, message) + self.keyspace = keyspace + self.table = table + + +class InvalidRequest(Exception): + """ + A query was made that was invalid for some reason, such as trying to set + the keyspace for a connection to a nonexistent keyspace. + """ + pass + + +class Unauthorized(Exception): + """ + The current user is not authorized to perfom the requested operation. + """ + pass + + +class AuthenticationFailed(Exception): + """ + Failed to authenticate. + """ + pass + + +class OperationTimedOut(Exception): + """ + The operation took longer than the specified (client-side) timeout + to complete. This is not an error generated by Cassandra, only + the driver. + """ + + errors = None + """ + A dict of errors keyed by the :class:`~.Host` against which they occurred. + """ + + last_host = None + """ + The last :class:`~.Host` this operation was attempted against. + """ + + def __init__(self, errors=None, last_host=None): + self.errors = errors + self.last_host = last_host + message = "errors=%s, last_host=%s" % (self.errors, self.last_host) + Exception.__init__(self, message) + + +class UnsupportedOperation(Exception): + """ + An attempt was made to use a feature that is not supported by the + selected protocol version. See :attr:`Cluster.protocol_version` + for more details. + """ + pass diff --git a/cassandra/auth.py b/cassandra/auth.py new file mode 100644 index 0000000..67d302a --- /dev/null +++ b/cassandra/auth.py @@ -0,0 +1,177 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +try: + from puresasl.client import SASLClient +except ImportError: + SASLClient = None + +class AuthProvider(object): + """ + An abstract class that defines the interface that will be used for + creating :class:`~.Authenticator` instances when opening new + connections to Cassandra. + + .. versionadded:: 2.0.0 + """ + + def new_authenticator(self, host): + """ + Implementations of this class should return a new instance + of :class:`~.Authenticator` or one of its subclasses. + """ + raise NotImplementedError() + + +class Authenticator(object): + """ + An abstract class that handles SASL authentication with Cassandra servers. + + Each time a new connection is created and the server requires authentication, + a new instance of this class will be created by the corresponding + :class:`~.AuthProvider` to handler that authentication. The lifecycle of the + new :class:`~.Authenticator` will the be: + + 1) The :meth:`~.initial_response()` method will be called. The return + value will be sent to the server to initiate the handshake. + + 2) The server will respond to each client response by either issuing a + challenge or indicating that the authentication is complete (successful or not). + If a new challenge is issued, :meth:`~.evaluate_challenge()` + will be called to produce a response that will be sent to the + server. This challenge/response negotiation will continue until the server + responds that authentication is successful (or an :exc:`~.AuthenticationFailed` + is raised). + + 3) When the server indicates that authentication is successful, + :meth:`~.on_authentication_success` will be called a token string that + that the server may optionally have sent. + + The exact nature of the negotiation between the client and server is specific + to the authentication mechanism configured server-side. + + .. versionadded:: 2.0.0 + """ + + def initial_response(self): + """ + Returns an message to send to the server to initiate the SASL handshake. + :const:`None` may be returned to send an empty message. + """ + return None + + def evaluate_challenge(self, challenge): + """ + Called when the server sends a challenge message. Generally, this method + should return :const:`None` when authentication is complete from a + client perspective. Otherwise, a string should be returned. + """ + raise NotImplementedError() + + def on_authentication_success(self, token): + """ + Called when the server indicates that authentication was successful. + Depending on the authentication mechanism, `token` may be :const:`None` + or a string. + """ + pass + + +class PlainTextAuthProvider(AuthProvider): + """ + An :class:`~.AuthProvider` that works with Cassandra's PasswordAuthenticator. + + Example usage:: + + from cassandra.cluster import Cluster + from cassandra.auth import PlainTextAuthProvider + + auth_provider = PlainTextAuthProvider( + username='cassandra', password='cassandra') + cluster = Cluster(auth_provider=auth_provider) + + .. versionadded:: 2.0.0 + """ + + def __init__(self, username, password): + self.username = username + self.password = password + + def new_authenticator(self, host): + return PlainTextAuthenticator(self.username, self.password) + + +class PlainTextAuthenticator(Authenticator): + """ + An :class:`~.Authenticator` that works with Cassandra's PasswordAuthenticator. + + .. versionadded:: 2.0.0 + """ + + def __init__(self, username, password): + self.username = username + self.password = password + + def initial_response(self): + return "\x00%s\x00%s" % (self.username, self.password) + + def evaluate_challenge(self, challenge): + return None + + +class SaslAuthProvider(AuthProvider): + """ + An :class:`~.AuthProvider` supporting general SASL auth mechanisms + + Suitable for GSSAPI or other SASL mechanisms + + Example usage:: + + from cassandra.cluster import Cluster + from cassandra.auth import SaslAuthProvider + + sasl_kwargs = {'host': 'localhost', + 'service': 'dse', + 'mechanism': 'GSSAPI', + 'qops': 'auth'.split(',')} + auth_provider = SaslAuthProvider(**sasl_kwargs) + cluster = Cluster(auth_provider=auth_provider) + + .. versionadded:: 2.1.4 + """ + + def __init__(self, **sasl_kwargs): + if SASLClient is None: + raise ImportError('The puresasl library has not been installed') + self.sasl_kwargs = sasl_kwargs + + def new_authenticator(self, host): + return SaslAuthenticator(**self.sasl_kwargs) + +class SaslAuthenticator(Authenticator): + """ + A pass-through :class:`~.Authenticator` using the third party package + 'pure-sasl' for authentication + + .. versionadded:: 2.1.4 + """ + + def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs): + if SASLClient is None: + raise ImportError('The puresasl library has not been installed') + self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) + + def initial_response(self): + return self.sasl.process() + + def evaluate_challenge(self, challenge): + return self.sasl.process(challenge) diff --git a/cassandra/cluster.py b/cassandra/cluster.py new file mode 100644 index 0000000..c7b2be7 --- /dev/null +++ b/cassandra/cluster.py @@ -0,0 +1,3152 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module houses the main classes you will interact with, +:class:`.Cluster` and :class:`.Session`. +""" +from __future__ import absolute_import + +import atexit +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +import logging +from random import random +import socket +import sys +import time +from threading import Lock, RLock, Thread, Event + +import six +from six.moves import range +from six.moves import queue as Queue + +import weakref +from weakref import WeakValueDictionary +try: + from weakref import WeakSet +except ImportError: + from cassandra.util import WeakSet # NOQA + +from functools import partial, wraps +from itertools import groupby + +from cassandra import (ConsistencyLevel, AuthenticationFailed, + InvalidRequest, OperationTimedOut, + UnsupportedOperation, Unauthorized) +from cassandra.connection import (ConnectionException, ConnectionShutdown, + ConnectionHeartbeat) +from cassandra.cqltypes import UserType +from cassandra.encoder import Encoder +from cassandra.protocol import (QueryMessage, ResultMessage, + ErrorMessage, ReadTimeoutErrorMessage, + WriteTimeoutErrorMessage, + UnavailableErrorMessage, + OverloadedErrorMessage, + PrepareMessage, ExecuteMessage, + PreparedQueryNotFound, + IsBootstrappingErrorMessage, + BatchMessage, RESULT_KIND_PREPARED, + RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, + RESULT_KIND_SCHEMA_CHANGE) +from cassandra.metadata import Metadata, protect_name +from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy, + ExponentialReconnectionPolicy, HostDistance, + RetryPolicy) +from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, + HostConnectionPool, HostConnection, + NoConnectionsAvailable) +from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, + BatchStatement, bind_params, QueryTrace, Statement, + named_tuple_factory, dict_factory, FETCH_SIZE_UNSET) + +def _is_eventlet_monkey_patched(): + if 'eventlet.patcher' not in sys.modules: + return False + import eventlet.patcher + return eventlet.patcher.is_monkey_patched('socket') + +# default to gevent when we are monkey patched with gevent, eventlet when +# monkey patched with eventlet, otherwise if libev is available, use that as +# the default because it's fastest. Otherwise, use asyncore. +if 'gevent.monkey' in sys.modules: + from cassandra.io.geventreactor import GeventConnection as DefaultConnection +elif _is_eventlet_monkey_patched(): + from cassandra.io.eventletreactor import EventletConnection as DefaultConnection +else: + try: + from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA + except ImportError: + from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA + +# Forces load of utf8 encoding module to avoid deadlock that occurs +# if code that is being imported tries to import the module in a seperate +# thread. +# See http://bugs.python.org/issue10923 +"".encode('utf8') + +log = logging.getLogger(__name__) + + +DEFAULT_MIN_REQUESTS = 5 +DEFAULT_MAX_REQUESTS = 100 + +DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST = 2 +DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST = 8 + +DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1 +DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2 + + +_NOT_SET = object() + + +class NoHostAvailable(Exception): + """ + Raised when an operation is attempted but all connections are + busy, defunct, closed, or resulted in errors when used. + """ + + errors = None + """ + A map of the form ``{ip: exception}`` which details the particular + Exception that was caught for each host the operation was attempted + against. + """ + + def __init__(self, message, errors): + Exception.__init__(self, message, errors) + self.errors = errors + + +def _future_completed(future): + """ Helper for run_in_executor() """ + exc = future.exception() + if exc: + log.debug("Failed to run task on executor", exc_info=exc) + + +def run_in_executor(f): + """ + A decorator to run the given method in the ThreadPoolExecutor. + """ + + @wraps(f) + def new_f(self, *args, **kwargs): + + if self.is_shutdown: + return + try: + future = self.executor.submit(f, self, *args, **kwargs) + future.add_done_callback(_future_completed) + except Exception: + log.exception("Failed to submit task to executor") + + return new_f + + +def _shutdown_cluster(cluster): + if cluster and not cluster.is_shutdown: + cluster.shutdown() + + +class Cluster(object): + """ + The main class to use when interacting with a Cassandra cluster. + Typically, one instance of this class will be created for each + separate Cassandra cluster that your application interacts with. + + Example usage:: + + >>> from cassandra.cluster import Cluster + >>> cluster = Cluster(['192.168.1.1', '192.168.1.2']) + >>> session = cluster.connect() + >>> session.execute("CREATE KEYSPACE ...") + >>> ... + >>> cluster.shutdown() + + """ + + contact_points = ['127.0.0.1'] + """ + The list of contact points to try connecting for cluster discovery. + + Defaults to loopback interface. + + Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit + local_dc set, the DC is chosen from an arbitrary host in contact_points. + In this case, contact_points should contain only nodes from a single, + local DC. + """ + + port = 9042 + """ + The server-side port to open connections to. Defaults to 9042. + """ + + cql_version = None + """ + If a specific version of CQL should be used, this may be set to that + string version. Otherwise, the highest CQL version supported by the + server will be automatically used. + """ + + protocol_version = 2 + """ + The version of the native protocol to use. + + Version 2 of the native protocol adds support for lightweight transactions, + batch operations, and automatic query paging. The v2 protocol is + supported by Cassandra 2.0+. + + Version 3 of the native protocol adds support for protocol-level + client-side timestamps (see :attr:`.Session.use_client_timestamp`), + serial consistency levels for :class:`~.BatchStatement`, and an + improved connection pool. + + The following table describes the native protocol versions that + are supported by each version of Cassandra: + + +-------------------+-------------------+ + | Cassandra Version | Protocol Versions | + +===================+===================+ + | 1.2 | 1 | + +-------------------+-------------------+ + | 2.0 | 1, 2 | + +-------------------+-------------------+ + | 2.1 | 1, 2, 3 | + +-------------------+-------------------+ + """ + + compression = True + """ + Controls compression for communications between the driver and Cassandra. + If left as the default of :const:`True`, either lz4 or snappy compression + may be used, depending on what is supported by both the driver + and Cassandra. If both are fully supported, lz4 will be preferred. + + You may also set this to 'snappy' or 'lz4' to request that specific + compression type. + + Setting this to :const:`False` disables compression. + """ + + _auth_provider = None + _auth_provider_callable = None + + @property + def auth_provider(self): + """ + When :attr:`~.Cluster.protocol_version` is 2 or higher, this should + be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`, + such as :class:`~.PlainTextAuthProvider`. + + When :attr:`~.Cluster.protocol_version` is 1, this should be + a function that accepts one argument, the IP address of a node, + and returns a dict of credentials for that node. + + When not using authentication, this should be left as :const:`None`. + """ + return self._auth_provider + + @auth_provider.setter # noqa + def auth_provider(self, value): + if not value: + self._auth_provider = value + return + + try: + self._auth_provider_callable = value.new_authenticator + except AttributeError: + if self.protocol_version > 1: + raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider " + "interface when protocol_version >= 2") + elif not callable(value): + raise TypeError("auth_provider must be callable when protocol_version == 1") + self._auth_provider_callable = value + + self._auth_provider = value + + load_balancing_policy = None + """ + An instance of :class:`.policies.LoadBalancingPolicy` or + one of its subclasses. Defaults to :class:`~.RoundRobinPolicy`. + """ + + reconnection_policy = ExponentialReconnectionPolicy(1.0, 600.0) + """ + An instance of :class:`.policies.ReconnectionPolicy`. Defaults to an instance + of :class:`.ExponentialReconnectionPolicy` with a base delay of one second and + a max delay of ten minutes. + """ + + default_retry_policy = RetryPolicy() + """ + A default :class:`.policies.RetryPolicy` instance to use for all + :class:`.Statement` objects which do not have a :attr:`~.Statement.retry_policy` + explicitly set. + """ + + conviction_policy_factory = SimpleConvictionPolicy + """ + A factory function which creates instances of + :class:`.policies.ConvictionPolicy`. Defaults to + :class:`.policies.SimpleConvictionPolicy`. + """ + + connect_to_remote_hosts = True + """ + If left as :const:`True`, hosts that are considered :attr:`~.HostDistance.REMOTE` + by the :attr:`~.Cluster.load_balancing_policy` will have a connection + opened to them. Otherwise, they will not have a connection opened to them. + + .. versionadded:: 2.1.0 + """ + + metrics_enabled = False + """ + Whether or not metric collection is enabled. If enabled, :attr:`.metrics` + will be an instance of :class:`~cassandra.metrics.Metrics`. + """ + + metrics = None + """ + An instance of :class:`cassandra.metrics.Metrics` if :attr:`.metrics_enabled` is + :const:`True`, else :const:`None`. + """ + + ssl_options = None + """ + A optional dict which will be used as kwargs for ``ssl.wrap_socket()`` + when new sockets are created. This should be used when client encryption + is enabled in Cassandra. + + By default, a ``ca_certs`` value should be supplied (the value should be + a string pointing to the location of the CA certs file), and you probably + want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLSv1`` to match + Cassandra's default protocol. + """ + + sockopts = None + """ + An optional list of tuples which will be used as arguments to + ``socket.setsockopt()`` for all created sockets. + """ + + max_schema_agreement_wait = 10 + """ + The maximum duration (in seconds) that the driver will wait for schema + agreement across the cluster. Defaults to ten seconds. + If set <= 0, the driver will bypass schema agreement waits altogether. + """ + + metadata = None + """ + An instance of :class:`cassandra.metadata.Metadata`. + """ + + connection_class = DefaultConnection + """ + This determines what event loop system will be used for managing + I/O with Cassandra. These are the current options: + + * :class:`cassandra.io.asyncorereactor.AsyncoreConnection` + * :class:`cassandra.io.libevreactor.LibevConnection` + * :class:`cassandra.io.geventreactor.GeventConnection` (requires monkey-patching) + * :class:`cassandra.io.twistedreactor.TwistedConnection` + + By default, ``AsyncoreConnection`` will be used, which uses + the ``asyncore`` module in the Python standard library. The + performance is slightly worse than with ``libev``, but it is + supported on a wider range of systems. + + If ``libev`` is installed, ``LibevConnection`` will be used instead. + + If gevent monkey-patching of the standard library is detected, + GeventConnection will be used automatically. + """ + + control_connection_timeout = 2.0 + """ + A timeout, in seconds, for queries made by the control connection, such + as querying the current schema and information about nodes in the cluster. + If set to :const:`None`, there will be no timeout for these queries. + """ + + idle_heartbeat_interval = 30 + """ + Interval, in seconds, on which to heartbeat idle connections. This helps + keep connections open through network devices that expire idle connections. + It also helps discover bad connections early in low-traffic scenarios. + Setting to zero disables heartbeats. + """ + + schema_event_refresh_window = 2 + """ + Window, in seconds, within which a schema component will be refreshed after + receiving a schema_change event. + + The driver delays a random amount of time in the range [0.0, window) + before executing the refresh. This serves two purposes: + + 1.) Spread the refresh for deployments with large fanout from C* to client tier, + preventing a 'thundering herd' problem with many clients refreshing simultaneously. + + 2.) Remove redundant refreshes. Redundant events arriving within the delay period + are discarded, and only one refresh is executed. + + Setting this to zero will execute refreshes immediately. + + Setting this negative will disable schema refreshes in response to push events + (refreshes will still occur in response to schema change responses to DDL statements + executed by Sessions of this Cluster). + """ + + topology_event_refresh_window = 10 + """ + Window, in seconds, within which the node and token list will be refreshed after + receiving a topology_change event. + + Setting this to zero will execute refreshes immediately. + + Setting this negative will disable node refreshes in response to push events + (refreshes will still occur in response to new nodes observed on "UP" events). + + See :attr:`.schema_event_refresh_window` for discussion of rationale + """ + + sessions = None + control_connection = None + scheduler = None + executor = None + is_shutdown = False + _is_setup = False + _prepared_statements = None + _prepared_statement_lock = None + _idle_heartbeat = None + + _user_types = None + """ + A map of {keyspace: {type_name: UserType}} + """ + + _listeners = None + _listener_lock = None + + def __init__(self, + contact_points=["127.0.0.1"], + port=9042, + compression=True, + auth_provider=None, + load_balancing_policy=None, + reconnection_policy=None, + default_retry_policy=None, + conviction_policy_factory=None, + metrics_enabled=False, + connection_class=None, + ssl_options=None, + sockopts=None, + cql_version=None, + protocol_version=2, + executor_threads=2, + max_schema_agreement_wait=10, + control_connection_timeout=2.0, + idle_heartbeat_interval=30, + schema_event_refresh_window=2, + topology_event_refresh_window=10): + """ + Any of the mutable Cluster attributes may be set as keyword arguments + to the constructor. + """ + if contact_points is not None: + if isinstance(contact_points, six.string_types): + raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings") + + self.contact_points = contact_points + + self.port = port + self.compression = compression + self.protocol_version = protocol_version + self.auth_provider = auth_provider + + if load_balancing_policy is not None: + if isinstance(load_balancing_policy, type): + raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class") + + self.load_balancing_policy = load_balancing_policy + else: + self.load_balancing_policy = RoundRobinPolicy() + + if reconnection_policy is not None: + if isinstance(reconnection_policy, type): + raise TypeError("reconnection_policy should not be a class, it should be an instance of that class") + + self.reconnection_policy = reconnection_policy + + if default_retry_policy is not None: + if isinstance(default_retry_policy, type): + raise TypeError("default_retry_policy should not be a class, it should be an instance of that class") + + self.default_retry_policy = default_retry_policy + + if conviction_policy_factory is not None: + if not callable(conviction_policy_factory): + raise ValueError("conviction_policy_factory must be callable") + self.conviction_policy_factory = conviction_policy_factory + + if connection_class is not None: + self.connection_class = connection_class + + self.metrics_enabled = metrics_enabled + self.ssl_options = ssl_options + self.sockopts = sockopts + self.cql_version = cql_version + self.max_schema_agreement_wait = max_schema_agreement_wait + self.control_connection_timeout = control_connection_timeout + self.idle_heartbeat_interval = idle_heartbeat_interval + self.schema_event_refresh_window = schema_event_refresh_window + self.topology_event_refresh_window = topology_event_refresh_window + + self._listeners = set() + self._listener_lock = Lock() + + # let Session objects be GC'ed (and shutdown) when the user no longer + # holds a reference. + self.sessions = WeakSet() + self.metadata = Metadata() + self.control_connection = None + self._prepared_statements = WeakValueDictionary() + self._prepared_statement_lock = Lock() + + self._user_types = defaultdict(dict) + + self._min_requests_per_connection = { + HostDistance.LOCAL: DEFAULT_MIN_REQUESTS, + HostDistance.REMOTE: DEFAULT_MIN_REQUESTS + } + + self._max_requests_per_connection = { + HostDistance.LOCAL: DEFAULT_MAX_REQUESTS, + HostDistance.REMOTE: DEFAULT_MAX_REQUESTS + } + + self._core_connections_per_host = { + HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, + HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST + } + + self._max_connections_per_host = { + HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, + HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST + } + + self.executor = ThreadPoolExecutor(max_workers=executor_threads) + self.scheduler = _Scheduler(self.executor) + + self._lock = RLock() + + if self.metrics_enabled: + from cassandra.metrics import Metrics + self.metrics = Metrics(weakref.proxy(self)) + + self.control_connection = ControlConnection( + self, self.control_connection_timeout, + self.schema_event_refresh_window, self.topology_event_refresh_window) + + def register_user_type(self, keyspace, user_type, klass): + """ + Registers a class to use to represent a particular user-defined type. + Query parameters for this user-defined type will be assumed to be + instances of `klass`. Result sets for this user-defined type will + be instances of `klass`. If no class is registered for a user-defined + type, a namedtuple will be used for result sets, and non-prepared + statements may not encode parameters for this type correctly. + + `keyspace` is the name of the keyspace that the UDT is defined in. + + `user_type` is the string name of the UDT to register the mapping + for. + + `klass` should be a class with attributes whose names match the + fields of the user-defined type. The constructor must accepts kwargs + for each of the fields in the UDT. + + This method should only be called after the type has been created + within Cassandra. + + Example:: + + cluster = Cluster(protocol_version=3) + session = cluster.connect() + session.set_keyspace('mykeyspace') + session.execute("CREATE TYPE address (street text, zipcode int)") + session.execute("CREATE TABLE users (id int PRIMARY KEY, location address)") + + # create a class to map to the "address" UDT + class Address(object): + + def __init__(self, street, zipcode): + self.street = street + self.zipcode = zipcode + + cluster.register_user_type('mykeyspace', 'address', Address) + + # insert a row using an instance of Address + session.execute("INSERT INTO users (id, location) VALUES (%s, %s)", + (0, Address("123 Main St.", 78723))) + + # results will include Address instances + results = session.execute("SELECT * FROM users") + row = results[0] + print row.id, row.location.street, row.location.zipcode + + """ + self._user_types[keyspace][user_type] = klass + for session in self.sessions: + session.user_type_registered(keyspace, user_type, klass) + UserType.evict_udt_class(keyspace, user_type) + + def get_min_requests_per_connection(self, host_distance): + return self._min_requests_per_connection[host_distance] + + def set_min_requests_per_connection(self, host_distance, min_requests): + if self.protocol_version >= 3: + raise UnsupportedOperation( + "Cluster.set_min_requests_per_connection() only has an effect " + "when using protocol_version 1 or 2.") + self._min_requests_per_connection[host_distance] = min_requests + + def get_max_requests_per_connection(self, host_distance): + return self._max_requests_per_connection[host_distance] + + def set_max_requests_per_connection(self, host_distance, max_requests): + if self.protocol_version >= 3: + raise UnsupportedOperation( + "Cluster.set_max_requests_per_connection() only has an effect " + "when using protocol_version 1 or 2.") + self._max_requests_per_connection[host_distance] = max_requests + + def get_core_connections_per_host(self, host_distance): + """ + Gets the minimum number of connections per Session that will be opened + for each host with :class:`~.HostDistance` equal to `host_distance`. + The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for + :attr:`~HostDistance.REMOTE`. + + This property is ignored if :attr:`~.Cluster.protocol_version` is + 3 or higher. + """ + return self._core_connections_per_host[host_distance] + + def set_core_connections_per_host(self, host_distance, core_connections): + """ + Sets the minimum number of connections per Session that will be opened + for each host with :class:`~.HostDistance` equal to `host_distance`. + The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for + :attr:`~HostDistance.REMOTE`. + + If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this + is not supported (there is always one connection per host, unless + the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) + and using this will result in an :exc:`~.UnsupporteOperation`. + """ + if self.protocol_version >= 3: + raise UnsupportedOperation( + "Cluster.set_core_connections_per_host() only has an effect " + "when using protocol_version 1 or 2.") + old = self._core_connections_per_host[host_distance] + self._core_connections_per_host[host_distance] = core_connections + if old < core_connections: + self._ensure_core_connections() + + def get_max_connections_per_host(self, host_distance): + """ + Gets the maximum number of connections per Session that will be opened + for each host with :class:`~.HostDistance` equal to `host_distance`. + The default is 8 for :attr:`~HostDistance.LOCAL` and 2 for + :attr:`~HostDistance.REMOTE`. + + This property is ignored if :attr:`~.Cluster.protocol_version` is + 3 or higher. + """ + return self._max_connections_per_host[host_distance] + + def set_max_connections_per_host(self, host_distance, max_connections): + """ + Sets the maximum number of connections per Session that will be opened + for each host with :class:`~.HostDistance` equal to `host_distance`. + The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for + :attr:`~HostDistance.REMOTE`. + + If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this + is not supported (there is always one connection per host, unless + the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) + and using this will result in an :exc:`~.UnsupporteOperation`. + """ + if self.protocol_version >= 3: + raise UnsupportedOperation( + "Cluster.set_max_connections_per_host() only has an effect " + "when using protocol_version 1 or 2.") + self._max_connections_per_host[host_distance] = max_connections + + def connection_factory(self, address, *args, **kwargs): + """ + Called to create a new connection with proper configuration. + Intended for internal use only. + """ + kwargs = self._make_connection_kwargs(address, kwargs) + return self.connection_class.factory(address, *args, **kwargs) + + def _make_connection_factory(self, host, *args, **kwargs): + kwargs = self._make_connection_kwargs(host.address, kwargs) + return partial(self.connection_class.factory, host.address, *args, **kwargs) + + def _make_connection_kwargs(self, address, kwargs_dict): + if self._auth_provider_callable: + kwargs_dict['authenticator'] = self._auth_provider_callable(address) + + kwargs_dict['port'] = self.port + kwargs_dict['compression'] = self.compression + kwargs_dict['sockopts'] = self.sockopts + kwargs_dict['ssl_options'] = self.ssl_options + kwargs_dict['cql_version'] = self.cql_version + kwargs_dict['protocol_version'] = self.protocol_version + kwargs_dict['user_type_map'] = self._user_types + + return kwargs_dict + + def connect(self, keyspace=None): + """ + Creates and returns a new :class:`~.Session` object. If `keyspace` + is specified, that keyspace will be the default keyspace for + operations on the ``Session``. + """ + with self._lock: + if self.is_shutdown: + raise Exception("Cluster is already shut down") + + if not self._is_setup: + log.debug("Connecting to cluster, contact points: %s; protocol version: %s", + self.contact_points, self.protocol_version) + self.connection_class.initialize_reactor() + atexit.register(partial(_shutdown_cluster, self)) + for address in self.contact_points: + host, new = self.add_host(address, signal=False) + if new: + host.set_up() + for listener in self.listeners: + listener.on_add(host) + + self.load_balancing_policy.populate( + weakref.proxy(self), self.metadata.all_hosts()) + + try: + self.control_connection.connect() + log.debug("Control connection created") + except Exception: + log.exception("Control connection failed to connect, " + "shutting down Cluster:") + self.shutdown() + raise + + self.load_balancing_policy.check_supported() + + if self.idle_heartbeat_interval: + self._idle_heartbeat = ConnectionHeartbeat(self.idle_heartbeat_interval, self.get_connection_holders) + self._is_setup = True + + session = self._new_session() + if keyspace: + session.set_keyspace(keyspace) + return session + + def get_connection_holders(self): + holders = [] + for s in self.sessions: + holders.extend(s.get_pools()) + holders.append(self.control_connection) + return holders + + def shutdown(self): + """ + Closes all sessions and connection associated with this Cluster. + To ensure all connections are properly closed, **you should always + call shutdown() on a Cluster instance when you are done with it**. + + Once shutdown, a Cluster should not be used for any purpose. + """ + with self._lock: + if self.is_shutdown: + return + else: + self.is_shutdown = True + + if self._idle_heartbeat: + self._idle_heartbeat.stop() + + self.scheduler.shutdown() + + self.control_connection.shutdown() + + for session in self.sessions: + session.shutdown() + + self.executor.shutdown() + + def _new_session(self): + session = Session(self, self.metadata.all_hosts()) + for keyspace, type_map in six.iteritems(self._user_types): + for udt_name, klass in six.iteritems(type_map): + session.user_type_registered(keyspace, udt_name, klass) + self.sessions.add(session) + return session + + def _cleanup_failed_on_up_handling(self, host): + self.load_balancing_policy.on_down(host) + self.control_connection.on_down(host) + for session in self.sessions: + session.remove_pool(host) + + self._start_reconnector(host, is_host_addition=False) + + def _on_up_future_completed(self, host, futures, results, lock, finished_future): + with lock: + futures.discard(finished_future) + + try: + results.append(finished_future.result()) + except Exception as exc: + results.append(exc) + + if futures: + return + + try: + # all futures have completed at this point + for exc in [f for f in results if isinstance(f, Exception)]: + log.error("Unexpected failure while marking node %s up:", host, exc_info=exc) + self._cleanup_failed_on_up_handling(host) + return + + if not all(results): + log.debug("Connection pool could not be created, not marking node %s up", host) + self._cleanup_failed_on_up_handling(host) + return + + log.info("Connection pools established for node %s", host) + # mark the host as up and notify all listeners + host.set_up() + for listener in self.listeners: + listener.on_up(host) + finally: + with host.lock: + host._currently_handling_node_up = False + + # see if there are any pools to add or remove now that the host is marked up + for session in self.sessions: + session.update_created_pools() + + def on_up(self, host): + """ + Intended for internal use only. + """ + if self.is_shutdown: + return + + log.debug("Waiting to acquire lock for handling up status of node %s", host) + with host.lock: + if host._currently_handling_node_up: + log.debug("Another thread is already handling up status of node %s", host) + return + + if host.is_up: + log.debug("Host %s was already marked up", host) + return + + host._currently_handling_node_up = True + log.debug("Starting to handle up status of node %s", host) + + have_future = False + futures = set() + try: + log.info("Host %s may be up; will prepare queries and open connection pool", host) + + reconnector = host.get_and_set_reconnection_handler(None) + if reconnector: + log.debug("Now that host %s is up, cancelling the reconnection handler", host) + reconnector.cancel() + + self._prepare_all_queries(host) + log.debug("Done preparing all queries for host %s, ", host) + + for session in self.sessions: + session.remove_pool(host) + + log.debug("Signalling to load balancing policy that host %s is up", host) + self.load_balancing_policy.on_up(host) + + log.debug("Signalling to control connection that host %s is up", host) + self.control_connection.on_up(host) + + log.debug("Attempting to open new connection pools for host %s", host) + futures_lock = Lock() + futures_results = [] + callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) + for session in self.sessions: + future = session.add_or_renew_pool(host, is_host_addition=False) + if future is not None: + have_future = True + future.add_done_callback(callback) + futures.add(future) + except Exception: + log.exception("Unexpected failure handling node %s being marked up:", host) + for future in futures: + future.cancel() + + self._cleanup_failed_on_up_handling(host) + + with host.lock: + host._currently_handling_node_up = False + raise + else: + if not have_future: + with host.lock: + host._currently_handling_node_up = False + + # for testing purposes + return futures + + def _start_reconnector(self, host, is_host_addition): + if self.load_balancing_policy.distance(host) == HostDistance.IGNORED: + return + + schedule = self.reconnection_policy.new_schedule() + + # in order to not hold references to this Cluster open and prevent + # proper shutdown when the program ends, we'll just make a closure + # of the current Cluster attributes to create new Connections with + conn_factory = self._make_connection_factory(host) + + reconnector = _HostReconnectionHandler( + host, conn_factory, is_host_addition, self.on_add, self.on_up, + self.scheduler, schedule, host.get_and_set_reconnection_handler, + new_handler=None) + + old_reconnector = host.get_and_set_reconnection_handler(reconnector) + if old_reconnector: + log.debug("Old host reconnector found for %s, cancelling", host) + old_reconnector.cancel() + + log.debug("Starting reconnector for host %s", host) + reconnector.start() + + @run_in_executor + def on_down(self, host, is_host_addition, expect_host_to_be_down=False): + """ + Intended for internal use only. + """ + if self.is_shutdown: + return + + with host.lock: + if (not host.is_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): + return + + host.set_down() + + log.warning("Host %s has been marked down", host) + + self.load_balancing_policy.on_down(host) + self.control_connection.on_down(host) + for session in self.sessions: + session.on_down(host) + + for listener in self.listeners: + listener.on_down(host) + + self._start_reconnector(host, is_host_addition) + + def on_add(self, host, refresh_nodes=True): + if self.is_shutdown: + return + + log.debug("Handling new host %r and notifying listeners", host) + + distance = self.load_balancing_policy.distance(host) + if distance != HostDistance.IGNORED: + self._prepare_all_queries(host) + log.debug("Done preparing queries for new host %r", host) + + self.load_balancing_policy.on_add(host) + self.control_connection.on_add(host, refresh_nodes) + + if distance == HostDistance.IGNORED: + log.debug("Not adding connection pool for new host %r because the " + "load balancing policy has marked it as IGNORED", host) + self._finalize_add(host) + return + + futures_lock = Lock() + futures_results = [] + futures = set() + + def future_completed(future): + with futures_lock: + futures.discard(future) + + try: + futures_results.append(future.result()) + except Exception as exc: + futures_results.append(exc) + + if futures: + return + + log.debug('All futures have completed for added host %s', host) + + for exc in [f for f in futures_results if isinstance(f, Exception)]: + log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) + return + + if not all(futures_results): + log.warning("Connection pool could not be created, not marking node %s up", host) + return + + self._finalize_add(host) + + have_future = False + for session in self.sessions: + future = session.add_or_renew_pool(host, is_host_addition=True) + if future is not None: + have_future = True + futures.add(future) + future.add_done_callback(future_completed) + + if not have_future: + self._finalize_add(host) + + def _finalize_add(self, host): + # mark the host as up and notify all listeners + host.set_up() + for listener in self.listeners: + listener.on_add(host) + + # see if there are any pools to add or remove now that the host is marked up + for session in self.sessions: + session.update_created_pools() + + def on_remove(self, host): + if self.is_shutdown: + return + + log.debug("Removing host %s", host) + host.set_down() + self.load_balancing_policy.on_remove(host) + for session in self.sessions: + session.on_remove(host) + for listener in self.listeners: + listener.on_remove(host) + self.control_connection.on_remove(host) + + def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): + is_down = host.signal_connection_failure(connection_exc) + if is_down: + self.on_down(host, is_host_addition, expect_host_to_be_down) + return is_down + + def add_host(self, address, datacenter=None, rack=None, signal=True, refresh_nodes=True): + """ + Called when adding initial contact points and when the control + connection subsequently discovers a new node. + Returns a Host instance, and a flag indicating whether it was new in + the metadata. + Intended for internal use only. + """ + host, new = self.metadata.add_or_return_host(Host(address, self.conviction_policy_factory, datacenter, rack)) + if new and signal: + log.info("New Cassandra host %r discovered", host) + self.on_add(host, refresh_nodes) + + return host, new + + def remove_host(self, host): + """ + Called when the control connection observes that a node has left the + ring. Intended for internal use only. + """ + if host and self.metadata.remove_host(host): + log.info("Cassandra host %s removed", host) + self.on_remove(host) + + def register_listener(self, listener): + """ + Adds a :class:`cassandra.policies.HostStateListener` subclass instance to + the list of listeners to be notified when a host is added, removed, + marked up, or marked down. + """ + with self._listener_lock: + self._listeners.add(listener) + + def unregister_listener(self, listener): + """ Removes a registered listener. """ + with self._listener_lock: + self._listeners.remove(listener) + + @property + def listeners(self): + with self._listener_lock: + return self._listeners.copy() + + def _ensure_core_connections(self): + """ + If any host has fewer than the configured number of core connections + open, attempt to open connections until that number is met. + """ + for session in self.sessions: + for pool in session._pools.values(): + pool.ensure_core_connections() + + def refresh_schema(self, keyspace=None, table=None, usertype=None, max_schema_agreement_wait=None): + """ + Synchronously refresh the schema metadata. + + By default, the timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait` + and :attr:`~.Cluster.control_connection_timeout`. + + Passing max_schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`. + + Setting max_schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately. + + An Exception is raised if schema refresh fails for any reason. + """ + if not self.control_connection.refresh_schema(keyspace, table, usertype, max_schema_agreement_wait): + raise Exception("Schema was not refreshed. See log for details.") + + def submit_schema_refresh(self, keyspace=None, table=None, usertype=None): + """ + Schedule a refresh of the internal representation of the current + schema for this cluster. If `keyspace` is specified, only that + keyspace will be refreshed, and likewise for `table`. + """ + return self.executor.submit( + self.control_connection.refresh_schema, keyspace, table, usertype) + + def refresh_nodes(self): + """ + Synchronously refresh the node list and token metadata + + An Exception is raised if node refresh fails for any reason. + """ + if not self.control_connection.refresh_node_list_and_token_map(): + raise Exception("Node list was not refreshed. See log for details.") + + def set_meta_refresh_enabled(self, enabled): + """ + Sets a flag to enable (True) or disable (False) all metadata refresh queries. + This applies to both schema and node topology. + + Disabling this is useful to minimize refreshes during multiple changes. + + Meta refresh must be enabled for the driver to become aware of any cluster + topology changes or schema updates. + """ + self.control_connection.set_meta_refresh_enabled(bool(enabled)) + + def _prepare_all_queries(self, host): + if not self._prepared_statements: + return + + log.debug("Preparing all known prepared statements against host %s", host) + connection = None + try: + connection = self.connection_factory(host.address) + try: + self.control_connection.wait_for_schema_agreement(connection) + except Exception: + log.debug("Error waiting for schema agreement before preparing statements against host %s", host, exc_info=True) + + statements = self._prepared_statements.values() + for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): + if keyspace is not None: + connection.set_keyspace_blocking(keyspace) + + # prepare 10 statements at a time + ks_statements = list(ks_statements) + chunks = [] + for i in range(0, len(ks_statements), 10): + chunks.append(ks_statements[i:i + 10]) + + for ks_chunk in chunks: + messages = [PrepareMessage(query=s.query_string) for s in ks_chunk] + # TODO: make this timeout configurable somehow? + responses = connection.wait_for_responses(*messages, timeout=5.0) + for response in responses: + if (not isinstance(response, ResultMessage) or + response.kind != RESULT_KIND_PREPARED): + log.debug("Got unexpected response when preparing " + "statement on host %s: %r", host, response) + + log.debug("Done preparing all known prepared statements against host %s", host) + except OperationTimedOut as timeout: + log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout) + except (ConnectionException, socket.error) as exc: + log.warning("Error trying to prepare all statements on host %s: %r", host, exc) + except Exception: + log.exception("Error trying to prepare all statements on host %s", host) + finally: + if connection: + connection.close() + + def prepare_on_all_sessions(self, query_id, prepared_statement, excluded_host): + with self._prepared_statement_lock: + self._prepared_statements[query_id] = prepared_statement + for session in self.sessions: + session.prepare_on_all_hosts(prepared_statement.query_string, excluded_host) + + +class Session(object): + """ + A collection of connection pools for each host in the cluster. + Instances of this class should not be created directly, only + using :meth:`.Cluster.connect()`. + + Queries and statements can be executed through ``Session`` instances + using the :meth:`~.Session.execute()` and :meth:`~.Session.execute_async()` + methods. + + Example usage:: + + >>> session = cluster.connect() + >>> session.set_keyspace("mykeyspace") + >>> session.execute("SELECT * FROM mycf") + + """ + + cluster = None + hosts = None + keyspace = None + is_shutdown = False + + row_factory = staticmethod(named_tuple_factory) + """ + The format to return row results in. By default, each + returned row will be a named tuple. You can alternatively + use any of the following: + + - :func:`cassandra.query.tuple_factory` - return a result row as a tuple + - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple + - :func:`cassandra.query.dict_factory` - return a result row as a dict + - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict + + """ + + default_timeout = 10.0 + """ + A default timeout, measured in seconds, for queries executed through + :meth:`.execute()` or :meth:`.execute_async()`. This default may be + overridden with the `timeout` parameter for either of those methods + or the `timeout` parameter for :meth:`.ResponseFuture.result()`. + + Setting this to :const:`None` will cause no timeouts to be set by default. + + **Important**: This timeout currently has no effect on callbacks registered + on a :class:`~.ResponseFuture` through :meth:`.ResponseFuture.add_callback` or + :meth:`.ResponseFuture.add_errback`; even if a query exceeds this default + timeout, neither the registered callback or errback will be called. + + .. versionadded:: 2.0.0 + """ + + default_consistency_level = ConsistencyLevel.ONE + """ + The default :class:`~ConsistencyLevel` for operations executed through + this session. This default may be overridden by setting the + :attr:`~.Statement.consistency_level` on individual statements. + + .. versionadded:: 1.2.0 + """ + + max_trace_wait = 2.0 + """ + The maximum amount of time (in seconds) the driver will wait for trace + details to be populated server-side for a query before giving up. + If the `trace` parameter for :meth:`~.execute()` or :meth:`~.execute_async()` + is :const:`True`, the driver will repeatedly attempt to fetch trace + details for the query (using exponential backoff) until this limit is + hit. If the limit is passed, an error will be logged and the + :attr:`.Statement.trace` will be left as :const:`None`. """ + + default_fetch_size = 5000 + """ + By default, this many rows will be fetched at a time. Setting + this to :const:`None` will disable automatic paging for large query + results. The fetch size can be also specified per-query through + :attr:`.Statement.fetch_size`. + + This only takes effect when protocol version 2 or higher is used. + See :attr:`.Cluster.protocol_version` for details. + + .. versionadded:: 2.0.0 + """ + + use_client_timestamp = True + """ + When using protocol version 3 or higher, write timestamps may be supplied + client-side at the protocol level. (Normally they are generated + server-side by the coordinator node.) Note that timestamps specified + within a CQL query will override this timestamp. + + .. versionadded:: 2.1.0 + """ + + encoder = None + """ + A :class:`~cassandra.encoder.Encoder` instance that will be used when + formatting query parameters for non-prepared statements. This is not used + for prepared statements (because prepared statements give the driver more + information about what CQL types are expected, allowing it to accept a + wider range of python types). + + The encoder uses a mapping from python types to encoder methods (for + specific CQL types). This mapping can be be modified by users as they see + fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping + values if possible, because they take precautions to avoid injections and + properly sanitize data. + + Example:: + + cluster = Cluster() + session = cluster.connect("mykeyspace") + session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple + + session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple)") + session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')]) + + .. versionadded:: 2.1.0 + """ + + _lock = None + _pools = None + _load_balancer = None + _metrics = None + _protocol_version = None + + def __init__(self, cluster, hosts): + self.cluster = cluster + self.hosts = hosts + + self._lock = RLock() + self._pools = {} + self._load_balancer = cluster.load_balancing_policy + self._metrics = cluster.metrics + self._protocol_version = self.cluster.protocol_version + + self.encoder = Encoder() + + # create connection pools in parallel + futures = [] + for host in hosts: + future = self.add_or_renew_pool(host, is_host_addition=False) + if future is not None: + futures.append(future) + + for future in futures: + future.result() + + def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False): + """ + Execute the given query and synchronously wait for the response. + + If an error is encountered while executing the query, an Exception + will be raised. + + `query` may be a query string or an instance of :class:`cassandra.query.Statement`. + + `parameters` may be a sequence or dict of parameters to bind. If a + sequence is used, ``%s`` should be used the placeholder for each + argument. If a dict is used, ``%(name)s`` style placeholders must + be used. + + `timeout` should specify a floating-point timeout (in seconds) after + which an :exc:`.OperationTimedOut` exception will be raised if the query + has not completed. If not set, the timeout defaults to + :attr:`~.Session.default_timeout`. If set to :const:`None`, there is + no timeout. + + If `trace` is set to :const:`True`, an attempt will be made to + fetch the trace details and attach them to the `query`'s + :attr:`~.Statement.trace` attribute in the form of a :class:`.QueryTrace` + instance. This requires that `query` be a :class:`.Statement` subclass + instance and not just a string. If there is an error fetching the + trace details, the :attr:`~.Statement.trace` attribute will be left as + :const:`None`. + """ + if timeout is _NOT_SET: + timeout = self.default_timeout + + if trace and not isinstance(query, Statement): + raise TypeError( + "The query argument must be an instance of a subclass of " + "cassandra.query.Statement when trace=True") + + future = self.execute_async(query, parameters, trace) + try: + result = future.result(timeout) + finally: + if trace: + try: + query.trace = future.get_query_trace(self.max_trace_wait) + except Exception: + log.exception("Unable to fetch query trace:") + + return result + + def execute_async(self, query, parameters=None, trace=False): + """ + Execute the given query and return a :class:`~.ResponseFuture` object + which callbacks may be attached to for asynchronous response + delivery. You may also call :meth:`~.ResponseFuture.result()` + on the :class:`.ResponseFuture` to syncronously block for results at + any time. + + If `trace` is set to :const:`True`, you may call + :meth:`.ResponseFuture.get_query_trace()` after the request + completes to retrieve a :class:`.QueryTrace` instance. + + Example usage:: + + >>> session = cluster.connect() + >>> future = session.execute_async("SELECT * FROM mycf") + + >>> def log_results(results): + ... for row in results: + ... log.info("Results: %s", row) + + >>> def log_error(exc): + >>> log.error("Operation failed: %s", exc) + + >>> future.add_callbacks(log_results, log_error) + + Async execution with blocking wait for results:: + + >>> future = session.execute_async("SELECT * FROM mycf") + >>> # do other stuff... + + >>> try: + ... results = future.result() + ... except Exception: + ... log.exception("Operation failed:") + + """ + future = self._create_response_future(query, parameters, trace) + future.send_request() + return future + + def _create_response_future(self, query, parameters, trace): + """ Returns the ResponseFuture before calling send_request() on it """ + + prepared_statement = None + + if isinstance(query, six.string_types): + query = SimpleStatement(query) + elif isinstance(query, PreparedStatement): + query = query.bind(parameters) + + cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level + fetch_size = query.fetch_size + if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2: + fetch_size = self.default_fetch_size + elif self._protocol_version == 1: + fetch_size = None + + if self._protocol_version >= 3 and self.use_client_timestamp: + timestamp = int(time.time() * 1e6) + else: + timestamp = None + + if isinstance(query, SimpleStatement): + query_string = query.query_string + if six.PY2 and isinstance(query_string, six.text_type): + query_string = query_string.encode('utf-8') + if parameters: + query_string = bind_params(query_string, parameters, self.encoder) + message = QueryMessage( + query_string, cl, query.serial_consistency_level, + fetch_size, timestamp=timestamp) + elif isinstance(query, BoundStatement): + message = ExecuteMessage( + query.prepared_statement.query_id, query.values, cl, + query.serial_consistency_level, fetch_size, + timestamp=timestamp) + prepared_statement = query.prepared_statement + elif isinstance(query, BatchStatement): + if self._protocol_version < 2: + raise UnsupportedOperation( + "BatchStatement execution is only supported with protocol version " + "2 or higher (supported in Cassandra 2.0 and higher). Consider " + "setting Cluster.protocol_version to 2 to support this operation.") + message = BatchMessage( + query.batch_type, query._statements_and_parameters, cl, + query.serial_consistency_level, timestamp) + + if trace: + message.tracing = True + + return ResponseFuture( + self, message, query, self.default_timeout, metrics=self._metrics, + prepared_statement=prepared_statement) + + def prepare(self, query): + """ + Prepares a query string, returing a :class:`~cassandra.query.PreparedStatement` + instance which can be used as follows:: + + >>> session = cluster.connect("mykeyspace") + >>> query = "INSERT INTO users (id, name, age) VALUES (?, ?, ?)" + >>> prepared = session.prepare(query) + >>> session.execute(prepared, (user.id, user.name, user.age)) + + Or you may bind values to the prepared statement ahead of time:: + + >>> prepared = session.prepare(query) + >>> bound_stmt = prepared.bind((user.id, user.name, user.age)) + >>> session.execute(bound_stmt) + + Of course, prepared statements may (and should) be reused:: + + >>> prepared = session.prepare(query) + >>> for user in users: + ... bound = prepared.bind((user.id, user.name, user.age)) + ... session.execute(bound) + + **Important**: PreparedStatements should be prepared only once. + Preparing the same query more than once will likely affect performance. + """ + message = PrepareMessage(query=query) + future = ResponseFuture(self, message, query=None) + try: + future.send_request() + query_id, column_metadata = future.result(self.default_timeout) + except Exception: + log.exception("Error preparing query:") + raise + + prepared_statement = PreparedStatement.from_message( + query_id, column_metadata, self.cluster.metadata, query, self.keyspace, + self._protocol_version) + + host = future._current_host + try: + self.cluster.prepare_on_all_sessions(query_id, prepared_statement, host) + except Exception: + log.exception("Error preparing query on all hosts:") + + return prepared_statement + + def prepare_on_all_hosts(self, query, excluded_host): + """ + Prepare the given query on all hosts, excluding ``excluded_host``. + Intended for internal use only. + """ + futures = [] + for host in self._pools.keys(): + if host != excluded_host and host.is_up: + future = ResponseFuture(self, PrepareMessage(query=query), None) + + # we don't care about errors preparing against specific hosts, + # since we can always prepare them as needed when the prepared + # statement is used. Just log errors and continue on. + try: + request_id = future._query(host) + except Exception: + log.exception("Error preparing query for host %s:", host) + continue + + if request_id is None: + # the error has already been logged by ResponsFuture + log.debug("Failed to prepare query for host %s: %r", + host, future._errors.get(host)) + continue + + futures.append((host, future)) + + for host, future in futures: + try: + future.result(self.default_timeout) + except Exception: + log.exception("Error preparing query for host %s:", host) + + def shutdown(self): + """ + Close all connections. ``Session`` instances should not be used + for any purpose after being shutdown. + """ + with self._lock: + if self.is_shutdown: + return + else: + self.is_shutdown = True + + for pool in self._pools.values(): + pool.shutdown() + + def add_or_renew_pool(self, host, is_host_addition): + """ + For internal use only. + """ + distance = self._load_balancer.distance(host) + if distance == HostDistance.IGNORED: + return None + + def run_add_or_renew_pool(): + try: + if self._protocol_version >= 3: + new_pool = HostConnection(host, distance, self) + else: + new_pool = HostConnectionPool(host, distance, self) + except AuthenticationFailed as auth_exc: + conn_exc = ConnectionException(str(auth_exc), host=host) + self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) + return False + except Exception as conn_exc: + log.warning("Failed to create connection pool for new host %s:", + host, exc_info=conn_exc) + # the host itself will still be marked down, so we need to pass + # a special flag to make sure the reconnector is created + self.cluster.signal_connection_failure( + host, conn_exc, is_host_addition, expect_host_to_be_down=True) + return False + + previous = self._pools.get(host) + self._pools[host] = new_pool + log.debug("Added pool for host %s to session", host) + if previous: + previous.shutdown() + + return True + + return self.submit(run_add_or_renew_pool) + + def remove_pool(self, host): + pool = self._pools.pop(host, None) + if pool: + log.debug("Removed connection pool for %r", host) + return self.submit(pool.shutdown) + else: + return None + + def update_created_pools(self): + """ + When the set of live nodes change, the loadbalancer will change its + mind on host distances. It might change it on the node that came/left + but also on other nodes (for instance, if a node dies, another + previously ignored node may be now considered). + + This method ensures that all hosts for which a pool should exist + have one, and hosts that shouldn't don't. + + For internal use only. + """ + for host in self.cluster.metadata.all_hosts(): + distance = self._load_balancer.distance(host) + pool = self._pools.get(host) + + if not pool or pool.is_shutdown: + if distance != HostDistance.IGNORED and host.is_up: + self.add_or_renew_pool(host, False) + elif distance != pool.host_distance: + # the distance has changed + if distance == HostDistance.IGNORED: + self.remove_pool(host) + else: + pool.host_distance = distance + + def on_down(self, host): + """ + Called by the parent Cluster instance when a node is marked down. + Only intended for internal use. + """ + future = self.remove_pool(host) + if future: + future.add_done_callback(lambda f: self.update_created_pools()) + + def on_remove(self, host): + """ Internal """ + self.on_down(host) + + def set_keyspace(self, keyspace): + """ + Set the default keyspace for all queries made through this Session. + This operation blocks until complete. + """ + self.execute('USE %s' % (protect_name(keyspace),)) + + def _set_keyspace_for_all_pools(self, keyspace, callback): + """ + Asynchronously sets the keyspace on all pools. When all + pools have set all of their connections, `callback` will be + called with a dictionary of all errors that occurred, keyed + by the `Host` that they occurred against. + """ + self.keyspace = keyspace + + remaining_callbacks = set(self._pools.values()) + errors = {} + + if not remaining_callbacks: + callback(errors) + return + + def pool_finished_setting_keyspace(pool, host_errors): + remaining_callbacks.remove(pool) + if host_errors: + errors[pool.host] = host_errors + + if not remaining_callbacks: + callback(host_errors) + + for pool in self._pools.values(): + pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace) + + def user_type_registered(self, keyspace, user_type, klass): + """ + Called by the parent Cluster instance when the user registers a new + mapping from a user-defined type to a class. Intended for internal + use only. + """ + try: + ks_meta = self.cluster.metadata.keyspaces[keyspace] + except KeyError: + raise UserTypeDoesNotExist( + 'Keyspace %s does not exist or has not been discovered by the driver' % (keyspace,)) + + try: + type_meta = ks_meta.user_types[user_type] + except KeyError: + raise UserTypeDoesNotExist( + 'User type %s does not exist in keyspace %s' % (user_type, keyspace)) + + def encode(val): + return '{ %s }' % ' , '.join('%s : %s' % ( + field_name, + self.encoder.cql_encode_all_types(getattr(val, field_name, None)) + ) for field_name in type_meta.field_names) + + self.encoder.mapping[klass] = encode + + def submit(self, fn, *args, **kwargs): + """ Internal """ + if not self.is_shutdown: + return self.cluster.executor.submit(fn, *args, **kwargs) + + def get_pool_state(self): + return dict((host, pool.get_state()) for host, pool in self._pools.items()) + + def get_pools(self): + return self._pools.values() + + +class UserTypeDoesNotExist(Exception): + """ + An attempt was made to use a user-defined type that does not exist. + + .. versionadded:: 2.1.0 + """ + pass + + +class _ControlReconnectionHandler(_ReconnectionHandler): + """ + Internal + """ + + def __init__(self, control_connection, *args, **kwargs): + _ReconnectionHandler.__init__(self, *args, **kwargs) + self.control_connection = weakref.proxy(control_connection) + + def try_reconnect(self): + # we'll either get back a new Connection or a NoHostAvailable + return self.control_connection._reconnect_internal() + + def on_reconnection(self, connection): + self.control_connection._set_new_connection(connection) + + def on_exception(self, exc, next_delay): + # TODO only overridden to add logging, so add logging + if isinstance(exc, AuthenticationFailed): + return False + else: + log.debug("Error trying to reconnect control connection: %r", exc) + return True + + +def _watch_callback(obj_weakref, method_name, *args, **kwargs): + """ + A callback handler for the ControlConnection that tolerates + weak references. + """ + obj = obj_weakref() + if obj is None: + return + getattr(obj, method_name)(*args, **kwargs) + + +def _clear_watcher(conn, expiring_weakref): + """ + Called when the ControlConnection object is about to be finalized. + This clears watchers on the underlying Connection object. + """ + try: + conn.control_conn_disposed() + except ReferenceError: + pass + + +class ControlConnection(object): + """ + Internal + """ + + _SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces" + _SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies" + _SELECT_COLUMNS = "SELECT * FROM system.schema_columns" + _SELECT_USERTYPES = "SELECT * FROM system.schema_usertypes" + _SELECT_TRIGGERS = "SELECT * FROM system.schema_triggers" + + _SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address, schema_version FROM system.peers" + _SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner, schema_version FROM system.local WHERE key='local'" + + _SELECT_SCHEMA_PEERS = "SELECT peer, rpc_address, schema_version FROM system.peers" + _SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'" + + _is_shutdown = False + _timeout = None + _protocol_version = None + + _schema_event_refresh_window = None + _topology_event_refresh_window = None + + _meta_refresh_enabled = True + + # for testing purposes + _time = time + + def __init__(self, cluster, timeout, + schema_event_refresh_window, + topology_event_refresh_window): + # use a weak reference to allow the Cluster instance to be GC'ed (and + # shutdown) since implementing __del__ disables the cycle detector + self._cluster = weakref.proxy(cluster) + self._connection = None + self._timeout = timeout + + self._schema_event_refresh_window = schema_event_refresh_window + self._topology_event_refresh_window = topology_event_refresh_window + + self._lock = RLock() + self._schema_agreement_lock = Lock() + + self._reconnection_handler = None + self._reconnection_lock = RLock() + + def connect(self): + if self._is_shutdown: + return + + self._protocol_version = self._cluster.protocol_version + self._set_new_connection(self._reconnect_internal()) + + def _set_new_connection(self, conn): + """ + Replace existing connection (if there is one) and close it. + """ + with self._lock: + old = self._connection + self._connection = conn + + if old: + log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) + old.close() + + def _reconnect_internal(self): + """ + Tries to connect to each host in the query plan until one succeeds + or every attempt fails. If successful, a new Connection will be + returned. Otherwise, :exc:`NoHostAvailable` will be raised + with an "errors" arg that is a dict mapping host addresses + to the exception that was raised when an attempt was made to open + a connection to that host. + """ + errors = {} + for host in self._cluster.load_balancing_policy.make_query_plan(): + try: + return self._try_connect(host) + except ConnectionException as exc: + errors[host.address] = exc + log.warning("[control connection] Error connecting to %s:", host, exc_info=True) + self._cluster.signal_connection_failure(host, exc, is_host_addition=False) + except Exception as exc: + errors[host.address] = exc + log.warning("[control connection] Error connecting to %s:", host, exc_info=True) + + raise NoHostAvailable("Unable to connect to any servers", errors) + + def _try_connect(self, host): + """ + Creates a new Connection, registers for pushed events, and refreshes + node/token and schema metadata. + """ + log.debug("[control connection] Opening new connection to %s", host) + connection = self._cluster.connection_factory(host.address, is_control_connection=True) + + log.debug("[control connection] Established new connection %r, " + "registering watchers and refreshing schema and topology", + connection) + + # use weak references in both directions + # _clear_watcher will be called when this ControlConnection is about to be finalized + # _watch_callback will get the actual callback from the Connection and relay it to + # this object (after a dereferencing a weakref) + self_weakref = weakref.ref(self, callback=partial(_clear_watcher, weakref.proxy(connection))) + try: + connection.register_watchers({ + "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), + "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), + "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') + }, register_timeout=self._timeout) + + peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=ConsistencyLevel.ONE) + local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=ConsistencyLevel.ONE) + shared_results = connection.wait_for_responses( + peers_query, local_query, timeout=self._timeout) + + self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results) + self._refresh_schema(connection, preloaded_results=shared_results) + if not self._cluster.metadata.keyspaces: + log.warning("[control connection] No schema built on connect; retrying without wait for schema agreement") + self._refresh_schema(connection, preloaded_results=shared_results, schema_agreement_wait=0) + except Exception: + connection.close() + raise + + return connection + + def reconnect(self): + if self._is_shutdown: + return + + self._submit(self._reconnect) + + def _reconnect(self): + log.debug("[control connection] Attempting to reconnect") + try: + self._set_new_connection(self._reconnect_internal()) + except NoHostAvailable: + # make a retry schedule (which includes backoff) + schedule = self.cluster.reconnection_policy.new_schedule() + + with self._reconnection_lock: + + # cancel existing reconnection attempts + if self._reconnection_handler: + self._reconnection_handler.cancel() + + # when a connection is successfully made, _set_new_connection + # will be called with the new connection and then our + # _reconnection_handler will be cleared out + self._reconnection_handler = _ControlReconnectionHandler( + self, self._cluster.scheduler, schedule, + self._get_and_set_reconnection_handler, + new_handler=None) + self._reconnection_handler.start() + except Exception: + log.debug("[control connection] error reconnecting", exc_info=True) + raise + + def _get_and_set_reconnection_handler(self, new_handler): + """ + Called by the _ControlReconnectionHandler when a new connection + is successfully created. Clears out the _reconnection_handler on + this ControlConnection. + """ + with self._reconnection_lock: + old = self._reconnection_handler + self._reconnection_handler = new_handler + return old + + def _submit(self, *args, **kwargs): + try: + if not self._cluster.is_shutdown: + return self._cluster.executor.submit(*args, **kwargs) + except ReferenceError: + pass + return None + + def shutdown(self): + with self._lock: + if self._is_shutdown: + return + else: + self._is_shutdown = True + + log.debug("Shutting down control connection") + # stop trying to reconnect (if we are) + if self._reconnection_handler: + self._reconnection_handler.cancel() + + if self._connection: + self._connection.close() + del self._connection + + def refresh_schema(self, keyspace=None, table=None, usertype=None, + schema_agreement_wait=None): + if not self._meta_refresh_enabled: + log.debug("[control connection] Skipping schema refresh because meta refresh is disabled") + return False + + try: + if self._connection: + return self._refresh_schema(self._connection, keyspace, table, usertype, + schema_agreement_wait=schema_agreement_wait) + except ReferenceError: + pass # our weak reference to the Cluster is no good + except Exception: + log.debug("[control connection] Error refreshing schema", exc_info=True) + self._signal_error() + return False + + def _refresh_schema(self, connection, keyspace=None, table=None, usertype=None, + preloaded_results=None, schema_agreement_wait=None): + if self._cluster.is_shutdown: + return False + + assert table is None or usertype is None + + agreed = self.wait_for_schema_agreement(connection, + preloaded_results=preloaded_results, + wait_time=schema_agreement_wait) + if not agreed: + log.debug("Skipping schema refresh due to lack of schema agreement") + return False + + cl = ConsistencyLevel.ONE + if table: + def _handle_results(success, result): + if success: + return dict_factory(*result.results) if result else {} + else: + raise result + + # a particular table changed + where_clause = " WHERE keyspace_name = '%s' AND columnfamily_name = '%s'" % (keyspace, table) + cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl) + col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) + triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl) + (cf_success, cf_result), (col_success, col_result), (triggers_success, triggers_result) \ + = connection.wait_for_responses(cf_query, col_query, triggers_query, timeout=self._timeout, fail_on_error=False) + + log.debug("[control connection] Fetched table info for %s.%s, rebuilding metadata", keyspace, table) + cf_result = _handle_results(cf_success, cf_result) + col_result = _handle_results(col_success, col_result) + + # handle the triggers table not existing in Cassandra 1.2 + if not triggers_success and isinstance(triggers_result, InvalidRequest): + triggers_result = {} + else: + triggers_result = _handle_results(triggers_success, triggers_result) + + self._cluster.metadata.table_changed(keyspace, table, cf_result, col_result, triggers_result) + elif usertype: + # user defined types within this keyspace changed + where_clause = " WHERE keyspace_name = '%s' AND type_name = '%s'" % (keyspace, usertype) + types_query = QueryMessage(query=self._SELECT_USERTYPES + where_clause, consistency_level=cl) + types_result = connection.wait_for_response(types_query) + log.debug("[control connection] Fetched user type info for %s.%s, rebuilding metadata", keyspace, usertype) + types_result = dict_factory(*types_result.results) if types_result.results else {} + self._cluster.metadata.usertype_changed(keyspace, usertype, types_result) + elif keyspace: + # only the keyspace itself changed (such as replication settings) + where_clause = " WHERE keyspace_name = '%s'" % (keyspace,) + ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl) + ks_result = connection.wait_for_response(ks_query) + log.debug("[control connection] Fetched keyspace info for %s, rebuilding metadata", keyspace) + ks_result = dict_factory(*ks_result.results) if ks_result.results else {} + self._cluster.metadata.keyspace_changed(keyspace, ks_result) + else: + # build everything from scratch + queries = [ + QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_COLUMN_FAMILIES, consistency_level=cl), + QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), + QueryMessage(query=self._SELECT_USERTYPES, consistency_level=cl), + QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl) + ] + + responses = connection.wait_for_responses(*queries, timeout=self._timeout, fail_on_error=False) + (ks_success, ks_result), (cf_success, cf_result), \ + (col_success, col_result), (types_success, types_result), \ + (trigger_success, triggers_result) = responses + + if ks_success: + ks_result = dict_factory(*ks_result.results) + else: + raise ks_result + + if cf_success: + cf_result = dict_factory(*cf_result.results) + else: + raise cf_result + + if col_success: + col_result = dict_factory(*col_result.results) + else: + raise col_result + + # if we're connected to Cassandra < 2.0, the trigges table will not exist + if trigger_success: + triggers_result = dict_factory(*triggers_result.results) + else: + if isinstance(triggers_result, InvalidRequest): + log.debug("[control connection] triggers table not found") + triggers_result = {} + elif isinstance(triggers_result, Unauthorized): + log.warning("[control connection] this version of Cassandra does not allow access to schema_triggers metadata with authorization enabled (CASSANDRA-7967); " + "The driver will operate normally, but will not reflect triggers in the local metadata model, or schema strings.") + triggers_result = {} + else: + raise triggers_result + + # if we're connected to Cassandra < 2.1, the usertypes table will not exist + if types_success: + types_result = dict_factory(*types_result.results) if types_result.results else {} + else: + if isinstance(types_result, InvalidRequest): + log.debug("[control connection] user types table not found") + types_result = {} + else: + raise types_result + + log.debug("[control connection] Fetched schema, rebuilding metadata") + self._cluster.metadata.rebuild_schema(ks_result, types_result, cf_result, col_result, triggers_result) + return True + + def refresh_node_list_and_token_map(self, force_token_rebuild=False): + if not self._meta_refresh_enabled: + log.debug("[control connection] Skipping node list refresh because meta refresh is disabled") + return False + + try: + if self._connection: + self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild) + return True + except ReferenceError: + pass # our weak reference to the Cluster is no good + except Exception: + log.debug("[control connection] Error refreshing node list and token map", exc_info=True) + self._signal_error() + return False + + def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, + force_token_rebuild=False): + + if preloaded_results: + log.debug("[control connection] Refreshing node list and token map using preloaded results") + peers_result = preloaded_results[0] + local_result = preloaded_results[1] + else: + log.debug("[control connection] Refreshing node list and token map") + cl = ConsistencyLevel.ONE + peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl) + local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl) + peers_result, local_result = connection.wait_for_responses( + peers_query, local_query, timeout=self._timeout) + + peers_result = dict_factory(*peers_result.results) + + partitioner = None + token_map = {} + + if local_result.results: + local_rows = dict_factory(*(local_result.results)) + local_row = local_rows[0] + cluster_name = local_row["cluster_name"] + self._cluster.metadata.cluster_name = cluster_name + + host = self._cluster.metadata.get_host(connection.host) + if host: + datacenter = local_row.get("data_center") + rack = local_row.get("rack") + self._update_location_info(host, datacenter, rack) + + partitioner = local_row.get("partitioner") + tokens = local_row.get("tokens") + if partitioner and tokens: + token_map[host] = tokens + + # Check metadata.partitioner to see if we haven't built anything yet. If + # every node in the cluster was in the contact points, we won't discover + # any new nodes, so we need this additional check. (See PYTHON-90) + should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None + found_hosts = set() + for row in peers_result: + addr = row.get("rpc_address") + + if not addr or addr in ["0.0.0.0", "::"]: + addr = row.get("peer") + + tokens = row.get("tokens") + if not tokens: + log.warning("Excluding host (%s) with no tokens in system.peers table of %s." % (addr, connection.host)) + continue + + found_hosts.add(addr) + + host = self._cluster.metadata.get_host(addr) + datacenter = row.get("data_center") + rack = row.get("rack") + if host is None: + log.debug("[control connection] Found new host to connect to: %s", addr) + host, _ = self._cluster.add_host(addr, datacenter, rack, signal=True, refresh_nodes=False) + should_rebuild_token_map = True + else: + should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) + + if partitioner and tokens: + token_map[host] = tokens + + for old_host in self._cluster.metadata.all_hosts(): + if old_host.address != connection.host and old_host.address not in found_hosts: + should_rebuild_token_map = True + if old_host.address not in self._cluster.contact_points: + log.debug("[control connection] Found host that has been removed: %r", old_host) + self._cluster.remove_host(old_host) + + log.debug("[control connection] Finished fetching ring info") + if partitioner and should_rebuild_token_map: + log.debug("[control connection] Rebuilding token map due to topology changes") + self._cluster.metadata.rebuild_token_map(partitioner, token_map) + + def _update_location_info(self, host, datacenter, rack): + if host.datacenter == datacenter and host.rack == rack: + return False + + # If the dc/rack information changes, we need to update the load balancing policy. + # For that, we remove and re-add the node against the policy. Not the most elegant, and assumes + # that the policy will update correctly, but in practice this should work. + self._cluster.load_balancing_policy.on_down(host) + host.set_location_info(datacenter, rack) + self._cluster.load_balancing_policy.on_up(host) + return True + + def _handle_topology_change(self, event): + change_type = event["change_type"] + addr, port = event["address"] + if change_type == "NEW_NODE" or change_type == "MOVED_NODE": + if self._topology_event_refresh_window >= 0: + delay = random() * self._topology_event_refresh_window + self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) + elif change_type == "REMOVED_NODE": + host = self._cluster.metadata.get_host(addr) + self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host) + + def _handle_status_change(self, event): + change_type = event["change_type"] + addr, port = event["address"] + host = self._cluster.metadata.get_host(addr) + if change_type == "UP": + delay = 1 + random() * 0.5 # randomness to avoid thundering herd problem on events + if host is None: + # this is the first time we've seen the node + self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) + else: + self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host) + elif change_type == "DOWN": + # Note that there is a slight risk we can receive the event late and thus + # mark the host down even though we already had reconnected successfully. + # But it is unlikely, and don't have too much consequence since we'll try reconnecting + # right away, so we favor the detection to make the Host.is_up more accurate. + if host is not None: + # this will be run by the scheduler + self._cluster.on_down(host, is_host_addition=False) + + def _handle_schema_change(self, event): + if self._schema_event_refresh_window < 0: + return + + keyspace = event.get('keyspace') + table = event.get('table') + usertype = event.get('type') + delay = random() * self._schema_event_refresh_window + self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, keyspace, table, usertype) + + def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): + + total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait + if total_timeout <= 0: + return True + + # Each schema change typically generates two schema refreshes, one + # from the response type and one from the pushed notification. Holding + # a lock is just a simple way to cut down on the number of schema queries + # we'll make. + with self._schema_agreement_lock: + if self._is_shutdown: + return + + if not connection: + connection = self._connection + + if preloaded_results: + log.debug("[control connection] Attempting to use preloaded results for schema agreement") + + peers_result = preloaded_results[0] + local_result = preloaded_results[1] + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host) + if schema_mismatches is None: + return True + + log.debug("[control connection] Waiting for schema agreement") + start = self._time.time() + elapsed = 0 + cl = ConsistencyLevel.ONE + schema_mismatches = None + while elapsed < total_timeout: + peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl) + local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl) + try: + timeout = min(self._timeout, total_timeout - elapsed) + peers_result, local_result = connection.wait_for_responses( + peers_query, local_query, timeout=timeout) + except OperationTimedOut as timeout: + log.debug("[control connection] Timed out waiting for " + "response during schema agreement check: %s", timeout) + elapsed = self._time.time() - start + continue + except ConnectionShutdown: + if self._is_shutdown: + log.debug("[control connection] Aborting wait for schema match due to shutdown") + return None + else: + raise + + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host) + if schema_mismatches is None: + return True + + log.debug("[control connection] Schemas mismatched, trying again") + self._time.sleep(0.2) + elapsed = self._time.time() - start + + log.warning("Node %s is reporting a schema disagreement: %s", + connection.host, schema_mismatches) + return False + + def _get_schema_mismatches(self, peers_result, local_result, local_address): + peers_result = dict_factory(*peers_result.results) + + versions = defaultdict(set) + if local_result.results: + local_row = dict_factory(*local_result.results)[0] + if local_row.get("schema_version"): + versions[local_row.get("schema_version")].add(local_address) + + for row in peers_result: + schema_ver = row.get('schema_version') + if not schema_ver: + continue + + addr = row.get("rpc_address") + if not addr or addr in ["0.0.0.0", "::"]: + addr = row.get("peer") + + peer = self._cluster.metadata.get_host(addr) + if peer and peer.is_up: + versions[schema_ver].add(addr) + + if len(versions) == 1: + log.debug("[control connection] Schemas match") + return None + + return dict((version, list(nodes)) for version, nodes in six.iteritems(versions)) + + def _signal_error(self): + # try just signaling the cluster, as this will trigger a reconnect + # as part of marking the host down + if self._connection and self._connection.is_defunct: + host = self._cluster.metadata.get_host(self._connection.host) + # host may be None if it's already been removed, but that indicates + # that errors have already been reported, so we're fine + if host: + self._cluster.signal_connection_failure( + host, self._connection.last_error, is_host_addition=False) + return + + # if the connection is not defunct or the host already left, reconnect + # manually + self.reconnect() + + def on_up(self, host): + pass + + def on_down(self, host): + + conn = self._connection + if conn and conn.host == host.address and \ + self._reconnection_handler is None: + log.debug("[control connection] Control connection host (%s) is " + "considered down, starting reconnection", host) + # this will result in a task being submitted to the executor to reconnect + self.reconnect() + + def on_add(self, host, refresh_nodes=True): + if refresh_nodes: + self.refresh_node_list_and_token_map(force_token_rebuild=True) + + def on_remove(self, host): + self.refresh_node_list_and_token_map(force_token_rebuild=True) + + def get_connections(self): + c = getattr(self, '_connection', None) + return [c] if c else [] + + def return_connection(self, connection): + if connection is self._connection and (connection.is_defunct or connection.is_closed): + self.reconnect() + + def set_meta_refresh_enabled(self, enabled): + self._meta_refresh_enabled = enabled + + +def _stop_scheduler(scheduler, thread): + try: + if not scheduler.is_shutdown: + scheduler.shutdown() + except ReferenceError: + pass + + thread.join() + + +class _Scheduler(object): + + _queue = None + _scheduled_tasks = None + _executor = None + is_shutdown = False + + def __init__(self, executor): + self._queue = Queue.PriorityQueue() + self._scheduled_tasks = set() + self._executor = executor + + t = Thread(target=self.run, name="Task Scheduler") + t.daemon = True + t.start() + + # although this runs on a daemonized thread, we prefer to stop + # it gracefully to avoid random errors during interpreter shutdown + atexit.register(partial(_stop_scheduler, weakref.proxy(self), t)) + + def shutdown(self): + try: + log.debug("Shutting down Cluster Scheduler") + except AttributeError: + # this can happen on interpreter shutdown + pass + self.is_shutdown = True + self._queue.put_nowait((0, None)) + + def schedule(self, delay, fn, *args): + self._insert_task(delay, (fn, args)) + + def schedule_unique(self, delay, fn, *args): + task = (fn, args) + if task not in self._scheduled_tasks: + self._insert_task(delay, task) + else: + log.debug("Ignoring schedule_unique for already-scheduled task: %r", task) + + def _insert_task(self, delay, task): + if not self.is_shutdown: + run_at = time.time() + delay + self._scheduled_tasks.add(task) + self._queue.put_nowait((run_at, task)) + else: + log.debug("Ignoring scheduled task after shutdown: %r", task) + + def run(self): + while True: + if self.is_shutdown: + return + + try: + while True: + run_at, task = self._queue.get(block=True, timeout=None) + if self.is_shutdown: + log.debug("Not executing scheduled task due to Scheduler shutdown") + return + if run_at <= time.time(): + self._scheduled_tasks.remove(task) + fn, args = task + future = self._executor.submit(fn, *args) + future.add_done_callback(self._log_if_failed) + else: + self._queue.put_nowait((run_at, task)) + break + except Queue.Empty: + pass + + time.sleep(0.1) + + def _log_if_failed(self, future): + exc = future.exception() + if exc: + log.warning( + "An internally scheduled tasked failed with an unhandled exception:", + exc_info=exc) + + +def refresh_schema_and_set_result(keyspace, table, usertype, control_conn, response_future): + try: + if control_conn._meta_refresh_enabled: + log.debug("Refreshing schema in response to schema change. Keyspace: %s; Table: %s, Type: %s", + keyspace, table, usertype) + control_conn._refresh_schema(response_future._connection, keyspace, table, usertype) + else: + log.debug("Skipping schema refresh in response to schema change because meta refresh is disabled; " + "Keyspace: %s; Table: %s, Type: %s", keyspace, table, usertype) + except Exception: + log.exception("Exception refreshing schema in response to schema change:") + response_future.session.submit( + control_conn.refresh_schema, keyspace, table, usertype) + finally: + response_future._set_final_result(None) + + +class ResponseFuture(object): + """ + An asynchronous response delivery mechanism that is returned from calls + to :meth:`.Session.execute_async()`. + + There are two ways for results to be delivered: + - Synchronously, by calling :meth:`.result()` + - Asynchronously, by attaching callback and errback functions via + :meth:`.add_callback()`, :meth:`.add_errback()`, and + :meth:`.add_callbacks()`. + """ + + query = None + """ + The :class:`~.Statement` instance that is being executed through this + :class:`.ResponseFuture`. + """ + + session = None + row_factory = None + message = None + default_timeout = None + + _req_id = None + _final_result = _NOT_SET + _final_exception = None + _query_trace = None + _callbacks = None + _errbacks = None + _current_host = None + _current_pool = None + _connection = None + _query_retries = 0 + _start_time = None + _metrics = None + _paging_state = None + + def __init__(self, session, message, query, default_timeout=None, metrics=None, prepared_statement=None): + self.session = session + self.row_factory = session.row_factory + self.message = message + self.query = query + self.default_timeout = default_timeout + self._metrics = metrics + self.prepared_statement = prepared_statement + self._callback_lock = Lock() + if metrics is not None: + self._start_time = time.time() + self._make_query_plan() + self._event = Event() + self._errors = {} + self._callbacks = [] + self._errbacks = [] + + def _make_query_plan(self): + # convert the list/generator/etc to an iterator so that subsequent + # calls to send_request (which retries may do) will resume where + # they last left off + self.query_plan = iter(self.session._load_balancer.make_query_plan( + self.session.keyspace, self.query)) + + def send_request(self): + """ Internal """ + # query_plan is an iterator, so this will resume where we last left + # off if send_request() is called multiple times + for host in self.query_plan: + req_id = self._query(host) + if req_id is not None: + self._req_id = req_id + return + + self._set_final_exception(NoHostAvailable( + "Unable to complete the operation against any hosts", self._errors)) + + def _query(self, host, message=None, cb=None): + if message is None: + message = self.message + + if cb is None: + cb = self._set_result + + pool = self.session._pools.get(host) + if not pool: + self._errors[host] = ConnectionException("Host has been marked down or removed") + return None + elif pool.is_shutdown: + self._errors[host] = ConnectionException("Pool is shutdown") + return None + + self._current_host = host + self._current_pool = pool + + connection = None + try: + # TODO get connectTimeout from cluster settings + connection, request_id = pool.borrow_connection(timeout=2.0) + self._connection = connection + connection.send_msg(message, request_id, cb=cb) + return request_id + except NoConnectionsAvailable as exc: + log.debug("All connections for host %s are at capacity, moving to the next host", host) + self._errors[host] = exc + return None + except Exception as exc: + log.debug("Error querying host %s", host, exc_info=True) + self._errors[host] = exc + if self._metrics is not None: + self._metrics.on_connection_error() + if connection: + pool.return_connection(connection) + return None + + @property + def has_more_pages(self): + """ + Returns :const:`True` if there are more pages left in the + query results, :const:`False` otherwise. This should only + be checked after the first page has been returned. + + .. versionadded:: 2.0.0 + """ + return self._paging_state is not None + + def start_fetching_next_page(self): + """ + If there are more pages left in the query result, this asynchronously + starts fetching the next page. If there are no pages left, :exc:`.QueryExhausted` + is raised. Also see :attr:`.has_more_pages`. + + This should only be called after the first page has been returned. + + .. versionadded:: 2.0.0 + """ + if not self._paging_state: + raise QueryExhausted() + + self._make_query_plan() + self.message.paging_state = self._paging_state + self._event.clear() + self._final_result = _NOT_SET + self._final_exception = None + self.send_request() + + def _reprepare(self, prepare_message): + cb = partial(self.session.submit, self._execute_after_prepare) + request_id = self._query(self._current_host, prepare_message, cb=cb) + if request_id is None: + # try to submit the original prepared statement on some other host + self.send_request() + + def _set_result(self, response): + try: + if self._current_pool and self._connection: + self._current_pool.return_connection(self._connection) + + trace_id = getattr(response, 'trace_id', None) + if trace_id: + self._query_trace = QueryTrace(trace_id, self.session) + + if isinstance(response, ResultMessage): + if response.kind == RESULT_KIND_SET_KEYSPACE: + session = getattr(self, 'session', None) + # since we're running on the event loop thread, we need to + # use a non-blocking method for setting the keyspace on + # all connections in this session, otherwise the event + # loop thread will deadlock waiting for keyspaces to be + # set. This uses a callback chain which ends with + # self._set_keyspace_completed() being called in the + # event loop thread. + if session: + session._set_keyspace_for_all_pools( + response.results, self._set_keyspace_completed) + elif response.kind == RESULT_KIND_SCHEMA_CHANGE: + # refresh the schema before responding, but do it in another + # thread instead of the event loop thread + self.session.submit( + refresh_schema_and_set_result, + response.results['keyspace'], + response.results.get('table'), + response.results.get('type'), + self.session.cluster.control_connection, + self) + else: + results = getattr(response, 'results', None) + if results is not None and response.kind == RESULT_KIND_ROWS: + self._paging_state = response.paging_state + results = self.row_factory(*results) + self._set_final_result(results) + elif isinstance(response, ErrorMessage): + retry_policy = None + if self.query: + retry_policy = self.query.retry_policy + if not retry_policy: + retry_policy = self.session.cluster.default_retry_policy + + if isinstance(response, ReadTimeoutErrorMessage): + if self._metrics is not None: + self._metrics.on_read_timeout() + retry = retry_policy.on_read_timeout( + self.query, retry_num=self._query_retries, **response.info) + elif isinstance(response, WriteTimeoutErrorMessage): + if self._metrics is not None: + self._metrics.on_write_timeout() + retry = retry_policy.on_write_timeout( + self.query, retry_num=self._query_retries, **response.info) + elif isinstance(response, UnavailableErrorMessage): + if self._metrics is not None: + self._metrics.on_unavailable() + retry = retry_policy.on_unavailable( + self.query, retry_num=self._query_retries, **response.info) + elif isinstance(response, OverloadedErrorMessage): + if self._metrics is not None: + self._metrics.on_other_error() + # need to retry against a different host here + log.warning("Host %s is overloaded, retrying against a different " + "host", self._current_host) + self._retry(reuse_connection=False, consistency_level=None) + return + elif isinstance(response, IsBootstrappingErrorMessage): + if self._metrics is not None: + self._metrics.on_other_error() + # need to retry against a different host here + self._retry(reuse_connection=False, consistency_level=None) + return + elif isinstance(response, PreparedQueryNotFound): + if self.prepared_statement: + query_id = self.prepared_statement.query_id + assert query_id == response.info, \ + "Got different query ID in server response (%s) than we " \ + "had before (%s)" % (response.info, query_id) + else: + query_id = response.info + + try: + prepared_statement = self.session.cluster._prepared_statements[query_id] + except KeyError: + if not self.prepared_statement: + log.error("Tried to execute unknown prepared statement: id=%s", + query_id.encode('hex')) + self._set_final_exception(response) + return + else: + prepared_statement = self.prepared_statement + self.session.cluster._prepared_statements[query_id] = prepared_statement + + current_keyspace = self._connection.keyspace + prepared_keyspace = prepared_statement.keyspace + if prepared_keyspace and current_keyspace != prepared_keyspace: + self._set_final_exception( + ValueError("The Session's current keyspace (%s) does " + "not match the keyspace the statement was " + "prepared with (%s)" % + (current_keyspace, prepared_keyspace))) + return + + log.debug("Re-preparing unrecognized prepared statement against host %s: %s", + self._current_host, prepared_statement.query_string) + prepare_message = PrepareMessage(query=prepared_statement.query_string) + # since this might block, run on the executor to avoid hanging + # the event loop thread + self.session.submit(self._reprepare, prepare_message) + return + else: + if hasattr(response, 'to_exception'): + self._set_final_exception(response.to_exception()) + else: + self._set_final_exception(response) + return + + retry_type, consistency = retry + if retry_type is RetryPolicy.RETRY: + self._query_retries += 1 + self._retry(reuse_connection=True, consistency_level=consistency) + elif retry_type is RetryPolicy.RETHROW: + self._set_final_exception(response.to_exception()) + else: # IGNORE + if self._metrics is not None: + self._metrics.on_ignore() + self._set_final_result(None) + elif isinstance(response, ConnectionException): + if self._metrics is not None: + self._metrics.on_connection_error() + if not isinstance(response, ConnectionShutdown): + self._connection.defunct(response) + self._retry(reuse_connection=False, consistency_level=None) + elif isinstance(response, Exception): + if hasattr(response, 'to_exception'): + self._set_final_exception(response.to_exception()) + else: + self._set_final_exception(response) + else: + # we got some other kind of response message + msg = "Got unexpected message: %r" % (response,) + exc = ConnectionException(msg, self._current_host) + self._connection.defunct(exc) + self._set_final_exception(exc) + except Exception as exc: + # almost certainly caused by a bug, but we need to set something here + log.exception("Unexpected exception while handling result in ResponseFuture:") + self._set_final_exception(exc) + + def _set_keyspace_completed(self, errors): + if not errors: + self._set_final_result(None) + else: + self._set_final_exception(ConnectionException( + "Failed to set keyspace on all hosts: %s" % (errors,))) + + def _execute_after_prepare(self, response): + """ + Handle the response to our attempt to prepare a statement. + If it succeeded, run the original query again against the same host. + """ + if self._current_pool and self._connection: + self._current_pool.return_connection(self._connection) + + if self._final_exception: + return + + if isinstance(response, ResultMessage): + if response.kind == RESULT_KIND_PREPARED: + # use self._query to re-use the same host and + # at the same time properly borrow the connection + request_id = self._query(self._current_host) + if request_id is None: + # this host errored out, move on to the next + self.send_request() + else: + self._set_final_exception(ConnectionException( + "Got unexpected response when preparing statement " + "on host %s: %s" % (self._current_host, response))) + elif isinstance(response, ErrorMessage): + self._set_final_exception(response) + elif isinstance(response, ConnectionException): + log.debug("Connection error when preparing statement on host %s: %s", + self._current_host, response) + # try again on a different host, preparing again if necessary + self._errors[self._current_host] = response + self.send_request() + else: + self._set_final_exception(ConnectionException( + "Got unexpected response type when preparing " + "statement on host %s: %s" % (self._current_host, response))) + + def _set_final_result(self, response): + if self._metrics is not None: + self._metrics.request_timer.addValue(time.time() - self._start_time) + + with self._callback_lock: + self._final_result = response + + self._event.set() + + # apply each callback + for callback in self._callbacks: + fn, args, kwargs = callback + fn(response, *args, **kwargs) + + def _set_final_exception(self, response): + if self._metrics is not None: + self._metrics.request_timer.addValue(time.time() - self._start_time) + + with self._callback_lock: + self._final_exception = response + self._event.set() + + for errback in self._errbacks: + fn, args, kwargs = errback + fn(response, *args, **kwargs) + + def _retry(self, reuse_connection, consistency_level): + if self._final_exception: + # the connection probably broke while we were waiting + # to retry the operation + return + + if self._metrics is not None: + self._metrics.on_retry() + if consistency_level is not None: + self.message.consistency_level = consistency_level + + # don't retry on the event loop thread + self.session.submit(self._retry_task, reuse_connection) + + def _retry_task(self, reuse_connection): + if self._final_exception: + # the connection probably broke while we were waiting + # to retry the operation + return + + if reuse_connection and self._query(self._current_host) is not None: + return + + # otherwise, move onto another host + self.send_request() + + def result(self, timeout=_NOT_SET): + """ + Return the final result or raise an Exception if errors were + encountered. If the final result or error has not been set + yet, this method will block until that time. + + You may set a timeout (in seconds) with the `timeout` parameter. + By default, the :attr:`~.default_timeout` for the :class:`.Session` + this was created through will be used for the timeout on this + operation. If the timeout is exceeded, an + :exc:`cassandra.OperationTimedOut` will be raised. + + Example usage:: + + >>> future = session.execute_async("SELECT * FROM mycf") + >>> # do other stuff... + + >>> try: + ... rows = future.result() + ... for row in rows: + ... ... # process results + ... except Exception: + ... log.exception("Operation failed:") + + """ + if timeout is _NOT_SET: + timeout = self.default_timeout + + if self._final_result is not _NOT_SET: + if self._paging_state is None: + return self._final_result + else: + return PagedResult(self, self._final_result, timeout) + elif self._final_exception: + raise self._final_exception + else: + self._event.wait(timeout=timeout) + if self._final_result is not _NOT_SET: + if self._paging_state is None: + return self._final_result + else: + return PagedResult(self, self._final_result, timeout) + elif self._final_exception: + raise self._final_exception + else: + raise OperationTimedOut(errors=self._errors, last_host=self._current_host) + + def get_query_trace(self, max_wait=None): + """ + Returns the :class:`~.query.QueryTrace` instance representing a trace + of the last attempt for this operation, or :const:`None` if tracing was + not enabled for this query. Note that this may raise an exception if + there are problems retrieving the trace details from Cassandra. If the + trace is not available after `max_wait` seconds, + :exc:`cassandra.query.TraceUnavailable` will be raised. + """ + if not self._query_trace: + return None + + self._query_trace.populate(max_wait) + return self._query_trace + + def add_callback(self, fn, *args, **kwargs): + """ + Attaches a callback function to be called when the final results arrive. + + By default, `fn` will be called with the results as the first and only + argument. If `*args` or `**kwargs` are supplied, they will be passed + through as additional positional or keyword arguments to `fn`. + + If an error is hit while executing the operation, a callback attached + here will not be called. Use :meth:`.add_errback()` or :meth:`add_callbacks()` + if you wish to handle that case. + + If the final result has already been seen when this method is called, + the callback will be called immediately (before this method returns). + + **Important**: if the callback you attach results in an exception being + raised, **the exception will be ignored**, so please ensure your + callback handles all error cases that you care about. + + Usage example:: + + >>> session = cluster.connect("mykeyspace") + + >>> def handle_results(rows, start_time, should_log=False): + ... if should_log: + ... log.info("Total time: %f", time.time() - start_time) + ... ... + + >>> future = session.execute_async("SELECT * FROM users") + >>> future.add_callback(handle_results, time.time(), should_log=True) + + """ + run_now = False + with self._callback_lock: + if self._final_result is not _NOT_SET: + run_now = True + else: + self._callbacks.append((fn, args, kwargs)) + if run_now: + fn(self._final_result, *args, **kwargs) + return self + + def add_errback(self, fn, *args, **kwargs): + """ + Like :meth:`.add_callback()`, but handles error cases. + An Exception instance will be passed as the first positional argument + to `fn`. + """ + run_now = False + with self._callback_lock: + if self._final_exception: + run_now = True + else: + self._errbacks.append((fn, args, kwargs)) + if run_now: + fn(self._final_exception, *args, **kwargs) + return self + + def add_callbacks(self, callback, errback, + callback_args=(), callback_kwargs=None, + errback_args=(), errback_kwargs=None): + """ + A convenient combination of :meth:`.add_callback()` and + :meth:`.add_errback()`. + + Example usage:: + + >>> session = cluster.connect() + >>> query = "SELECT * FROM mycf" + >>> future = session.execute_async(query) + + >>> def log_results(results, level='debug'): + ... for row in results: + ... log.log(level, "Result: %s", row) + + >>> def log_error(exc, query): + ... log.error("Query '%s' failed: %s", query, exc) + + >>> future.add_callbacks( + ... callback=log_results, callback_kwargs={'level': 'info'}, + ... errback=log_error, errback_args=(query,)) + + """ + self.add_callback(callback, *callback_args, **(callback_kwargs or {})) + self.add_errback(errback, *errback_args, **(errback_kwargs or {})) + + def clear_callbacks(self): + with self._callback_lock: + self._callback = [] + self._errback = [] + + def __str__(self): + result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result + return "" \ + % (self.query, self._req_id, result, self._final_exception, self._current_host) + __repr__ = __str__ + + +class QueryExhausted(Exception): + """ + Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and + there are no more pages. You can check :attr:`.ResponseFuture.has_more_pages` + before calling to avoid this. + + .. versionadded:: 2.0.0 + """ + pass + + +class PagedResult(object): + """ + An iterator over the rows from a paged query result. Whenever the number + of result rows for a query exceed the :attr:`~.query.Statement.fetch_size` + (or :attr:`~.Session.default_fetch_size`, if not set) an instance of this + class will be returned. + + You can treat this as a normal iterator over rows:: + + >>> from cassandra.query import SimpleStatement + >>> statement = SimpleStatement("SELECT * FROM users", fetch_size=10) + >>> for user_row in session.execute(statement): + ... process_user(user_row) + + Whenever there are no more rows in the current page, the next page will + be fetched transparently. However, note that it *is* possible for + an :class:`Exception` to be raised while fetching the next page, just + like you might see on a normal call to ``session.execute()``. + + .. versionadded: 2.0.0 + """ + + response_future = None + + def __init__(self, response_future, initial_response, timeout=_NOT_SET): + self.response_future = response_future + self.current_response = iter(initial_response) + self.timeout = timeout + + def __iter__(self): + return self + + def next(self): + try: + return next(self.current_response) + except StopIteration: + if not self.response_future.has_more_pages: + raise + + self.response_future.start_fetching_next_page() + result = self.response_future.result(self.timeout) + if self.response_future.has_more_pages: + self.current_response = result.current_response + else: + self.current_response = iter(result) + + return next(self.current_response) + + __next__ = next diff --git a/cassandra/concurrent.py b/cassandra/concurrent.py new file mode 100644 index 0000000..8d7743e --- /dev/null +++ b/cassandra/concurrent.py @@ -0,0 +1,196 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import sys + +from itertools import count, cycle +import logging +from six.moves import xrange +from threading import Event + +from cassandra.cluster import PagedResult + +log = logging.getLogger(__name__) + + +def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True): + """ + Executes a sequence of (statement, parameters) tuples concurrently. Each + ``parameters`` item must be a sequence or :const:`None`. + + A sequence of ``(success, result_or_exc)`` tuples is returned in the same + order that the statements were passed in. If ``success`` is :const:`False`, + there was an error executing the statement, and ``result_or_exc`` will be + an :class:`Exception`. If ``success`` is :const:`True`, ``result_or_exc`` + will be the query result. + + If `raise_on_first_error` is left as :const:`True`, execution will stop + after the first failed statement and the corresponding exception will be + raised. + + The `concurrency` parameter controls how many statements will be executed + concurrently. When :attr:`.Cluster.protocol_version` is set to 1 or 2, + it is recommended that this be kept below 100 times the number of + core connections per host times the number of connected hosts (see + :meth:`.Cluster.set_core_connections_per_host`). If that amount is exceeded, + the event loop thread may attempt to block on new connection creation, + substantially impacting throughput. If :attr:`~.Cluster.protocol_version` + is 3 or higher, you can safely experiment with higher levels of concurrency. + + Example usage:: + + select_statement = session.prepare("SELECT * FROM users WHERE id=?") + + statements_and_params = [] + for user_id in user_ids: + params = (user_id, ) + statements_and_params.append((select_statement, params)) + + results = execute_concurrent( + session, statements_and_params, raise_on_first_error=False) + + for (success, result) in results: + if not success: + handle_error(result) # result will be an Exception + else: + process_user(result[0]) # result will be a list of rows + + """ + if concurrency <= 0: + raise ValueError("concurrency must be greater than 0") + + if not statements_and_parameters: + return [] + + # TODO handle iterators and generators naturally without converting the + # whole thing to a list. This would require not building a result + # list of Nones up front (we don't know how many results there will be), + # so a dict keyed by index should be used instead. The tricky part is + # knowing when you're the final statement to finish. + statements_and_parameters = list(statements_and_parameters) + + event = Event() + first_error = [] if raise_on_first_error else None + to_execute = len(statements_and_parameters) + results = [None] * to_execute + num_finished = count(1) + statements = enumerate(iter(statements_and_parameters)) + for i in xrange(min(concurrency, len(statements_and_parameters))): + _execute_next(_sentinel, i, event, session, statements, results, None, num_finished, to_execute, first_error) + + event.wait() + if first_error: + exc = first_error[0] + if six.PY2 and isinstance(exc, tuple): + (exc_type, value, traceback) = exc + six.reraise(exc_type, value, traceback) + else: + raise exc + else: + return results + + +def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs): + """ + Like :meth:`~cassandra.concurrent.execute_concurrent()`, but takes a single + statement and a sequence of parameters. Each item in ``parameters`` + should be a sequence or :const:`None`. + + Example usage:: + + statement = session.prepare("INSERT INTO mytable (a, b) VALUES (1, ?)") + parameters = [(x,) for x in range(1000)] + execute_concurrent_with_args(session, statement, parameters, concurrency=50) + """ + return execute_concurrent(session, list(zip(cycle((statement,)), parameters)), *args, **kwargs) + + +_sentinel = object() + + +def _handle_error(error, result_index, event, session, statements, results, + future, num_finished, to_execute, first_error): + if first_error is not None: + first_error.append(error) + event.set() + return + else: + results[result_index] = (False, error) + if next(num_finished) >= to_execute: + event.set() + return + + try: + (next_index, (statement, params)) = next(statements) + except StopIteration: + return + + try: + future = session.execute_async(statement, params) + args = (next_index, event, session, statements, results, future, num_finished, to_execute, first_error) + future.add_callbacks( + callback=_execute_next, callback_args=args, + errback=_handle_error, errback_args=args) + except Exception as exc: + if first_error is not None: + if six.PY2: + first_error.append(sys.exc_info()) + else: + first_error.append(exc) + event.set() + return + else: + results[next_index] = (False, exc) + if next(num_finished) >= to_execute: + event.set() + return + + +def _execute_next(result, result_index, event, session, statements, results, + future, num_finished, to_execute, first_error): + if result is not _sentinel: + if future.has_more_pages: + result = PagedResult(future, result) + future.clear_callbacks() + results[result_index] = (True, result) + finished = next(num_finished) + if finished >= to_execute: + event.set() + return + + try: + (next_index, (statement, params)) = next(statements) + except StopIteration: + return + + try: + future = session.execute_async(statement, params) + args = (next_index, event, session, statements, results, future, num_finished, to_execute, first_error) + future.add_callbacks( + callback=_execute_next, callback_args=args, + errback=_handle_error, errback_args=args) + except Exception as exc: + if first_error is not None: + if six.PY2: + first_error.append(sys.exc_info()) + else: + first_error.append(exc) + event.set() + return + else: + results[next_index] = (False, exc) + if next(num_finished) >= to_execute: + event.set() + return diff --git a/cassandra/connection.py b/cassandra/connection.py new file mode 100644 index 0000000..2a020c0 --- /dev/null +++ b/cassandra/connection.py @@ -0,0 +1,856 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import # to enable import io from stdlib +from collections import defaultdict, deque +import errno +from functools import wraps, partial +import io +import logging +import os +import sys +from threading import Thread, Event, RLock +import time + +if 'gevent.monkey' in sys.modules: + from gevent.queue import Queue, Empty +else: + from six.moves.queue import Queue, Empty # noqa + +import six +from six.moves import range + +from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut +from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack, int32_unpack +from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage, + StartupMessage, ErrorMessage, CredentialsMessage, + QueryMessage, ResultMessage, decode_response, + InvalidRequestException, SupportedMessage, + AuthResponseMessage, AuthChallengeMessage, + AuthSuccessMessage, ProtocolException) +from cassandra.util import OrderedDict + + +log = logging.getLogger(__name__) + +# We use an ordered dictionary and specifically add lz4 before +# snappy so that lz4 will be preferred. Changing the order of this +# will change the compression preferences for the driver. +locally_supported_compressions = OrderedDict() + +try: + import lz4 +except ImportError: + pass +else: + + # Cassandra writes the uncompressed message length in big endian order, + # but the lz4 lib requires little endian order, so we wrap these + # functions to handle that + + def lz4_compress(byts): + # write length in big-endian instead of little-endian + return int32_pack(len(byts)) + lz4.compress(byts)[4:] + + def lz4_decompress(byts): + # flip from big-endian to little-endian + return lz4.decompress(byts[3::-1] + byts[4:]) + + locally_supported_compressions['lz4'] = (lz4_compress, lz4_decompress) + +try: + import snappy +except ImportError: + pass +else: + # work around apparently buggy snappy decompress + def decompress(byts): + if byts == '\x00': + return '' + return snappy.decompress(byts) + locally_supported_compressions['snappy'] = (snappy.compress, decompress) + + +PROTOCOL_VERSION_MASK = 0x7f + +HEADER_DIRECTION_FROM_CLIENT = 0x00 +HEADER_DIRECTION_TO_CLIENT = 0x80 +HEADER_DIRECTION_MASK = 0x80 + +NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK) + + +class ConnectionException(Exception): + """ + An unrecoverable error was hit when attempting to use a connection, + or the connection was already closed or defunct. + """ + + def __init__(self, message, host=None): + Exception.__init__(self, message) + self.host = host + + +class ConnectionShutdown(ConnectionException): + """ + Raised when a connection has been marked as defunct or has been closed. + """ + pass + + +class ConnectionBusy(Exception): + """ + An attempt was made to send a message through a :class:`.Connection` that + was already at the max number of in-flight operations. + """ + pass + + +class ProtocolError(Exception): + """ + Communication did not match the protocol that this driver expects. + """ + pass + + +def defunct_on_error(f): + + @wraps(f) + def wrapper(self, *args, **kwargs): + try: + return f(self, *args, **kwargs) + except Exception as exc: + self.defunct(exc) + return wrapper + + +DEFAULT_CQL_VERSION = '3.0.0' + + +class Connection(object): + + in_buffer_size = 4096 + out_buffer_size = 4096 + + cql_version = None + protocol_version = 2 + + keyspace = None + compression = True + compressor = None + decompressor = None + + ssl_options = None + last_error = None + + # The current number of operations that are in flight. More precisely, + # the number of request IDs that are currently in use. + in_flight = 0 + + # A set of available request IDs. When using the v3 protocol or higher, + # this will not initially include all request IDs in order to save memory, + # but the set will grow if it is exhausted. + request_ids = None + + # Tracks the highest used request ID in order to help with growing the + # request_ids set + highest_request_id = 0 + + is_defunct = False + is_closed = False + lock = None + user_type_map = None + + msg_received = False + + is_control_connection = False + _iobuf = None + + def __init__(self, host='127.0.0.1', port=9042, authenticator=None, + ssl_options=None, sockopts=None, compression=True, + cql_version=None, protocol_version=2, is_control_connection=False, + user_type_map=None): + self.host = host + self.port = port + self.authenticator = authenticator + self.ssl_options = ssl_options + self.sockopts = sockopts + self.compression = compression + self.cql_version = cql_version + self.protocol_version = protocol_version + self.is_control_connection = is_control_connection + self.user_type_map = user_type_map + self._push_watchers = defaultdict(set) + self._iobuf = io.BytesIO() + if protocol_version >= 3: + self._header_unpack = v3_header_unpack + self._header_length = 5 + self.max_request_id = (2 ** 15) - 1 + # Don't fill the deque with 2**15 items right away. Start with 300 and add + # more if needed. + self.request_ids = deque(range(300)) + self.highest_request_id = 299 + else: + self._header_unpack = header_unpack + self._header_length = 4 + self.max_request_id = (2 ** 7) - 1 + self.request_ids = deque(range(self.max_request_id + 1)) + self.highest_request_id = self.max_request_id + + # 0 8 16 24 32 40 + # +---------+---------+---------+---------+---------+ + # | version | flags | stream | opcode | + # +---------+---------+---------+---------+---------+ + # | length | + # +---------+---------+---------+---------+ + # | | + # . ... body ... . + # . . + # . . + # +---------------------------------------- + self._full_header_length = self._header_length + 4 + + self.lock = RLock() + + @classmethod + def initialize_reactor(self): + """ + Called once by Cluster.connect(). This should be used by implementations + to set up any resources that will be shared across connections. + """ + pass + + @classmethod + def handle_fork(self): + """ + Called after a forking. This should cleanup any remaining reactor state + from the parent process. + """ + pass + + def close(self): + raise NotImplementedError() + + def defunct(self, exc): + with self.lock: + if self.is_defunct or self.is_closed: + return + self.is_defunct = True + + log.debug("Defuncting connection (%s) to %s:", + id(self), self.host, exc_info=exc) + + self.last_error = exc + self.close() + self.error_all_callbacks(exc) + self.connected_event.set() + return exc + + def error_all_callbacks(self, exc): + with self.lock: + callbacks = self._callbacks + self._callbacks = {} + new_exc = ConnectionShutdown(str(exc)) + for cb in callbacks.values(): + try: + cb(new_exc) + except Exception: + log.warning("Ignoring unhandled exception while erroring callbacks for a " + "failed connection (%s) to host %s:", + id(self), self.host, exc_info=True) + + def get_request_id(self): + """ + This must be called while self.lock is held. + """ + try: + return self.request_ids.popleft() + except IndexError: + self.highest_request_id += 1 + # in_flight checks should guarantee this + assert self.highest_request_id <= self.max_request_id + return self.highest_request_id + + def handle_pushed(self, response): + log.debug("Message pushed from server: %r", response) + for cb in self._push_watchers.get(response.event_type, []): + try: + cb(response.event_args) + except Exception: + log.exception("Pushed event handler errored, ignoring:") + + def send_msg(self, msg, request_id, cb): + if self.is_defunct: + raise ConnectionShutdown("Connection to %s is defunct" % self.host) + elif self.is_closed: + raise ConnectionShutdown("Connection to %s is closed" % self.host) + + self._callbacks[request_id] = cb + self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor)) + return request_id + + def wait_for_response(self, msg, timeout=None): + return self.wait_for_responses(msg, timeout=timeout)[0] + + def wait_for_responses(self, *msgs, **kwargs): + """ + Returns a list of (success, response) tuples. If success + is False, response will be an Exception. Otherwise, response + will be the normal query response. + + If fail_on_error was left as True and one of the requests + failed, the corresponding Exception will be raised. + """ + if self.is_closed or self.is_defunct: + raise ConnectionShutdown("Connection %s is already closed" % (self, )) + timeout = kwargs.get('timeout') + fail_on_error = kwargs.get('fail_on_error', True) + waiter = ResponseWaiter(self, len(msgs), fail_on_error) + + # busy wait for sufficient space on the connection + messages_sent = 0 + while True: + needed = len(msgs) - messages_sent + with self.lock: + available = min(needed, self.max_request_id - self.in_flight) + request_ids = [self.get_request_id() for _ in range(available)] + self.in_flight += available + + for i, request_id in enumerate(request_ids): + self.send_msg(msgs[messages_sent + i], + request_id, + partial(waiter.got_response, index=messages_sent + i)) + messages_sent += available + + if messages_sent == len(msgs): + break + else: + if timeout is not None: + timeout -= 0.01 + if timeout <= 0.0: + raise OperationTimedOut() + time.sleep(0.01) + + try: + return waiter.deliver(timeout) + except OperationTimedOut: + raise + except Exception as exc: + self.defunct(exc) + raise + + def register_watcher(self, event_type, callback): + raise NotImplementedError() + + def register_watchers(self, type_callback_dict): + raise NotImplementedError() + + def control_conn_disposed(self): + self.is_control_connection = False + self._push_watchers = {} + + def process_io_buffer(self): + while True: + pos = self._iobuf.tell() + if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes): + # we don't have a complete header yet or we + # already saw a header, but we don't have a + # complete message yet + return + else: + # have enough for header, read body len from header + self._iobuf.seek(self._header_length) + body_len = int32_unpack(self._iobuf.read(4)) + + # seek to end to get length of current buffer + self._iobuf.seek(0, os.SEEK_END) + pos = self._iobuf.tell() + + if pos >= body_len + self._full_header_length: + # read message header and body + self._iobuf.seek(0) + msg = self._iobuf.read(self._full_header_length + body_len) + + # leave leftover in current buffer + leftover = self._iobuf.read() + self._iobuf = io.BytesIO() + self._iobuf.write(leftover) + + self._total_reqd_bytes = 0 + self.process_msg(msg, body_len) + else: + self._total_reqd_bytes = body_len + self._full_header_length + return + + @defunct_on_error + def process_msg(self, msg, body_len): + version, flags, stream_id, opcode = self._header_unpack(msg[:self._header_length]) + if stream_id < 0: + callback = None + else: + callback = self._callbacks.pop(stream_id, None) + with self.lock: + self.request_ids.append(stream_id) + + self.msg_received = True + + body = None + try: + # check that the protocol version is supported + given_version = version & PROTOCOL_VERSION_MASK + if given_version != self.protocol_version: + msg = "Server protocol version (%d) does not match the specified driver protocol version (%d). " +\ + "Consider setting Cluster.protocol_version to %d." + raise ProtocolError(msg % (given_version, self.protocol_version, given_version)) + + # check that the header direction is correct + if version & HEADER_DIRECTION_MASK != HEADER_DIRECTION_TO_CLIENT: + raise ProtocolError( + "Header direction in response is incorrect; opcode %04x, stream id %r" + % (opcode, stream_id)) + + if body_len > 0: + body = msg[self._full_header_length:] + elif body_len == 0: + body = six.binary_type() + else: + raise ProtocolError("Got negative body length: %r" % body_len) + + response = decode_response(given_version, self.user_type_map, stream_id, + flags, opcode, body, self.decompressor) + except Exception as exc: + log.exception("Error decoding response from Cassandra. " + "opcode: %04x; message contents: %r", opcode, msg) + if callback is not None: + callback(exc) + self.defunct(exc) + return + + try: + if stream_id >= 0: + if isinstance(response, ProtocolException): + log.error("Closing connection %s due to protocol error: %s", self, response.summary_msg()) + self.defunct(response) + if callback is not None: + callback(response) + else: + self.handle_pushed(response) + except Exception: + log.exception("Callback handler errored, ignoring:") + + @defunct_on_error + def _send_options_message(self): + if self.cql_version is None and (not self.compression or not locally_supported_compressions): + log.debug("Not sending options message for new connection(%s) to %s " + "because compression is disabled and a cql version was not " + "specified", id(self), self.host) + self._compressor = None + self.cql_version = DEFAULT_CQL_VERSION + self._send_startup_message() + else: + log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host) + self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response) + + @defunct_on_error + def _handle_options_response(self, options_response): + if self.is_defunct: + return + + if not isinstance(options_response, SupportedMessage): + if isinstance(options_response, ConnectionException): + raise options_response + else: + log.error("Did not get expected SupportedMessage response; " + "instead, got: %s", options_response) + raise ConnectionException("Did not get expected SupportedMessage " + "response; instead, got: %s" + % (options_response,)) + + log.debug("Received options response on new connection (%s) from %s", + id(self), self.host) + supported_cql_versions = options_response.cql_versions + remote_supported_compressions = options_response.options['COMPRESSION'] + + if self.cql_version: + if self.cql_version not in supported_cql_versions: + raise ProtocolError( + "cql_version %r is not supported by remote (w/ native " + "protocol). Supported versions: %r" + % (self.cql_version, supported_cql_versions)) + else: + self.cql_version = supported_cql_versions[0] + + self._compressor = None + compression_type = None + if self.compression: + overlap = (set(locally_supported_compressions.keys()) & + set(remote_supported_compressions)) + if len(overlap) == 0: + log.debug("No available compression types supported on both ends." + " locally supported: %r. remotely supported: %r", + locally_supported_compressions.keys(), + remote_supported_compressions) + else: + compression_type = None + if isinstance(self.compression, six.string_types): + # the user picked a specific compression type ('snappy' or 'lz4') + if self.compression not in remote_supported_compressions: + raise ProtocolError( + "The requested compression type (%s) is not supported by the Cassandra server at %s" + % (self.compression, self.host)) + compression_type = self.compression + else: + # our locally supported compressions are ordered to prefer + # lz4, if available + for k in locally_supported_compressions.keys(): + if k in overlap: + compression_type = k + break + + # set the decompressor here, but set the compressor only after + # a successful Ready message + self._compressor, self.decompressor = \ + locally_supported_compressions[compression_type] + + self._send_startup_message(compression_type) + + @defunct_on_error + def _send_startup_message(self, compression=None): + log.debug("Sending StartupMessage on %s", self) + opts = {} + if compression: + opts['COMPRESSION'] = compression + sm = StartupMessage(cqlversion=self.cql_version, options=opts) + self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response) + log.debug("Sent StartupMessage on %s", self) + + @defunct_on_error + def _handle_startup_response(self, startup_response, did_authenticate=False): + if self.is_defunct: + return + if isinstance(startup_response, ReadyMessage): + log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.host) + if self._compressor: + self.compressor = self._compressor + self.connected_event.set() + elif isinstance(startup_response, AuthenticateMessage): + log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s", + id(self), self.host, startup_response.authenticator) + + if self.authenticator is None: + raise AuthenticationFailed('Remote end requires authentication.') + + self.authenticator_class = startup_response.authenticator + + if isinstance(self.authenticator, dict): + log.debug("Sending credentials-based auth response on %s", self) + cm = CredentialsMessage(creds=self.authenticator) + callback = partial(self._handle_startup_response, did_authenticate=True) + self.send_msg(cm, self.get_request_id(), cb=callback) + else: + log.debug("Sending SASL-based auth response on %s", self) + initial_response = self.authenticator.initial_response() + initial_response = "" if initial_response is None else initial_response + self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), self._handle_auth_response) + elif isinstance(startup_response, ErrorMessage): + log.debug("Received ErrorMessage on new connection (%s) from %s: %s", + id(self), self.host, startup_response.summary_msg()) + if did_authenticate: + raise AuthenticationFailed( + "Failed to authenticate to %s: %s" % + (self.host, startup_response.summary_msg())) + else: + raise ConnectionException( + "Failed to initialize new connection to %s: %s" + % (self.host, startup_response.summary_msg())) + elif isinstance(startup_response, ConnectionShutdown): + log.debug("Connection to %s was closed during the startup handshake", (self.host)) + raise startup_response + else: + msg = "Unexpected response during Connection setup: %r" + log.error(msg, startup_response) + raise ProtocolError(msg % (startup_response,)) + + @defunct_on_error + def _handle_auth_response(self, auth_response): + if self.is_defunct: + return + + if isinstance(auth_response, AuthSuccessMessage): + log.debug("Connection %s successfully authenticated", self) + self.authenticator.on_authentication_success(auth_response.token) + if self._compressor: + self.compressor = self._compressor + self.connected_event.set() + elif isinstance(auth_response, AuthChallengeMessage): + response = self.authenticator.evaluate_challenge(auth_response.challenge) + msg = AuthResponseMessage("" if response is None else response) + log.debug("Responding to auth challenge on %s", self) + self.send_msg(msg, self.get_request_id(), self._handle_auth_response) + elif isinstance(auth_response, ErrorMessage): + log.debug("Received ErrorMessage on new connection (%s) from %s: %s", + id(self), self.host, auth_response.summary_msg()) + raise AuthenticationFailed( + "Failed to authenticate to %s: %s" % + (self.host, auth_response.summary_msg())) + elif isinstance(auth_response, ConnectionShutdown): + log.debug("Connection to %s was closed during the authentication process", self.host) + raise auth_response + else: + msg = "Unexpected response during Connection authentication to %s: %r" + log.error(msg, self.host, auth_response) + raise ProtocolError(msg % (self.host, auth_response)) + + def set_keyspace_blocking(self, keyspace): + if not keyspace or keyspace == self.keyspace: + return + + query = QueryMessage(query='USE "%s"' % (keyspace,), + consistency_level=ConsistencyLevel.ONE) + try: + result = self.wait_for_response(query) + except InvalidRequestException as ire: + # the keyspace probably doesn't exist + raise ire.to_exception() + except Exception as exc: + conn_exc = ConnectionException( + "Problem while setting keyspace: %r" % (exc,), self.host) + self.defunct(conn_exc) + raise conn_exc + + if isinstance(result, ResultMessage): + self.keyspace = keyspace + else: + conn_exc = ConnectionException( + "Problem while setting keyspace: %r" % (result,), self.host) + self.defunct(conn_exc) + raise conn_exc + + def set_keyspace_async(self, keyspace, callback): + """ + Use this in order to avoid deadlocking the event loop thread. + When the operation completes, `callback` will be called with + two arguments: this connection and an Exception if an error + occurred, otherwise :const:`None`. + """ + if not keyspace or keyspace == self.keyspace: + callback(self, None) + return + + query = QueryMessage(query='USE "%s"' % (keyspace,), + consistency_level=ConsistencyLevel.ONE) + + def process_result(result): + if isinstance(result, ResultMessage): + self.keyspace = keyspace + callback(self, None) + elif isinstance(result, InvalidRequestException): + callback(self, result.to_exception()) + else: + callback(self, self.defunct(ConnectionException( + "Problem while setting keyspace: %r" % (result,), self.host))) + + request_id = None + # we use a busy wait on the lock here because: + # - we'll only spin if the connection is at max capacity, which is very + # unlikely for a set_keyspace call + # - it allows us to avoid signaling a condition every time a request completes + while True: + with self.lock: + if self.in_flight < self.max_request_id: + request_id = self.get_request_id() + self.in_flight += 1 + break + + time.sleep(0.001) + + self.send_msg(query, request_id, process_result) + + @property + def is_idle(self): + return not self.msg_received + + def reset_idle(self): + self.msg_received = False + + def __str__(self): + status = "" + if self.is_defunct: + status = " (defunct)" + elif self.is_closed: + status = " (closed)" + + return "<%s(%r) %s:%d%s>" % (self.__class__.__name__, id(self), self.host, self.port, status) + __repr__ = __str__ + + +class ResponseWaiter(object): + + def __init__(self, connection, num_responses, fail_on_error): + self.connection = connection + self.pending = num_responses + self.fail_on_error = fail_on_error + self.error = None + self.responses = [None] * num_responses + self.event = Event() + + def got_response(self, response, index): + with self.connection.lock: + self.connection.in_flight -= 1 + if isinstance(response, Exception): + if hasattr(response, 'to_exception'): + response = response.to_exception() + if self.fail_on_error: + self.error = response + self.event.set() + else: + self.responses[index] = (False, response) + else: + if not self.fail_on_error: + self.responses[index] = (True, response) + else: + self.responses[index] = response + + self.pending -= 1 + if not self.pending: + self.event.set() + + def deliver(self, timeout=None): + """ + If fail_on_error was set to False, a list of (success, response) + tuples will be returned. If success is False, response will be + an Exception. Otherwise, response will be the normal query response. + + If fail_on_error was left as True and one of the requests + failed, the corresponding Exception will be raised. Otherwise, + the normal response will be returned. + """ + self.event.wait(timeout) + if self.error: + raise self.error + elif not self.event.is_set(): + raise OperationTimedOut() + else: + return self.responses + + +class HeartbeatFuture(object): + def __init__(self, connection, owner): + self._exception = None + self._event = Event() + self.connection = connection + self.owner = owner + log.debug("Sending options message heartbeat on idle connection (%s) %s", + id(connection), connection.host) + with connection.lock: + if connection.in_flight < connection.max_request_id: + connection.in_flight += 1 + connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) + else: + self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold") + self._event.set() + + def wait(self, timeout): + self._event.wait(timeout) + if self._event.is_set(): + if self._exception: + raise self._exception + else: + raise OperationTimedOut() + + def _options_callback(self, response): + if not isinstance(response, SupportedMessage): + if isinstance(response, ConnectionException): + self._exception = response + else: + self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s" + % (response,)) + + log.debug("Received options response on connection (%s) from %s", + id(self.connection), self.connection.host) + self._event.set() + + +class ConnectionHeartbeat(Thread): + + def __init__(self, interval_sec, get_connection_holders): + Thread.__init__(self, name="Connection heartbeat") + self._interval = interval_sec + self._get_connection_holders = get_connection_holders + self._shutdown_event = Event() + self.daemon = True + self.start() + + class ShutdownException(Exception): + pass + + def run(self): + self._shutdown_event.wait(self._interval) + while not self._shutdown_event.is_set(): + start_time = time.time() + + futures = [] + failed_connections = [] + try: + for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]: + for connection in connections: + self._raise_if_stopped() + if not (connection.is_defunct or connection.is_closed): + if connection.is_idle: + try: + futures.append(HeartbeatFuture(connection, owner)) + except Exception: + log.warning("Failed sending heartbeat message on connection (%s) to %s", + id(connection), connection.host, exc_info=True) + failed_connections.append((connection, owner)) + else: + connection.reset_idle() + else: + # make sure the owner sees this defunt/closed connection + owner.return_connection(connection) + self._raise_if_stopped() + + for f in futures: + self._raise_if_stopped() + connection = f.connection + try: + f.wait(self._interval) + # TODO: move this, along with connection locks in pool, down into Connection + with connection.lock: + connection.in_flight -= 1 + connection.reset_idle() + except Exception: + log.warning("Heartbeat failed for connection (%s) to %s", + id(connection), connection.host, exc_info=True) + failed_connections.append((f.connection, f.owner)) + + for connection, owner in failed_connections: + self._raise_if_stopped() + connection.defunct(Exception('Connection heartbeat failure')) + owner.return_connection(connection) + except self.ShutdownException: + pass + except Exception: + log.error("Failed connection heartbeat", exc_info=True) + + elapsed = time.time() - start_time + self._shutdown_event.wait(max(self._interval - elapsed, 0.01)) + + def stop(self): + self._shutdown_event.set() + self.join() + + def _raise_if_stopped(self): + if self._shutdown_event.is_set(): + raise self.ShutdownException() diff --git a/cassandra/cqlengine/__init__.py b/cassandra/cqlengine/__init__.py new file mode 100644 index 0000000..38a02fd --- /dev/null +++ b/cassandra/cqlengine/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six + + +# compaction +SizeTieredCompactionStrategy = "SizeTieredCompactionStrategy" +LeveledCompactionStrategy = "LeveledCompactionStrategy" + +# Caching constants. +CACHING_ALL = "ALL" +CACHING_KEYS_ONLY = "KEYS_ONLY" +CACHING_ROWS_ONLY = "ROWS_ONLY" +CACHING_NONE = "NONE" + + +class CQLEngineException(Exception): + pass + + +class ValidationError(CQLEngineException): + pass + + +class UnicodeMixin(object): + if six.PY3: + __str__ = lambda x: x.__unicode__() + else: + __str__ = lambda x: six.text_type(x).encode('utf-8') diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py new file mode 100644 index 0000000..ec6a468 --- /dev/null +++ b/cassandra/cqlengine/columns.py @@ -0,0 +1,926 @@ +# Copyright 2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy, copy +from datetime import date, datetime +import logging +import re +import six +import warnings + +from cassandra.cqltypes import DateType +from cassandra.encoder import cql_quote + +from cassandra.cqlengine import ValidationError + +log = logging.getLogger(__name__) + + +class BaseValueManager(object): + + def __init__(self, instance, column, value): + self.instance = instance + self.column = column + self.previous_value = deepcopy(value) + self.value = value + self.explicit = False + + @property + def deleted(self): + return self.value is None and self.previous_value is not None + + @property + def changed(self): + """ + Indicates whether or not this value has changed. + + :rtype: boolean + + """ + return self.value != self.previous_value + + def reset_previous_value(self): + self.previous_value = copy(self.value) + + def getval(self): + return self.value + + def setval(self, val): + self.value = val + + def delval(self): + self.value = None + + def get_property(self): + _get = lambda slf: self.getval() + _set = lambda slf, val: self.setval(val) + _del = lambda slf: self.delval() + + if self.column.can_delete: + return property(_get, _set, _del) + else: + return property(_get, _set) + + +class ValueQuoter(object): + """ + contains a single value, which will quote itself for CQL insertion statements + """ + def __init__(self, value): + self.value = value + + def __str__(self): + raise NotImplementedError + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.value == other.value + return False + + +class Column(object): + + # the cassandra type this column maps to + db_type = None + value_manager = BaseValueManager + + instance_counter = 0 + + primary_key = False + """ + bool flag, indicates this column is a primary key. The first primary key defined + on a model is the partition key (unless partition keys are set), all others are cluster keys + """ + + partition_key = False + + """ + indicates that this column should be the partition key, defining + more than one partition key column creates a compound partition key + """ + + index = False + """ + bool flag, indicates an index should be created for this column + """ + + db_field = None + """ + the fieldname this field will map to in the database + """ + + default = None + """ + the default value, can be a value or a callable (no args) + """ + + required = False + """ + boolean, is the field required? Model validation will raise and + exception if required is set to True and there is a None value assigned + """ + + clustering_order = None + """ + only applicable on clustering keys (primary keys that are not partition keys) + determines the order that the clustering keys are sorted on disk + """ + + polymorphic_key = False + """ + *Deprecated* + + see :attr:`~.discriminator_column` + """ + + discriminator_column = False + """ + boolean, if set to True, this column will be used for discriminating records + of inherited models. + + Should only be set on a column of an abstract model being used for inheritance. + + There may only be one discriminator column per model. See :attr:`~.__discriminator_value__` + for how to specify the value of this column on specialized models. + """ + + static = False + """ + boolean, if set to True, this is a static column, with a single value per partition + """ + + def __init__(self, + primary_key=False, + partition_key=False, + index=False, + db_field=None, + default=None, + required=False, + clustering_order=None, + polymorphic_key=False, + discriminator_column=False, + static=False): + self.partition_key = partition_key + self.primary_key = partition_key or primary_key + self.index = index + self.db_field = db_field + self.default = default + self.required = required + self.clustering_order = clustering_order + + if polymorphic_key: + msg = "polymorphic_key is deprecated. Use discriminator_column instead." + warnings.warn(msg, DeprecationWarning) + log.warning(msg) + + self.discriminator_column = discriminator_column or polymorphic_key + self.polymorphic_key = self.discriminator_column + + # the column name in the model definition + self.column_name = None + self.static = static + + self.value = None + + # keep track of instantiation order + self.position = Column.instance_counter + Column.instance_counter += 1 + + def validate(self, value): + """ + Returns a cleaned and validated value. Raises a ValidationError + if there's a problem + """ + if value is None: + if self.required: + raise ValidationError('{} - None values are not allowed'.format(self.column_name or self.db_field)) + return value + + def to_python(self, value): + """ + Converts data from the database into python values + raises a ValidationError if the value can't be converted + """ + return value + + def to_database(self, value): + """ + Converts python value into database value + """ + if value is None and self.has_default: + return self.get_default() + return value + + @property + def has_default(self): + return self.default is not None + + @property + def is_primary_key(self): + return self.primary_key + + @property + def can_delete(self): + return not self.primary_key + + def get_default(self): + if self.has_default: + if callable(self.default): + return self.default() + else: + return self.default + + def get_column_def(self): + """ + Returns a column definition for CQL table definition + """ + static = "static" if self.static else "" + return '{} {} {}'.format(self.cql, self.db_type, static) + + # TODO: make columns use cqltypes under the hood + # until then, this bridges the gap in using types along with cassandra.metadata for CQL generation + def cql_parameterized_type(self): + return self.db_type + + def set_column_name(self, name): + """ + Sets the column name during document class construction + This value will be ignored if db_field is set in __init__ + """ + self.column_name = name + + @property + def db_field_name(self): + """ Returns the name of the cql name of this column """ + return self.db_field or self.column_name + + @property + def db_index_name(self): + """ Returns the name of the cql index """ + return 'index_{}'.format(self.db_field_name) + + @property + def cql(self): + return self.get_cql() + + def get_cql(self): + return '"{}"'.format(self.db_field_name) + + def _val_is_null(self, val): + """ determines if the given value equates to a null value for the given column type """ + return val is None + + @property + def sub_columns(self): + return [] + + +class Blob(Column): + """ + Stores a raw binary value + """ + db_type = 'blob' + + def to_database(self, value): + + if not isinstance(value, (six.binary_type, bytearray)): + raise Exception("expecting a binary, got a %s" % type(value)) + + val = super(Bytes, self).to_database(value) + return bytearray(val) + + def to_python(self, value): + return value + +Bytes = Blob + + +class Ascii(Column): + """ + Stores a US-ASCII character string + """ + db_type = 'ascii' + + +class Inet(Column): + """ + Stores an IP address in IPv4 or IPv6 format + """ + db_type = 'inet' + + +class Text(Column): + """ + Stores a UTF-8 encoded string + """ + db_type = 'text' + + def __init__(self, min_length=None, max_length=None, **kwargs): + """ + :param int min_length: Sets the minimum length of this string, for validation purposes. + Defaults to 1 if this is a ``required`` column. Otherwise, None. + :param int max_lemgth: Sets the maximum length of this string, for validation purposes. + """ + self.min_length = min_length or (1 if kwargs.get('required', False) else None) + self.max_length = max_length + super(Text, self).__init__(**kwargs) + + def validate(self, value): + value = super(Text, self).validate(value) + if value is None: + return + if not isinstance(value, (six.string_types, bytearray)) and value is not None: + raise ValidationError('{} {} is not a string'.format(self.column_name, type(value))) + if self.max_length: + if len(value) > self.max_length: + raise ValidationError('{} is longer than {} characters'.format(self.column_name, self.max_length)) + if self.min_length: + if len(value) < self.min_length: + raise ValidationError('{} is shorter than {} characters'.format(self.column_name, self.min_length)) + return value + + +class Integer(Column): + """ + Stores a 32-bit signed integer value + """ + + db_type = 'int' + + def validate(self, value): + val = super(Integer, self).validate(value) + if val is None: + return + try: + return int(val) + except (TypeError, ValueError): + raise ValidationError("{} {} can't be converted to integral value".format(self.column_name, value)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + +class BigInt(Integer): + """ + Stores a 64-bit signed long value + """ + db_type = 'bigint' + + +class VarInt(Column): + """ + Stores an arbitrary-precision integer + """ + db_type = 'varint' + + def validate(self, value): + val = super(VarInt, self).validate(value) + if val is None: + return + try: + return int(val) + except (TypeError, ValueError): + raise ValidationError( + "{} {} can't be converted to integral value".format(self.column_name, value)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + +class CounterValueManager(BaseValueManager): + def __init__(self, instance, column, value): + super(CounterValueManager, self).__init__(instance, column, value) + self.value = self.value or 0 + self.previous_value = self.previous_value or 0 + + +class Counter(Integer): + """ + Stores a counter that can be inremented and decremented + """ + db_type = 'counter' + + value_manager = CounterValueManager + + def __init__(self, + index=False, + db_field=None, + required=False): + super(Counter, self).__init__( + primary_key=False, + partition_key=False, + index=index, + db_field=db_field, + default=0, + required=required, + ) + + +class DateTime(Column): + """ + Stores a datetime value + """ + db_type = 'timestamp' + + def to_python(self, value): + if value is None: + return + if isinstance(value, datetime): + return value + elif isinstance(value, date): + return datetime(*(value.timetuple()[:6])) + try: + return datetime.utcfromtimestamp(value) + except TypeError: + return datetime.utcfromtimestamp(DateType.deserialize(value)) + + def to_database(self, value): + value = super(DateTime, self).to_database(value) + if value is None: + return + if not isinstance(value, datetime): + if isinstance(value, date): + value = datetime(value.year, value.month, value.day) + else: + raise ValidationError("{} '{}' is not a datetime object".format(self.column_name, value)) + epoch = datetime(1970, 1, 1, tzinfo=value.tzinfo) + offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 + + return int(((value - epoch).total_seconds() - offset) * 1000) + + +class Date(Column): + """ + *Note: this type is overloaded, and will likely be changed or removed to accommodate distinct date type + in a future version* + + Stores a date value, with no time-of-day + """ + db_type = 'timestamp' + + def to_python(self, value): + if value is None: + return + if isinstance(value, datetime): + return value.date() + elif isinstance(value, date): + return value + try: + return datetime.utcfromtimestamp(value).date() + except TypeError: + return datetime.utcfromtimestamp(DateType.deserialize(value)).date() + + def to_database(self, value): + value = super(Date, self).to_database(value) + if value is None: + return + if isinstance(value, datetime): + value = value.date() + if not isinstance(value, date): + raise ValidationError("{} '{}' is not a date object".format(self.column_name, repr(value))) + + return int((value - date(1970, 1, 1)).total_seconds() * 1000) + + +class UUID(Column): + """ + Stores a type 1 or 4 UUID + """ + db_type = 'uuid' + + re_uuid = re.compile(r'[0-9a-f]{8}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{12}') + + def validate(self, value): + val = super(UUID, self).validate(value) + if val is None: + return + from uuid import UUID as _UUID + if isinstance(val, _UUID): + return val + if isinstance(val, six.string_types) and self.re_uuid.match(val): + return _UUID(val) + raise ValidationError("{} {} is not a valid uuid".format(self.column_name, value)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + +from uuid import UUID as pyUUID, getnode + + +class TimeUUID(UUID): + """ + UUID containing timestamp + """ + + db_type = 'timeuuid' + + @classmethod + def from_datetime(self, dt): + """ + generates a UUID for a given datetime + + :param dt: datetime + :type dt: datetime + :return: + """ + global _last_timestamp + + epoch = datetime(1970, 1, 1, tzinfo=dt.tzinfo) + offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 + timestamp = (dt - epoch).total_seconds() - offset + + node = None + clock_seq = None + + nanoseconds = int(timestamp * 1e9) + timestamp = int(nanoseconds // 100) + 0x01b21dd213814000 + + if clock_seq is None: + import random + clock_seq = random.randrange(1 << 14) # instead of stable storage + time_low = timestamp & 0xffffffff + time_mid = (timestamp >> 32) & 0xffff + time_hi_version = (timestamp >> 48) & 0x0fff + clock_seq_low = clock_seq & 0xff + clock_seq_hi_variant = (clock_seq >> 8) & 0x3f + if node is None: + node = getnode() + return pyUUID(fields=(time_low, time_mid, time_hi_version, + clock_seq_hi_variant, clock_seq_low, node), version=1) + + +class Boolean(Column): + """ + Stores a boolean True or False value + """ + db_type = 'boolean' + + def validate(self, value): + """ Always returns a Python boolean. """ + value = super(Boolean, self).validate(value) + + if value is not None: + value = bool(value) + + return value + + def to_python(self, value): + return self.validate(value) + + +class Float(Column): + """ + Stores a floating point value + """ + db_type = 'double' + + def __init__(self, double_precision=True, **kwargs): + self.db_type = 'double' if double_precision else 'float' + super(Float, self).__init__(**kwargs) + + def validate(self, value): + value = super(Float, self).validate(value) + if value is None: + return + try: + return float(value) + except (TypeError, ValueError): + raise ValidationError("{} {} is not a valid float".format(self.column_name, value)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + +class Decimal(Column): + """ + Stores a variable precision decimal value + """ + db_type = 'decimal' + + def validate(self, value): + from decimal import Decimal as _Decimal + from decimal import InvalidOperation + val = super(Decimal, self).validate(value) + if val is None: + return + try: + return _Decimal(val) + except InvalidOperation: + raise ValidationError("{} '{}' can't be coerced to decimal".format(self.column_name, val)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + +class BaseContainerColumn(Column): + """ + Base Container type for collection-like columns. + + https://cassandra.apache.org/doc/cql3/CQL.html#collections + """ + + def __init__(self, value_type, **kwargs): + """ + :param value_type: a column class indicating the types of the value + """ + inheritance_comparator = issubclass if isinstance(value_type, type) else isinstance + if not inheritance_comparator(value_type, Column): + raise ValidationError('value_type must be a column class') + if inheritance_comparator(value_type, BaseContainerColumn): + raise ValidationError('container types cannot be nested') + if value_type.db_type is None: + raise ValidationError('value_type cannot be an abstract column type') + + if isinstance(value_type, type): + self.value_type = value_type + self.value_col = self.value_type() + else: + self.value_col = value_type + self.value_type = self.value_col.__class__ + + super(BaseContainerColumn, self).__init__(**kwargs) + + def validate(self, value): + value = super(BaseContainerColumn, self).validate(value) + # It is dangerous to let collections have more than 65535. + # See: https://issues.apache.org/jira/browse/CASSANDRA-5428 + if value is not None and len(value) > 65535: + raise ValidationError("{} Collection can't have more than 65535 elements.".format(self.column_name)) + return value + + def _val_is_null(self, val): + return not val + + @property + def sub_columns(self): + return [self.value_col] + + +class BaseContainerQuoter(ValueQuoter): + + def __nonzero__(self): + return bool(self.value) + + +class Set(BaseContainerColumn): + """ + Stores a set of unordered, unique values + + http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_set_t.html + """ + class Quoter(BaseContainerQuoter): + + def __str__(self): + cq = cql_quote + return '{' + ', '.join([cq(v) for v in self.value]) + '}' + + def __init__(self, value_type, strict=True, default=set, **kwargs): + """ + :param value_type: a column class indicating the types of the value + :param strict: sets whether non set values will be coerced to set + type on validation, or raise a validation error, defaults to True + """ + self.strict = strict + self.db_type = 'set<{}>'.format(value_type.db_type) + super(Set, self).__init__(value_type, default=default, **kwargs) + + def validate(self, value): + val = super(Set, self).validate(value) + if val is None: + return + types = (set,) if self.strict else (set, list, tuple) + if not isinstance(val, types): + if self.strict: + raise ValidationError('{} {} is not a set object'.format(self.column_name, val)) + else: + raise ValidationError('{} {} cannot be coerced to a set object'.format(self.column_name, val)) + + if None in val: + raise ValidationError("{} None not allowed in a set".format(self.column_name)) + + return {self.value_col.validate(v) for v in val} + + def to_python(self, value): + if value is None: + return set() + return {self.value_col.to_python(v) for v in value} + + def to_database(self, value): + if value is None: + return None + + if isinstance(value, self.Quoter): + return value + return self.Quoter({self.value_col.to_database(v) for v in value}) + + +class List(BaseContainerColumn): + """ + Stores a list of ordered values + + http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_list_t.html + """ + class Quoter(BaseContainerQuoter): + + def __str__(self): + cq = cql_quote + return '[' + ', '.join([cq(v) for v in self.value]) + ']' + + def __nonzero__(self): + return bool(self.value) + + def __init__(self, value_type, default=list, **kwargs): + """ + :param value_type: a column class indicating the types of the value + """ + self.db_type = 'list<{}>'.format(value_type.db_type) + return super(List, self).__init__(value_type=value_type, default=default, **kwargs) + + def validate(self, value): + val = super(List, self).validate(value) + if val is None: + return + if not isinstance(val, (set, list, tuple)): + raise ValidationError('{} {} is not a list object'.format(self.column_name, val)) + if None in val: + raise ValidationError("{} None is not allowed in a list".format(self.column_name)) + return [self.value_col.validate(v) for v in val] + + def to_python(self, value): + if value is None: + return [] + return [self.value_col.to_python(v) for v in value] + + def to_database(self, value): + if value is None: + return None + if isinstance(value, self.Quoter): + return value + return self.Quoter([self.value_col.to_database(v) for v in value]) + + +class Map(BaseContainerColumn): + """ + Stores a key -> value map (dictionary) + + http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_map_t.html + """ + class Quoter(BaseContainerQuoter): + + def __str__(self): + cq = cql_quote + return '{' + ', '.join([cq(k) + ':' + cq(v) for k, v in self.value.items()]) + '}' + + def get(self, key): + return self.value.get(key) + + def keys(self): + return self.value.keys() + + def items(self): + return self.value.items() + + def __init__(self, key_type, value_type, default=dict, **kwargs): + """ + :param key_type: a column class indicating the types of the key + :param value_type: a column class indicating the types of the value + """ + + self.db_type = 'map<{}, {}>'.format(key_type.db_type, value_type.db_type) + + inheritance_comparator = issubclass if isinstance(key_type, type) else isinstance + if not inheritance_comparator(key_type, Column): + raise ValidationError('key_type must be a column class') + if inheritance_comparator(key_type, BaseContainerColumn): + raise ValidationError('container types cannot be nested') + if key_type.db_type is None: + raise ValidationError('key_type cannot be an abstract column type') + + if isinstance(key_type, type): + self.key_type = key_type + self.key_col = self.key_type() + else: + self.key_col = key_type + self.key_type = self.key_col.__class__ + super(Map, self).__init__(value_type, default=default, **kwargs) + + def validate(self, value): + val = super(Map, self).validate(value) + if val is None: + return + if not isinstance(val, dict): + raise ValidationError('{} {} is not a dict object'.format(self.column_name, val)) + return {self.key_col.validate(k): self.value_col.validate(v) for k, v in val.items()} + + def to_python(self, value): + if value is None: + return {} + if value is not None: + return {self.key_col.to_python(k): self.value_col.to_python(v) for k, v in value.items()} + + def to_database(self, value): + if value is None: + return None + if isinstance(value, self.Quoter): + return value + return self.Quoter({self.key_col.to_database(k): self.value_col.to_database(v) for k, v in value.items()}) + + @property + def sub_columns(self): + return [self.key_col, self.value_col] + + +class UDTValueManager(BaseValueManager): + @property + def changed(self): + return self.value != self.previous_value or self.value.has_changed_fields() + + def reset_previous_value(self): + self.value.reset_changed_fields() + self.previous_value = copy(self.value) + + +class UserDefinedType(Column): + """ + User Defined Type column + + http://www.datastax.com/documentation/cql/3.1/cql/cql_using/cqlUseUDT.html + + These columns are represented by a specialization of :class:`cassandra.cqlengine.usertype.UserType`. + + Please see :ref:`user_types` for examples and discussion. + """ + + value_manager = UDTValueManager + + def __init__(self, user_type, **kwargs): + """ + :param type user_type: specifies the :class:`~.UserType` model of the column + """ + self.user_type = user_type + self.db_type = "frozen<%s>" % user_type.type_name() + super(UserDefinedType, self).__init__(**kwargs) + + @property + def sub_columns(self): + return list(self.user_type._fields.values()) + + +def resolve_udts(col_def, out_list): + for col in col_def.sub_columns: + resolve_udts(col, out_list) + if isinstance(col_def, UserDefinedType): + out_list.append(col_def.user_type) + + +class _PartitionKeysToken(Column): + """ + virtual column representing token of partition columns. + Used by filter(pk__token=Token(...)) filters + """ + + def __init__(self, model): + self.partition_columns = model._partition_keys.values() + super(_PartitionKeysToken, self).__init__(partition_key=True) + + @property + def db_field_name(self): + return 'token({})'.format(', '.join(['"{}"'.format(c.db_field_name) for c in self.partition_columns])) + + def to_database(self, value): + from cqlengine.functions import Token + assert isinstance(value, Token) + value.set_columns(self.partition_columns) + return value + + def get_cql(self): + return "token({})".format(", ".join(c.cql for c in self.partition_columns)) diff --git a/cassandra/cqlengine/connection.py b/cassandra/cqlengine/connection.py new file mode 100644 index 0000000..ead241c --- /dev/null +++ b/cassandra/cqlengine/connection.py @@ -0,0 +1,215 @@ +# Copyright 2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import logging +import six + +from cassandra import ConsistencyLevel +from cassandra.cluster import Cluster, _NOT_SET, NoHostAvailable, UserTypeDoesNotExist +from cassandra.query import SimpleStatement, Statement, dict_factory + +from cassandra.cqlengine import CQLEngineException +from cassandra.cqlengine.statements import BaseCQLStatement + + +log = logging.getLogger(__name__) + +NOT_SET = _NOT_SET # required for passing timeout to Session.execute + +Host = namedtuple('Host', ['name', 'port']) + +cluster = None +session = None +lazy_connect_args = None +default_consistency_level = ConsistencyLevel.ONE + + +# Because type models may be registered before a connection is present, +# and because sessions may be replaced, we must register UDTs here, in order +# to have them registered when a new session is established. +udt_by_keyspace = {} + + +class UndefinedKeyspaceException(CQLEngineException): + pass + + +def default(): + """ + Configures the global mapper connection to localhost, using the driver defaults + (except for row_factory) + """ + global cluster, session + + if session: + log.warning("configuring new connection for cqlengine when one was already set") + + cluster = Cluster() + session = cluster.connect() + session.row_factory = dict_factory + + _register_known_types(cluster) + + log.debug("cqlengine connection initialized with default session to localhost") + + +def set_session(s): + """ + Configures the global mapper connection with a preexisting :class:`cassandra.cluster.Session` + + Note: the mapper presently requires a Session :attr:`~.row_factory` set to ``dict_factory``. + This may be relaxed in the future + """ + global cluster, session + + if session: + log.warning("configuring new connection for cqlengine when one was already set") + + if s.row_factory is not dict_factory: + raise CQLEngineException("Failed to initialize: 'Session.row_factory' must be 'dict_factory'.") + session = s + cluster = s.cluster + + _register_known_types(cluster) + + log.debug("cqlengine connection initialized with %s", s) + + +def setup( + hosts, + default_keyspace, + consistency=ConsistencyLevel.ONE, + lazy_connect=False, + retry_connect=False, + **kwargs): + """ + Setup a the driver connection used by the mapper + + :param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`) + :param str default_keyspace: The default keyspace to use + :param int consistency: The global default :class:`~.ConsistencyLevel` + :param bool lazy_connect: True if should not connect until first use + :param bool retry_connect: True if we should retry to connect even if there was a connection failure initially + :param \*\*kwargs: Pass-through keyword arguments for :class:`cassandra.cluster.Cluster` + """ + global cluster, session, default_consistency_level, lazy_connect_args + + if 'username' in kwargs or 'password' in kwargs: + raise CQLEngineException("Username & Password are now handled by using the native driver's auth_provider") + + if not default_keyspace: + raise UndefinedKeyspaceException() + + from cassandra.cqlengine import models + models.DEFAULT_KEYSPACE = default_keyspace + + default_consistency_level = consistency + if lazy_connect: + kwargs['default_keyspace'] = default_keyspace + kwargs['consistency'] = consistency + kwargs['lazy_connect'] = False + kwargs['retry_connect'] = retry_connect + lazy_connect_args = (hosts, kwargs) + return + + cluster = Cluster(hosts, **kwargs) + try: + session = cluster.connect() + log.debug("cqlengine connection initialized with internally created session") + except NoHostAvailable: + if retry_connect: + log.warning("connect failed, setting up for re-attempt on first use") + kwargs['default_keyspace'] = default_keyspace + kwargs['consistency'] = consistency + kwargs['lazy_connect'] = False + kwargs['retry_connect'] = retry_connect + lazy_connect_args = (hosts, kwargs) + raise + session.row_factory = dict_factory + + _register_known_types(cluster) + + +def execute(query, params=None, consistency_level=None, timeout=NOT_SET): + + handle_lazy_connect() + + if not session: + raise CQLEngineException("It is required to setup() cqlengine before executing queries") + + if consistency_level is None: + consistency_level = default_consistency_level + + if isinstance(query, Statement): + pass + + elif isinstance(query, BaseCQLStatement): + params = query.get_context() + query = str(query) + query = SimpleStatement(query, consistency_level=consistency_level) + + elif isinstance(query, six.string_types): + query = SimpleStatement(query, consistency_level=consistency_level) + + log.debug(query.query_string) + + params = params or {} + result = session.execute(query, params, timeout=timeout) + + return result + + +def get_session(): + handle_lazy_connect() + return session + + +def get_cluster(): + handle_lazy_connect() + if not cluster: + raise CQLEngineException("%s.cluster is not configured. Call one of the setup or default functions first." % __name__) + return cluster + + +def handle_lazy_connect(): + global lazy_connect_args + if lazy_connect_args: + log.debug("lazy connect") + hosts, kwargs = lazy_connect_args + lazy_connect_args = None + setup(hosts, **kwargs) + + +def register_udt(keyspace, type_name, klass): + try: + udt_by_keyspace[keyspace][type_name] = klass + except KeyError: + udt_by_keyspace[keyspace] = {type_name: klass} + + global cluster + if cluster: + try: + cluster.register_user_type(keyspace, type_name, klass) + except UserTypeDoesNotExist: + pass # new types are covered in management sync functions + + +def _register_known_types(cluster): + for ks_name, name_type_map in udt_by_keyspace.items(): + for type_name, klass in name_type_map.items(): + try: + cluster.register_user_type(ks_name, type_name, klass) + except UserTypeDoesNotExist: + pass # new types are covered in management sync functions diff --git a/cassandra/cqlengine/functions.py b/cassandra/cqlengine/functions.py new file mode 100644 index 0000000..5c6e4c7 --- /dev/null +++ b/cassandra/cqlengine/functions.py @@ -0,0 +1,133 @@ +# Copyright 2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +from cassandra.cqlengine import UnicodeMixin, ValidationError + + +class QueryValue(UnicodeMixin): + """ + Base class for query filter values. Subclasses of these classes can + be passed into .filter() keyword args + """ + + format_string = '%({})s' + + def __init__(self, value): + self.value = value + self.context_id = None + + def __unicode__(self): + return self.format_string.format(self.context_id) + + def set_context_id(self, ctx_id): + self.context_id = ctx_id + + def get_context_size(self): + return 1 + + def update_context(self, ctx): + ctx[str(self.context_id)] = self.value + + +class BaseQueryFunction(QueryValue): + """ + Base class for filtering functions. Subclasses of these classes can + be passed into .filter() and will be translated into CQL functions in + the resulting query + """ + pass + + +class MinTimeUUID(BaseQueryFunction): + """ + return a fake timeuuid corresponding to the smallest possible timeuuid for the given timestamp + + http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun + """ + + format_string = 'MinTimeUUID(%({})s)' + + def __init__(self, value): + """ + :param value: the time to create a maximum time uuid from + :type value: datetime + """ + if not isinstance(value, datetime): + raise ValidationError('datetime instance is required') + super(MinTimeUUID, self).__init__(value) + + def to_database(self, val): + epoch = datetime(1970, 1, 1, tzinfo=val.tzinfo) + offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 + return int(((val - epoch).total_seconds() - offset) * 1000) + + def update_context(self, ctx): + ctx[str(self.context_id)] = self.to_database(self.value) + + +class MaxTimeUUID(BaseQueryFunction): + """ + return a fake timeuuid corresponding to the largest possible timeuuid for the given timestamp + + http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun + """ + + format_string = 'MaxTimeUUID(%({})s)' + + def __init__(self, value): + """ + :param value: the time to create a minimum time uuid from + :type value: datetime + """ + if not isinstance(value, datetime): + raise ValidationError('datetime instance is required') + super(MaxTimeUUID, self).__init__(value) + + def to_database(self, val): + epoch = datetime(1970, 1, 1, tzinfo=val.tzinfo) + offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 + return int(((val - epoch).total_seconds() - offset) * 1000) + + def update_context(self, ctx): + ctx[str(self.context_id)] = self.to_database(self.value) + + +class Token(BaseQueryFunction): + """ + compute the token for a given partition key + + http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun + """ + + def __init__(self, *values): + if len(values) == 1 and isinstance(values[0], (list, tuple)): + values = values[0] + super(Token, self).__init__(values) + self._columns = None + + def set_columns(self, columns): + self._columns = columns + + def get_context_size(self): + return len(self.value) + + def __unicode__(self): + token_args = ', '.join('%({})s'.format(self.context_id + i) for i in range(self.get_context_size())) + return "token({})".format(token_args) + + def update_context(self, ctx): + for i, (col, val) in enumerate(zip(self._columns, self.value)): + ctx[str(self.context_id + i)] = col.to_database(val) diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py new file mode 100644 index 0000000..7ea642a --- /dev/null +++ b/cassandra/cqlengine/management.py @@ -0,0 +1,538 @@ +# Copyright 2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import json +import logging +import os +import six +import warnings + +from cassandra import metadata +from cassandra.cqlengine import CQLEngineException, SizeTieredCompactionStrategy, LeveledCompactionStrategy +from cassandra.cqlengine import columns +from cassandra.cqlengine.connection import execute, get_cluster +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.named import NamedTable +from cassandra.cqlengine.usertype import UserType + +CQLENG_ALLOW_SCHEMA_MANAGEMENT = 'CQLENG_ALLOW_SCHEMA_MANAGEMENT' + +Field = namedtuple('Field', ['name', 'type']) + +log = logging.getLogger(__name__) + +# system keyspaces +schema_columnfamilies = NamedTable('system', 'schema_columnfamilies') + + +def create_keyspace(name, strategy_class, replication_factor, durable_writes=True, **replication_values): + """ + *Deprecated - use :func:`create_keyspace_simple` or :func:`create_keyspace_network_topology` instead* + + Creates a keyspace + + If the keyspace already exists, it will not be modified. + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + + :param str name: name of keyspace to create + :param str strategy_class: keyspace replication strategy class (:attr:`~.SimpleStrategy` or :attr:`~.NetworkTopologyStrategy` + :param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy` + :param bool durable_writes: Write log is bypassed if set to False + :param \*\*replication_values: Additional values to ad to the replication options map + """ + if not _allow_schema_modification(): + return + + msg = "Deprecated. Use create_keyspace_simple or create_keyspace_network_topology instead" + warnings.warn(msg, DeprecationWarning) + log.warning(msg) + + cluster = get_cluster() + + if name not in cluster.metadata.keyspaces: + # try the 1.2 method + replication_map = { + 'class': strategy_class, + 'replication_factor': replication_factor + } + replication_map.update(replication_values) + if strategy_class.lower() != 'simplestrategy': + # Although the Cassandra documentation states for `replication_factor` + # that it is "Required if class is SimpleStrategy; otherwise, + # not used." we get an error if it is present. + replication_map.pop('replication_factor', None) + + query = """ + CREATE KEYSPACE {} + WITH REPLICATION = {} + """.format(name, json.dumps(replication_map).replace('"', "'")) + + if strategy_class != 'SimpleStrategy': + query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false') + + execute(query) + + +def create_keyspace_simple(name, replication_factor, durable_writes=True): + """ + Creates a keyspace with SimpleStrategy for replica placement + + If the keyspace already exists, it will not be modified. + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + + :param str name: name of keyspace to create + :param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy` + :param bool durable_writes: Write log is bypassed if set to False + """ + _create_keyspace(name, durable_writes, 'SimpleStrategy', + {'replication_factor': replication_factor}) + + +def create_keyspace_network_topology(name, dc_replication_map, durable_writes=True): + """ + Creates a keyspace with NetworkTopologyStrategy for replica placement + + If the keyspace already exists, it will not be modified. + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + + :param str name: name of keyspace to create + :param dict dc_replication_map: map of dc_names: replication_factor + :param bool durable_writes: Write log is bypassed if set to False + """ + _create_keyspace(name, durable_writes, 'NetworkTopologyStrategy', dc_replication_map) + + +def _create_keyspace(name, durable_writes, strategy_class, strategy_options): + if not _allow_schema_modification(): + return + + cluster = get_cluster() + + if name not in cluster.metadata.keyspaces: + log.info("Creating keyspace %s ", name) + ks_meta = metadata.KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) + execute(ks_meta.as_cql_query()) + else: + log.info("Not creating keyspace %s because it already exists", name) + + +def delete_keyspace(name): + msg = "Deprecated. Use drop_keyspace instead" + warnings.warn(msg, DeprecationWarning) + log.warning(msg) + drop_keyspace(name) + + +def drop_keyspace(name): + """ + Drops a keyspace, if it exists. + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + :param str name: name of keyspace to drop + """ + if not _allow_schema_modification(): + return + + cluster = get_cluster() + if name in cluster.metadata.keyspaces: + execute("DROP KEYSPACE {}".format(name)) + + +def sync_table(model): + """ + Inspects the model and creates / updates the corresponding table and columns. + + Any User Defined Types used in the table are implicitly synchronized. + + This function can only add fields that are not part of the primary key. + + Note that the attributes removed from the model are not deleted on the database. + They become effectively ignored by (will not show up on) the model. + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + """ + if not _allow_schema_modification(): + return + + if not issubclass(model, Model): + raise CQLEngineException("Models must be derived from base Model.") + + if model.__abstract__: + raise CQLEngineException("cannot create table from abstract model") + + # construct query string + cf_name = model.column_family_name() + raw_cf_name = model.column_family_name(include_keyspace=False) + + ks_name = model._get_keyspace() + + cluster = get_cluster() + + keyspace = cluster.metadata.keyspaces[ks_name] + tables = keyspace.tables + + syncd_types = set() + for col in model._columns.values(): + udts = [] + columns.resolve_udts(col, udts) + for udt in [u for u in udts if u not in syncd_types]: + _sync_type(ks_name, udt, syncd_types) + + # check for an existing column family + if raw_cf_name not in tables: + log.debug("sync_table creating new table %s", cf_name) + qs = get_create_table(model) + + try: + execute(qs) + except CQLEngineException as ex: + # 1.2 doesn't return cf names, so we have to examine the exception + # and ignore if it says the column family already exists + if "Cannot add already existing column family" not in unicode(ex): + raise + else: + log.debug("sync_table checking existing table %s", cf_name) + # see if we're missing any columns + fields = get_fields(model) + field_names = [x.name for x in fields] + model_fields = set() + # # TODO: does this work with db_name?? + for name, col in model._columns.items(): + if col.primary_key or col.partition_key: + continue # we can't mess with the PK + model_fields.add(name) + if col.db_field_name in field_names: + continue # skip columns already defined + + # add missing column using the column def + query = "ALTER TABLE {} add {}".format(cf_name, col.get_column_def()) + execute(query) + + db_fields_not_in_model = model_fields.symmetric_difference(field_names) + if db_fields_not_in_model: + log.info("Table %s has fields not referenced by model: %s", cf_name, db_fields_not_in_model) + + update_compaction(model) + + table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name] + + indexes = [c for n, c in model._columns.items() if c.index] + + for column in indexes: + if table.columns[column.db_field_name].index: + continue + + qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)] + qs += ['ON {}'.format(cf_name)] + qs += ['("{}")'.format(column.db_field_name)] + qs = ' '.join(qs) + execute(qs) + + +def sync_type(ks_name, type_model): + """ + Inspects the type_model and creates / updates the corresponding type. + + Note that the attributes removed from the type_model are not deleted on the database (this operation is not supported). + They become effectively ignored by (will not show up on) the type_model. + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + """ + if not _allow_schema_modification(): + return + + if not issubclass(type_model, UserType): + raise CQLEngineException("Types must be derived from base UserType.") + + _sync_type(ks_name, type_model) + + +def _sync_type(ks_name, type_model, omit_subtypes=None): + + syncd_sub_types = omit_subtypes or set() + for field in type_model._fields.values(): + udts = [] + columns.resolve_udts(field, udts) + for udt in [u for u in udts if u not in syncd_sub_types]: + _sync_type(ks_name, udt, syncd_sub_types) + syncd_sub_types.add(udt) + + type_name = type_model.type_name() + type_name_qualified = "%s.%s" % (ks_name, type_name) + + cluster = get_cluster() + + keyspace = cluster.metadata.keyspaces[ks_name] + defined_types = keyspace.user_types + + if type_name not in defined_types: + log.debug("sync_type creating new type %s", type_name_qualified) + cql = get_create_type(type_model, ks_name) + execute(cql) + cluster.refresh_schema(keyspace=ks_name, usertype=type_name) + type_model.register_for_keyspace(ks_name) + else: + defined_fields = defined_types[type_name].field_names + model_fields = set() + for field in type_model._fields.values(): + model_fields.add(field.db_field_name) + if field.db_field_name not in defined_fields: + execute("ALTER TYPE {} ADD {}".format(type_name_qualified, field.get_column_def())) + + if len(defined_fields) == len(model_fields): + log.info("Type %s did not require synchronization", type_name_qualified) + return + + db_fields_not_in_model = model_fields.symmetric_difference(defined_fields) + if db_fields_not_in_model: + log.info("Type %s has fields not referenced by model: %s", type_name_qualified, db_fields_not_in_model) + + type_model.register_for_keyspace(ks_name) + + +def get_create_type(type_model, keyspace): + type_meta = metadata.UserType(keyspace, + type_model.type_name(), + (f.db_field_name for f in type_model._fields.values()), + type_model._fields.values()) + return type_meta.as_cql_query() + + +def get_create_table(model): + cf_name = model.column_family_name() + qs = ['CREATE TABLE {}'.format(cf_name)] + + # add column types + pkeys = [] # primary keys + ckeys = [] # clustering keys + qtypes = [] # field types + + def add_column(col): + s = col.get_column_def() + if col.primary_key: + keys = (pkeys if col.partition_key else ckeys) + keys.append('"{}"'.format(col.db_field_name)) + qtypes.append(s) + + for name, col in model._columns.items(): + add_column(col) + + qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or '')) + + qs += ['({})'.format(', '.join(qtypes))] + + with_qs = [] + + table_properties = ['bloom_filter_fp_chance', 'caching', 'comment', + 'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds', + 'index_interval', 'memtable_flush_period_in_ms', 'populate_io_cache_on_flush', + 'read_repair_chance', 'replicate_on_write'] + for prop_name in table_properties: + prop_value = getattr(model, '__{}__'.format(prop_name), None) + if prop_value is not None: + # Strings needs to be single quoted + if isinstance(prop_value, six.string_types): + prop_value = "'{}'".format(prop_value) + with_qs.append("{} = {}".format(prop_name, prop_value)) + + _order = ['"{}" {}'.format(c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()] + if _order: + with_qs.append('clustering order by ({})'.format(', '.join(_order))) + + compaction_options = get_compaction_options(model) + if compaction_options: + compaction_options = json.dumps(compaction_options).replace('"', "'") + with_qs.append("compaction = {}".format(compaction_options)) + + # Add table properties. + if with_qs: + qs += ['WITH {}'.format(' AND '.join(with_qs))] + + qs = ' '.join(qs) + return qs + + +def get_compaction_options(model): + """ + Generates dictionary (later converted to a string) for creating and altering + tables with compaction strategy + + :param model: + :return: + """ + if not model.__compaction__: + return {} + + result = {'class': model.__compaction__} + + def setter(key, limited_to_strategy=None): + """ + sets key in result, checking if the key is limited to either SizeTiered or Leveled + :param key: one of the compaction options, like "bucket_high" + :param limited_to_strategy: SizeTieredCompactionStrategy, LeveledCompactionStrategy + :return: + """ + mkey = "__compaction_{}__".format(key) + tmp = getattr(model, mkey) + if tmp and limited_to_strategy and limited_to_strategy != model.__compaction__: + raise CQLEngineException("{} is limited to {}".format(key, limited_to_strategy)) + + if tmp: + # Explicitly cast the values to strings to be able to compare the + # values against introspected values from Cassandra. + result[key] = str(tmp) + + setter('tombstone_compaction_interval') + setter('tombstone_threshold') + + setter('bucket_high', SizeTieredCompactionStrategy) + setter('bucket_low', SizeTieredCompactionStrategy) + setter('max_threshold', SizeTieredCompactionStrategy) + setter('min_threshold', SizeTieredCompactionStrategy) + setter('min_sstable_size', SizeTieredCompactionStrategy) + + setter('sstable_size_in_mb', LeveledCompactionStrategy) + + return result + + +def get_fields(model): + # returns all fields that aren't part of the PK + ks_name = model._get_keyspace() + col_family = model.column_family_name(include_keyspace=False) + field_types = ['regular', 'static'] + query = "select * from system.schema_columns where keyspace_name = %s and columnfamily_name = %s" + tmp = execute(query, [ks_name, col_family]) + + # Tables containing only primary keys do not appear to create + # any entries in system.schema_columns, as only non-primary-key attributes + # appear to be inserted into the schema_columns table + try: + return [Field(x['column_name'], x['validator']) for x in tmp if x['type'] in field_types] + except KeyError: + return [Field(x['column_name'], x['validator']) for x in tmp] + # convert to Field named tuples + + +def get_table_settings(model): + # returns the table as provided by the native driver for a given model + cluster = get_cluster() + ks = model._get_keyspace() + table = model.column_family_name(include_keyspace=False) + table = cluster.metadata.keyspaces[ks].tables[table] + return table + + +def update_compaction(model): + """Updates the compaction options for the given model if necessary. + + :param model: The model to update. + + :return: `True`, if the compaction options were modified in Cassandra, + `False` otherwise. + :rtype: bool + """ + log.debug("Checking %s for compaction differences", model) + table = get_table_settings(model) + + existing_options = table.options.copy() + + existing_compaction_strategy = existing_options['compaction_strategy_class'] + + existing_options = json.loads(existing_options['compaction_strategy_options']) + + desired_options = get_compaction_options(model) + + desired_compact_strategy = desired_options.get('class', SizeTieredCompactionStrategy) + + desired_options.pop('class', None) + + do_update = False + + if desired_compact_strategy not in existing_compaction_strategy: + do_update = True + + for k, v in desired_options.items(): + val = existing_options.pop(k, None) + if val != v: + do_update = True + + # check compaction_strategy_options + if do_update: + options = get_compaction_options(model) + # jsonify + options = json.dumps(options).replace('"', "'") + cf_name = model.column_family_name() + query = "ALTER TABLE {} with compaction = {}".format(cf_name, options) + execute(query) + return True + + return False + + +def drop_table(model): + """ + Drops the table indicated by the model, if it exists. + + **This function should be used with caution, especially in production environments. + Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** + + *There are plans to guard schema-modifying functions with an environment-driven conditional.* + """ + if not _allow_schema_modification(): + return + + # don't try to delete non existant tables + meta = get_cluster().metadata + + ks_name = model._get_keyspace() + raw_cf_name = model.column_family_name(include_keyspace=False) + + try: + meta.keyspaces[ks_name].tables[raw_cf_name] + execute('drop table {};'.format(model.column_family_name(include_keyspace=True))) + except KeyError: + pass + + +def _allow_schema_modification(): + if not os.getenv(CQLENG_ALLOW_SCHEMA_MANAGEMENT): + msg = CQLENG_ALLOW_SCHEMA_MANAGEMENT + " environment variable is not set. Future versions of this package will require this variable to enable management functions." + warnings.warn(msg) + log.warning(msg) + + return True diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py new file mode 100644 index 0000000..35d10bb --- /dev/null +++ b/cassandra/cqlengine/models.py @@ -0,0 +1,976 @@ +# Copyright 2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +import six +import warnings + +from cassandra.cqlengine import CQLEngineException, ValidationError +from cassandra.cqlengine import columns +from cassandra.cqlengine import connection +from cassandra.cqlengine import query +from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist +from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned +from cassandra.util import OrderedDict + +log = logging.getLogger(__name__) + + +class ModelException(CQLEngineException): + pass + + +class ModelDefinitionException(ModelException): + pass + + +class PolymorphicModelException(ModelException): + pass + + +class UndefinedKeyspaceWarning(Warning): + pass + +DEFAULT_KEYSPACE = None + + +class hybrid_classmethod(object): + """ + Allows a method to behave as both a class method and + normal instance method depending on how it's called + """ + def __init__(self, clsmethod, instmethod): + self.clsmethod = clsmethod + self.instmethod = instmethod + + def __get__(self, instance, owner): + if instance is None: + return self.clsmethod.__get__(owner, owner) + else: + return self.instmethod.__get__(instance, owner) + + def __call__(self, *args, **kwargs): + """ + Just a hint to IDEs that it's ok to call this + """ + raise NotImplementedError + + +class QuerySetDescriptor(object): + """ + returns a fresh queryset for the given model + it's declared on everytime it's accessed + """ + + def __get__(self, obj, model): + """ :rtype: ModelQuerySet """ + if model.__abstract__: + raise CQLEngineException('cannot execute queries against abstract models') + queryset = model.__queryset__(model) + + # if this is a concrete polymorphic model, and the discriminator + # key is an indexed column, add a filter clause to only return + # logical rows of the proper type + if model._is_polymorphic and not model._is_polymorphic_base: + name, column = model._discriminator_column_name, model._discriminator_column + if column.partition_key or column.index: + # look for existing poly types + return queryset.filter(**{name: model.__discriminator_value__}) + + return queryset + + def __call__(self, *args, **kwargs): + """ + Just a hint to IDEs that it's ok to call this + + :rtype: ModelQuerySet + """ + raise NotImplementedError + + +class TransactionDescriptor(object): + """ + returns a query set descriptor + """ + def __get__(self, instance, model): + if instance: + def transaction_setter(*prepared_transaction, **unprepared_transactions): + if len(prepared_transaction) > 0: + transactions = prepared_transaction[0] + else: + transactions = instance.objects.iff(**unprepared_transactions)._transaction + instance._transaction = transactions + return instance + + return transaction_setter + qs = model.__queryset__(model) + + def transaction_setter(**unprepared_transactions): + transactions = model.objects.iff(**unprepared_transactions)._transaction + qs._transaction = transactions + return qs + return transaction_setter + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class TTLDescriptor(object): + """ + returns a query set descriptor + """ + def __get__(self, instance, model): + if instance: + # instance = copy.deepcopy(instance) + # instance method + def ttl_setter(ts): + instance._ttl = ts + return instance + return ttl_setter + + qs = model.__queryset__(model) + + def ttl_setter(ts): + qs._ttl = ts + return qs + + return ttl_setter + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class TimestampDescriptor(object): + """ + returns a query set descriptor with a timestamp specified + """ + def __get__(self, instance, model): + if instance: + # instance method + def timestamp_setter(ts): + instance._timestamp = ts + return instance + return timestamp_setter + + return model.objects.timestamp + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class IfNotExistsDescriptor(object): + """ + return a query set descriptor with a if_not_exists flag specified + """ + def __get__(self, instance, model): + if instance: + # instance method + def ifnotexists_setter(ife): + instance._if_not_exists = ife + return instance + return ifnotexists_setter + + return model.objects.if_not_exists + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class ConsistencyDescriptor(object): + """ + returns a query set descriptor if called on Class, instance if it was an instance call + """ + def __get__(self, instance, model): + if instance: + # instance = copy.deepcopy(instance) + def consistency_setter(consistency): + instance.__consistency__ = consistency + return instance + return consistency_setter + + qs = model.__queryset__(model) + + def consistency_setter(consistency): + qs._consistency = consistency + return qs + + return consistency_setter + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class ColumnQueryEvaluator(query.AbstractQueryableColumn): + """ + Wraps a column and allows it to be used in comparator + expressions, returning query operators + + ie: + Model.column == 5 + """ + + def __init__(self, column): + self.column = column + + def __unicode__(self): + return self.column.db_field_name + + def _get_column(self): + """ :rtype: ColumnQueryEvaluator """ + return self.column + + +class ColumnDescriptor(object): + """ + Handles the reading and writing of column values to and from + a model instance's value manager, as well as creating + comparator queries + """ + + def __init__(self, column): + """ + :param column: + :type column: columns.Column + :return: + """ + self.column = column + self.query_evaluator = ColumnQueryEvaluator(self.column) + + def __get__(self, instance, owner): + """ + Returns either the value or column, depending + on if an instance is provided or not + + :param instance: the model instance + :type instance: Model + """ + try: + return instance._values[self.column.column_name].getval() + except AttributeError: + return self.query_evaluator + + def __set__(self, instance, value): + """ + Sets the value on an instance, raises an exception with classes + TODO: use None instance to create update statements + """ + if instance: + return instance._values[self.column.column_name].setval(value) + else: + raise AttributeError('cannot reassign column values') + + def __delete__(self, instance): + """ + Sets the column value to None, if possible + """ + if instance: + if self.column.can_delete: + instance._values[self.column.column_name].delval() + else: + raise AttributeError('cannot delete {} columns'.format(self.column.column_name)) + + +class BaseModel(object): + """ + The base model class, don't inherit from this, inherit from Model, defined below + """ + + class DoesNotExist(_DoesNotExist): + pass + + class MultipleObjectsReturned(_MultipleObjectsReturned): + pass + + objects = QuerySetDescriptor() + ttl = TTLDescriptor() + consistency = ConsistencyDescriptor() + iff = TransactionDescriptor() + + # custom timestamps, see USING TIMESTAMP X + timestamp = TimestampDescriptor() + + if_not_exists = IfNotExistsDescriptor() + + # _len is lazily created by __len__ + + __table_name__ = None + + __keyspace__ = None + + __default_ttl__ = None + + __polymorphic_key__ = None # DEPRECATED + __discriminator_value__ = None + + # compaction options + __compaction__ = None + __compaction_tombstone_compaction_interval__ = None + __compaction_tombstone_threshold__ = None + + # compaction - size tiered options + __compaction_bucket_high__ = None + __compaction_bucket_low__ = None + __compaction_max_threshold__ = None + __compaction_min_threshold__ = None + __compaction_min_sstable_size__ = None + + # compaction - leveled options + __compaction_sstable_size_in_mb__ = None + + # end compaction + # the queryset class used for this class + __queryset__ = query.ModelQuerySet + __dmlquery__ = query.DMLQuery + + __consistency__ = None # can be set per query + + # Additional table properties + __bloom_filter_fp_chance__ = None + __caching__ = None + __comment__ = None + __dclocal_read_repair_chance__ = None + __default_time_to_live__ = None + __gc_grace_seconds__ = None + __index_interval__ = None + __memtable_flush_period_in_ms__ = None + __populate_io_cache_on_flush__ = None + __read_repair_chance__ = None + __replicate_on_write__ = None + + _timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP) + + _if_not_exists = False # optional if_not_exists flag to check existence before insertion + + def __init__(self, **values): + self._values = {} + self._ttl = self.__default_ttl__ + self._timestamp = None + self._transaction = None + + for name, column in self._columns.items(): + value = values.get(name, None) + if value is not None or isinstance(column, columns.BaseContainerColumn): + value = column.to_python(value) + value_mngr = column.value_manager(self, column, value) + if name in values: + value_mngr.explicit = True + self._values[name] = value_mngr + + # a flag set by the deserializer to indicate + # that update should be used when persisting changes + self._is_persisted = False + self._batch = None + self._timeout = connection.NOT_SET + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, + ', '.join('{}={!r}'.format(k, getattr(self, k)) + for k in self._defined_columns.keys() + if k != self._discriminator_column_name)) + + def __str__(self): + """ + Pretty printing of models by their primary key + """ + return '{} <{}>'.format(self.__class__.__name__, + ', '.join('{}={}'.format(k, getattr(self, k)) for k in self._primary_keys.keys())) + + @classmethod + def _discover_polymorphic_submodels(cls): + if not cls._is_polymorphic_base: + raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes') + + def _discover(klass): + if not klass._is_polymorphic_base and klass.__discriminator_value__ is not None: + cls._discriminator_map[klass.__discriminator_value__] = klass + for subklass in klass.__subclasses__(): + _discover(subklass) + _discover(cls) + + @classmethod + def _get_model_by_discriminator_value(cls, key): + if not cls._is_polymorphic_base: + raise ModelException('_get_model_by_discriminator_value can only be called on polymorphic base classes') + return cls._discriminator_map.get(key) + + @classmethod + def _construct_instance(cls, values): + """ + method used to construct instances from query results + this is where polymorphic deserialization occurs + """ + # we're going to take the values, which is from the DB as a dict + # and translate that into our local fields + # the db_map is a db_field -> model field map + items = values.items() + field_dict = dict([(cls._db_map.get(k, k), v) for k, v in items]) + + if cls._is_polymorphic: + disc_key = field_dict.get(cls._discriminator_column_name) + + if disc_key is None: + raise PolymorphicModelException('discriminator value was not found in values') + + poly_base = cls if cls._is_polymorphic_base else cls._polymorphic_base + + klass = poly_base._get_model_by_discriminator_value(disc_key) + if klass is None: + poly_base._discover_polymorphic_submodels() + klass = poly_base._get_model_by_discriminator_value(disc_key) + if klass is None: + raise PolymorphicModelException( + 'unrecognized discriminator column {} for class {}'.format(disc_key, poly_base.__name__) + ) + + if not issubclass(klass, cls): + raise PolymorphicModelException( + '{} is not a subclass of {}'.format(klass.__name__, cls.__name__) + ) + + field_dict = {k: v for k, v in field_dict.items() if k in klass._columns.keys()} + + else: + klass = cls + + instance = klass(**field_dict) + instance._is_persisted = True + return instance + + def _can_update(self): + """ + Called by the save function to check if this should be + persisted with update or insert + + :return: + """ + if not self._is_persisted: + return False + + return all([not self._values[k].changed for k in self._primary_keys]) + + @classmethod + def _get_keyspace(cls): + """ + Returns the manual keyspace, if set, otherwise the default keyspace + """ + return cls.__keyspace__ or DEFAULT_KEYSPACE + + @classmethod + def _get_column(cls, name): + """ + Returns the column matching the given name, raising a key error if + it doesn't exist + + :param name: the name of the column to return + :rtype: Column + """ + return cls._columns[name] + + def __eq__(self, other): + if self.__class__ != other.__class__: + return False + + # check attribute keys + keys = set(self._columns.keys()) + other_keys = set(other._columns.keys()) + if keys != other_keys: + return False + + # check that all of the attributes match + for key in other_keys: + if getattr(self, key, None) != getattr(other, key, None): + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def column_family_name(cls, include_keyspace=True): + """ + Returns the column family name if it's been defined + otherwise, it creates it from the module and class name + """ + cf_name = '' + if cls.__table_name__: + cf_name = cls.__table_name__.lower() + else: + # get polymorphic base table names if model is polymorphic + if cls._is_polymorphic and not cls._is_polymorphic_base: + return cls._polymorphic_base.column_family_name(include_keyspace=include_keyspace) + + camelcase = re.compile(r'([a-z])([A-Z])') + ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s) + + cf_name += ccase(cls.__name__) + # trim to less than 48 characters or cassandra will complain + cf_name = cf_name[-48:] + cf_name = cf_name.lower() + cf_name = re.sub(r'^_+', '', cf_name) + + if not include_keyspace: + return cf_name + + return '{}.{}'.format(cls._get_keyspace(), cf_name) + + def validate(self): + """ + Cleans and validates the field values + """ + for name, col in self._columns.items(): + v = getattr(self, name) + if v is None and not self._values[name].explicit and col.has_default: + v = col.get_default() + val = col.validate(v) + setattr(self, name, val) + + # Let an instance be used like a dict of its columns keys/values + def __iter__(self): + """ Iterate over column ids. """ + for column_id in self._columns.keys(): + yield column_id + + def __getitem__(self, key): + """ Returns column's value. """ + if not isinstance(key, six.string_types): + raise TypeError + if key not in self._columns.keys(): + raise KeyError + return getattr(self, key) + + def __setitem__(self, key, val): + """ Sets a column's value. """ + if not isinstance(key, six.string_types): + raise TypeError + if key not in self._columns.keys(): + raise KeyError + return setattr(self, key, val) + + def __len__(self): + """ + Returns the number of columns defined on that model. + """ + try: + return self._len + except: + self._len = len(self._columns.keys()) + return self._len + + def keys(self): + """ Returns a list of column IDs. """ + return [k for k in self] + + def values(self): + """ Returns list of column values. """ + return [self[k] for k in self] + + def items(self): + """ Returns a list of column ID/value tuples. """ + return [(k, self[k]) for k in self] + + def _as_dict(self): + """ Returns a map of column names to cleaned values """ + values = self._dynamic_columns or {} + for name, col in self._columns.items(): + values[name] = col.to_database(getattr(self, name, None)) + return values + + @classmethod + def create(cls, **kwargs): + """ + Create an instance of this model in the database. + + Takes the model column values as keyword arguments. + + Returns the instance. + """ + extra_columns = set(kwargs.keys()) - set(cls._columns.keys()) + if extra_columns: + raise ValidationError("Incorrect columns passed: {}".format(extra_columns)) + return cls.objects.create(**kwargs) + + @classmethod + def all(cls): + """ + Returns a queryset representing all stored objects + + This is a pass-through to the model objects().all() + """ + return cls.objects.all() + + @classmethod + def filter(cls, *args, **kwargs): + """ + Returns a queryset based on filter parameters. + + This is a pass-through to the model objects().:method:`~cqlengine.queries.filter`. + """ + return cls.objects.filter(*args, **kwargs) + + @classmethod + def get(cls, *args, **kwargs): + """ + Returns a single object based on the passed filter constraints. + + This is a pass-through to the model objects().:method:`~cqlengine.queries.get`. + """ + return cls.objects.get(*args, **kwargs) + + def timeout(self, timeout): + """ + Sets a timeout for use in :meth:`~.save`, :meth:`~.update`, and :meth:`~.delete` + operations + """ + assert self._batch is None, 'Setting both timeout and batch is not supported' + self._timeout = timeout + return self + + def save(self): + """ + Saves an object to the database. + + .. code-block:: python + + #create a person instance + person = Person(first_name='Kimberly', last_name='Eggleston') + #saves it to Cassandra + person.save() + """ + + # handle polymorphic models + if self._is_polymorphic: + if self._is_polymorphic_base: + raise PolymorphicModelException('cannot save polymorphic base model') + else: + setattr(self, self._discriminator_column_name, self.__discriminator_value__) + + self.validate() + self.__dmlquery__(self.__class__, self, + batch=self._batch, + ttl=self._ttl, + timestamp=self._timestamp, + consistency=self.__consistency__, + if_not_exists=self._if_not_exists, + transaction=self._transaction, + timeout=self._timeout).save() + + # reset the value managers + for v in self._values.values(): + v.reset_previous_value() + self._is_persisted = True + + self._ttl = self.__default_ttl__ + self._timestamp = None + + return self + + def update(self, **values): + """ + Performs an update on the model instance. You can pass in values to set on the model + for updating, or you can call without values to execute an update against any modified + fields. If no fields on the model have been modified since loading, no query will be + performed. Model validation is performed normally. + + It is possible to do a blind update, that is, to update a field without having first selected the object out of the database. + See :ref:`Blind Updates ` + """ + for k, v in values.items(): + col = self._columns.get(k) + + # check for nonexistant columns + if col is None: + raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.__class__.__name__, k)) + + # check for primary key update attempts + if col.is_primary_key: + raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(k, self.__module__, self.__class__.__name__)) + + setattr(self, k, v) + + # handle polymorphic models + if self._is_polymorphic: + if self._is_polymorphic_base: + raise PolymorphicModelException('cannot update polymorphic base model') + else: + setattr(self, self._discriminator_column_name, self.__disciminator_value__) + + self.validate() + self.__dmlquery__(self.__class__, self, + batch=self._batch, + ttl=self._ttl, + timestamp=self._timestamp, + consistency=self.__consistency__, + transaction=self._transaction, + timeout=self._timeout).update() + + # reset the value managers + for v in self._values.values(): + v.reset_previous_value() + self._is_persisted = True + + self._ttl = self.__default_ttl__ + self._timestamp = None + + return self + + def delete(self): + """ + Deletes the object from the database + """ + self.__dmlquery__(self.__class__, self, + batch=self._batch, + timestamp=self._timestamp, + consistency=self.__consistency__, + timeout=self._timeout).delete() + + def get_changed_columns(self): + """ + Returns a list of the columns that have been updated since instantiation or save + """ + return [k for k, v in self._values.items() if v.changed] + + @classmethod + def _class_batch(cls, batch): + return cls.objects.batch(batch) + + def _inst_batch(self, batch): + assert self._timeout is connection.NOT_SET, 'Setting both timeout and batch is not supported' + self._batch = batch + return self + + batch = hybrid_classmethod(_class_batch, _inst_batch) + + +class ModelMetaClass(type): + + def __new__(cls, name, bases, attrs): + # move column definitions into columns dict + # and set default column names + column_dict = OrderedDict() + primary_keys = OrderedDict() + pk_name = None + + # get inherited properties + inherited_columns = OrderedDict() + for base in bases: + for k, v in getattr(base, '_defined_columns', {}).items(): + inherited_columns.setdefault(k, v) + + # short circuit __abstract__ inheritance + is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False) + + # short circuit __discriminator_value__ inheritance + # __polymorphic_key__ is deprecated + poly_key = attrs.get('__polymorphic_key__', None) + if poly_key: + msg = '__polymorphic_key__ is deprecated. Use __discriminator_value__ instead' + warnings.warn(msg, DeprecationWarning) + log.warning(msg) + attrs['__discriminator_value__'] = attrs.get('__discriminator_value__', poly_key) + attrs['__polymorphic_key__'] = attrs['__discriminator_value__'] + + def _transform_column(col_name, col_obj): + column_dict[col_name] = col_obj + if col_obj.primary_key: + primary_keys[col_name] = col_obj + col_obj.set_column_name(col_name) + # set properties + attrs[col_name] = ColumnDescriptor(col_obj) + + column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] + column_definitions = sorted(column_definitions, key=lambda x: x[1].position) + + is_polymorphic_base = any([c[1].discriminator_column for c in column_definitions]) + + column_definitions = [x for x in inherited_columns.items()] + column_definitions + discriminator_columns = [c for c in column_definitions if c[1].discriminator_column] + is_polymorphic = len(discriminator_columns) > 0 + if len(discriminator_columns) > 1: + raise ModelDefinitionException('only one discriminator_column (polymorphic_key (deprecated)) can be defined in a model, {} found'.format(len(discriminator_columns))) + + if attrs['__discriminator_value__'] and not is_polymorphic: + raise ModelDefinitionException('__discriminator_value__ specified, but no base columns defined with discriminator_column=True') + + discriminator_column_name, discriminator_column = discriminator_columns[0] if discriminator_columns else (None, None) + + if isinstance(discriminator_column, (columns.BaseContainerColumn, columns.Counter)): + raise ModelDefinitionException('counter and container columns cannot be used as discriminator columns (polymorphic_key (deprecated)) ') + + # find polymorphic base class + polymorphic_base = None + if is_polymorphic and not is_polymorphic_base: + def _get_polymorphic_base(bases): + for base in bases: + if getattr(base, '_is_polymorphic_base', False): + return base + klass = _get_polymorphic_base(base.__bases__) + if klass: + return klass + polymorphic_base = _get_polymorphic_base(bases) + + defined_columns = OrderedDict(column_definitions) + + # check for primary key + if not is_abstract and not any([v.primary_key for k, v in column_definitions]): + raise ModelDefinitionException("At least 1 primary key is required.") + + counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)] + data_columns = [c for c in defined_columns.values() if not c.primary_key and not isinstance(c, columns.Counter)] + if counter_columns and data_columns: + raise ModelDefinitionException('counter models may not have data columns') + + has_partition_keys = any(v.partition_key for (k, v) in column_definitions) + + # transform column definitions + for k, v in column_definitions: + # don't allow a column with the same name as a built-in attribute or method + if k in BaseModel.__dict__: + raise ModelDefinitionException("column '{}' conflicts with built-in attribute/method".format(k)) + + # counter column primary keys are not allowed + if (v.primary_key or v.partition_key) and isinstance(v, (columns.Counter, columns.BaseContainerColumn)): + raise ModelDefinitionException('counter columns and container columns cannot be used as primary keys') + + # this will mark the first primary key column as a partition + # key, if one hasn't been set already + if not has_partition_keys and v.primary_key: + v.partition_key = True + has_partition_keys = True + _transform_column(k, v) + + partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) + clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key) + + # setup partition key shortcut + if len(partition_keys) == 0: + if not is_abstract: + raise ModelException("at least one partition key must be defined") + if len(partition_keys) == 1: + pk_name = [x for x in partition_keys.keys()][0] + attrs['pk'] = attrs[pk_name] + else: + # composite partition key case, get/set a tuple of values + _get = lambda self: tuple(self._values[c].getval() for c in partition_keys.keys()) + _set = lambda self, val: tuple(self._values[c].setval(v) for (c, v) in zip(partition_keys.keys(), val)) + attrs['pk'] = property(_get, _set) + + # some validation + col_names = set() + for v in column_dict.values(): + # check for duplicate column names + if v.db_field_name in col_names: + raise ModelException("{} defines the column {} more than once".format(name, v.db_field_name)) + if v.clustering_order and not (v.primary_key and not v.partition_key): + raise ModelException("clustering_order may be specified only for clustering primary keys") + if v.clustering_order and v.clustering_order.lower() not in ('asc', 'desc'): + raise ModelException("invalid clustering order {} for column {}".format(repr(v.clustering_order), v.db_field_name)) + col_names.add(v.db_field_name) + + # create db_name -> model name map for loading + db_map = {} + for field_name, col in column_dict.items(): + db_map[col.db_field_name] = field_name + + # add management members to the class + attrs['_columns'] = column_dict + attrs['_primary_keys'] = primary_keys + attrs['_defined_columns'] = defined_columns + + # maps the database field to the models key + attrs['_db_map'] = db_map + attrs['_pk_name'] = pk_name + attrs['_dynamic_columns'] = {} + + attrs['_partition_keys'] = partition_keys + attrs['_clustering_keys'] = clustering_keys + attrs['_has_counter'] = len(counter_columns) > 0 + + # add polymorphic management attributes + attrs['_is_polymorphic_base'] = is_polymorphic_base + attrs['_is_polymorphic'] = is_polymorphic + attrs['_polymorphic_base'] = polymorphic_base + attrs['_discriminator_column'] = discriminator_column + attrs['_discriminator_column_name'] = discriminator_column_name + attrs['_discriminator_map'] = {} if is_polymorphic_base else None + + # setup class exceptions + DoesNotExistBase = None + for base in bases: + DoesNotExistBase = getattr(base, 'DoesNotExist', None) + if DoesNotExistBase is not None: + break + + DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist) + attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {}) + + MultipleObjectsReturnedBase = None + for base in bases: + MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None) + if MultipleObjectsReturnedBase is not None: + break + + MultipleObjectsReturnedBase = DoesNotExistBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned) + attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {}) + + # create the class and add a QuerySet to it + klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs) + + udts = [] + for col in column_dict.values(): + columns.resolve_udts(col, udts) + + for user_type in set(udts): + user_type.register_for_keyspace(klass._get_keyspace()) + + return klass + + +@six.add_metaclass(ModelMetaClass) +class Model(BaseModel): + __abstract__ = True + """ + *Optional.* Indicates that this model is only intended to be used as a base class for other models. + You can't create tables for abstract models, but checks around schema validity are skipped during class construction. + """ + + __table_name__ = None + """ + *Optional.* Sets the name of the CQL table for this model. If left blank, the table name will be the name of the model, with it's module name as it's prefix. Manually defined table names are not inherited. + """ + + __keyspace__ = None + """ + Sets the name of the keyspace used by this model. + """ + + __default_ttl__ = None + """ + *Optional* The default ttl used by this model. + + This can be overridden by using the :meth:`~.ttl` method. + """ + + __polymorphic_key__ = None + """ + *Deprecated.* + + see :attr:`~.__discriminator_value__` + """ + + __discriminator_value__ = None + """ + *Optional* Specifies a value for the discriminator column when using model inheritance. + """ diff --git a/cassandra/cqlengine/named.py b/cassandra/cqlengine/named.py new file mode 100644 index 0000000..b38c078 --- /dev/null +++ b/cassandra/cqlengine/named.py @@ -0,0 +1,138 @@ +# Copyright 2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cqlengine import CQLEngineException +from cassandra.cqlengine.query import AbstractQueryableColumn, SimpleQuerySet +from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist +from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned + + +class QuerySetDescriptor(object): + """ + returns a fresh queryset for the given model + it's declared on everytime it's accessed + """ + + def __get__(self, obj, model): + """ :rtype: ModelQuerySet """ + if model.__abstract__: + raise CQLEngineException('cannot execute queries against abstract models') + return SimpleQuerySet(obj) + + def __call__(self, *args, **kwargs): + """ + Just a hint to IDEs that it's ok to call this + + :rtype: ModelQuerySet + """ + raise NotImplementedError + + +class NamedColumn(AbstractQueryableColumn): + """ + A column that is not coupled to a model class, or type + """ + + def __init__(self, name): + self.name = name + + def __unicode__(self): + return self.name + + def _get_column(self): + """ :rtype: NamedColumn """ + return self + + @property + def db_field_name(self): + return self.name + + @property + def cql(self): + return self.get_cql() + + def get_cql(self): + return '"{}"'.format(self.name) + + def to_database(self, val): + return val + + +class NamedTable(object): + """ + A Table that is not coupled to a model class + """ + + __abstract__ = False + + objects = QuerySetDescriptor() + + class DoesNotExist(_DoesNotExist): + pass + + class MultipleObjectsReturned(_MultipleObjectsReturned): + pass + + def __init__(self, keyspace, name): + self.keyspace = keyspace + self.name = name + + def column(self, name): + return NamedColumn(name) + + def column_family_name(self, include_keyspace=True): + """ + Returns the column family name if it's been defined + otherwise, it creates it from the module and class name + """ + if include_keyspace: + return '{}.{}'.format(self.keyspace, self.name) + else: + return self.name + + def _get_column(self, name): + """ + Returns the column matching the given name + + :rtype: Column + """ + return self.column(name) + + # def create(self, **kwargs): + # return self.objects.create(**kwargs) + + def all(self): + return self.objects.all() + + def filter(self, *args, **kwargs): + return self.objects.filter(*args, **kwargs) + + def get(self, *args, **kwargs): + return self.objects.get(*args, **kwargs) + + +class NamedKeyspace(object): + """ + A keyspace + """ + + def __init__(self, name): + self.name = name + + def table(self, name): + """ + returns a table descriptor with the given + name that belongs to this keyspace + """ + return NamedTable(self.name, name) diff --git a/cassandra/cqlengine/operators.py b/cassandra/cqlengine/operators.py new file mode 100644 index 0000000..08c8ccb --- /dev/null +++ b/cassandra/cqlengine/operators.py @@ -0,0 +1,99 @@ +# Copyright 2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cqlengine import UnicodeMixin + + +class QueryOperatorException(Exception): + pass + + +class BaseQueryOperator(UnicodeMixin): + # The symbol that identifies this operator in kwargs + # ie: colname__ + symbol = None + + # The comparator symbol this operator uses in cql + cql_symbol = None + + def __unicode__(self): + if self.cql_symbol is None: + raise QueryOperatorException("cql symbol is None") + return self.cql_symbol + + @classmethod + def get_operator(cls, symbol): + if cls == BaseQueryOperator: + raise QueryOperatorException("get_operator can only be called from a BaseQueryOperator subclass") + if not hasattr(cls, 'opmap'): + cls.opmap = {} + + def _recurse(klass): + if klass.symbol: + cls.opmap[klass.symbol.upper()] = klass + for subklass in klass.__subclasses__(): + _recurse(subklass) + pass + + _recurse(cls) + try: + return cls.opmap[symbol.upper()] + except KeyError: + raise QueryOperatorException("{} doesn't map to a QueryOperator".format(symbol)) + + +class BaseWhereOperator(BaseQueryOperator): + """ base operator used for where clauses """ + + +class EqualsOperator(BaseWhereOperator): + symbol = 'EQ' + cql_symbol = '=' + + +class InOperator(EqualsOperator): + symbol = 'IN' + cql_symbol = 'IN' + + +class GreaterThanOperator(BaseWhereOperator): + symbol = "GT" + cql_symbol = '>' + + +class GreaterThanOrEqualOperator(BaseWhereOperator): + symbol = "GTE" + cql_symbol = '>=' + + +class LessThanOperator(BaseWhereOperator): + symbol = "LT" + cql_symbol = '<' + + +class LessThanOrEqualOperator(BaseWhereOperator): + symbol = "LTE" + cql_symbol = '<=' + + +class BaseAssignmentOperator(BaseQueryOperator): + """ base operator used for insert and delete statements """ + + +class AssignmentOperator(BaseAssignmentOperator): + cql_symbol = "=" + + +class AddSymbol(BaseAssignmentOperator): + cql_symbol = "+" diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py new file mode 100644 index 0000000..2a5c974 --- /dev/null +++ b/cassandra/cqlengine/query.py @@ -0,0 +1,1207 @@ +# Copyright 2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from datetime import datetime, timedelta +import time +import six + +from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin +from cassandra.cqlengine import connection +from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue +from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterThanOperator, + GreaterThanOrEqualOperator, LessThanOperator, + LessThanOrEqualOperator, BaseWhereOperator) +# import * ? +from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement, + UpdateStatement, AssignmentClause, InsertStatement, + BaseCQLStatement, MapUpdateClause, MapDeleteClause, + ListUpdateClause, SetUpdateClause, CounterUpdateClause, + TransactionClause) + + +class QueryException(CQLEngineException): + pass + + +class IfNotExistsWithCounterColumn(CQLEngineException): + pass + + +class LWTException(CQLEngineException): + pass + + +class DoesNotExist(QueryException): + pass + + +class MultipleObjectsReturned(QueryException): + pass + + +def check_applied(result): + """ + check if result contains some column '[applied]' with false value, + if that value is false, it means our light-weight transaction didn't + applied to database. + """ + if result and '[applied]' in result[0] and not result[0]['[applied]']: + raise LWTException('') + + +class AbstractQueryableColumn(UnicodeMixin): + """ + exposes cql query operators through pythons + builtin comparator symbols + """ + + def _get_column(self): + raise NotImplementedError + + def __unicode__(self): + raise NotImplementedError + + def _to_database(self, val): + if isinstance(val, QueryValue): + return val + else: + return self._get_column().to_database(val) + + def in_(self, item): + """ + Returns an in operator + + used where you'd typically want to use python's `in` operator + """ + return WhereClause(six.text_type(self), InOperator(), item) + + def __eq__(self, other): + return WhereClause(six.text_type(self), EqualsOperator(), self._to_database(other)) + + def __gt__(self, other): + return WhereClause(six.text_type(self), GreaterThanOperator(), self._to_database(other)) + + def __ge__(self, other): + return WhereClause(six.text_type(self), GreaterThanOrEqualOperator(), self._to_database(other)) + + def __lt__(self, other): + return WhereClause(six.text_type(self), LessThanOperator(), self._to_database(other)) + + def __le__(self, other): + return WhereClause(six.text_type(self), LessThanOrEqualOperator(), self._to_database(other)) + + +class BatchType(object): + Unlogged = 'UNLOGGED' + Counter = 'COUNTER' + + +class BatchQuery(object): + """ + Handles the batching of queries + + http://www.datastax.com/docs/1.2/cql_cli/cql/BATCH + """ + _consistency = None + + def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on_exception=False, + timeout=connection.NOT_SET): + """ + :param batch_type: (optional) One of batch type values available through BatchType enum + :type batch_type: str or None + :param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied + to the batch transaction. + :type timestamp: datetime or timedelta or None + :param consistency: (optional) One of consistency values ("ANY", "ONE", "QUORUM" etc) + :type consistency: The :class:`.ConsistencyLevel` to be used for the batch query, or None. + :param execute_on_exception: (Defaults to False) Indicates that when the BatchQuery instance is used + as a context manager the queries accumulated within the context must be executed despite + encountering an error within the context. By default, any exception raised from within + the context scope will cause the batched queries not to be executed. + :type execute_on_exception: bool + :param timeout: (optional) Timeout for the entire batch (in seconds), if not specified fallback + to default session timeout + :type timeout: float or None + """ + self.queries = [] + self.batch_type = batch_type + if timestamp is not None and not isinstance(timestamp, (datetime, timedelta)): + raise CQLEngineException('timestamp object must be an instance of datetime') + self.timestamp = timestamp + self._consistency = consistency + self._execute_on_exception = execute_on_exception + self._timeout = timeout + self._callbacks = [] + + def add_query(self, query): + if not isinstance(query, BaseCQLStatement): + raise CQLEngineException('only BaseCQLStatements can be added to a batch query') + self.queries.append(query) + + def consistency(self, consistency): + self._consistency = consistency + + def _execute_callbacks(self): + for callback, args, kwargs in self._callbacks: + callback(*args, **kwargs) + + # trying to clear up the ref counts for objects mentioned in the set + del self._callbacks + + def add_callback(self, fn, *args, **kwargs): + """Add a function and arguments to be passed to it to be executed after the batch executes. + + A batch can support multiple callbacks. + + Note, that if the batch does not execute, the callbacks are not executed. + A callback, thus, is an "on batch success" handler. + + :param fn: Callable object + :type fn: callable + :param *args: Positional arguments to be passed to the callback at the time of execution + :param **kwargs: Named arguments to be passed to the callback at the time of execution + """ + if not callable(fn): + raise ValueError("Value for argument 'fn' is {} and is not a callable object.".format(type(fn))) + self._callbacks.append((fn, args, kwargs)) + + def execute(self): + if len(self.queries) == 0: + # Empty batch is a no-op + # except for callbacks + self._execute_callbacks() + return + + opener = 'BEGIN ' + (self.batch_type + ' ' if self.batch_type else '') + ' BATCH' + if self.timestamp: + + if isinstance(self.timestamp, six.integer_types): + ts = self.timestamp + elif isinstance(self.timestamp, (datetime, timedelta)): + ts = self.timestamp + if isinstance(self.timestamp, timedelta): + ts += datetime.now() # Apply timedelta + ts = int(time.mktime(ts.timetuple()) * 1e+6 + ts.microsecond) + else: + raise ValueError("Batch expects a long, a timedelta, or a datetime") + + opener += ' USING TIMESTAMP {}'.format(ts) + + query_list = [opener] + parameters = {} + ctx_counter = 0 + for query in self.queries: + query.update_context_id(ctx_counter) + ctx = query.get_context() + ctx_counter += len(ctx) + query_list.append(' ' + str(query)) + parameters.update(ctx) + + query_list.append('APPLY BATCH;') + + tmp = connection.execute('\n'.join(query_list), parameters, self._consistency, self._timeout) + check_applied(tmp) + + self.queries = [] + self._execute_callbacks() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # don't execute if there was an exception by default + if exc_type is not None and not self._execute_on_exception: + return + self.execute() + + +class AbstractQuerySet(object): + + def __init__(self, model): + super(AbstractQuerySet, self).__init__() + self.model = model + + # Where clause filters + self._where = [] + + # Transaction clause filters + self._transaction = [] + + # ordering arguments + self._order = [] + + self._allow_filtering = False + + # CQL has a default limit of 10000, it's defined here + # because explicit is better than implicit + self._limit = 10000 + + # see the defer and only methods + self._defer_fields = [] + self._only_fields = [] + + self._values_list = False + self._flat_values_list = False + + # results cache + self._result_cache = None + self._result_idx = None + + self._batch = None + self._ttl = getattr(model, '__default_ttl__', None) + self._consistency = None + self._timestamp = None + self._if_not_exists = False + self._timeout = connection.NOT_SET + + @property + def column_family_name(self): + return self.model.column_family_name() + + def _execute(self, q): + if self._batch: + return self._batch.add_query(q) + else: + result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) + if self._transaction: + check_applied(result) + return result + + def __unicode__(self): + return six.text_type(self._select_query()) + + def __str__(self): + return str(self.__unicode__()) + + def __call__(self, *args, **kwargs): + return self.filter(*args, **kwargs) + + def __deepcopy__(self, memo): + clone = self.__class__(self.model) + for k, v in self.__dict__.items(): + if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these + clone.__dict__[k] = None + elif k == '_batch': + # we need to keep the same batch instance across + # all queryset clones, otherwise the batched queries + # fly off into other batch instances which are never + # executed, thx @dokai + clone.__dict__[k] = self._batch + elif k == '_timeout': + clone.__dict__[k] = self._timeout + else: + clone.__dict__[k] = copy.deepcopy(v, memo) + + return clone + + def __len__(self): + self._execute_query() + return len(self._result_cache) + + # ----query generation / execution---- + + def _select_fields(self): + """ returns the fields to select """ + return [] + + def _validate_select_where(self): + """ put select query validation here """ + + def _select_query(self): + """ + Returns a select clause based on the given filter args + """ + if self._where: + self._validate_select_where() + return SelectStatement( + self.column_family_name, + fields=self._select_fields(), + where=self._where, + order_by=self._order, + limit=self._limit, + allow_filtering=self._allow_filtering + ) + + # ----Reads------ + + def _execute_query(self): + if self._batch: + raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") + if self._result_cache is None: + self._result_cache = list(self._execute(self._select_query())) + self._construct_result = self._get_result_constructor() + + def _fill_result_cache_to_idx(self, idx): + self._execute_query() + if self._result_idx is None: + self._result_idx = -1 + + qty = idx - self._result_idx + if qty < 1: + return + else: + for idx in range(qty): + self._result_idx += 1 + self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx]) + + def __iter__(self): + self._execute_query() + + for idx in range(len(self._result_cache)): + instance = self._result_cache[idx] + if isinstance(instance, dict): + self._fill_result_cache_to_idx(idx) + yield self._result_cache[idx] + + def __getitem__(self, s): + self._execute_query() + + num_results = len(self._result_cache) + + if isinstance(s, slice): + # calculate the amount of results that need to be loaded + end = num_results if s.step is None else s.step + if end < 0: + end += num_results + else: + end -= 1 + self._fill_result_cache_to_idx(end) + return self._result_cache[s.start:s.stop:s.step] + else: + # return the object at this index + s = int(s) + + # handle negative indexing + if s < 0: + s += num_results + + if s >= num_results: + raise IndexError + else: + self._fill_result_cache_to_idx(s) + return self._result_cache[s] + + def _get_result_constructor(self): + """ + Returns a function that will be used to instantiate query results + """ + raise NotImplementedError + + def batch(self, batch_obj): + """ + Set a batch object to run the query on. + + Note: running a select query with a batch object will raise an exception + """ + if batch_obj is not None and not isinstance(batch_obj, BatchQuery): + raise CQLEngineException('batch_obj must be a BatchQuery instance or None') + clone = copy.deepcopy(self) + clone._batch = batch_obj + return clone + + def first(self): + try: + return six.next(iter(self)) + except StopIteration: + return None + + def all(self): + """ + Returns a queryset matching all rows + + .. code-block:: python + + for user in User.objects().all(): + print(user) + """ + return copy.deepcopy(self) + + def consistency(self, consistency): + """ + Sets the consistency level for the operation. See :class:`.ConsistencyLevel`. + + .. code-block:: python + + for user in User.objects(id=3).consistency(CL.ONE): + print(user) + """ + clone = copy.deepcopy(self) + clone._consistency = consistency + return clone + + def _parse_filter_arg(self, arg): + """ + Parses a filter arg in the format: + __ + :returns: colname, op tuple + """ + statement = arg.rsplit('__', 1) + if len(statement) == 1: + return arg, None + elif len(statement) == 2: + return statement[0], statement[1] + else: + raise QueryException("Can't parse '{}'".format(arg)) + + def iff(self, *args, **kwargs): + """Adds IF statements to queryset""" + if len([x for x in kwargs.values() if x is None]): + raise CQLEngineException("None values on iff are not allowed") + + clone = copy.deepcopy(self) + for operator in args: + if not isinstance(operator, TransactionClause): + raise QueryException('{} is not a valid query operator'.format(operator)) + clone._transaction.append(operator) + + for col_name, val in kwargs.items(): + exists = False + try: + column = self.model._get_column(col_name) + except KeyError: + if col_name == 'pk__token': + if not isinstance(val, Token): + raise QueryException("Virtual column 'pk__token' may only be compared to Token() values") + column = columns._PartitionKeysToken(self.model) + else: + raise QueryException("Can't resolve column name: '{}'".format(col_name)) + + if isinstance(val, Token): + if col_name != 'pk__token': + raise QueryException("Token() values may only be compared to the 'pk__token' virtual column") + partition_columns = column.partition_columns + if len(partition_columns) != len(val.value): + raise QueryException( + 'Token() received {} arguments but model has {} partition keys'.format( + len(val.value), len(partition_columns))) + val.set_columns(partition_columns) + + if isinstance(val, BaseQueryFunction) or exists is True: + query_val = val + else: + query_val = column.to_database(val) + + clone._transaction.append(TransactionClause(col_name, query_val)) + + return clone + + def filter(self, *args, **kwargs): + """ + Adds WHERE arguments to the queryset, returning a new queryset + + See :ref:`retrieving-objects-with-filters` + + Returns a QuerySet filtered on the keyword arguments + """ + # add arguments to the where clause filters + if len([x for x in kwargs.values() if x is None]): + raise CQLEngineException("None values on filter are not allowed") + + clone = copy.deepcopy(self) + for operator in args: + if not isinstance(operator, WhereClause): + raise QueryException('{} is not a valid query operator'.format(operator)) + clone._where.append(operator) + + for arg, val in kwargs.items(): + col_name, col_op = self._parse_filter_arg(arg) + quote_field = True + # resolve column and operator + try: + column = self.model._get_column(col_name) + except KeyError: + if col_name == 'pk__token': + if not isinstance(val, Token): + raise QueryException("Virtual column 'pk__token' may only be compared to Token() values") + column = columns._PartitionKeysToken(self.model) + quote_field = False + else: + raise QueryException("Can't resolve column name: '{}'".format(col_name)) + + if isinstance(val, Token): + if col_name != 'pk__token': + raise QueryException("Token() values may only be compared to the 'pk__token' virtual column") + partition_columns = column.partition_columns + if len(partition_columns) != len(val.value): + raise QueryException( + 'Token() received {} arguments but model has {} partition keys'.format( + len(val.value), len(partition_columns))) + val.set_columns(partition_columns) + + # get query operator, or use equals if not supplied + operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') + operator = operator_class() + + if isinstance(operator, InOperator): + if not isinstance(val, (list, tuple)): + raise QueryException('IN queries must use a list/tuple value') + query_val = [column.to_database(v) for v in val] + elif isinstance(val, BaseQueryFunction): + query_val = val + else: + query_val = column.to_database(val) + + clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field)) + + return clone + + def get(self, *args, **kwargs): + """ + Returns a single instance matching this query, optionally with additional filter kwargs. + + See :ref:`retrieving-objects-with-filters` + + Returns a single object matching the QuerySet. + + .. code-block:: python + + user = User.get(id=1) + + If no objects are matched, a :class:`~.DoesNotExist` exception is raised. + + If more than one object is found, a :class:`~.MultipleObjectsReturned` exception is raised. + """ + if args or kwargs: + return self.filter(*args, **kwargs).get() + + self._execute_query() + if len(self._result_cache) == 0: + raise self.model.DoesNotExist + elif len(self._result_cache) > 1: + raise self.model.MultipleObjectsReturned('{} objects found'.format(len(self._result_cache))) + else: + return self[0] + + def _get_ordering_condition(self, colname): + order_type = 'DESC' if colname.startswith('-') else 'ASC' + colname = colname.replace('-', '') + + return colname, order_type + + def order_by(self, *colnames): + """ + Sets the column(s) to be used for ordering + + Default order is ascending, prepend a '-' to any column name for descending + + *Note: column names must be a clustering key* + + .. code-block:: python + + from uuid import uuid1,uuid4 + + class Comment(Model): + photo_id = UUID(primary_key=True) + comment_id = TimeUUID(primary_key=True, default=uuid1) # second primary key component is a clustering key + comment = Text() + + sync_table(Comment) + + u = uuid4() + for x in range(5): + Comment.create(photo_id=u, comment="test %d" % x) + + print("Normal") + for comment in Comment.objects(photo_id=u): + print comment.comment_id + + print("Reversed") + for comment in Comment.objects(photo_id=u).order_by("-comment_id"): + print comment.comment_id + """ + if len(colnames) == 0: + clone = copy.deepcopy(self) + clone._order = [] + return clone + + conditions = [] + for colname in colnames: + conditions.append('"{}" {}'.format(*self._get_ordering_condition(colname))) + + clone = copy.deepcopy(self) + clone._order.extend(conditions) + return clone + + def count(self): + """ + Returns the number of rows matched by this query + """ + if self._batch: + raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") + + if self._result_cache is None: + query = self._select_query() + query.count = True + result = self._execute(query) + return result[0]['count'] + else: + return len(self._result_cache) + + def limit(self, v): + """ + Limits the number of results returned by Cassandra. + + *Note that CQL's default limit is 10,000, so all queries without a limit set explicitly will have an implicit limit of 10,000* + + .. code-block:: python + + for user in User.objects().limit(100): + print(user) + """ + if not (v is None or isinstance(v, six.integer_types)): + raise TypeError + if v == self._limit: + return self + + if v < 0: + raise QueryException("Negative limit is not allowed") + + clone = copy.deepcopy(self) + clone._limit = v + return clone + + def allow_filtering(self): + """ + Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key + """ + clone = copy.deepcopy(self) + clone._allow_filtering = True + return clone + + def _only_or_defer(self, action, fields): + clone = copy.deepcopy(self) + if clone._defer_fields or clone._only_fields: + raise QueryException("QuerySet alread has only or defer fields defined") + + # check for strange fields + missing_fields = [f for f in fields if f not in self.model._columns.keys()] + if missing_fields: + raise QueryException( + "Can't resolve fields {} in {}".format( + ', '.join(missing_fields), self.model.__name__)) + + if action == 'defer': + clone._defer_fields = fields + elif action == 'only': + clone._only_fields = fields + else: + raise ValueError + + return clone + + def only(self, fields): + """ Load only these fields for the returned query """ + return self._only_or_defer('only', fields) + + def defer(self, fields): + """ Don't load these fields for the returned query """ + return self._only_or_defer('defer', fields) + + def create(self, **kwargs): + return self.model(**kwargs).batch(self._batch).ttl(self._ttl).\ + consistency(self._consistency).if_not_exists(self._if_not_exists).\ + timestamp(self._timestamp).save() + + def delete(self): + """ + Deletes the contents of a query + """ + # validate where clause + partition_key = [x for x in self.model._primary_keys.values()][0] + if not any([c.field == partition_key.column_name for c in self._where]): + raise QueryException("The partition key must be defined on delete queries") + + dq = DeleteStatement( + self.column_family_name, + where=self._where, + timestamp=self._timestamp + ) + self._execute(dq) + + def __eq__(self, q): + if len(self._where) == len(q._where): + return all([w in q._where for w in self._where]) + return False + + def __ne__(self, q): + return not (self != q) + + def timeout(self, timeout): + """ + :param timeout: Timeout for the query (in seconds) + :type timeout: float or None + """ + clone = copy.deepcopy(self) + clone._timeout = timeout + return clone + + +class ResultObject(dict): + """ + adds attribute access to a dictionary + """ + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError + + +class SimpleQuerySet(AbstractQuerySet): + """ + + """ + + def _get_result_constructor(self): + """ + Returns a function that will be used to instantiate query results + """ + def _construct_instance(values): + return ResultObject(values) + return _construct_instance + + +class ModelQuerySet(AbstractQuerySet): + """ + """ + def _validate_select_where(self): + """ Checks that a filterset will not create invalid select statement """ + # check that there's either a = or IN relationship with a primary key or indexed field + equal_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)] + token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) + if not any([w.primary_key or w.index for w in equal_ops]) and not token_comparison and not self._allow_filtering: + raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field') + + if not self._allow_filtering: + # if the query is not on an indexed field + if not any([w.index for w in equal_ops]): + if not any([w.partition_key for w in equal_ops]) and not token_comparison: + raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset') + + def _select_fields(self): + if self._defer_fields or self._only_fields: + fields = self.model._columns.keys() + if self._defer_fields: + fields = [f for f in fields if f not in self._defer_fields] + elif self._only_fields: + fields = self._only_fields + return [self.model._columns[f].db_field_name for f in fields] + return super(ModelQuerySet, self)._select_fields() + + def _get_result_constructor(self): + """ Returns a function that will be used to instantiate query results """ + if not self._values_list: # we want models + return lambda rows: self.model._construct_instance(rows) + elif self._flat_values_list: # the user has requested flattened list (1 value per row) + return lambda row: row.popitem()[1] + else: + return lambda row: self._get_row_value_list(self._only_fields, row) + + def _get_row_value_list(self, fields, row): + result = [] + for x in fields: + result.append(row[x]) + return result + + def _get_ordering_condition(self, colname): + colname, order_type = super(ModelQuerySet, self)._get_ordering_condition(colname) + + column = self.model._columns.get(colname) + if column is None: + raise QueryException("Can't resolve the column name: '{}'".format(colname)) + + # validate the column selection + if not column.primary_key: + raise QueryException( + "Can't order on '{}', can only order on (clustered) primary keys".format(colname)) + + pks = [v for k, v in self.model._columns.items() if v.primary_key] + if column == pks[0]: + raise QueryException( + "Can't order by the first primary key (partition key), clustering (secondary) keys only") + + return column.db_field_name, order_type + + def values_list(self, *fields, **kwargs): + """ Instructs the query set to return tuples, not model instance """ + flat = kwargs.pop('flat', False) + if kwargs: + raise TypeError('Unexpected keyword arguments to values_list: %s' + % (kwargs.keys(),)) + if flat and len(fields) > 1: + raise TypeError("'flat' is not valid when values_list is called with more than one field.") + clone = self.only(fields) + clone._values_list = True + clone._flat_values_list = flat + return clone + + def ttl(self, ttl): + """ + Sets the ttl (in seconds) for modified data. + + *Note that running a select query with a ttl value will raise an exception* + """ + clone = copy.deepcopy(self) + clone._ttl = ttl + return clone + + def timestamp(self, timestamp): + """ + Allows for custom timestamps to be saved with the record. + """ + clone = copy.deepcopy(self) + clone._timestamp = timestamp + return clone + + def if_not_exists(self): + if self.model._has_counter: + raise IfNotExistsWithCounterColumn('if_not_exists cannot be used with tables containing columns') + clone = copy.deepcopy(self) + clone._if_not_exists = True + return clone + + def update(self, **values): + """ + Performs an update on the row selected by the queryset. Include values to update in the + update like so: + + .. code-block:: python + + Model.objects(key=n).update(value='x') + + Passing in updates for columns which are not part of the model will raise a ValidationError. + + Per column validation will be performed, but instance level validation will not + (i.e., `Model.validate` is not called). This is sometimes referred to as a blind update. + + For example: + + .. code-block:: python + + class User(Model): + id = Integer(primary_key=True) + name = Text() + + setup(["localhost"], "test") + sync_table(User) + + u = User.create(id=1, name="jon") + + User.objects(id=1).update(name="Steve") + + # sets name to null + User.objects(id=1).update(name=None) + + + Also supported is blindly adding and removing elements from container columns, + without loading a model instance from Cassandra. + + Using the syntax `.update(column_name={x, y, z})` will overwrite the contents of the container, like updating a + non container column. However, adding `__` to the end of the keyword arg, makes the update call add + or remove items from the collection, without overwriting then entire column. + + Given the model below, here are the operations that can be performed on the different container columns: + + .. code-block:: python + + class Row(Model): + row_id = columns.Integer(primary_key=True) + set_column = columns.Set(Integer) + list_column = columns.List(Integer) + map_column = columns.Map(Integer, Integer) + + :class:`~cqlengine.columns.Set` + + - `add`: adds the elements of the given set to the column + - `remove`: removes the elements of the given set to the column + + + .. code-block:: python + + # add elements to a set + Row.objects(row_id=5).update(set_column__add={6}) + + # remove elements to a set + Row.objects(row_id=5).update(set_column__remove={4}) + + :class:`~cqlengine.columns.List` + + - `append`: appends the elements of the given list to the end of the column + - `prepend`: prepends the elements of the given list to the beginning of the column + + .. code-block:: python + + # append items to a list + Row.objects(row_id=5).update(list_column__append=[6, 7]) + + # prepend items to a list + Row.objects(row_id=5).update(list_column__prepend=[1, 2]) + + + :class:`~cqlengine.columns.Map` + + - `update`: adds the given keys/values to the columns, creating new entries if they didn't exist, and overwriting old ones if they did + + .. code-block:: python + + # add items to a map + Row.objects(row_id=5).update(map_column__update={1: 2, 3: 4}) + """ + if not values: + return + + nulled_columns = set() + us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, + timestamp=self._timestamp, transactions=self._transaction) + for name, val in values.items(): + col_name, col_op = self._parse_filter_arg(name) + col = self.model._columns.get(col_name) + # check for nonexistant columns + if col is None: + raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, col_name)) + # check for primary key update attempts + if col.is_primary_key: + raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(col_name, self.__module__, self.model.__name__)) + + # we should not provide default values in this use case. + val = col.validate(val) + + if val is None: + nulled_columns.add(col_name) + continue + + # add the update statements + if isinstance(col, columns.Counter): + # TODO: implement counter updates + raise NotImplementedError + elif isinstance(col, (columns.List, columns.Set, columns.Map)): + if isinstance(col, columns.List): + klass = ListUpdateClause + elif isinstance(col, columns.Set): + klass = SetUpdateClause + elif isinstance(col, columns.Map): + klass = MapUpdateClause + else: + raise RuntimeError + us.add_assignment_clause(klass(col_name, col.to_database(val), operation=col_op)) + else: + us.add_assignment_clause(AssignmentClause( + col_name, col.to_database(val))) + + if us.assignments: + self._execute(us) + + if nulled_columns: + ds = DeleteStatement(self.column_family_name, fields=nulled_columns, where=self._where) + self._execute(ds) + + +class DMLQuery(object): + """ + A query object used for queries performing inserts, updates, or deletes + + this is usually instantiated by the model instance to be modified + + unlike the read query object, this is mutable + """ + _ttl = None + _consistency = None + _timestamp = None + _if_not_exists = False + + def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, + if_not_exists=False, transaction=None, timeout=connection.NOT_SET): + self.model = model + self.column_family_name = self.model.column_family_name() + self.instance = instance + self._batch = batch + self._ttl = ttl + self._consistency = consistency + self._timestamp = timestamp + self._if_not_exists = if_not_exists + self._transaction = transaction + self._timeout = timeout + + def _execute(self, q): + if self._batch: + return self._batch.add_query(q) + else: + tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) + if self._if_not_exists or self._transaction: + check_applied(tmp) + return tmp + + def batch(self, batch_obj): + if batch_obj is not None and not isinstance(batch_obj, BatchQuery): + raise CQLEngineException('batch_obj must be a BatchQuery instance or None') + self._batch = batch_obj + return self + + def _delete_null_columns(self): + """ + executes a delete query to remove columns that have changed to null + """ + ds = DeleteStatement(self.column_family_name) + deleted_fields = False + for _, v in self.instance._values.items(): + col = v.column + if v.deleted: + ds.add_field(col.db_field_name) + deleted_fields = True + elif isinstance(col, columns.Map): + uc = MapDeleteClause(col.db_field_name, v.value, v.previous_value) + if uc.get_context_size() > 0: + ds.add_field(uc) + deleted_fields = True + + if deleted_fields: + for name, col in self.model._primary_keys.items(): + ds.add_where_clause(WhereClause( + col.db_field_name, + EqualsOperator(), + col.to_database(getattr(self.instance, name)) + )) + self._execute(ds) + + def update(self): + """ + updates a row. + This is a blind update call. + All validation and cleaning needs to happen + prior to calling this. + """ + if self.instance is None: + raise CQLEngineException("DML Query intance attribute is None") + assert type(self.instance) == self.model + null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True + static_changed_only = True + statement = UpdateStatement(self.column_family_name, ttl=self._ttl, + timestamp=self._timestamp, transactions=self._transaction) + for name, col in self.instance._clustering_keys.items(): + null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) + # get defined fields and their column names + for name, col in self.model._columns.items(): + # if clustering key is null, don't include non static columns + if null_clustering_key and not col.static and not col.partition_key: + continue + if not col.is_primary_key: + val = getattr(self.instance, name, None) + val_mgr = self.instance._values[name] + + # don't update something that is null + if val is None: + continue + + # don't update something if it hasn't changed + if not val_mgr.changed and not isinstance(col, columns.Counter): + continue + + static_changed_only = static_changed_only and col.static + if isinstance(col, (columns.BaseContainerColumn, columns.Counter)): + # get appropriate clause + if isinstance(col, columns.List): + klass = ListUpdateClause + elif isinstance(col, columns.Map): + klass = MapUpdateClause + elif isinstance(col, columns.Set): + klass = SetUpdateClause + elif isinstance(col, columns.Counter): + klass = CounterUpdateClause + else: + raise RuntimeError + + # do the stuff + clause = klass(col.db_field_name, val, + previous=val_mgr.previous_value, column=col) + if clause.get_context_size() > 0: + statement.add_assignment_clause(clause) + else: + statement.add_assignment_clause(AssignmentClause( + col.db_field_name, + col.to_database(val) + )) + + if statement.get_context_size() > 0 or self.instance._has_counter: + for name, col in self.model._primary_keys.items(): + # only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error + if (null_clustering_key or static_changed_only) and (not col.partition_key): + continue + statement.add_where_clause(WhereClause( + col.db_field_name, + EqualsOperator(), + col.to_database(getattr(self.instance, name)) + )) + self._execute(statement) + + if not null_clustering_key: + self._delete_null_columns() + + def save(self): + """ + Creates / updates a row. + This is a blind insert call. + All validation and cleaning needs to happen + prior to calling this. + """ + if self.instance is None: + raise CQLEngineException("DML Query intance attribute is None") + assert type(self.instance) == self.model + + nulled_fields = set() + if self.instance._has_counter or self.instance._can_update(): + return self.update() + else: + insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists) + static_save_only = False if len(self.instance._clustering_keys) == 0 else True + for name, col in self.instance._clustering_keys.items(): + static_save_only = static_save_only and col._val_is_null(getattr(self.instance, name, None)) + for name, col in self.instance._columns.items(): + if static_save_only and not col.static and not col.partition_key: + continue + val = getattr(self.instance, name, None) + if col._val_is_null(val): + if self.instance._values[name].changed: + nulled_fields.add(col.db_field_name) + continue + insert.add_assignment_clause(AssignmentClause( + col.db_field_name, + col.to_database(getattr(self.instance, name, None)) + )) + + # skip query execution if it's empty + # caused by pointless update queries + if not insert.is_empty: + self._execute(insert) + # delete any nulled columns + if not static_save_only: + self._delete_null_columns() + + def delete(self): + """ Deletes one instance """ + if self.instance is None: + raise CQLEngineException("DML Query instance attribute is None") + + ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp) + for name, col in self.model._primary_keys.items(): + if (not col.partition_key) and (getattr(self.instance, name) is None): + continue + + ds.add_where_clause(WhereClause( + col.db_field_name, + EqualsOperator(), + col.to_database(getattr(self.instance, name)) + )) + self._execute(ds) diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py new file mode 100644 index 0000000..f64cc7a --- /dev/null +++ b/cassandra/cqlengine/statements.py @@ -0,0 +1,837 @@ +# Copyright 2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +import logging +import time +import six +import warnings + +from cassandra.cqlengine import UnicodeMixin +from cassandra.cqlengine.functions import QueryValue +from cassandra.cqlengine.operators import BaseWhereOperator, InOperator + +log = logging.getLogger(__name__) + + +class StatementException(Exception): + pass + + +class ValueQuoter(UnicodeMixin): + + def __init__(self, value): + self.value = value + + def __unicode__(self): + from cassandra.encoder import cql_quote + if isinstance(self.value, bool): + return 'true' if self.value else 'false' + elif isinstance(self.value, (list, tuple)): + return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' + elif isinstance(self.value, dict): + return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}' + elif isinstance(self.value, set): + return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}' + return cql_quote(self.value) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.value == other.value + return False + + +class InQuoter(ValueQuoter): + + def __unicode__(self): + from cassandra.encoder import cql_quote + return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')' + + +class BaseClause(UnicodeMixin): + + def __init__(self, field, value): + self.field = field + self.value = value + self.context_id = None + + def __unicode__(self): + raise NotImplementedError + + def __hash__(self): + return hash(self.field) ^ hash(self.value) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.field == other.field and self.value == other.value + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_context_size(self): + """ returns the number of entries this clause will add to the query context """ + return 1 + + def set_context_id(self, i): + """ sets the value placeholder that will be used in the query """ + self.context_id = i + + def update_context(self, ctx): + """ updates the query context with this clauses values """ + assert isinstance(ctx, dict) + ctx[str(self.context_id)] = self.value + + +class WhereClause(BaseClause): + """ a single where statement used in queries """ + + def __init__(self, field, operator, value, quote_field=True): + """ + + :param field: + :param operator: + :param value: + :param quote_field: hack to get the token function rendering properly + :return: + """ + if not isinstance(operator, BaseWhereOperator): + raise StatementException( + "operator must be of type {}, got {}".format(BaseWhereOperator, type(operator)) + ) + super(WhereClause, self).__init__(field, value) + self.operator = operator + self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value) + self.quote_field = quote_field + + def __unicode__(self): + field = ('"{}"' if self.quote_field else '{}').format(self.field) + return u'{} {} {}'.format(field, self.operator, six.text_type(self.query_value)) + + def __hash__(self): + return super(WhereClause, self).__hash__() ^ hash(self.operator) + + def __eq__(self, other): + if super(WhereClause, self).__eq__(other): + return self.operator.__class__ == other.operator.__class__ + return False + + def get_context_size(self): + return self.query_value.get_context_size() + + def set_context_id(self, i): + super(WhereClause, self).set_context_id(i) + self.query_value.set_context_id(i) + + def update_context(self, ctx): + if isinstance(self.operator, InOperator): + ctx[str(self.context_id)] = InQuoter(self.value) + else: + self.query_value.update_context(ctx) + + +class AssignmentClause(BaseClause): + """ a single variable st statement """ + + def __unicode__(self): + return u'"{}" = %({})s'.format(self.field, self.context_id) + + def insert_tuple(self): + return self.field, self.context_id + + +class TransactionClause(BaseClause): + """ A single variable iff statement """ + + def __unicode__(self): + return u'"{}" = %({})s'.format(self.field, self.context_id) + + def insert_tuple(self): + return self.field, self.context_id + + +class ContainerUpdateClause(AssignmentClause): + + def __init__(self, field, value, operation=None, previous=None, column=None): + super(ContainerUpdateClause, self).__init__(field, value) + self.previous = previous + self._assignments = None + self._operation = operation + self._analyzed = False + self._column = column + + def _to_database(self, val): + return self._column.to_database(val) if self._column else val + + def _analyze(self): + raise NotImplementedError + + def get_context_size(self): + raise NotImplementedError + + def update_context(self, ctx): + raise NotImplementedError + + +class SetUpdateClause(ContainerUpdateClause): + """ updates a set collection """ + + def __init__(self, field, value, operation=None, previous=None, column=None): + super(SetUpdateClause, self).__init__(field, value, operation, previous, column=column) + self._additions = None + self._removals = None + + def __unicode__(self): + qs = [] + ctx_id = self.context_id + if (self.previous is None and + self._assignments is None and + self._additions is None and + self._removals is None): + qs += ['"{}" = %({})s'.format(self.field, ctx_id)] + if self._assignments is not None: + qs += ['"{}" = %({})s'.format(self.field, ctx_id)] + ctx_id += 1 + if self._additions is not None: + qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] + ctx_id += 1 + if self._removals is not None: + qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] + + return ', '.join(qs) + + def _analyze(self): + """ works out the updates to be performed """ + if self.value is None or self.value == self.previous: + pass + elif self._operation == "add": + self._additions = self.value + elif self._operation == "remove": + self._removals = self.value + elif self.previous is None: + self._assignments = self.value + else: + # partial update time + self._additions = (self.value - self.previous) or None + self._removals = (self.previous - self.value) or None + self._analyzed = True + + def get_context_size(self): + if not self._analyzed: + self._analyze() + if (self.previous is None and + not self._assignments and + self._additions is None and + self._removals is None): + return 1 + return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals)) + + def update_context(self, ctx): + if not self._analyzed: + self._analyze() + ctx_id = self.context_id + if (self.previous is None and + self._assignments is None and + self._additions is None and + self._removals is None): + ctx[str(ctx_id)] = self._to_database({}) + if self._assignments is not None: + ctx[str(ctx_id)] = self._to_database(self._assignments) + ctx_id += 1 + if self._additions is not None: + ctx[str(ctx_id)] = self._to_database(self._additions) + ctx_id += 1 + if self._removals is not None: + ctx[str(ctx_id)] = self._to_database(self._removals) + + +class ListUpdateClause(ContainerUpdateClause): + """ updates a list collection """ + + def __init__(self, field, value, operation=None, previous=None, column=None): + super(ListUpdateClause, self).__init__(field, value, operation, previous, column=column) + self._append = None + self._prepend = None + + def __unicode__(self): + if not self._analyzed: + self._analyze() + qs = [] + ctx_id = self.context_id + if self._assignments is not None: + qs += ['"{}" = %({})s'.format(self.field, ctx_id)] + ctx_id += 1 + + if self._prepend is not None: + qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)] + ctx_id += 1 + + if self._append is not None: + qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] + + return ', '.join(qs) + + def get_context_size(self): + if not self._analyzed: + self._analyze() + return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend)) + + def update_context(self, ctx): + if not self._analyzed: + self._analyze() + ctx_id = self.context_id + if self._assignments is not None: + ctx[str(ctx_id)] = self._to_database(self._assignments) + ctx_id += 1 + if self._prepend is not None: + msg = "Previous versions of cqlengine implicitly reversed prepended lists to account for CASSANDRA-8733. " \ + "THIS VERSION DOES NOT. This warning will be removed in a future release." + warnings.warn(msg) + log.warning(msg) + + ctx[str(ctx_id)] = self._to_database(self._prepend) + ctx_id += 1 + if self._append is not None: + ctx[str(ctx_id)] = self._to_database(self._append) + + def _analyze(self): + """ works out the updates to be performed """ + if self.value is None or self.value == self.previous: + pass + + elif self._operation == "append": + self._append = self.value + + elif self._operation == "prepend": + self._prepend = self.value + + elif self.previous is None: + self._assignments = self.value + + elif len(self.value) < len(self.previous): + # if elements have been removed, + # rewrite the whole list + self._assignments = self.value + + elif len(self.previous) == 0: + # if we're updating from an empty + # list, do a complete insert + self._assignments = self.value + else: + + # the max start idx we want to compare + search_space = len(self.value) - max(0, len(self.previous) - 1) + + # the size of the sub lists we want to look at + search_size = len(self.previous) + + for i in range(search_space): + # slice boundary + j = i + search_size + sub = self.value[i:j] + idx_cmp = lambda idx: self.previous[idx] == sub[idx] + if idx_cmp(0) and idx_cmp(-1) and self.previous == sub: + self._prepend = self.value[:i] or None + self._append = self.value[j:] or None + break + + # if both append and prepend are still None after looking + # at both lists, an insert statement will be created + if self._prepend is self._append is None: + self._assignments = self.value + + self._analyzed = True + + +class MapUpdateClause(ContainerUpdateClause): + """ updates a map collection """ + + def __init__(self, field, value, operation=None, previous=None, column=None): + super(MapUpdateClause, self).__init__(field, value, operation, previous, column=column) + self._updates = None + + def _analyze(self): + if self._operation == "update": + self._updates = self.value.keys() + else: + if self.previous is None: + self._updates = sorted([k for k, v in self.value.items()]) + else: + self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None + self._analyzed = True + + def get_context_size(self): + if not self._analyzed: + self._analyze() + if self.previous is None and not self._updates: + return 1 + return len(self._updates or []) * 2 + + def update_context(self, ctx): + if not self._analyzed: + self._analyze() + ctx_id = self.context_id + if self.previous is None and not self._updates: + ctx[str(ctx_id)] = {} + else: + for key in self._updates or []: + val = self.value.get(key) + ctx[str(ctx_id)] = self._column.key_col.to_database(key) if self._column else key + ctx[str(ctx_id + 1)] = self._column.value_col.to_database(val) if self._column else val + ctx_id += 2 + + def __unicode__(self): + if not self._analyzed: + self._analyze() + qs = [] + + ctx_id = self.context_id + if self.previous is None and not self._updates: + qs += ['"{}" = %({})s'.format(self.field, ctx_id)] + else: + for _ in self._updates or []: + qs += ['"{}"[%({})s] = %({})s'.format(self.field, ctx_id, ctx_id + 1)] + ctx_id += 2 + + return ', '.join(qs) + + +class CounterUpdateClause(ContainerUpdateClause): + + def __init__(self, field, value, previous=None, column=None): + super(CounterUpdateClause, self).__init__(field, value, previous=previous, column=column) + self.previous = self.previous or 0 + + def get_context_size(self): + return 1 + + def update_context(self, ctx): + ctx[str(self.context_id)] = self._to_database(abs(self.value - self.previous)) + + def __unicode__(self): + delta = self.value - self.previous + sign = '-' if delta < 0 else '+' + return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id) + + +class BaseDeleteClause(BaseClause): + pass + + +class FieldDeleteClause(BaseDeleteClause): + """ deletes a field from a row """ + + def __init__(self, field): + super(FieldDeleteClause, self).__init__(field, None) + + def __unicode__(self): + return '"{}"'.format(self.field) + + def update_context(self, ctx): + pass + + def get_context_size(self): + return 0 + + +class MapDeleteClause(BaseDeleteClause): + """ removes keys from a map """ + + def __init__(self, field, value, previous=None): + super(MapDeleteClause, self).__init__(field, value) + self.value = self.value or {} + self.previous = previous or {} + self._analyzed = False + self._removals = None + + def _analyze(self): + self._removals = sorted([k for k in self.previous if k not in self.value]) + self._analyzed = True + + def update_context(self, ctx): + if not self._analyzed: + self._analyze() + for idx, key in enumerate(self._removals): + ctx[str(self.context_id + idx)] = key + + def get_context_size(self): + if not self._analyzed: + self._analyze() + return len(self._removals) + + def __unicode__(self): + if not self._analyzed: + self._analyze() + return ', '.join(['"{}"[%({})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))]) + + +class BaseCQLStatement(UnicodeMixin): + """ The base cql statement class """ + + def __init__(self, table, consistency=None, timestamp=None, where=None): + super(BaseCQLStatement, self).__init__() + self.table = table + self.consistency = consistency + self.context_id = 0 + self.context_counter = self.context_id + self.timestamp = timestamp + + self.where_clauses = [] + for clause in where or []: + self.add_where_clause(clause) + + def add_where_clause(self, clause): + """ + adds a where clause to this statement + :param clause: the clause to add + :type clause: WhereClause + """ + if not isinstance(clause, WhereClause): + raise StatementException("only instances of WhereClause can be added to statements") + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.where_clauses.append(clause) + + def get_context(self): + """ + returns the context dict for this statement + :rtype: dict + """ + ctx = {} + for clause in self.where_clauses or []: + clause.update_context(ctx) + return ctx + + def get_context_size(self): + return len(self.get_context()) + + def update_context_id(self, i): + self.context_id = i + self.context_counter = self.context_id + for clause in self.where_clauses: + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + + @property + def timestamp_normalized(self): + """ + we're expecting self.timestamp to be either a long, int, a datetime, or a timedelta + :return: + """ + if not self.timestamp: + return None + + if isinstance(self.timestamp, six.integer_types): + return self.timestamp + + if isinstance(self.timestamp, timedelta): + tmp = datetime.now() + self.timestamp + else: + tmp = self.timestamp + + return int(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond) + + def __unicode__(self): + raise NotImplementedError + + def __repr__(self): + return self.__unicode__() + + @property + def _where(self): + return 'WHERE {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses])) + + +class SelectStatement(BaseCQLStatement): + """ a cql select statement """ + + def __init__(self, + table, + fields=None, + count=False, + consistency=None, + where=None, + order_by=None, + limit=None, + allow_filtering=False): + + """ + :param where + :type where list of cqlengine.statements.WhereClause + """ + super(SelectStatement, self).__init__( + table, + consistency=consistency, + where=where + ) + + self.fields = [fields] if isinstance(fields, six.string_types) else (fields or []) + self.count = count + self.order_by = [order_by] if isinstance(order_by, six.string_types) else order_by + self.limit = limit + self.allow_filtering = allow_filtering + + def __unicode__(self): + qs = ['SELECT'] + if self.count: + qs += ['COUNT(*)'] + else: + qs += [', '.join(['"{}"'.format(f) for f in self.fields]) if self.fields else '*'] + qs += ['FROM', self.table] + + if self.where_clauses: + qs += [self._where] + + if self.order_by and not self.count: + qs += ['ORDER BY {}'.format(', '.join(six.text_type(o) for o in self.order_by))] + + if self.limit: + qs += ['LIMIT {}'.format(self.limit)] + + if self.allow_filtering: + qs += ['ALLOW FILTERING'] + + return ' '.join(qs) + + +class AssignmentStatement(BaseCQLStatement): + """ value assignment statements """ + + def __init__(self, + table, + assignments=None, + consistency=None, + where=None, + ttl=None, + timestamp=None): + super(AssignmentStatement, self).__init__( + table, + consistency=consistency, + where=where, + ) + self.ttl = ttl + self.timestamp = timestamp + + # add assignments + self.assignments = [] + for assignment in assignments or []: + self.add_assignment_clause(assignment) + + def update_context_id(self, i): + super(AssignmentStatement, self).update_context_id(i) + for assignment in self.assignments: + assignment.set_context_id(self.context_counter) + self.context_counter += assignment.get_context_size() + + def add_assignment_clause(self, clause): + """ + adds an assignment clause to this statement + :param clause: the clause to add + :type clause: AssignmentClause + """ + if not isinstance(clause, AssignmentClause): + raise StatementException("only instances of AssignmentClause can be added to statements") + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.assignments.append(clause) + + @property + def is_empty(self): + return len(self.assignments) == 0 + + def get_context(self): + ctx = super(AssignmentStatement, self).get_context() + for clause in self.assignments: + clause.update_context(ctx) + return ctx + + +class InsertStatement(AssignmentStatement): + """ an cql insert select statement """ + + def __init__(self, + table, + assignments=None, + consistency=None, + where=None, + ttl=None, + timestamp=None, + if_not_exists=False): + super(InsertStatement, self).__init__(table, + assignments=assignments, + consistency=consistency, + where=where, + ttl=ttl, + timestamp=timestamp) + + self.if_not_exists = if_not_exists + + def add_where_clause(self, clause): + raise StatementException("Cannot add where clauses to insert statements") + + def __unicode__(self): + qs = ['INSERT INTO {}'.format(self.table)] + + # get column names and context placeholders + fields = [a.insert_tuple() for a in self.assignments] + columns, values = zip(*fields) + + qs += ["({})".format(', '.join(['"{}"'.format(c) for c in columns]))] + qs += ['VALUES'] + qs += ["({})".format(', '.join(['%({})s'.format(v) for v in values]))] + + if self.if_not_exists: + qs += ["IF NOT EXISTS"] + + if self.ttl: + qs += ["USING TTL {}".format(self.ttl)] + + if self.timestamp: + qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)] + + return ' '.join(qs) + + +class UpdateStatement(AssignmentStatement): + """ an cql update select statement """ + + def __init__(self, + table, + assignments=None, + consistency=None, + where=None, + ttl=None, + timestamp=None, + transactions=None): + super(UpdateStatement, self). __init__(table, + assignments=assignments, + consistency=consistency, + where=where, + ttl=ttl, + timestamp=timestamp) + + # Add iff statements + self.transactions = [] + for transaction in transactions or []: + self.add_transaction_clause(transaction) + + def __unicode__(self): + qs = ['UPDATE', self.table] + + using_options = [] + + if self.ttl: + using_options += ["TTL {}".format(self.ttl)] + + if self.timestamp: + using_options += ["TIMESTAMP {}".format(self.timestamp_normalized)] + + if using_options: + qs += ["USING {}".format(" AND ".join(using_options))] + + qs += ['SET'] + qs += [', '.join([six.text_type(c) for c in self.assignments])] + + if self.where_clauses: + qs += [self._where] + + if len(self.transactions) > 0: + qs += [self._get_transactions()] + + return ' '.join(qs) + + def add_transaction_clause(self, clause): + """ + Adds a iff clause to this statement + + :param clause: The clause that will be added to the iff statement + :type clause: TransactionClause + """ + if not isinstance(clause, TransactionClause): + raise StatementException('only instances of AssignmentClause can be added to statements') + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.transactions.append(clause) + + def get_context(self): + ctx = super(UpdateStatement, self).get_context() + for clause in self.transactions or []: + clause.update_context(ctx) + return ctx + + def _get_transactions(self): + return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions])) + + def update_context_id(self, i): + super(UpdateStatement, self).update_context_id(i) + for transaction in self.transactions: + transaction.set_context_id(self.context_counter) + self.context_counter += transaction.get_context_size() + + +class DeleteStatement(BaseCQLStatement): + """ a cql delete statement """ + + def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None): + super(DeleteStatement, self).__init__( + table, + consistency=consistency, + where=where, + timestamp=timestamp + ) + self.fields = [] + if isinstance(fields, six.string_types): + fields = [fields] + for field in fields or []: + self.add_field(field) + + def update_context_id(self, i): + super(DeleteStatement, self).update_context_id(i) + for field in self.fields: + field.set_context_id(self.context_counter) + self.context_counter += field.get_context_size() + + def get_context(self): + ctx = super(DeleteStatement, self).get_context() + for field in self.fields: + field.update_context(ctx) + return ctx + + def add_field(self, field): + if isinstance(field, six.string_types): + field = FieldDeleteClause(field) + if not isinstance(field, BaseClause): + raise StatementException("only instances of AssignmentClause can be added to statements") + field.set_context_id(self.context_counter) + self.context_counter += field.get_context_size() + self.fields.append(field) + + def __unicode__(self): + qs = ['DELETE'] + if self.fields: + qs += [', '.join(['{}'.format(f) for f in self.fields])] + qs += ['FROM', self.table] + + delete_option = [] + + if self.timestamp: + delete_option += ["TIMESTAMP {}".format(self.timestamp_normalized)] + + if delete_option: + qs += [" USING {} ".format(" AND ".join(delete_option))] + + if self.where_clauses: + qs += [self._where] + + return ' '.join(qs) diff --git a/cassandra/cqlengine/usertype.py b/cassandra/cqlengine/usertype.py new file mode 100644 index 0000000..1e30fc8 --- /dev/null +++ b/cassandra/cqlengine/usertype.py @@ -0,0 +1,204 @@ +import re +import six + +from cassandra.util import OrderedDict +from cassandra.cqlengine import CQLEngineException +from cassandra.cqlengine import columns +from cassandra.cqlengine import connection +from cassandra.cqlengine import models + + +class UserTypeException(CQLEngineException): + pass + + +class UserTypeDefinitionException(UserTypeException): + pass + + +class BaseUserType(object): + """ + The base type class; don't inherit from this, inherit from UserType, defined below + """ + __type_name__ = None + + _fields = None + _db_map = None + + def __init__(self, **values): + self._values = {} + + for name, field in self._fields.items(): + value = values.get(name, None) + if value is not None or isinstance(field, columns.BaseContainerColumn): + value = field.to_python(value) + value_mngr = field.value_manager(self, field, value) + if name in values: + value_mngr.explicit = True + self._values[name] = value_mngr + + def __eq__(self, other): + if self.__class__ != other.__class__: + return False + + keys = set(self._fields.keys()) + other_keys = set(other._fields.keys()) + if keys != other_keys: + return False + + for key in other_keys: + if getattr(self, key, None) != getattr(other, key, None): + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "{{{}}}".format(', '.join("'{}': {}".format(k, getattr(self, k)) for k, v in six.iteritems(self._values))) + + def has_changed_fields(self): + return any(v.changed for v in self._values.values()) + + def reset_changed_fields(self): + for v in self._values.values(): + v.reset_previous_value() + + def __iter__(self): + for field in self._fields.keys(): + yield field + + def __getitem__(self, key): + if not isinstance(key, six.string_types): + raise TypeError + if key not in self._fields.keys(): + raise KeyError + return getattr(self, key) + + def __setitem__(self, key, val): + if not isinstance(key, six.string_types): + raise TypeError + if key not in self._fields.keys(): + raise KeyError + return setattr(self, key, val) + + def __len__(self): + try: + return self._len + except: + self._len = len(self._columns.keys()) + return self._len + + def keys(self): + """ Returns a list of column IDs. """ + return [k for k in self] + + def values(self): + """ Returns list of column values. """ + return [self[k] for k in self] + + def items(self): + """ Returns a list of column ID/value tuples. """ + return [(k, self[k]) for k in self] + + @classmethod + def register_for_keyspace(cls, keyspace): + connection.register_udt(keyspace, cls.type_name(), cls) + + @classmethod + def type_name(cls): + """ + Returns the type name if it's been defined + otherwise, it creates it from the class name + """ + if cls.__type_name__: + type_name = cls.__type_name__.lower() + else: + camelcase = re.compile(r'([a-z])([A-Z])') + ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2)), s) + + type_name = ccase(cls.__name__) + # trim to less than 48 characters or cassandra will complain + type_name = type_name[-48:] + type_name = type_name.lower() + type_name = re.sub(r'^_+', '', type_name) + cls.__type_name__ = type_name + + return type_name + + def validate(self): + """ + Cleans and validates the field values + """ + pass + for name, field in self._fields.items(): + v = getattr(self, name) + if v is None and not self._values[name].explicit and field.has_default: + v = field.get_default() + val = field.validate(v) + setattr(self, name, val) + + +class UserTypeMetaClass(type): + + def __new__(cls, name, bases, attrs): + field_dict = OrderedDict() + + field_defs = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] + field_defs = sorted(field_defs, key=lambda x: x[1].position) + + def _transform_column(field_name, field_obj): + field_dict[field_name] = field_obj + field_obj.set_column_name(field_name) + attrs[field_name] = models.ColumnDescriptor(field_obj) + + # transform field definitions + for k, v in field_defs: + # don't allow a field with the same name as a built-in attribute or method + if k in BaseUserType.__dict__: + raise UserTypeDefinitionException("field '{}' conflicts with built-in attribute/method".format(k)) + _transform_column(k, v) + + # create db_name -> model name map for loading + db_map = {} + for field_name, field in field_dict.items(): + db_map[field.db_field_name] = field_name + + attrs['_fields'] = field_dict + attrs['_db_map'] = db_map + + klass = super(UserTypeMetaClass, cls).__new__(cls, name, bases, attrs) + + return klass + + +@six.add_metaclass(UserTypeMetaClass) +class UserType(BaseUserType): + """ + This class is used to model User Defined Types. To define a type, declare a class inheriting from this, + and assign field types as class attributes: + + .. code-block:: python + + # connect with default keyspace ... + + from cassandra.cqlengine.columns import Text, Integer + from cassandra.cqlengine.usertype import UserType + + class address(UserType): + street = Text() + zipcode = Integer() + + from cassandra.cqlengine import management + management.sync_type(address) + + Please see :ref:`user_types` for a complete example and discussion. + """ + + __type_name__ = None + """ + *Optional.* Sets the name of the CQL type for this type. + + If not specified, the type name will be the name of the class, with it's module name as it's prefix. + """ diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py new file mode 100644 index 0000000..5b95d4c --- /dev/null +++ b/cassandra/cqltypes.py @@ -0,0 +1,1075 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Representation of Cassandra data types. These classes should make it simple for +the library (and caller software) to deal with Cassandra-style Java class type +names and CQL type specifiers, and convert between them cleanly. Parameterized +types are fully supported in both flavors. Once you have the right Type object +for the type you want, you can use it to serialize, deserialize, or retrieve +the corresponding CQL or Cassandra type strings. +""" + +# NOTE: +# If/when the need arises for interpret types from CQL string literals in +# different ways (for https://issues.apache.org/jira/browse/CASSANDRA-3799, +# for example), these classes would be a good place to tack on +# .from_cql_literal() and .as_cql_literal() classmethods (or whatever). + +from __future__ import absolute_import # to enable import io from stdlib +from binascii import unhexlify +import calendar +from collections import namedtuple +from decimal import Decimal +import io +import logging +import re +import socket +import time +import six +from six.moves import range +import sys +from uuid import UUID +import warnings + + +from cassandra.marshal import (int8_pack, int8_unpack, + uint16_pack, uint16_unpack, uint32_pack, uint32_unpack, + int32_pack, int32_unpack, int64_pack, int64_unpack, + float_pack, float_unpack, double_pack, double_unpack, + varint_pack, varint_unpack) +from cassandra import util + +apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' + +log = logging.getLogger(__name__) + +if six.PY3: + _number_types = frozenset((int, float)) + long = int + + def _name_from_hex_string(encoded_name): + bin_str = unhexlify(encoded_name) + return bin_str.decode('ascii') +else: + _number_types = frozenset((int, long, float)) + _name_from_hex_string = unhexlify + + +def trim_if_startswith(s, prefix): + if s.startswith(prefix): + return s[len(prefix):] + return s + + +def unix_time_from_uuid1(u): + msg = "'cassandra.cqltypes.unix_time_from_uuid1' has moved to 'cassandra.util'. This entry point will be removed in the next major version." + warnings.warn(msg, DeprecationWarning) + log.warning(msg) + return util.unix_time_from_uuid1(u) + + +def datetime_from_timestamp(timestamp): + msg = "'cassandra.cqltypes.datetime_from_timestamp' has moved to 'cassandra.util'. This entry point will be removed in the next major version." + warnings.warn(msg, DeprecationWarning) + log.warning(msg) + return util.datetime_from_timestamp(timestamp) + + +_casstypes = {} + + +class CassandraTypeType(type): + """ + The CassandraType objects in this module will normally be used directly, + rather than through instances of those types. They can be instantiated, + of course, but the type information is what this driver mainly needs. + + This metaclass registers CassandraType classes in the global + by-cassandra-typename and by-cql-typename registries, unless their class + name starts with an underscore. + """ + + def __new__(metacls, name, bases, dct): + dct.setdefault('cassname', name) + cls = type.__new__(metacls, name, bases, dct) + if not name.startswith('_'): + _casstypes[name] = cls + return cls + + +casstype_scanner = re.Scanner(( + (r'[()]', lambda s, t: t), + (r'[a-zA-Z0-9_.:=>]+', lambda s, t: t), + (r'[\s,]', None), +)) + + +def lookup_casstype_simple(casstype): + """ + Given a Cassandra type name (either fully distinguished or not), hand + back the CassandraType class responsible for it. If a name is not + recognized, a custom _UnrecognizedType subclass will be created for it. + + This function does not handle complex types (so no type parameters-- + nothing with parentheses). Use lookup_casstype() instead if you might need + that. + """ + shortname = trim_if_startswith(casstype, apache_cassandra_type_prefix) + try: + typeclass = _casstypes[shortname] + except KeyError: + typeclass = mkUnrecognizedType(casstype) + return typeclass + + +def parse_casstype_args(typestring): + tokens, remainder = casstype_scanner.scan(typestring) + if remainder: + raise ValueError("weird characters %r at end" % remainder) + + # use a stack of (types, names) lists + args = [([], [])] + for tok in tokens: + if tok == '(': + args.append(([], [])) + elif tok == ')': + types, names = args.pop() + prev_types, prev_names = args[-1] + prev_types[-1] = prev_types[-1].apply_parameters(types, names) + else: + types, names = args[-1] + parts = re.split(':|=>', tok) + tok = parts.pop() + if parts: + names.append(parts[0]) + else: + names.append(None) + + ctype = lookup_casstype_simple(tok) + types.append(ctype) + + # return the first (outer) type, which will have all parameters applied + return args[0][0][0] + + +def lookup_casstype(casstype): + """ + Given a Cassandra type as a string (possibly including parameters), hand + back the CassandraType class responsible for it. If a name is not + recognized, a custom _UnrecognizedType subclass will be created for it. + + Example: + + >>> lookup_casstype('org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type)') + + + """ + if isinstance(casstype, (CassandraType, CassandraTypeType)): + return casstype + try: + return parse_casstype_args(casstype) + except (ValueError, AssertionError, IndexError) as e: + raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e)) + + +class EmptyValue(object): + """ See _CassandraType.support_empty_values """ + + def __str__(self): + return "EMPTY" + __repr__ = __str__ + +EMPTY = EmptyValue() + + +@six.add_metaclass(CassandraTypeType) +class _CassandraType(object): + subtypes = () + num_subtypes = 0 + empty_binary_ok = False + + support_empty_values = False + """ + Back in the Thrift days, empty strings were used for "null" values of + all types, including non-string types. For most users, an empty + string value in an int column is the same as being null/not present, + so the driver normally returns None in this case. (For string-like + types, it *will* return an empty string by default instead of None.) + + To avoid this behavior, set this to :const:`True`. Instead of returning + None for empty string values, the EMPTY singleton (an instance + of EmptyValue) will be returned. + """ + + def __init__(self, val): + self.val = self.validate(val) + + def __repr__(self): + return '<%s( %r )>' % (self.cql_parameterized_type(), self.val) + + @staticmethod + def validate(val): + """ + Called to transform an input value into one of a suitable type + for this class. As an example, the BooleanType class uses this + to convert an incoming value to True or False. + """ + return val + + @classmethod + def from_binary(cls, byts, protocol_version): + """ + Deserialize a bytestring into a value. See the deserialize() method + for more information. This method differs in that if None or the empty + string is passed in, None may be returned. + """ + if byts is None: + return None + elif len(byts) == 0 and not cls.empty_binary_ok: + return EMPTY if cls.support_empty_values else None + return cls.deserialize(byts, protocol_version) + + @classmethod + def to_binary(cls, val, protocol_version): + """ + Serialize a value into a bytestring. See the serialize() method for + more information. This method differs in that if None is passed in, + the result is the empty string. + """ + return b'' if val is None else cls.serialize(val, protocol_version) + + @staticmethod + def deserialize(byts, protocol_version): + """ + Given a bytestring, deserialize into a value according to the protocol + for this type. Note that this does not create a new instance of this + class; it merely gives back a value that would be appropriate to go + inside an instance of this class. + """ + return byts + + @staticmethod + def serialize(val, protocol_version): + """ + Given a value appropriate for this class, serialize it according to the + protocol for this type and return the corresponding bytestring. + """ + return val + + @classmethod + def cass_parameterized_type_with(cls, subtypes, full=False): + """ + Return the name of this type as it would be expressed by Cassandra, + optionally fully qualified. If subtypes is not None, it is expected + to be a list of other CassandraType subclasses, and the output + string includes the Cassandra names for those subclasses as well, + as parameters to this one. + + Example: + + >>> LongType.cass_parameterized_type_with(()) + 'LongType' + >>> LongType.cass_parameterized_type_with((), full=True) + 'org.apache.cassandra.db.marshal.LongType' + >>> SetType.cass_parameterized_type_with([DecimalType], full=True) + 'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)' + """ + cname = cls.cassname + if full and '.' not in cname: + cname = apache_cassandra_type_prefix + cname + if not subtypes: + return cname + sublist = ', '.join(styp.cass_parameterized_type(full=full) for styp in subtypes) + return '%s(%s)' % (cname, sublist) + + @classmethod + def apply_parameters(cls, subtypes, names=None): + """ + Given a set of other CassandraTypes, create a new subtype of this type + using them as parameters. This is how composite types are constructed. + + >>> MapType.apply_parameters(DateType, BooleanType) + + + `subtypes` will be a sequence of CassandraTypes. If provided, `names` + will be an equally long sequence of column names or Nones. + """ + if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes: + raise ValueError("%s types require %d subtypes (%d given)" + % (cls.typename, cls.num_subtypes, len(subtypes))) + newname = cls.cass_parameterized_type_with(subtypes) + if six.PY2 and isinstance(newname, unicode): + newname = newname.encode('utf-8') + return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names}) + + @classmethod + def cql_parameterized_type(cls): + """ + Return a CQL type specifier for this type. If this type has parameters, + they are included in standard CQL <> notation. + """ + if not cls.subtypes: + return cls.typename + return '%s<%s>' % (cls.typename, ', '.join(styp.cql_parameterized_type() for styp in cls.subtypes)) + + @classmethod + def cass_parameterized_type(cls, full=False): + """ + Return a Cassandra type specifier for this type. If this type has + parameters, they are included in the standard () notation. + """ + return cls.cass_parameterized_type_with(cls.subtypes, full=full) + + +# it's initially named with a _ to avoid registering it as a real type, but +# client programs may want to use the name still for isinstance(), etc +CassandraType = _CassandraType + + +class _UnrecognizedType(_CassandraType): + num_subtypes = 'UNKNOWN' + + +if six.PY3: + def mkUnrecognizedType(casstypename): + return CassandraTypeType(casstypename, + (_UnrecognizedType,), + {'typename': "'%s'" % casstypename}) +else: + def mkUnrecognizedType(casstypename): # noqa + return CassandraTypeType(casstypename.encode('utf8'), + (_UnrecognizedType,), + {'typename': "'%s'" % casstypename}) + + +class BytesType(_CassandraType): + typename = 'blob' + empty_binary_ok = True + + @staticmethod + def validate(val): + return bytearray(val) + + @staticmethod + def serialize(val, protocol_version): + return six.binary_type(val) + + +class DecimalType(_CassandraType): + typename = 'decimal' + + @staticmethod + def validate(val): + return Decimal(val) + + @staticmethod + def deserialize(byts, protocol_version): + scale = int32_unpack(byts[:4]) + unscaled = varint_unpack(byts[4:]) + return Decimal('%de%d' % (unscaled, -scale)) + + @staticmethod + def serialize(dec, protocol_version): + try: + sign, digits, exponent = dec.as_tuple() + except AttributeError: + raise TypeError("Non-Decimal type received for Decimal value") + unscaled = int(''.join([str(digit) for digit in digits])) + if sign: + unscaled *= -1 + scale = int32_pack(-exponent) + unscaled = varint_pack(unscaled) + return scale + unscaled + + +class UUIDType(_CassandraType): + typename = 'uuid' + + @staticmethod + def deserialize(byts, protocol_version): + return UUID(bytes=byts) + + @staticmethod + def serialize(uuid, protocol_version): + try: + return uuid.bytes + except AttributeError: + raise TypeError("Got a non-UUID object for a UUID value") + + +class BooleanType(_CassandraType): + typename = 'boolean' + + @staticmethod + def validate(val): + return bool(val) + + @staticmethod + def deserialize(byts, protocol_version): + return bool(int8_unpack(byts)) + + @staticmethod + def serialize(truth, protocol_version): + return int8_pack(truth) + + +if six.PY2: + class AsciiType(_CassandraType): + typename = 'ascii' + empty_binary_ok = True +else: + class AsciiType(_CassandraType): + typename = 'ascii' + empty_binary_ok = True + + @staticmethod + def deserialize(byts, protocol_version): + return byts.decode('ascii') + + @staticmethod + def serialize(var, protocol_version): + try: + return var.encode('ascii') + except UnicodeDecodeError: + return var + + +class FloatType(_CassandraType): + typename = 'float' + + @staticmethod + def deserialize(byts, protocol_version): + return float_unpack(byts) + + @staticmethod + def serialize(byts, protocol_version): + return float_pack(byts) + + +class DoubleType(_CassandraType): + typename = 'double' + + @staticmethod + def deserialize(byts, protocol_version): + return double_unpack(byts) + + @staticmethod + def serialize(byts, protocol_version): + return double_pack(byts) + + +class LongType(_CassandraType): + typename = 'bigint' + + @staticmethod + def deserialize(byts, protocol_version): + return int64_unpack(byts) + + @staticmethod + def serialize(byts, protocol_version): + return int64_pack(byts) + + +class Int32Type(_CassandraType): + typename = 'int' + + @staticmethod + def deserialize(byts, protocol_version): + return int32_unpack(byts) + + @staticmethod + def serialize(byts, protocol_version): + return int32_pack(byts) + + +class IntegerType(_CassandraType): + typename = 'varint' + + @staticmethod + def deserialize(byts, protocol_version): + return varint_unpack(byts) + + @staticmethod + def serialize(byts, protocol_version): + return varint_pack(byts) + + +have_ipv6_packing = hasattr(socket, 'inet_ntop') + + +class InetAddressType(_CassandraType): + typename = 'inet' + + # TODO: implement basic ipv6 support for Windows? + # inet_ntop and inet_pton aren't available on Windows + + @staticmethod + def deserialize(byts, protocol_version): + if len(byts) == 16: + if not have_ipv6_packing: + raise Exception( + "IPv6 addresses cannot currently be handled on Windows") + return socket.inet_ntop(socket.AF_INET6, byts) + else: + return socket.inet_ntoa(byts) + + @staticmethod + def serialize(addr, protocol_version): + if ':' in addr: + fam = socket.AF_INET6 + if not have_ipv6_packing: + raise Exception( + "IPv6 addresses cannot currently be handled on Windows") + return socket.inet_pton(fam, addr) + else: + fam = socket.AF_INET + return socket.inet_aton(addr) + + +class CounterColumnType(LongType): + typename = 'counter' + +cql_timestamp_formats = ( + '%Y-%m-%d %H:%M', + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%dT%H:%M', + '%Y-%m-%dT%H:%M:%S', + '%Y-%m-%d' +) + +_have_warned_about_timestamps = False + + +class DateType(_CassandraType): + typename = 'timestamp' + + @classmethod + def validate(cls, val): + if isinstance(val, six.string_types): + val = cls.interpret_datestring(val) + return val + + @staticmethod + def interpret_datestring(val): + if val[-5] in ('+', '-'): + offset = (int(val[-4:-2]) * 3600 + int(val[-2:]) * 60) * int(val[-5] + '1') + val = val[:-5] + else: + offset = -time.timezone + for tformat in cql_timestamp_formats: + try: + tval = time.strptime(val, tformat) + except ValueError: + continue + # scale seconds to millis for the raw value + return (calendar.timegm(tval) + offset) * 1e3 + else: + raise ValueError("can't interpret %r as a date" % (val,)) + + def my_timestamp(self): + return self.val + + @staticmethod + def deserialize(byts, protocol_version): + timestamp = int64_unpack(byts) / 1000.0 + return util.datetime_from_timestamp(timestamp) + + @staticmethod + def serialize(v, protocol_version): + try: + # v is datetime + timestamp_seconds = calendar.timegm(v.utctimetuple()) + timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3 + except AttributeError: + # Ints and floats are valid timestamps too + if type(v) not in _number_types: + raise TypeError('DateType arguments must be a datetime or timestamp') + timestamp = v + + return int64_pack(long(timestamp)) + + +class TimestampType(DateType): + pass + + +class TimeUUIDType(DateType): + typename = 'timeuuid' + + def my_timestamp(self): + return unix_time_from_uuid1(self.val) + + @staticmethod + def deserialize(byts, protocol_version): + return UUID(bytes=byts) + + @staticmethod + def serialize(timeuuid, protocol_version): + try: + return timeuuid.bytes + except AttributeError: + raise TypeError("Got a non-UUID object for a UUID value") + + +class SimpleDateType(_CassandraType): + typename = 'date' + date_format = "%Y-%m-%d" + + @classmethod + def validate(cls, val): + if not isinstance(val, util.Date): + val = util.Date(val) + return val + + @staticmethod + def serialize(val, protocol_version): + # Values of the 'date'` type are encoded as 32-bit unsigned integers + # representing a number of days with epoch (January 1st, 1970) at the center of the + # range (2^31). + try: + days = val.days_from_epoch + except AttributeError: + days = util.Date(val).days_from_epoch + return uint32_pack(days + 2 ** 31) + + @staticmethod + def deserialize(byts, protocol_version): + days = uint32_unpack(byts) - 2 ** 31 + return util.Date(days) + + +class TimeType(_CassandraType): + typename = 'time' + + @classmethod + def validate(cls, val): + if not isinstance(val, util.Time): + val = util.Time(val) + return val + + @staticmethod + def serialize(val, protocol_version): + try: + nano = val.nanosecond_time + except AttributeError: + nano = util.Time(val).nanosecond_time + return int64_pack(nano) + + @staticmethod + def deserialize(byts, protocol_version): + return util.Time(int64_unpack(byts)) + + +class UTF8Type(_CassandraType): + typename = 'text' + empty_binary_ok = True + + @staticmethod + def deserialize(byts, protocol_version): + return byts.decode('utf8') + + @staticmethod + def serialize(ustr, protocol_version): + try: + return ustr.encode('utf-8') + except UnicodeDecodeError: + # already utf-8 + return ustr + + +class VarcharType(UTF8Type): + typename = 'varchar' + + +class _ParameterizedType(_CassandraType): + def __init__(self, val): + if not self.subtypes: + raise ValueError("%s type with no parameters can't be instantiated" % (self.typename,)) + _CassandraType.__init__(self, val) + + @classmethod + def deserialize(cls, byts, protocol_version): + if not cls.subtypes: + raise NotImplementedError("can't deserialize unparameterized %s" + % cls.typename) + return cls.deserialize_safe(byts, protocol_version) + + @classmethod + def serialize(cls, val, protocol_version): + if not cls.subtypes: + raise NotImplementedError("can't serialize unparameterized %s" + % cls.typename) + return cls.serialize_safe(val, protocol_version) + + +class _SimpleParameterizedType(_ParameterizedType): + @classmethod + def validate(cls, val): + subtype, = cls.subtypes + return cls.adapter([subtype.validate(subval) for subval in val]) + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + subtype, = cls.subtypes + if protocol_version >= 3: + unpack = int32_unpack + length = 4 + else: + unpack = uint16_unpack + length = 2 + numelements = unpack(byts[:length]) + p = length + result = [] + for _ in range(numelements): + itemlen = unpack(byts[p:p + length]) + p += length + item = byts[p:p + itemlen] + p += itemlen + result.append(subtype.from_binary(item, protocol_version)) + return cls.adapter(result) + + @classmethod + def serialize_safe(cls, items, protocol_version): + if isinstance(items, six.string_types): + raise TypeError("Received a string for a type that expects a sequence") + + subtype, = cls.subtypes + pack = int32_pack if protocol_version >= 3 else uint16_pack + buf = io.BytesIO() + buf.write(pack(len(items))) + for item in items: + itembytes = subtype.to_binary(item, protocol_version) + buf.write(pack(len(itembytes))) + buf.write(itembytes) + return buf.getvalue() + + +class ListType(_SimpleParameterizedType): + typename = 'list' + num_subtypes = 1 + adapter = list + + +class SetType(_SimpleParameterizedType): + typename = 'set' + num_subtypes = 1 + adapter = util.sortedset + + +class MapType(_ParameterizedType): + typename = 'map' + num_subtypes = 2 + + @classmethod + def validate(cls, val): + key_type, value_type = cls.subtypes + return dict((key_type.validate(k), value_type.validate(v)) for (k, v) in six.iteritems(val)) + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + key_type, value_type = cls.subtypes + if protocol_version >= 3: + unpack = int32_unpack + length = 4 + else: + unpack = uint16_unpack + length = 2 + numelements = unpack(byts[:length]) + p = length + themap = util.OrderedMapSerializedKey(key_type, protocol_version) + for _ in range(numelements): + key_len = unpack(byts[p:p + length]) + p += length + keybytes = byts[p:p + key_len] + p += key_len + val_len = unpack(byts[p:p + length]) + p += length + valbytes = byts[p:p + val_len] + p += val_len + key = key_type.from_binary(keybytes, protocol_version) + val = value_type.from_binary(valbytes, protocol_version) + themap._insert_unchecked(key, keybytes, val) + return themap + + @classmethod + def serialize_safe(cls, themap, protocol_version): + key_type, value_type = cls.subtypes + pack = int32_pack if protocol_version >= 3 else uint16_pack + buf = io.BytesIO() + buf.write(pack(len(themap))) + try: + items = six.iteritems(themap) + except AttributeError: + raise TypeError("Got a non-map object for a map value") + for key, val in items: + keybytes = key_type.to_binary(key, protocol_version) + valbytes = value_type.to_binary(val, protocol_version) + buf.write(pack(len(keybytes))) + buf.write(keybytes) + buf.write(pack(len(valbytes))) + buf.write(valbytes) + return buf.getvalue() + + +class TupleType(_ParameterizedType): + typename = 'tuple' + num_subtypes = 'UNKNOWN' + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + proto_version = max(3, protocol_version) + p = 0 + values = [] + for col_type in cls.subtypes: + if p == len(byts): + break + itemlen = int32_unpack(byts[p:p + 4]) + p += 4 + if itemlen >= 0: + item = byts[p:p + itemlen] + p += itemlen + else: + item = None + # collections inside UDTs are always encoded with at least the + # version 3 format + values.append(col_type.from_binary(item, proto_version)) + + if len(values) < len(cls.subtypes): + nones = [None] * (len(cls.subtypes) - len(values)) + values = values + nones + + return tuple(values) + + @classmethod + def serialize_safe(cls, val, protocol_version): + if len(val) > len(cls.subtypes): + raise ValueError("Expected %d items in a tuple, but got %d: %s" % + (len(cls.subtypes), len(val), val)) + + proto_version = max(3, protocol_version) + buf = io.BytesIO() + for item, subtype in zip(val, cls.subtypes): + if item is not None: + packed_item = subtype.to_binary(item, proto_version) + buf.write(int32_pack(len(packed_item))) + buf.write(packed_item) + else: + buf.write(int32_pack(-1)) + return buf.getvalue() + + @classmethod + def cql_parameterized_type(cls): + subtypes_string = ', '.join(sub.cql_parameterized_type() for sub in cls.subtypes) + return 'frozen>' % (subtypes_string,) + + +class UserType(TupleType): + typename = "'org.apache.cassandra.db.marshal.UserType'" + + _cache = {} + _module = sys.modules[__name__] + + @classmethod + def make_udt_class(cls, keyspace, udt_name, names_and_types, mapped_class): + if six.PY2 and isinstance(udt_name, unicode): + udt_name = udt_name.encode('utf-8') + try: + return cls._cache[(keyspace, udt_name)] + except KeyError: + field_names, types = zip(*names_and_types) + instance = type(udt_name, (cls,), {'subtypes': types, + 'cassname': cls.cassname, + 'typename': udt_name, + 'fieldnames': field_names, + 'keyspace': keyspace, + 'mapped_class': mapped_class, + 'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)}) + cls._cache[(keyspace, udt_name)] = instance + return instance + + @classmethod + def evict_udt_class(cls, keyspace, udt_name): + if six.PY2 and isinstance(udt_name, unicode): + udt_name = udt_name.encode('utf-8') + try: + del cls._cache[(keyspace, udt_name)] + except KeyError: + pass + + @classmethod + def apply_parameters(cls, subtypes, names): + keyspace = subtypes[0] + udt_name = _name_from_hex_string(subtypes[1].cassname) + field_names = [_name_from_hex_string(encoded_name) for encoded_name in names[2:]] + assert len(field_names) == len(subtypes[2:]) + return type(udt_name, (cls,), {'subtypes': subtypes[2:], + 'cassname': cls.cassname, + 'typename': udt_name, + 'fieldnames': field_names, + 'keyspace': keyspace, + 'mapped_class': None, + 'tuple_type': namedtuple(udt_name, field_names)}) + + @classmethod + def cql_parameterized_type(cls): + return "frozen<%s>" % (cls.typename,) + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + proto_version = max(3, protocol_version) + p = 0 + values = [] + for col_type in cls.subtypes: + if p == len(byts): + break + itemlen = int32_unpack(byts[p:p + 4]) + p += 4 + if itemlen >= 0: + item = byts[p:p + itemlen] + p += itemlen + else: + item = None + # collections inside UDTs are always encoded with at least the + # version 3 format + values.append(col_type.from_binary(item, proto_version)) + + if len(values) < len(cls.subtypes): + nones = [None] * (len(cls.subtypes) - len(values)) + values = values + nones + + if cls.mapped_class: + return cls.mapped_class(**dict(zip(cls.fieldnames, values))) + else: + return cls.tuple_type(*values) + + @classmethod + def serialize_safe(cls, val, protocol_version): + proto_version = max(3, protocol_version) + buf = io.BytesIO() + for fieldname, subtype in zip(cls.fieldnames, cls.subtypes): + item = getattr(val, fieldname) + if item is not None: + packed_item = subtype.to_binary(getattr(val, fieldname), proto_version) + buf.write(int32_pack(len(packed_item))) + buf.write(packed_item) + else: + buf.write(int32_pack(-1)) + return buf.getvalue() + + @classmethod + def _make_registered_udt_namedtuple(cls, keyspace, name, field_names): + # this is required to make the type resolvable via this module... + # required when unregistered udts are pickled for use as keys in + # util.OrderedMap + qualified_name = "%s_%s" % (keyspace, name) + nt = getattr(cls._module, qualified_name, None) + if not nt: + nt = namedtuple(qualified_name, field_names) + setattr(cls._module, qualified_name, nt) + return nt + + +class CompositeType(_ParameterizedType): + typename = "'org.apache.cassandra.db.marshal.CompositeType'" + num_subtypes = 'UNKNOWN' + + @classmethod + def cql_parameterized_type(cls): + """ + There is no CQL notation for Composites, so we override this. + """ + typestring = cls.cass_parameterized_type(full=True) + return "'%s'" % (typestring,) + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + result = [] + for subtype in cls.subtypes: + if not byts: + # CompositeType can have missing elements at the end + break + + element_length = uint16_unpack(byts[:2]) + element = byts[2:2 + element_length] + + # skip element length, element, and the EOC (one byte) + byts = byts[2 + element_length + 1:] + result.append(subtype.from_binary(element, protocol_version)) + + return tuple(result) + + +class DynamicCompositeType(CompositeType): + typename = "'org.apache.cassandra.db.marshal.DynamicCompositeType'" + + +class ColumnToCollectionType(_ParameterizedType): + """ + This class only really exists so that we can cleanly evaluate types when + Cassandra includes this. We don't actually need or want the extra + information. + """ + typename = "'org.apache.cassandra.db.marshal.ColumnToCollectionType'" + num_subtypes = 'UNKNOWN' + + +class ReversedType(_ParameterizedType): + typename = "'org.apache.cassandra.db.marshal.ReversedType'" + num_subtypes = 1 + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + subtype, = cls.subtypes + return subtype.from_binary(byts) + + @classmethod + def serialize_safe(cls, val, protocol_version): + subtype, = cls.subtypes + return subtype.to_binary(val, protocol_version) + + +class FrozenType(_ParameterizedType): + typename = "frozen" + num_subtypes = 1 + + @classmethod + def deserialize_safe(cls, byts, protocol_version): + subtype, = cls.subtypes + return subtype.from_binary(byts) + + @classmethod + def serialize_safe(cls, val, protocol_version): + subtype, = cls.subtypes + return subtype.to_binary(val, protocol_version) + + +def is_counter_type(t): + if isinstance(t, six.string_types): + t = lookup_casstype(t) + return issubclass(t, CounterColumnType) + + +def cql_typename(casstypename): + """ + Translate a Cassandra-style type specifier (optionally-fully-distinguished + Java class names for data types, along with optional parameters) into a + CQL-style type specifier. + + >>> cql_typename('DateType') + 'timestamp' + >>> cql_typename('org.apache.cassandra.db.marshal.ListType(IntegerType)') + 'list' + """ + return lookup_casstype(casstypename).cql_parameterized_type() diff --git a/cassandra/decoder.py b/cassandra/decoder.py new file mode 100644 index 0000000..5c44d5b --- /dev/null +++ b/cassandra/decoder.py @@ -0,0 +1,58 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps +import warnings + +import cassandra.query + +import logging +log = logging.getLogger(__name__) + +_have_warned = False + + +def warn_once(f): + + @wraps(f) + def new_f(*args, **kwargs): + global _have_warned + if not _have_warned: + msg = "cassandra.decoder.%s has moved to cassandra.query.%s" % (f.__name__, f.__name__) + warnings.warn(msg, DeprecationWarning) + log.warning(msg) + _have_warned = True + return f(*args, **kwargs) + + return new_f + +tuple_factory = warn_once(cassandra.query.tuple_factory) +""" +Deprecated: use :meth:`cassandra.query.tuple_factory()` +""" + +named_tuple_factory = warn_once(cassandra.query.named_tuple_factory) +""" +Deprecated: use :meth:`cassandra.query.named_tuple_factory()` +""" + +dict_factory = warn_once(cassandra.query.dict_factory) +""" +Deprecated: use :meth:`cassandra.query.dict_factory()` +""" + +ordered_dict_factory = warn_once(cassandra.query.ordered_dict_factory) +""" +Deprecated: use :meth:`cassandra.query.ordered_dict_factory()` +""" diff --git a/cassandra/encoder.py b/cassandra/encoder.py new file mode 100644 index 0000000..02eed2a --- /dev/null +++ b/cassandra/encoder.py @@ -0,0 +1,212 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +These functions are used to convert Python objects into CQL strings. +When non-prepared statements are executed, these encoder functions are +called on each query parameter. +""" + +import logging +log = logging.getLogger(__name__) + +from binascii import hexlify +import calendar +import datetime +import sys +import types +from uuid import UUID +import six + +from cassandra.util import OrderedDict, OrderedMap, sortedset, Time + +if six.PY3: + long = int + + +def cql_quote(term): + # The ordering of this method is important for the result of this method to + # be a native str type (for both Python 2 and 3) + + # Handle quoting of native str and bool types + if isinstance(term, (str, bool)): + return "'%s'" % str(term).replace("'", "''") + # This branch of the if statement will only be used by Python 2 to catch + # unicode strings, text_type is used to prevent type errors with Python 3. + elif isinstance(term, six.text_type): + return "'%s'" % term.encode('utf8').replace("'", "''") + else: + return str(term) + + +class ValueSequence(list): + pass + + +class Encoder(object): + """ + A container for mapping python types to CQL string literals when working + with non-prepared statements. The type :attr:`~.Encoder.mapping` can be + directly customized by users. + """ + + mapping = None + """ + A map of python types to encoder functions. + """ + + def __init__(self): + self.mapping = { + float: self.cql_encode_float, + bytearray: self.cql_encode_bytes, + str: self.cql_encode_str, + int: self.cql_encode_object, + UUID: self.cql_encode_object, + datetime.datetime: self.cql_encode_datetime, + datetime.date: self.cql_encode_date, + datetime.time: self.cql_encode_time, + Time: self.cql_encode_time, + dict: self.cql_encode_map_collection, + OrderedDict: self.cql_encode_map_collection, + OrderedMap: self.cql_encode_map_collection, + list: self.cql_encode_list_collection, + tuple: self.cql_encode_list_collection, + set: self.cql_encode_set_collection, + sortedset: self.cql_encode_set_collection, + frozenset: self.cql_encode_set_collection, + types.GeneratorType: self.cql_encode_list_collection, + ValueSequence: self.cql_encode_sequence + } + + if six.PY2: + self.mapping.update({ + unicode: self.cql_encode_unicode, + buffer: self.cql_encode_bytes, + long: self.cql_encode_object, + types.NoneType: self.cql_encode_none, + }) + else: + self.mapping.update({ + memoryview: self.cql_encode_bytes, + bytes: self.cql_encode_bytes, + type(None): self.cql_encode_none, + }) + + def cql_encode_none(self, val): + """ + Converts :const:`None` to the string 'NULL'. + """ + return 'NULL' + + def cql_encode_unicode(self, val): + """ + Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping. + """ + return cql_quote(val.encode('utf-8')) + + def cql_encode_str(self, val): + """ + Escapes quotes in :class:`str` objects. + """ + return cql_quote(val) + + if six.PY3: + def cql_encode_bytes(self, val): + return (b'0x' + hexlify(val)).decode('utf-8') + elif sys.version_info >= (2, 7): + def cql_encode_bytes(self, val): # noqa + return b'0x' + hexlify(val) + else: + # python 2.6 requires string or read-only buffer for hexlify + def cql_encode_bytes(self, val): # noqa + return b'0x' + hexlify(buffer(val)) + + def cql_encode_object(self, val): + """ + Default encoder for all objects that do not have a specific encoder function + registered. This function simply calls :meth:`str()` on the object. + """ + return str(val) + + def cql_encode_float(self, val): + """ + Encode floats using repr to preserve precision + """ + return repr(val) + + def cql_encode_datetime(self, val): + """ + Converts a :class:`datetime.datetime` object to a (string) integer timestamp + with millisecond precision. + """ + timestamp = calendar.timegm(val.utctimetuple()) + return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3)) + + def cql_encode_date(self, val): + """ + Converts a :class:`datetime.date` object to a string with format + ``YYYY-MM-DD``. + """ + return "'%s'" % val.strftime('%Y-%m-%d') + + def cql_encode_time(self, val): + """ + Converts a :class:`datetime.date` object to a string with format + ``HH:MM:SS.mmmuuunnn``. + """ + return "'%s'" % val + + def cql_encode_sequence(self, val): + """ + Converts a sequence to a string of the form ``(item1, item2, ...)``. This + is suitable for ``IN`` value lists. + """ + return '( %s )' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) + for v in val) + + cql_encode_tuple = cql_encode_sequence + """ + Converts a sequence to a string of the form ``(item1, item2, ...)``. This + is suitable for ``tuple`` type columns. + """ + + def cql_encode_map_collection(self, val): + """ + Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``. + This is suitable for ``map`` type columns. + """ + return '{ %s }' % ' , '.join('%s : %s' % ( + self.mapping.get(type(k), self.cql_encode_object)(k), + self.mapping.get(type(v), self.cql_encode_object)(v) + ) for k, v in six.iteritems(val)) + + def cql_encode_list_collection(self, val): + """ + Converts a sequence to a string of the form ``[item1, item2, ...]``. This + is suitable for ``list`` type columns. + """ + return '[ %s ]' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) + + def cql_encode_set_collection(self, val): + """ + Converts a sequence to a string of the form ``{item1, item2, ...}``. This + is suitable for ``set`` type columns. + """ + return '{ %s }' % ' , '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) + + def cql_encode_all_types(self, val): + """ + Converts any type into a CQL string, defaulting to ``cql_encode_object`` + if :attr:`~Encoder.mapping` does not contain an entry for the type. + """ + return self.mapping.get(type(val), self.cql_encode_object)(val) diff --git a/cassandra/io/__init__.py b/cassandra/io/__init__.py new file mode 100644 index 0000000..e4b89e5 --- /dev/null +++ b/cassandra/io/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py new file mode 100644 index 0000000..2ac7156 --- /dev/null +++ b/cassandra/io/asyncorereactor.py @@ -0,0 +1,349 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import atexit +from collections import deque +from functools import partial +import logging +import os +import socket +import sys +from threading import Event, Lock, Thread +import weakref + +from six.moves import range + +from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL, EISCONN, errorcode +try: + from weakref import WeakSet +except ImportError: + from cassandra.util import WeakSet # noqa + +import asyncore + +try: + import ssl +except ImportError: + ssl = None # NOQA + +from cassandra import OperationTimedOut +from cassandra.connection import (Connection, ConnectionShutdown, + ConnectionException, NONBLOCKING) +from cassandra.protocol import RegisterMessage + +log = logging.getLogger(__name__) + + +def _cleanup(loop_weakref): + try: + loop = loop_weakref() + except ReferenceError: + return + + loop._cleanup() + + +class AsyncoreLoop(object): + + def __init__(self): + self._pid = os.getpid() + self._loop_lock = Lock() + self._started = False + self._shutdown = False + + self._conns_lock = Lock() + self._conns = WeakSet() + self._thread = None + atexit.register(partial(_cleanup, weakref.ref(self))) + + def maybe_start(self): + should_start = False + did_acquire = False + try: + did_acquire = self._loop_lock.acquire(False) + if did_acquire and not self._started: + self._started = True + should_start = True + finally: + if did_acquire: + self._loop_lock.release() + + if should_start: + self._thread = Thread(target=self._run_loop, name="cassandra_driver_event_loop") + self._thread.daemon = True + self._thread.start() + + def _run_loop(self): + log.debug("Starting asyncore event loop") + with self._loop_lock: + while True: + try: + asyncore.loop(timeout=0.001, use_poll=True, count=1000) + except Exception: + log.debug("Asyncore event loop stopped unexepectedly", exc_info=True) + break + + if self._shutdown: + break + + with self._conns_lock: + if len(self._conns) == 0: + break + + self._started = False + + log.debug("Asyncore event loop ended") + + def _cleanup(self): + self._shutdown = True + if not self._thread: + return + + log.debug("Waiting for event loop thread to join...") + self._thread.join(timeout=1.0) + if self._thread.is_alive(): + log.warning( + "Event loop thread could not be joined, so shutdown may not be clean. " + "Please call Cluster.shutdown() to avoid this.") + + log.debug("Event loop thread was joined") + + def connection_created(self, connection): + with self._conns_lock: + self._conns.add(connection) + + def connection_destroyed(self, connection): + with self._conns_lock: + self._conns.discard(connection) + + +class AsyncoreConnection(Connection, asyncore.dispatcher): + """ + An implementation of :class:`.Connection` that uses the ``asyncore`` + module in the Python standard library for its event loop. + """ + + _loop = None + + _total_reqd_bytes = 0 + _writable = False + _readable = False + + @classmethod + def initialize_reactor(cls): + if not cls._loop: + cls._loop = AsyncoreLoop() + else: + current_pid = os.getpid() + if cls._loop._pid != current_pid: + log.debug("Detected fork, clearing and reinitializing reactor state") + cls.handle_fork() + cls._loop = AsyncoreLoop() + + @classmethod + def handle_fork(cls): + if cls._loop: + cls._loop._cleanup() + cls._loop = None + + @classmethod + def factory(cls, *args, **kwargs): + timeout = kwargs.pop('timeout', 5.0) + conn = cls(*args, **kwargs) + conn.connected_event.wait(timeout) + if conn.last_error: + raise conn.last_error + elif not conn.connected_event.is_set(): + conn.close() + raise OperationTimedOut("Timed out creating connection") + else: + return conn + + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) + asyncore.dispatcher.__init__(self) + + self.connected_event = Event() + + self._callbacks = {} + self.deque = deque() + self.deque_lock = Lock() + + self._loop.connection_created(self) + + sockerr = None + addresses = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) + for (af, socktype, proto, canonname, sockaddr) in addresses: + try: + self.create_socket(af, socktype) + self.connect(sockaddr) + sockerr = None + break + except socket.error as err: + sockerr = err + if sockerr: + raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % ([a[4] for a in addresses], sockerr.strerror)) + + self.add_channel() + + if self.sockopts: + for args in self.sockopts: + self.socket.setsockopt(*args) + + self._writable = True + self._readable = True + + # start the event loop if needed + self._loop.maybe_start() + + def set_socket(self, sock): + # Overrides the same method in asyncore. We deliberately + # do not call add_channel() in this method so that we can call + # it later, after connect() has completed. + self.socket = sock + self._fileno = sock.fileno() + + def create_socket(self, family, type): + # copied from asyncore, but with the line to set the socket in + # non-blocking mode removed (we will do that after connecting) + self.family_and_type = family, type + sock = socket.socket(family, type) + if self.ssl_options: + if not ssl: + raise Exception("This version of Python was not compiled with SSL support") + sock = ssl.wrap_socket(sock, **self.ssl_options) + self.set_socket(sock) + + def connect(self, address): + # this is copied directly from asyncore.py, except that + # a timeout is set before connecting + self.connected = False + self.connecting = True + self.socket.settimeout(1.0) + err = self.socket.connect_ex(address) + if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \ + or err == EINVAL and os.name in ('nt', 'ce'): + raise ConnectionException("Timed out connecting to %s" % (address[0])) + if err in (0, EISCONN): + self.addr = address + self.socket.setblocking(0) + self.handle_connect_event() + else: + raise socket.error(err, os.strerror(err)) + + def close(self): + with self.lock: + if self.is_closed: + return + self.is_closed = True + + log.debug("Closing connection (%s) to %s", id(self), self.host) + self._writable = False + self._readable = False + asyncore.dispatcher.close(self) + log.debug("Closed socket to %s", self.host) + + self._loop.connection_destroyed(self) + + if not self.is_defunct: + self.error_all_callbacks( + ConnectionShutdown("Connection to %s was closed" % self.host)) + # don't leave in-progress operations hanging + self.connected_event.set() + + def handle_connect(self): + self._send_options_message() + + def handle_error(self): + self.defunct(sys.exc_info()[1]) + + def handle_close(self): + log.debug("Connection %s closed by server", self) + self.close() + + def handle_write(self): + while True: + with self.deque_lock: + try: + next_msg = self.deque.popleft() + except IndexError: + self._writable = False + return + + try: + sent = self.send(next_msg) + self._readable = True + except socket.error as err: + if (err.args[0] in NONBLOCKING): + with self.deque_lock: + self.deque.appendleft(next_msg) + else: + self.defunct(err) + return + else: + if sent < len(next_msg): + with self.deque_lock: + self.deque.appendleft(next_msg[sent:]) + if sent == 0: + return + + def handle_read(self): + try: + while True: + buf = self.recv(self.in_buffer_size) + self._iobuf.write(buf) + if len(buf) < self.in_buffer_size: + break + except socket.error as err: + if ssl and isinstance(err, ssl.SSLError): + if err.args[0] not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + self.defunct(err) + return + elif err.args[0] not in NONBLOCKING: + self.defunct(err) + return + + if self._iobuf.tell(): + self.process_io_buffer() + if not self._callbacks and not self.is_control_connection: + self._readable = False + + def push(self, data): + sabs = self.out_buffer_size + if len(data) > sabs: + chunks = [] + for i in range(0, len(data), sabs): + chunks.append(data[i:i + sabs]) + else: + chunks = [data] + + with self.deque_lock: + self.deque.extend(chunks) + self._writable = True + + def writable(self): + return self._writable + + def readable(self): + return self._readable or (self.is_control_connection and not (self.is_defunct or self.is_closed)) + + def register_watcher(self, event_type, callback, register_timeout=None): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), timeout=register_timeout) diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py new file mode 100644 index 0000000..ceac6a9 --- /dev/null +++ b/cassandra/io/eventletreactor.py @@ -0,0 +1,193 @@ +# Copyright 2014 Symantec Corporation +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Originally derived from MagnetoDB source: +# https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py + +from collections import defaultdict +from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL +import eventlet +from eventlet.green import select, socket +from eventlet.queue import Queue +from functools import partial +import logging +import os +from threading import Event + +from six.moves import xrange + +from cassandra import OperationTimedOut +from cassandra.connection import Connection, ConnectionShutdown +from cassandra.protocol import RegisterMessage + + +log = logging.getLogger(__name__) + + +def is_timeout(err): + return ( + err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or + (err == EINVAL and os.name in ('nt', 'ce')) + ) + + +class EventletConnection(Connection): + """ + An implementation of :class:`.Connection` that utilizes ``eventlet``. + """ + + _total_reqd_bytes = 0 + _read_watcher = None + _write_watcher = None + _socket = None + + @classmethod + def initialize_reactor(cls): + eventlet.monkey_patch() + + @classmethod + def factory(cls, *args, **kwargs): + timeout = kwargs.pop('timeout', 5.0) + conn = cls(*args, **kwargs) + conn.connected_event.wait(timeout) + if conn.last_error: + raise conn.last_error + elif not conn.connected_event.is_set(): + conn.close() + raise OperationTimedOut("Timed out creating connection") + else: + return conn + + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) + + self.connected_event = Event() + self._write_queue = Queue() + + self._callbacks = {} + self._push_watchers = defaultdict(set) + + sockerr = None + addresses = socket.getaddrinfo( + self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) + for (af, socktype, proto, canonname, sockaddr) in addresses: + try: + self._socket = socket.socket(af, socktype, proto) + self._socket.settimeout(1.0) + self._socket.connect(sockaddr) + sockerr = None + break + except socket.error as err: + sockerr = err + if sockerr: + raise socket.error( + sockerr.errno, + "Tried connecting to %s. Last error: %s" % ( + [a[4] for a in addresses], sockerr.strerror) + ) + + if self.sockopts: + for args in self.sockopts: + self._socket.setsockopt(*args) + + self._read_watcher = eventlet.spawn(lambda: self.handle_read()) + self._write_watcher = eventlet.spawn(lambda: self.handle_write()) + self._send_options_message() + + def close(self): + with self.lock: + if self.is_closed: + return + self.is_closed = True + + log.debug("Closing connection (%s) to %s" % (id(self), self.host)) + + cur_gthread = eventlet.getcurrent() + + if self._read_watcher and self._read_watcher != cur_gthread: + self._read_watcher.kill() + if self._write_watcher and self._write_watcher != cur_gthread: + self._write_watcher.kill() + if self._socket: + self._socket.close() + log.debug("Closed socket to %s" % (self.host,)) + + if not self.is_defunct: + self.error_all_callbacks( + ConnectionShutdown("Connection to %s was closed" % self.host)) + # don't leave in-progress operations hanging + self.connected_event.set() + + def handle_close(self): + log.debug("connection closed by server") + self.close() + + def handle_write(self): + while True: + try: + next_msg = self._write_queue.get() + self._socket.sendall(next_msg) + except socket.error as err: + log.debug("Exception during socket send for %s: %s", self, err) + self.defunct(err) + return # Leave the write loop + + def handle_read(self): + run_select = partial(select.select, (self._socket,), (), ()) + while True: + try: + run_select() + except Exception as exc: + if not self.is_closed: + log.debug("Exception during read select() for %s: %s", + self, exc) + self.defunct(exc) + return + + try: + buf = self._socket.recv(self.in_buffer_size) + self._iobuf.write(buf) + except socket.error as err: + if not is_timeout(err): + log.debug("Exception during socket recv for %s: %s", + self, err) + self.defunct(err) + return # leave the read loop + + if self._iobuf.tell(): + self.process_io_buffer() + else: + log.debug("Connection %s closed by server", self) + self.close() + return + + def push(self, data): + chunk_size = self.out_buffer_size + for i in xrange(0, len(data), chunk_size): + self._write_queue.put(data[i:i + chunk_size]) + + def register_watcher(self, event_type, callback, register_timeout=None): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), + timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), + timeout=register_timeout) diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py new file mode 100644 index 0000000..4cd9c68 --- /dev/null +++ b/cassandra/io/geventreactor.py @@ -0,0 +1,189 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gevent +from gevent import select, socket, ssl +from gevent.event import Event +from gevent.queue import Queue + +from collections import defaultdict +from functools import partial +import logging +import os + +from six.moves import xrange + +from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL + +from cassandra import OperationTimedOut +from cassandra.connection import Connection, ConnectionShutdown +from cassandra.protocol import RegisterMessage + + +log = logging.getLogger(__name__) + + +def is_timeout(err): + return ( + err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or + (err == EINVAL and os.name in ('nt', 'ce')) + ) + + +class GeventConnection(Connection): + """ + An implementation of :class:`.Connection` that utilizes ``gevent``. + """ + + _total_reqd_bytes = 0 + _read_watcher = None + _write_watcher = None + _socket = None + + @classmethod + def factory(cls, *args, **kwargs): + timeout = kwargs.pop('timeout', 5.0) + conn = cls(*args, **kwargs) + conn.connected_event.wait(timeout) + if conn.last_error: + raise conn.last_error + elif not conn.connected_event.is_set(): + conn.close() + raise OperationTimedOut("Timed out creating connection") + else: + return conn + + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) + + self.connected_event = Event() + self._write_queue = Queue() + + self._callbacks = {} + self._push_watchers = defaultdict(set) + + sockerr = None + addresses = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) + for (af, socktype, proto, canonname, sockaddr) in addresses: + try: + self._socket = socket.socket(af, socktype, proto) + if self.ssl_options: + self._socket = ssl.wrap_socket(self._socket, **self.ssl_options) + self._socket.settimeout(1.0) + self._socket.connect(sockaddr) + sockerr = None + break + except socket.error as err: + sockerr = err + if sockerr: + raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % ([a[4] for a in addresses], sockerr.strerror)) + + if self.sockopts: + for args in self.sockopts: + self._socket.setsockopt(*args) + + self._read_watcher = gevent.spawn(self.handle_read) + self._write_watcher = gevent.spawn(self.handle_write) + self._send_options_message() + + def close(self): + with self.lock: + if self.is_closed: + return + self.is_closed = True + + log.debug("Closing connection (%s) to %s" % (id(self), self.host)) + if self._read_watcher: + self._read_watcher.kill(block=False) + if self._write_watcher: + self._write_watcher.kill(block=False) + if self._socket: + self._socket.close() + log.debug("Closed socket to %s" % (self.host,)) + + if not self.is_defunct: + self.error_all_callbacks( + ConnectionShutdown("Connection to %s was closed" % self.host)) + # don't leave in-progress operations hanging + self.connected_event.set() + + def handle_close(self): + log.debug("connection closed by server") + self.close() + + def handle_write(self): + run_select = partial(select.select, (), (self._socket,), ()) + while True: + try: + next_msg = self._write_queue.get() + run_select() + except Exception as exc: + if not self.is_closed: + log.debug("Exception during write select() for %s: %s", self, exc) + self.defunct(exc) + return + + try: + self._socket.sendall(next_msg) + except socket.error as err: + log.debug("Exception during socket sendall for %s: %s", self, err) + self.defunct(err) + return # Leave the write loop + + def handle_read(self): + run_select = partial(select.select, (self._socket,), (), ()) + while True: + try: + run_select() + except Exception as exc: + if not self.is_closed: + log.debug("Exception during read select() for %s: %s", self, exc) + self.defunct(exc) + return + + try: + while True: + buf = self._socket.recv(self.in_buffer_size) + self._iobuf.write(buf) + if len(buf) < self.in_buffer_size: + break + except socket.error as err: + if not is_timeout(err): + log.debug("Exception during socket recv for %s: %s", self, err) + self.defunct(err) + return # leave the read loop + + if self._iobuf.tell(): + self.process_io_buffer() + else: + log.debug("Connection %s closed by server", self) + self.close() + return + + def push(self, data): + chunk_size = self.out_buffer_size + for i in xrange(0, len(data), chunk_size): + self._write_queue.put(data[i:i + chunk_size]) + + def register_watcher(self, event_type, callback, register_timeout=None): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), + timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), + timeout=register_timeout) diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py new file mode 100644 index 0000000..db11eaf --- /dev/null +++ b/cassandra/io/libevreactor.py @@ -0,0 +1,395 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import atexit +from collections import deque +from functools import partial +import logging +import os +import socket +from threading import Event, Lock, Thread +import weakref + +from six.moves import xrange + +from cassandra import OperationTimedOut +from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING +from cassandra.protocol import RegisterMessage +try: + import cassandra.io.libevwrapper as libev +except ImportError: + raise ImportError( + "The C extension needed to use libev was not found. This " + "probably means that you didn't have the required build dependencies " + "when installing the driver. See " + "http://datastax.github.io/python-driver/installation.html#c-extensions " + "for instructions on installing build dependencies and building " + "the C extension.") + + +try: + import ssl +except ImportError: + ssl = None # NOQA + +log = logging.getLogger(__name__) + + +def _cleanup(loop_weakref): + try: + loop = loop_weakref() + except ReferenceError: + return + + loop._cleanup() + + +class LibevLoop(object): + + def __init__(self): + self._pid = os.getpid() + self._loop = libev.Loop() + self._notifier = libev.Async(self._loop) + self._notifier.start() + + # prevent _notifier from keeping the loop from returning + self._loop.unref() + + self._started = False + self._shutdown = False + self._lock = Lock() + + self._thread = None + + # set of all connections; only replaced with a new copy + # while holding _conn_set_lock, never modified in place + self._live_conns = set() + # newly created connections that need their write/read watcher started + self._new_conns = set() + # recently closed connections that need their write/read watcher stopped + self._closed_conns = set() + self._conn_set_lock = Lock() + + self._preparer = libev.Prepare(self._loop, self._loop_will_run) + # prevent _preparer from keeping the loop from returning + self._loop.unref() + self._preparer.start() + + atexit.register(partial(_cleanup, weakref.ref(self))) + + def notify(self): + self._notifier.send() + + def maybe_start(self): + should_start = False + with self._lock: + if not self._started: + log.debug("Starting libev event loop") + self._started = True + should_start = True + + if should_start: + self._thread = Thread(target=self._run_loop, name="event_loop") + self._thread.daemon = True + self._thread.start() + + self._notifier.send() + + def _run_loop(self): + while True: + end_condition = self._loop.start() + # there are still active watchers, no deadlock + with self._lock: + if not self._shutdown and (end_condition or self._live_conns): + log.debug("Restarting event loop") + continue + else: + # all Connections have been closed, no active watchers + log.debug("All Connections currently closed, event loop ended") + self._started = False + break + + def _cleanup(self): + self._shutdown = True + if not self._thread: + return + + for conn in self._live_conns | self._new_conns | self._closed_conns: + conn.close() + if conn._write_watcher: + conn._write_watcher.stop() + del conn._write_watcher + if conn._read_watcher: + conn._read_watcher.stop() + del conn._read_watcher + + log.debug("Waiting for event loop thread to join...") + self._thread.join(timeout=1.0) + if self._thread.is_alive(): + log.warning( + "Event loop thread could not be joined, so shutdown may not be clean. " + "Please call Cluster.shutdown() to avoid this.") + + log.debug("Event loop thread was joined") + self._loop = None + + def connection_created(self, conn): + with self._conn_set_lock: + new_live_conns = self._live_conns.copy() + new_live_conns.add(conn) + self._live_conns = new_live_conns + + new_new_conns = self._new_conns.copy() + new_new_conns.add(conn) + self._new_conns = new_new_conns + + def connection_destroyed(self, conn): + with self._conn_set_lock: + new_live_conns = self._live_conns.copy() + new_live_conns.discard(conn) + self._live_conns = new_live_conns + + new_closed_conns = self._closed_conns.copy() + new_closed_conns.add(conn) + self._closed_conns = new_closed_conns + + self._notifier.send() + + def _loop_will_run(self, prepare): + changed = False + for conn in self._live_conns: + if not conn.deque and conn._write_watcher_is_active: + if conn._write_watcher: + conn._write_watcher.stop() + conn._write_watcher_is_active = False + changed = True + elif conn.deque and not conn._write_watcher_is_active: + conn._write_watcher.start() + conn._write_watcher_is_active = True + changed = True + + if self._new_conns: + with self._conn_set_lock: + to_start = self._new_conns + self._new_conns = set() + + for conn in to_start: + conn._read_watcher.start() + + changed = True + + if self._closed_conns: + with self._conn_set_lock: + to_stop = self._closed_conns + self._closed_conns = set() + + for conn in to_stop: + if conn._write_watcher: + conn._write_watcher.stop() + # clear reference cycles from IO callback + del conn._write_watcher + if conn._read_watcher: + conn._read_watcher.stop() + # clear reference cycles from IO callback + del conn._read_watcher + + changed = True + + if changed: + self._notifier.send() + + +class LibevConnection(Connection): + """ + An implementation of :class:`.Connection` that uses libev for its event loop. + """ + _libevloop = None + _write_watcher_is_active = False + _total_reqd_bytes = 0 + _read_watcher = None + _write_watcher = None + _socket = None + + @classmethod + def initialize_reactor(cls): + if not cls._libevloop: + cls._libevloop = LibevLoop() + else: + if cls._libevloop._pid != os.getpid(): + log.debug("Detected fork, clearing and reinitializing reactor state") + cls.handle_fork() + cls._libevloop = LibevLoop() + + @classmethod + def handle_fork(cls): + if cls._libevloop: + cls._libevloop._cleanup() + cls._libevloop = None + + @classmethod + def factory(cls, *args, **kwargs): + timeout = kwargs.pop('timeout', 5.0) + conn = cls(*args, **kwargs) + conn.connected_event.wait(timeout) + if conn.last_error: + raise conn.last_error + elif not conn.connected_event.is_set(): + conn.close() + raise OperationTimedOut("Timed out creating new connection") + else: + return conn + + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) + + self.connected_event = Event() + + self._callbacks = {} + self.deque = deque() + self._deque_lock = Lock() + + sockerr = None + addresses = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) + for (af, socktype, proto, canonname, sockaddr) in addresses: + try: + self._socket = socket.socket(af, socktype, proto) + if self.ssl_options: + if not ssl: + raise Exception("This version of Python was not compiled with SSL support") + self._socket = ssl.wrap_socket(self._socket, **self.ssl_options) + self._socket.settimeout(1.0) # TODO potentially make this value configurable + self._socket.connect(sockaddr) + sockerr = None + break + except socket.error as err: + sockerr = err + if sockerr: + raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % ([a[4] for a in addresses], sockerr.strerror)) + + self._socket.setblocking(0) + + if self.sockopts: + for args in self.sockopts: + self._socket.setsockopt(*args) + + with self._libevloop._lock: + self._read_watcher = libev.IO(self._socket.fileno(), libev.EV_READ, self._libevloop._loop, self.handle_read) + self._write_watcher = libev.IO(self._socket.fileno(), libev.EV_WRITE, self._libevloop._loop, self.handle_write) + + self._send_options_message() + + self._libevloop.connection_created(self) + + # start the global event loop if needed + self._libevloop.maybe_start() + + def close(self): + with self.lock: + if self.is_closed: + return + self.is_closed = True + + log.debug("Closing connection (%s) to %s", id(self), self.host) + self._libevloop.connection_destroyed(self) + self._socket.close() + log.debug("Closed socket to %s", self.host) + + # don't leave in-progress operations hanging + if not self.is_defunct: + self.error_all_callbacks( + ConnectionShutdown("Connection to %s was closed" % self.host)) + + def handle_write(self, watcher, revents, errno=None): + if revents & libev.EV_ERROR: + if errno: + exc = IOError(errno, os.strerror(errno)) + else: + exc = Exception("libev reported an error") + + self.defunct(exc) + return + + while True: + try: + with self._deque_lock: + next_msg = self.deque.popleft() + except IndexError: + return + + try: + sent = self._socket.send(next_msg) + except socket.error as err: + if (err.args[0] in NONBLOCKING): + with self._deque_lock: + self.deque.appendleft(next_msg) + else: + self.defunct(err) + return + else: + if sent < len(next_msg): + with self._deque_lock: + self.deque.appendleft(next_msg[sent:]) + + def handle_read(self, watcher, revents, errno=None): + if revents & libev.EV_ERROR: + if errno: + exc = IOError(errno, os.strerror(errno)) + else: + exc = Exception("libev reported an error") + + self.defunct(exc) + return + try: + while True: + buf = self._socket.recv(self.in_buffer_size) + self._iobuf.write(buf) + if len(buf) < self.in_buffer_size: + break + except socket.error as err: + if ssl and isinstance(err, ssl.SSLError): + if err.args[0] not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + self.defunct(err) + return + elif err.args[0] not in NONBLOCKING: + self.defunct(err) + return + + if self._iobuf.tell(): + self.process_io_buffer() + else: + log.debug("Connection %s closed by server", self) + self.close() + + def push(self, data): + sabs = self.out_buffer_size + if len(data) > sabs: + chunks = [] + for i in xrange(0, len(data), sabs): + chunks.append(data[i:i + sabs]) + else: + chunks = [data] + + with self._deque_lock: + self.deque.extend(chunks) + self._libevloop.notify() + + def register_watcher(self, event_type, callback, register_timeout=None): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), timeout=register_timeout) diff --git a/cassandra/io/libevwrapper.c b/cassandra/io/libevwrapper.c new file mode 100644 index 0000000..cbac83b --- /dev/null +++ b/cassandra/io/libevwrapper.c @@ -0,0 +1,542 @@ +#include +#include + +typedef struct libevwrapper_Loop { + PyObject_HEAD + struct ev_loop *loop; +} libevwrapper_Loop; + +static void +Loop_dealloc(libevwrapper_Loop *self) { + ev_loop_destroy(self->loop); + Py_TYPE(self)->tp_free((PyObject *)self); +}; + +static PyObject* +Loop_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + libevwrapper_Loop *self; + + self = (libevwrapper_Loop *)type->tp_alloc(type, 0); + if (self != NULL) { + self->loop = ev_loop_new(EVBACKEND_SELECT); + if (!self->loop) { + PyErr_SetString(PyExc_Exception, "Error getting new ev loop"); + Py_DECREF(self); + return NULL; + } + } + return (PyObject *)self; +}; + +static int +Loop_init(libevwrapper_Loop *self, PyObject *args, PyObject *kwds) { + if (!PyArg_ParseTuple(args, "")) { + PyErr_SetString(PyExc_TypeError, "Loop.__init__() takes no arguments"); + return -1; + } + return 0; +}; + +static PyObject * +Loop_start(libevwrapper_Loop *self, PyObject *args) { + Py_BEGIN_ALLOW_THREADS + ev_run(self->loop, 0); + Py_END_ALLOW_THREADS + Py_RETURN_NONE; +}; + +static PyObject * +Loop_unref(libevwrapper_Loop *self, PyObject *args) { + ev_unref(self->loop); + Py_RETURN_NONE; +} + +static PyMethodDef Loop_methods[] = { + {"start", (PyCFunction)Loop_start, METH_NOARGS, "Start the event loop"}, + {"unref", (PyCFunction)Loop_unref, METH_NOARGS, "Unrefrence the event loop"}, + {NULL} /* Sentinel */ +}; + +static +PyTypeObject libevwrapper_LoopType = { + PyVarObject_HEAD_INIT(NULL, 0) + "cassandra.io.libevwrapper.Loop",/*tp_name*/ + sizeof(libevwrapper_Loop), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)Loop_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + "Loop objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Loop_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)Loop_init, /* tp_init */ + 0, /* tp_alloc */ + Loop_new, /* tp_new */ +}; + +typedef struct libevwrapper_IO { + PyObject_HEAD + struct ev_io io; + struct libevwrapper_Loop *loop; + PyObject *callback; +} libevwrapper_IO; + +static void +IO_dealloc(libevwrapper_IO *self) { + Py_XDECREF(self->loop); + Py_XDECREF(self->callback); + Py_TYPE(self)->tp_free((PyObject *)self); +}; + +static void io_callback(struct ev_loop *loop, ev_io *watcher, int revents) { + libevwrapper_IO *self = watcher->data; + PyObject *result; + PyGILState_STATE gstate = PyGILState_Ensure(); + if (revents & EV_ERROR && errno) { + result = PyObject_CallFunction(self->callback, "Obi", self, revents, errno); + } else { + result = PyObject_CallFunction(self->callback, "Ob", self, revents); + } + if (!result) { + PyErr_WriteUnraisable(self->callback); + } + Py_XDECREF(result); + PyGILState_Release(gstate); +}; + +static int +IO_init(libevwrapper_IO *self, PyObject *args, PyObject *kwds) { + PyObject *socket; + PyObject *callback; + PyObject *loop; + int io_flags = 0, fd = -1; + struct ev_io *io = NULL; + + if (!PyArg_ParseTuple(args, "OiOO", &socket, &io_flags, &loop, &callback)) { + return -1; + } + + if (loop) { + Py_INCREF(loop); + self->loop = (libevwrapper_Loop *)loop; + } + + if (callback) { + if (!PyCallable_Check(callback)) { + PyErr_SetString(PyExc_TypeError, "callback parameter must be callable"); + Py_XDECREF(loop); + return -1; + } + Py_INCREF(callback); + self->callback = callback; + } + + fd = PyObject_AsFileDescriptor(socket); + if (fd == -1) { + PyErr_SetString(PyExc_TypeError, "unable to get file descriptor from socket"); + Py_XDECREF(callback); + Py_XDECREF(loop); + return -1; + } + io = &(self->io); + ev_io_init(io, io_callback, fd, io_flags); + self->io.data = self; + return 0; +} + +static PyObject* +IO_start(libevwrapper_IO *self, PyObject *args) { + ev_io_start(self->loop->loop, &self->io); + Py_RETURN_NONE; +} + +static PyObject* +IO_stop(libevwrapper_IO *self, PyObject *args) { + ev_io_stop(self->loop->loop, &self->io); + Py_RETURN_NONE; +} + +static PyObject* +IO_is_active(libevwrapper_IO *self, PyObject *args) { + struct ev_io *io = &(self->io); + return PyBool_FromLong(ev_is_active(io)); +} + +static PyObject* +IO_is_pending(libevwrapper_IO *self, PyObject *args) { + struct ev_io *io = &(self->io); + return PyBool_FromLong(ev_is_pending(io)); +} + +static PyMethodDef IO_methods[] = { + {"start", (PyCFunction)IO_start, METH_NOARGS, "Start the watcher"}, + {"stop", (PyCFunction)IO_stop, METH_NOARGS, "Stop the watcher"}, + {"is_active", (PyCFunction)IO_is_active, METH_NOARGS, "Is the watcher active?"}, + {"is_pending", (PyCFunction)IO_is_pending, METH_NOARGS, "Is the watcher pending?"}, + {NULL} /* Sentinal */ +}; + +static PyTypeObject libevwrapper_IOType = { + PyVarObject_HEAD_INIT(NULL, 0) + "cassandra.io.libevwrapper.IO", /*tp_name*/ + sizeof(libevwrapper_IO), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)IO_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + "IO objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + IO_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)IO_init, /* tp_init */ +}; + +typedef struct libevwrapper_Async { + PyObject_HEAD + struct ev_async async; + struct libevwrapper_Loop *loop; +} libevwrapper_Async; + +static void +Async_dealloc(libevwrapper_Async *self) { + Py_XDECREF(self->loop); + Py_TYPE(self)->tp_free((PyObject *)self); +}; + +static void async_callback(EV_P_ ev_async *watcher, int revents) {}; + +static int +Async_init(libevwrapper_Async *self, PyObject *args, PyObject *kwds) { + PyObject *loop; + static char *kwlist[] = {"loop", NULL}; + struct ev_async *async = NULL; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &loop)) { + PyErr_SetString(PyExc_TypeError, "unable to get file descriptor from socket"); + return -1; + } + + if (loop) { + Py_INCREF(loop); + self->loop = (libevwrapper_Loop *)loop; + } else { + return -1; + } + async = &(self->async); + ev_async_init(async, async_callback); + return 0; +}; + +static PyObject * +Async_start(libevwrapper_Async *self, PyObject *args) { + ev_async_start(self->loop->loop, &self->async); + Py_RETURN_NONE; +} + +static PyObject * +Async_send(libevwrapper_Async *self, PyObject *args) { + ev_async_send(self->loop->loop, &self->async); + Py_RETURN_NONE; +}; + +static PyMethodDef Async_methods[] = { + {"start", (PyCFunction)Async_start, METH_NOARGS, "Start the watcher"}, + {"send", (PyCFunction)Async_send, METH_NOARGS, "Notify the event loop"}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject libevwrapper_AsyncType = { + PyVarObject_HEAD_INIT(NULL, 0) + "cassandra.io.libevwrapper.Async", /*tp_name*/ + sizeof(libevwrapper_Async), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)Async_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + "Async objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Async_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)Async_init, /* tp_init */ +}; + +typedef struct libevwrapper_Prepare { + PyObject_HEAD + struct ev_prepare prepare; + struct libevwrapper_Loop *loop; + PyObject *callback; +} libevwrapper_Prepare; + +static void +Prepare_dealloc(libevwrapper_Prepare *self) { + Py_XDECREF(self->loop); + Py_XDECREF(self->callback); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static void prepare_callback(struct ev_loop *loop, ev_prepare *watcher, int revents) { + libevwrapper_Prepare *self = watcher->data; + PyObject *result = NULL; + PyGILState_STATE gstate; + + gstate = PyGILState_Ensure(); + result = PyObject_CallFunction(self->callback, "O", self); + if (!result) { + PyErr_WriteUnraisable(self->callback); + } + Py_XDECREF(result); + + PyGILState_Release(gstate); +} + +static int +Prepare_init(libevwrapper_Prepare *self, PyObject *args, PyObject *kwds) { + PyObject *callback; + PyObject *loop; + struct ev_prepare *prepare = NULL; + + if (!PyArg_ParseTuple(args, "OO", &loop, &callback)) { + return -1; + } + + if (loop) { + Py_INCREF(loop); + self->loop = (libevwrapper_Loop *)loop; + } else { + return -1; + } + + if (callback) { + if (!PyCallable_Check(callback)) { + PyErr_SetString(PyExc_TypeError, "callback parameter must be callable"); + Py_XDECREF(loop); + return -1; + } + Py_INCREF(callback); + self->callback = callback; + } + prepare = &(self->prepare); + ev_prepare_init(prepare, prepare_callback); + self->prepare.data = self; + return 0; +} + +static PyObject * +Prepare_start(libevwrapper_Prepare *self, PyObject *args) { + ev_prepare_start(self->loop->loop, &self->prepare); + Py_RETURN_NONE; +} + +static PyObject * +Prepare_stop(libevwrapper_Prepare *self, PyObject *args) { + ev_prepare_stop(self->loop->loop, &self->prepare); + Py_RETURN_NONE; +} + +static PyMethodDef Prepare_methods[] = { + {"start", (PyCFunction)Prepare_start, METH_NOARGS, "Start the Prepare watcher"}, + {"stop", (PyCFunction)Prepare_stop, METH_NOARGS, "Stop the Prepare watcher"}, + {NULL} /* Sentinal */ +}; + +static PyTypeObject libevwrapper_PrepareType = { + PyVarObject_HEAD_INIT(NULL, 0) + "cassandra.io.libevwrapper.Prepare", /*tp_name*/ + sizeof(libevwrapper_Prepare), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)Prepare_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + "Prepare objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Prepare_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)Prepare_init, /* tp_init */ +}; + +static PyMethodDef module_methods[] = { + {NULL} /* Sentinal */ +}; + +PyDoc_STRVAR(module_doc, +"libev wrapper methods"); + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "libevwrapper", + module_doc, + -1, + module_methods, + NULL, + NULL, + NULL, + NULL +}; + +#define INITERROR return NULL + +PyObject * +PyInit_libevwrapper(void) + +# else +# define INITERROR return + +void +initlibevwrapper(void) +#endif +{ + PyObject *module = NULL; + + if (PyType_Ready(&libevwrapper_LoopType) < 0) + INITERROR; + + libevwrapper_IOType.tp_new = PyType_GenericNew; + if (PyType_Ready(&libevwrapper_IOType) < 0) + INITERROR; + + libevwrapper_PrepareType.tp_new = PyType_GenericNew; + if (PyType_Ready(&libevwrapper_PrepareType) < 0) + INITERROR; + + libevwrapper_AsyncType.tp_new = PyType_GenericNew; + if (PyType_Ready(&libevwrapper_AsyncType) < 0) + INITERROR; + +# if PY_MAJOR_VERSION >= 3 + module = PyModule_Create(&moduledef); +# else + module = Py_InitModule3("libevwrapper", module_methods, module_doc); +# endif + + if (module == NULL) + INITERROR; + + if (PyModule_AddIntConstant(module, "EV_READ", EV_READ) == -1) + INITERROR; + if (PyModule_AddIntConstant(module, "EV_WRITE", EV_WRITE) == -1) + INITERROR; + if (PyModule_AddIntConstant(module, "EV_ERROR", EV_ERROR) == -1) + INITERROR; + + Py_INCREF(&libevwrapper_LoopType); + if (PyModule_AddObject(module, "Loop", (PyObject *)&libevwrapper_LoopType) == -1) + INITERROR; + + Py_INCREF(&libevwrapper_IOType); + if (PyModule_AddObject(module, "IO", (PyObject *)&libevwrapper_IOType) == -1) + INITERROR; + + Py_INCREF(&libevwrapper_PrepareType); + if (PyModule_AddObject(module, "Prepare", (PyObject *)&libevwrapper_PrepareType) == -1) + INITERROR; + + Py_INCREF(&libevwrapper_AsyncType); + if (PyModule_AddObject(module, "Async", (PyObject *)&libevwrapper_AsyncType) == -1) + INITERROR; + + if (!PyEval_ThreadsInitialized()) { + PyEval_InitThreads(); + } + +#if PY_MAJOR_VERSION >= 3 + return module; +#endif +} diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py new file mode 100644 index 0000000..1a5a64e --- /dev/null +++ b/cassandra/io/twistedreactor.py @@ -0,0 +1,259 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Module that implements an event loop based on twisted +( https://twistedmatrix.com ). +""" +from twisted.internet import reactor, protocol +from threading import Event, Thread, Lock +from functools import partial +import logging +import weakref +import atexit + +from cassandra import OperationTimedOut +from cassandra.connection import Connection, ConnectionShutdown +from cassandra.protocol import RegisterMessage + + +log = logging.getLogger(__name__) + + +def _cleanup(cleanup_weakref): + try: + cleanup_weakref()._cleanup() + except ReferenceError: + return + + +class TwistedConnectionProtocol(protocol.Protocol): + """ + Twisted Protocol class for handling data received and connection + made events. + """ + + def dataReceived(self, data): + """ + Callback function that is called when data has been received + on the connection. + + Reaches back to the Connection object and queues the data for + processing. + """ + self.transport.connector.factory.conn._iobuf.write(data) + self.transport.connector.factory.conn.handle_read() + + def connectionMade(self): + """ + Callback function that is called when a connection has succeeded. + + Reaches back to the Connection object and confirms that the connection + is ready. + """ + self.transport.connector.factory.conn.client_connection_made() + + def connectionLost(self, reason): + # reason is a Failure instance + self.transport.connector.factory.conn.defunct(reason.value) + + +class TwistedConnectionClientFactory(protocol.ClientFactory): + + def __init__(self, connection): + # ClientFactory does not define __init__() in parent classes + # and does not inherit from object. + self.conn = connection + + def buildProtocol(self, addr): + """ + Twisted function that defines which kind of protocol to use + in the ClientFactory. + """ + return TwistedConnectionProtocol() + + def clientConnectionFailed(self, connector, reason): + """ + Overridden twisted callback which is called when the + connection attempt fails. + """ + log.debug("Connect failed: %s", reason) + self.conn.defunct(reason.value) + + def clientConnectionLost(self, connector, reason): + """ + Overridden twisted callback which is called when the + connection goes away (cleanly or otherwise). + + It should be safe to call defunct() here instead of just close, because + we can assume that if the connection was closed cleanly, there are no + callbacks to error out. If this assumption turns out to be false, we + can call close() instead of defunct() when "reason" is an appropriate + type. + """ + log.debug("Connect lost: %s", reason) + self.conn.defunct(reason.value) + + +class TwistedLoop(object): + + _lock = None + _thread = None + + def __init__(self): + self._lock = Lock() + + def maybe_start(self): + with self._lock: + if not reactor.running: + self._thread = Thread(target=reactor.run, + name="cassandra_driver_event_loop", + kwargs={'installSignalHandlers': False}) + self._thread.daemon = True + self._thread.start() + atexit.register(partial(_cleanup, weakref.ref(self))) + + def _cleanup(self): + if self._thread: + reactor.callFromThread(reactor.stop) + self._thread.join(timeout=1.0) + if self._thread.is_alive(): + log.warning("Event loop thread could not be joined, so " + "shutdown may not be clean. Please call " + "Cluster.shutdown() to avoid this.") + log.debug("Event loop thread was joined") + + +class TwistedConnection(Connection): + """ + An implementation of :class:`.Connection` that utilizes the + Twisted event loop. + """ + + _loop = None + _total_reqd_bytes = 0 + + @classmethod + def initialize_reactor(cls): + if not cls._loop: + cls._loop = TwistedLoop() + + @classmethod + def factory(cls, *args, **kwargs): + """ + A factory function which returns connections which have + succeeded in connecting and are ready for service (or + raises an exception otherwise). + """ + timeout = kwargs.pop('timeout', 5.0) + conn = cls(*args, **kwargs) + conn.connected_event.wait(timeout) + if conn.last_error: + raise conn.last_error + elif not conn.connected_event.is_set(): + conn.close() + raise OperationTimedOut("Timed out creating connection") + else: + return conn + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + Note that we can't call reactor methods directly here because + it's not thread-safe, so we schedule the reactor/connection + stuff to be run from the event loop thread when it gets the + chance. + """ + Connection.__init__(self, *args, **kwargs) + + self.connected_event = Event() + self.is_closed = True + self.connector = None + + self._callbacks = {} + reactor.callFromThread(self.add_connection) + self._loop.maybe_start() + + def add_connection(self): + """ + Convenience function to connect and store the resulting + connector. + """ + self.connector = reactor.connectTCP( + host=self.host, port=self.port, + factory=TwistedConnectionClientFactory(self)) + + def client_connection_made(self): + """ + Called by twisted protocol when a connection attempt has + succeeded. + """ + with self.lock: + self.is_closed = False + self._send_options_message() + + def close(self): + """ + Disconnect and error-out all callbacks. + """ + with self.lock: + if self.is_closed: + return + self.is_closed = True + + log.debug("Closing connection (%s) to %s", id(self), self.host) + self.connector.disconnect() + log.debug("Closed socket to %s", self.host) + + if not self.is_defunct: + self.error_all_callbacks( + ConnectionShutdown("Connection to %s was closed" % self.host)) + # don't leave in-progress operations hanging + self.connected_event.set() + + def handle_read(self): + """ + Process the incoming data buffer. + """ + self.process_io_buffer() + + def push(self, data): + """ + This function is called when outgoing data should be queued + for sending. + + Note that we can't call transport.write() directly because + it is not thread-safe, so we schedule it to run from within + the event loop when it gets the chance. + """ + reactor.callFromThread(self.connector.transport.write, data) + + def register_watcher(self, event_type, callback, register_timeout=None): + """ + Register a callback for a given event type. + """ + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=[event_type]), + timeout=register_timeout) + + def register_watchers(self, type_callback_dict, register_timeout=None): + """ + Register multiple callback/event type pairs, expressed as a dict. + """ + for event_type, callback in type_callback_dict.items(): + self._push_watchers[event_type].add(callback) + self.wait_for_response( + RegisterMessage(event_list=type_callback_dict.keys()), + timeout=register_timeout) diff --git a/cassandra/marshal.py b/cassandra/marshal.py new file mode 100644 index 0000000..6451ab0 --- /dev/null +++ b/cassandra/marshal.py @@ -0,0 +1,84 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import struct + + +def _make_packer(format_string): + packer = struct.Struct(format_string) + pack = packer.pack + unpack = lambda s: packer.unpack(s)[0] + return pack, unpack + +int64_pack, int64_unpack = _make_packer('>q') +int32_pack, int32_unpack = _make_packer('>i') +int16_pack, int16_unpack = _make_packer('>h') +int8_pack, int8_unpack = _make_packer('>b') +uint64_pack, uint64_unpack = _make_packer('>Q') +uint32_pack, uint32_unpack = _make_packer('>I') +uint16_pack, uint16_unpack = _make_packer('>H') +uint8_pack, uint8_unpack = _make_packer('>B') +float_pack, float_unpack = _make_packer('>f') +double_pack, double_unpack = _make_packer('>d') + +# Special case for cassandra header +header_struct = struct.Struct('>BBbB') +header_pack = header_struct.pack +header_unpack = header_struct.unpack + +# in protocol version 3 and higher, the stream ID is two bytes +v3_header_struct = struct.Struct('>BBhB') +v3_header_pack = v3_header_struct.pack +v3_header_unpack = v3_header_struct.unpack + + +if six.PY3: + def varint_unpack(term): + val = int(''.join("%02x" % i for i in term), 16) + if (term[0] & 128) != 0: + val -= 1 << (len(term) * 8) + return val +else: + def varint_unpack(term): # noqa + val = int(term.encode('hex'), 16) + if (ord(term[0]) & 128) != 0: + val = val - (1 << (len(term) * 8)) + return val + + +def bitlength(n): + bitlen = 0 + while n > 0: + n >>= 1 + bitlen += 1 + return bitlen + + +def varint_pack(big): + pos = True + if big == 0: + return b'\x00' + if big < 0: + bytelength = bitlength(abs(big) - 1) // 8 + 1 + big = (1 << bytelength * 8) + big + pos = False + revbytes = bytearray() + while big > 0: + revbytes.append(big & 0xff) + big >>= 8 + if pos and revbytes[-1] & 0x80: + revbytes.append(0) + revbytes.reverse() + return six.binary_type(revbytes) diff --git a/cassandra/metadata.py b/cassandra/metadata.py new file mode 100644 index 0000000..8ca84fc --- /dev/null +++ b/cassandra/metadata.py @@ -0,0 +1,1436 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from bisect import bisect_right +from collections import defaultdict +from hashlib import md5 +from itertools import islice, cycle +import json +import logging +import re +from threading import RLock +import six + +murmur3 = None +try: + from cassandra.murmur3 import murmur3 +except ImportError as e: + pass + +import cassandra.cqltypes as types +from cassandra.marshal import varint_unpack +from cassandra.util import OrderedDict + +log = logging.getLogger(__name__) + +_keywords = set(( + 'select', 'from', 'where', 'and', 'key', 'insert', 'update', 'with', + 'limit', 'using', 'use', 'count', 'set', + 'begin', 'apply', 'batch', 'truncate', 'delete', 'in', 'create', + 'keyspace', 'schema', 'columnfamily', 'table', 'index', 'on', 'drop', + 'primary', 'into', 'values', 'timestamp', 'ttl', 'alter', 'add', 'type', + 'compact', 'storage', 'order', 'by', 'asc', 'desc', 'clustering', + 'token', 'writetime', 'map', 'list', 'to' +)) + +_unreserved_keywords = set(( + 'key', 'clustering', 'ttl', 'compact', 'storage', 'type', 'values' +)) + + +class Metadata(object): + """ + Holds a representation of the cluster schema and topology. + """ + + cluster_name = None + """ The string name of the cluster. """ + + keyspaces = None + """ + A map from keyspace names to matching :class:`~.KeyspaceMetadata` instances. + """ + + partitioner = None + """ + The string name of the partitioner for the cluster. + """ + + token_map = None + """ A :class:`~.TokenMap` instance describing the ring topology. """ + + def __init__(self): + self.keyspaces = {} + self._hosts = {} + self._hosts_lock = RLock() + + def export_schema_as_string(self): + """ + Returns a string that can be executed as a query in order to recreate + the entire schema. The string is formatted to be human readable. + """ + return "\n".join(ks.export_as_string() for ks in self.keyspaces.values()) + + def rebuild_schema(self, ks_results, type_results, cf_results, col_results, triggers_result): + """ + Rebuild the view of the current schema from a fresh set of rows from + the system schema tables. + + For internal use only. + """ + cf_def_rows = defaultdict(list) + col_def_rows = defaultdict(lambda: defaultdict(list)) + usertype_rows = defaultdict(list) + trigger_rows = defaultdict(lambda: defaultdict(list)) + + for row in cf_results: + cf_def_rows[row["keyspace_name"]].append(row) + + for row in col_results: + ksname = row["keyspace_name"] + cfname = row["columnfamily_name"] + col_def_rows[ksname][cfname].append(row) + + for row in type_results: + usertype_rows[row["keyspace_name"]].append(row) + + for row in triggers_result: + ksname = row["keyspace_name"] + cfname = row["columnfamily_name"] + trigger_rows[ksname][cfname].append(row) + + current_keyspaces = set() + for row in ks_results: + keyspace_meta = self._build_keyspace_metadata(row) + keyspace_col_rows = col_def_rows.get(keyspace_meta.name, {}) + keyspace_trigger_rows = trigger_rows.get(keyspace_meta.name, {}) + for table_row in cf_def_rows.get(keyspace_meta.name, []): + table_meta = self._build_table_metadata( + keyspace_meta, table_row, keyspace_col_rows, + keyspace_trigger_rows) + keyspace_meta.tables[table_meta.name] = table_meta + + for usertype_row in usertype_rows.get(keyspace_meta.name, []): + usertype = self._build_usertype(keyspace_meta.name, usertype_row) + keyspace_meta.user_types[usertype.name] = usertype + + current_keyspaces.add(keyspace_meta.name) + old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None) + self.keyspaces[keyspace_meta.name] = keyspace_meta + if old_keyspace_meta: + self._keyspace_updated(keyspace_meta.name) + else: + self._keyspace_added(keyspace_meta.name) + + # remove not-just-added keyspaces + removed_keyspaces = [ksname for ksname in self.keyspaces.keys() + if ksname not in current_keyspaces] + self.keyspaces = dict((name, meta) for name, meta in self.keyspaces.items() + if name in current_keyspaces) + for ksname in removed_keyspaces: + self._keyspace_removed(ksname) + + def keyspace_changed(self, keyspace, ks_results): + if not ks_results: + if keyspace in self.keyspaces: + del self.keyspaces[keyspace] + self._keyspace_removed(keyspace) + return + + keyspace_meta = self._build_keyspace_metadata(ks_results[0]) + old_keyspace_meta = self.keyspaces.get(keyspace, None) + self.keyspaces[keyspace] = keyspace_meta + if old_keyspace_meta: + keyspace_meta.tables = old_keyspace_meta.tables + keyspace_meta.user_types = old_keyspace_meta.user_types + if (keyspace_meta.replication_strategy != old_keyspace_meta.replication_strategy): + self._keyspace_updated(keyspace) + else: + self._keyspace_added(keyspace) + + def usertype_changed(self, keyspace, name, type_results): + if type_results: + new_usertype = self._build_usertype(keyspace, type_results[0]) + self.keyspaces[keyspace].user_types[name] = new_usertype + else: + # the type was deleted + self.keyspaces[keyspace].user_types.pop(name, None) + + def table_changed(self, keyspace, table, cf_results, col_results, triggers_result): + try: + keyspace_meta = self.keyspaces[keyspace] + except KeyError: + # we're trying to update a table in a keyspace we don't know about + log.error("Tried to update schema for table '%s' in unknown keyspace '%s'", + table, keyspace) + return + + if not cf_results: + # the table was removed + keyspace_meta.tables.pop(table, None) + else: + assert len(cf_results) == 1 + keyspace_meta.tables[table] = self._build_table_metadata( + keyspace_meta, cf_results[0], {table: col_results}, + {table: triggers_result}) + + def _keyspace_added(self, ksname): + if self.token_map: + self.token_map.rebuild_keyspace(ksname, build_if_absent=False) + + def _keyspace_updated(self, ksname): + if self.token_map: + self.token_map.rebuild_keyspace(ksname, build_if_absent=False) + + def _keyspace_removed(self, ksname): + if self.token_map: + self.token_map.remove_keyspace(ksname) + + def _build_keyspace_metadata(self, row): + name = row["keyspace_name"] + durable_writes = row["durable_writes"] + strategy_class = row["strategy_class"] + strategy_options = json.loads(row["strategy_options"]) + return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) + + def _build_usertype(self, keyspace, usertype_row): + type_classes = list(map(types.lookup_casstype, usertype_row['field_types'])) + return UserType(usertype_row['keyspace_name'], usertype_row['type_name'], + usertype_row['field_names'], type_classes) + + def _build_table_metadata(self, keyspace_metadata, row, col_rows, trigger_rows): + cfname = row["columnfamily_name"] + cf_col_rows = col_rows.get(cfname, []) + + if not cf_col_rows: # CASSANDRA-8487 + log.warning("Building table metadata with no column meta for %s.%s", + keyspace_metadata.name, cfname) + + comparator = types.lookup_casstype(row["comparator"]) + + if issubclass(comparator, types.CompositeType): + column_name_types = comparator.subtypes + is_composite_comparator = True + else: + column_name_types = (comparator,) + is_composite_comparator = False + + num_column_name_components = len(column_name_types) + last_col = column_name_types[-1] + + column_aliases = row.get("column_aliases", None) + + clustering_rows = [r for r in cf_col_rows + if r.get('type', None) == "clustering_key"] + if len(clustering_rows) > 1: + clustering_rows = sorted(clustering_rows, key=lambda row: row.get('component_index')) + + if column_aliases is not None: + column_aliases = json.loads(column_aliases) + else: + column_aliases = [r.get('column_name') for r in clustering_rows] + + if is_composite_comparator: + if issubclass(last_col, types.ColumnToCollectionType): + # collections + is_compact = False + has_value = False + clustering_size = num_column_name_components - 2 + elif (len(column_aliases) == num_column_name_components - 1 + and issubclass(last_col, types.UTF8Type)): + # aliases? + is_compact = False + has_value = False + clustering_size = num_column_name_components - 1 + else: + # compact table + is_compact = True + has_value = column_aliases or not cf_col_rows + clustering_size = num_column_name_components + + # Some thrift tables define names in composite types (see PYTHON-192) + if not column_aliases and hasattr(comparator, 'fieldnames'): + column_aliases = comparator.fieldnames + else: + is_compact = True + if column_aliases or not cf_col_rows: + has_value = True + clustering_size = num_column_name_components + else: + has_value = False + clustering_size = 0 + + table_meta = TableMetadata(keyspace_metadata, cfname) + table_meta.comparator = comparator + + # partition key + partition_rows = [r for r in cf_col_rows + if r.get('type', None) == "partition_key"] + + if len(partition_rows) > 1: + partition_rows = sorted(partition_rows, key=lambda row: row.get('component_index')) + + key_aliases = row.get("key_aliases") + if key_aliases is not None: + key_aliases = json.loads(key_aliases) if key_aliases else [] + else: + # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. + key_aliases = [r.get('column_name') for r in partition_rows] + + key_validator = row.get("key_validator") + if key_validator is not None: + key_type = types.lookup_casstype(key_validator) + key_types = key_type.subtypes if issubclass(key_type, types.CompositeType) else [key_type] + else: + key_types = [types.lookup_casstype(r.get('validator')) for r in partition_rows] + + for i, col_type in enumerate(key_types): + if len(key_aliases) > i: + column_name = key_aliases[i] + elif i == 0: + column_name = "key" + else: + column_name = "key%d" % i + + col = ColumnMetadata(table_meta, column_name, col_type) + table_meta.columns[column_name] = col + table_meta.partition_key.append(col) + + # clustering key + for i in range(clustering_size): + if len(column_aliases) > i: + column_name = column_aliases[i] + else: + column_name = "column%d" % i + + col = ColumnMetadata(table_meta, column_name, column_name_types[i]) + table_meta.columns[column_name] = col + table_meta.clustering_key.append(col) + + # value alias (if present) + if has_value: + value_alias_rows = [r for r in cf_col_rows + if r.get('type', None) == "compact_value"] + + if not key_aliases: # TODO are we checking the right thing here? + value_alias = "value" + else: + value_alias = row.get("value_alias", None) + if value_alias is None and value_alias_rows: # CASSANDRA-8487 + # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. + value_alias = value_alias_rows[0].get('column_name') + + default_validator = row.get("default_validator") + if default_validator: + validator = types.lookup_casstype(default_validator) + else: + if value_alias_rows: # CASSANDRA-8487 + validator = types.lookup_casstype(value_alias_rows[0].get('validator')) + + col = ColumnMetadata(table_meta, value_alias, validator) + if value_alias: # CASSANDRA-8487 + table_meta.columns[value_alias] = col + + # other normal columns + for col_row in cf_col_rows: + column_meta = self._build_column_metadata(table_meta, col_row) + table_meta.columns[column_meta.name] = column_meta + + if trigger_rows: + for trigger_row in trigger_rows[cfname]: + trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) + table_meta.triggers[trigger_meta.name] = trigger_meta + + table_meta.options = self._build_table_options(row) + table_meta.is_compact_storage = is_compact + + return table_meta + + def _build_table_options(self, row): + """ Setup the mostly-non-schema table options, like caching settings """ + options = dict((o, row.get(o)) for o in TableMetadata.recognized_options if o in row) + + # the option name when creating tables is "dclocal_read_repair_chance", + # but the column name in system.schema_columnfamilies is + # "local_read_repair_chance". We'll store this as dclocal_read_repair_chance, + # since that's probably what users are expecting (and we need it for the + # CREATE TABLE statement anyway). + if "local_read_repair_chance" in options: + val = options.pop("local_read_repair_chance") + options["dclocal_read_repair_chance"] = val + + return options + + def _build_column_metadata(self, table_metadata, row): + name = row["column_name"] + data_type = types.lookup_casstype(row["validator"]) + is_static = row.get("type", None) == "static" + column_meta = ColumnMetadata(table_metadata, name, data_type, is_static=is_static) + index_meta = self._build_index_metadata(column_meta, row) + column_meta.index = index_meta + return column_meta + + def _build_index_metadata(self, column_metadata, row): + index_name = row.get("index_name") + index_type = row.get("index_type") + if index_name or index_type: + options = row.get("index_options") + index_options = json.loads(options) if options else {} + return IndexMetadata(column_metadata, index_name, index_type, index_options) + else: + return None + + def _build_trigger_metadata(self, table_metadata, row): + name = row["trigger_name"] + options = row["trigger_options"] + trigger_meta = TriggerMetadata(table_metadata, name, options) + return trigger_meta + + def rebuild_token_map(self, partitioner, token_map): + """ + Rebuild our view of the topology from fresh rows from the + system topology tables. + For internal use only. + """ + self.partitioner = partitioner + if partitioner.endswith('RandomPartitioner'): + token_class = MD5Token + elif partitioner.endswith('Murmur3Partitioner'): + token_class = Murmur3Token + elif partitioner.endswith('ByteOrderedPartitioner'): + token_class = BytesToken + else: + self.token_map = None + return + + token_to_host_owner = {} + ring = [] + for host, token_strings in six.iteritems(token_map): + for token_string in token_strings: + token = token_class(token_string) + ring.append(token) + token_to_host_owner[token] = host + + all_tokens = sorted(ring) + self.token_map = TokenMap( + token_class, token_to_host_owner, all_tokens, self) + + def get_replicas(self, keyspace, key): + """ + Returns a list of :class:`.Host` instances that are replicas for a given + partition key. + """ + t = self.token_map + if not t: + return [] + try: + return t.get_replicas(keyspace, t.token_class.from_key(key)) + except NoMurmur3: + return [] + + def can_support_partitioner(self): + if self.partitioner.endswith('Murmur3Partitioner') and murmur3 is None: + return False + else: + return True + + def add_or_return_host(self, host): + """ + Returns a tuple (host, new), where ``host`` is a Host + instance, and ``new`` is a bool indicating whether + the host was newly added. + """ + with self._hosts_lock: + try: + return self._hosts[host.address], False + except KeyError: + self._hosts[host.address] = host + return host, True + + + def remove_host(self, host): + with self._hosts_lock: + return bool(self._hosts.pop(host.address, False)) + + def get_host(self, address): + return self._hosts.get(address) + + def all_hosts(self): + """ + Returns a list of all known :class:`.Host` instances in the cluster. + """ + with self._hosts_lock: + return self._hosts.values() + + +REPLICATION_STRATEGY_CLASS_PREFIX = "org.apache.cassandra.locator." + + +def trim_if_startswith(s, prefix): + if s.startswith(prefix): + return s[len(prefix):] + return s + + +_replication_strategies = {} + + +class ReplicationStrategyTypeType(type): + def __new__(metacls, name, bases, dct): + dct.setdefault('name', name) + cls = type.__new__(metacls, name, bases, dct) + if not name.startswith('_'): + _replication_strategies[name] = cls + return cls + + +@six.add_metaclass(ReplicationStrategyTypeType) +class _ReplicationStrategy(object): + options_map = None + + @classmethod + def create(cls, strategy_class, options_map): + if not strategy_class: + return None + + strategy_name = trim_if_startswith(strategy_class, REPLICATION_STRATEGY_CLASS_PREFIX) + + rs_class = _replication_strategies.get(strategy_name, None) + if rs_class is None: + rs_class = _UnknownStrategyBuilder(strategy_name) + _replication_strategies[strategy_name] = rs_class + + try: + rs_instance = rs_class(options_map) + except Exception as exc: + log.warning("Failed creating %s with options %s: %s", strategy_name, options_map, exc) + return None + + return rs_instance + + def make_token_replica_map(self, token_to_host_owner, ring): + raise NotImplementedError() + + def export_for_schema(self): + raise NotImplementedError() + + +ReplicationStrategy = _ReplicationStrategy + + +class _UnknownStrategyBuilder(object): + def __init__(self, name): + self.name = name + + def __call__(self, options_map): + strategy_instance = _UnknownStrategy(self.name, options_map) + return strategy_instance + + +class _UnknownStrategy(ReplicationStrategy): + def __init__(self, name, options_map): + self.name = name + self.options_map = options_map.copy() if options_map is not None else dict() + self.options_map['class'] = self.name + + def __eq__(self, other): + return (isinstance(other, _UnknownStrategy) + and self.name == other.name + and self.options_map == other.options_map) + + def export_for_schema(self): + """ + Returns a string version of these replication options which are + suitable for use in a CREATE KEYSPACE statement. + """ + if self.options_map: + return dict((str(key), str(value)) for key, value in self.options_map.items()) + return "{'class': '%s'}" % (self.name, ) + + def make_token_replica_map(self, token_to_host_owner, ring): + return {} + + +class SimpleStrategy(ReplicationStrategy): + + replication_factor = None + """ + The replication factor for this keyspace. + """ + + def __init__(self, options_map): + try: + self.replication_factor = int(options_map['replication_factor']) + except Exception: + raise ValueError("SimpleStrategy requires an integer 'replication_factor' option") + + def make_token_replica_map(self, token_to_host_owner, ring): + replica_map = {} + for i in range(len(ring)): + j, hosts = 0, list() + while len(hosts) < self.replication_factor and j < len(ring): + token = ring[(i + j) % len(ring)] + host = token_to_host_owner[token] + if host not in hosts: + hosts.append(host) + j += 1 + + replica_map[ring[i]] = hosts + return replica_map + + def export_for_schema(self): + """ + Returns a string version of these replication options which are + suitable for use in a CREATE KEYSPACE statement. + """ + return "{'class': 'SimpleStrategy', 'replication_factor': '%d'}" \ + % (self.replication_factor,) + + def __eq__(self, other): + if not isinstance(other, SimpleStrategy): + return False + + return self.replication_factor == other.replication_factor + + +class NetworkTopologyStrategy(ReplicationStrategy): + + dc_replication_factors = None + """ + A map of datacenter names to the replication factor for that DC. + """ + + def __init__(self, dc_replication_factors): + self.dc_replication_factors = dict( + (str(k), int(v)) for k, v in dc_replication_factors.items()) + + def make_token_replica_map(self, token_to_host_owner, ring): + # note: this does not account for hosts having different racks + replica_map = defaultdict(list) + ring_len = len(ring) + ring_len_range = range(ring_len) + dc_rf_map = dict((dc, int(rf)) + for dc, rf in self.dc_replication_factors.items() if rf > 0) + dcs = dict((h, h.datacenter) for h in set(token_to_host_owner.values())) + + # build a map of DCs to lists of indexes into `ring` for tokens that + # belong to that DC + dc_to_token_offset = defaultdict(list) + for i, token in enumerate(ring): + host = token_to_host_owner[token] + dc_to_token_offset[dcs[host]].append(i) + + # A map of DCs to an index into the dc_to_token_offset value for that dc. + # This is how we keep track of advancing around the ring for each DC. + dc_to_current_index = defaultdict(int) + + for i in ring_len_range: + remaining = dc_rf_map.copy() + replicas = replica_map[ring[i]] + + # go through each DC and find the replicas in that DC + for dc in dc_to_token_offset.keys(): + if dc not in remaining: + continue + + # advance our per-DC index until we're up to at least the + # current token in the ring + token_offsets = dc_to_token_offset[dc] + index = dc_to_current_index[dc] + num_tokens = len(token_offsets) + while index < num_tokens and token_offsets[index] < i: + index += 1 + dc_to_current_index[dc] = index + + # now add the next RF distinct token owners to the set of + # replicas for this DC + for token_offset in islice(cycle(token_offsets), index, index + num_tokens): + host = token_to_host_owner[ring[token_offset]] + if host in replicas: + continue + + replicas.append(host) + dc_remaining = remaining[dc] - 1 + if dc_remaining == 0: + del remaining[dc] + break + else: + remaining[dc] = dc_remaining + + return replica_map + + def export_for_schema(self): + """ + Returns a string version of these replication options which are + suitable for use in a CREATE KEYSPACE statement. + """ + ret = "{'class': 'NetworkTopologyStrategy'" + for dc, repl_factor in sorted(self.dc_replication_factors.items()): + ret += ", '%s': '%d'" % (dc, repl_factor) + return ret + "}" + + def __eq__(self, other): + if not isinstance(other, NetworkTopologyStrategy): + return False + + return self.dc_replication_factors == other.dc_replication_factors + + +class LocalStrategy(ReplicationStrategy): + def __init__(self, options_map): + pass + + def make_token_replica_map(self, token_to_host_owner, ring): + return {} + + def export_for_schema(self): + """ + Returns a string version of these replication options which are + suitable for use in a CREATE KEYSPACE statement. + """ + return "{'class': 'LocalStrategy'}" + + def __eq__(self, other): + return isinstance(other, LocalStrategy) + + +class KeyspaceMetadata(object): + """ + A representation of the schema for a single keyspace. + """ + + name = None + """ The string name of the keyspace. """ + + durable_writes = True + """ + A boolean indicating whether durable writes are enabled for this keyspace + or not. + """ + + replication_strategy = None + """ + A :class:`.ReplicationStrategy` subclass object. + """ + + tables = None + """ + A map from table names to instances of :class:`~.TableMetadata`. + """ + + user_types = None + """ + A map from user-defined type names to instances of :class:`~cassandra.metadata..UserType`. + + .. versionadded:: 2.1.0 + """ + + def __init__(self, name, durable_writes, strategy_class, strategy_options): + self.name = name + self.durable_writes = durable_writes + self.replication_strategy = ReplicationStrategy.create(strategy_class, strategy_options) + self.tables = {} + self.user_types = {} + + def export_as_string(self): + """ + Returns a CQL query string that can be used to recreate the entire keyspace, + including user-defined types and tables. + """ + return "\n\n".join([self.as_cql_query()] + self.user_type_strings() + [t.export_as_string() for t in self.tables.values()]) + + def as_cql_query(self): + """ + Returns a CQL query string that can be used to recreate just this keyspace, + not including user-defined types and tables. + """ + ret = "CREATE KEYSPACE %s WITH replication = %s " % ( + protect_name(self.name), + self.replication_strategy.export_for_schema()) + return ret + (' AND durable_writes = %s;' % ("true" if self.durable_writes else "false")) + + def user_type_strings(self): + user_type_strings = [] + types = self.user_types.copy() + keys = sorted(types.keys()) + for k in keys: + if k in types: + self.resolve_user_types(k, types, user_type_strings) + return user_type_strings + + def resolve_user_types(self, key, types, user_type_strings): + user_type = types.pop(key) + for field_type in user_type.field_types: + if field_type.cassname == 'UserType' and field_type.typename in types: + self.resolve_user_types(field_type.typename, types, user_type_strings) + user_type_strings.append(user_type.as_cql_query(formatted=True)) + + +class UserType(object): + """ + A user defined type, as created by ``CREATE TYPE`` statements. + + User-defined types were introduced in Cassandra 2.1. + + .. versionadded:: 2.1.0 + """ + + keyspace = None + """ + The string name of the keyspace in which this type is defined. + """ + + name = None + """ + The name of this type. + """ + + field_names = None + """ + An ordered list of the names for each field in this user-defined type. + """ + + field_types = None + """ + An ordered list of the types for each field in this user-defined type. + """ + + def __init__(self, keyspace, name, field_names, field_types): + self.keyspace = keyspace + self.name = name + self.field_names = field_names + self.field_types = field_types + + def as_cql_query(self, formatted=False): + """ + Returns a CQL query that can be used to recreate this type. + If `formatted` is set to :const:`True`, extra whitespace will + be added to make the query more readable. + """ + ret = "CREATE TYPE %s.%s (%s" % ( + protect_name(self.keyspace), + protect_name(self.name), + "\n" if formatted else "") + + if formatted: + field_join = ",\n" + padding = " " + else: + field_join = ", " + padding = "" + + fields = [] + for field_name, field_type in zip(self.field_names, self.field_types): + fields.append("%s %s" % (protect_name(field_name), field_type.cql_parameterized_type())) + + ret += field_join.join("%s%s" % (padding, field) for field in fields) + ret += "\n);" if formatted else ");" + return ret + + +class TableMetadata(object): + """ + A representation of the schema for a single table. + """ + + keyspace = None + """ An instance of :class:`~.KeyspaceMetadata`. """ + + name = None + """ The string name of the table. """ + + partition_key = None + """ + A list of :class:`.ColumnMetadata` instances representing the columns in + the partition key for this table. This will always hold at least one + column. + """ + + clustering_key = None + """ + A list of :class:`.ColumnMetadata` instances representing the columns + in the clustering key for this table. These are all of the + :attr:`.primary_key` columns that are not in the :attr:`.partition_key`. + + Note that a table may have no clustering keys, in which case this will + be an empty list. + """ + + @property + def primary_key(self): + """ + A list of :class:`.ColumnMetadata` representing the components of + the primary key for this table. + """ + return self.partition_key + self.clustering_key + + columns = None + """ + A dict mapping column names to :class:`.ColumnMetadata` instances. + """ + + is_compact_storage = False + + options = None + """ + A dict mapping table option names to their specific settings for this + table. + """ + + recognized_options = ( + "comment", + "read_repair_chance", + "dclocal_read_repair_chance", # kept to be safe, but see _build_table_options() + "local_read_repair_chance", + "replicate_on_write", + "gc_grace_seconds", + "bloom_filter_fp_chance", + "caching", + "compaction_strategy_class", + "compaction_strategy_options", + "min_compaction_threshold", + "max_compaction_threshold", + "compression_parameters", + "min_index_interval", + "max_index_interval", + "index_interval", + "speculative_retry", + "rows_per_partition_to_cache", + "memtable_flush_period_in_ms", + "populate_io_cache_on_flush", + "compaction", + "compression", + "default_time_to_live") + + compaction_options = { + "min_compaction_threshold": "min_threshold", + "max_compaction_threshold": "max_threshold", + "compaction_strategy_class": "class"} + + triggers = None + """ + A dict mapping trigger names to :class:`.TriggerMetadata` instances. + """ + + @property + def is_cql_compatible(self): + """ + A boolean indicating if this table can be represented as CQL in export + """ + # no such thing as DCT in CQL + incompatible = issubclass(self.comparator, types.DynamicCompositeType) + # no compact storage with more than one column beyond PK + incompatible |= self.is_compact_storage and len(self.columns) > len(self.primary_key) + 1 + + return not incompatible + + def __init__(self, keyspace_metadata, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None): + self.keyspace = keyspace_metadata + self.name = name + self.partition_key = [] if partition_key is None else partition_key + self.clustering_key = [] if clustering_key is None else clustering_key + self.columns = OrderedDict() if columns is None else columns + self.options = options + self.comparator = None + self.triggers = OrderedDict() if triggers is None else triggers + + def export_as_string(self): + """ + Returns a string of CQL queries that can be used to recreate this table + along with all indexes on it. The returned string is formatted to + be human readable. + """ + if self.is_cql_compatible: + ret = self.all_as_cql() + else: + # If we can't produce this table with CQL, comment inline + ret = "/*\nWarning: Table %s.%s omitted because it has constructs not compatible with CQL (was created via legacy API).\n" % \ + (self.keyspace.name, self.name) + ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s" % self.all_as_cql() + ret += "\n*/" + + return ret + + def all_as_cql(self): + ret = self.as_cql_query(formatted=True) + ret += ";" + + for col_meta in self.columns.values(): + if col_meta.index: + ret += "\n%s;" % (col_meta.index.as_cql_query(),) + + for trigger_meta in self.triggers.values(): + ret += "\n%s;" % (trigger_meta.as_cql_query(),) + return ret + + def as_cql_query(self, formatted=False): + """ + Returns a CQL query that can be used to recreate this table (index + creations are not included). If `formatted` is set to :const:`True`, + extra whitespace will be added to make the query human readable. + """ + ret = "CREATE TABLE %s.%s (%s" % ( + protect_name(self.keyspace.name), + protect_name(self.name), + "\n" if formatted else "") + + if formatted: + column_join = ",\n" + padding = " " + else: + column_join = ", " + padding = "" + + columns = [] + for col in self.columns.values(): + columns.append("%s %s%s" % (protect_name(col.name), col.typestring, ' static' if col.is_static else '')) + + if len(self.partition_key) == 1 and not self.clustering_key: + columns[0] += " PRIMARY KEY" + + ret += column_join.join("%s%s" % (padding, col) for col in columns) + + # primary key + if len(self.partition_key) > 1 or self.clustering_key: + ret += "%s%sPRIMARY KEY (" % (column_join, padding) + + if len(self.partition_key) > 1: + ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key) + else: + ret += self.partition_key[0].name + + if self.clustering_key: + ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key) + + ret += ")" + + # options + ret += "%s) WITH " % ("\n" if formatted else "") + + option_strings = [] + if self.is_compact_storage: + option_strings.append("COMPACT STORAGE") + + if self.clustering_key: + cluster_str = "CLUSTERING ORDER BY " + + clustering_names = protect_names([c.name for c in self.clustering_key]) + + if self.is_compact_storage and \ + not issubclass(self.comparator, types.CompositeType): + subtypes = [self.comparator] + else: + subtypes = self.comparator.subtypes + + inner = [] + for colname, coltype in zip(clustering_names, subtypes): + ordering = "DESC" if issubclass(coltype, types.ReversedType) else "ASC" + inner.append("%s %s" % (colname, ordering)) + + cluster_str += "(%s)" % ", ".join(inner) + option_strings.append(cluster_str) + + option_strings.extend(self._make_option_strings()) + + join_str = "\n AND " if formatted else " AND " + ret += join_str.join(option_strings) + + return ret + + def _make_option_strings(self): + ret = [] + options_copy = dict(self.options.items()) + if not options_copy.get('compaction'): + options_copy.pop('compaction', None) + + actual_options = json.loads(options_copy.pop('compaction_strategy_options', '{}')) + for system_table_name, compact_option_name in self.compaction_options.items(): + value = options_copy.pop(system_table_name, None) + if value: + actual_options.setdefault(compact_option_name, value) + + compaction_option_strings = ["'%s': '%s'" % (k, v) for k, v in actual_options.items()] + ret.append('compaction = {%s}' % ', '.join(compaction_option_strings)) + + for system_table_name in self.compaction_options.keys(): + options_copy.pop(system_table_name, None) # delete if present + options_copy.pop('compaction_strategy_option', None) + + if not options_copy.get('compression'): + params = json.loads(options_copy.pop('compression_parameters', '{}')) + param_strings = ["'%s': '%s'" % (k, v) for k, v in params.items()] + ret.append('compression = {%s}' % ', '.join(param_strings)) + + for name, value in options_copy.items(): + if value is not None: + if name == "comment": + value = value or "" + ret.append("%s = %s" % (name, protect_value(value))) + + return list(sorted(ret)) + + +if six.PY3: + def protect_name(name): + return maybe_escape_name(name) +else: + def protect_name(name): # NOQA + if isinstance(name, six.text_type): + name = name.encode('utf8') + return maybe_escape_name(name) + + +def protect_names(names): + return [protect_name(n) for n in names] + + +def protect_value(value): + if value is None: + return 'NULL' + if isinstance(value, (int, float, bool)): + return str(value).lower() + return "'%s'" % value.replace("'", "''") + + +valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$') + + +def is_valid_name(name): + if name is None: + return False + if name.lower() in _keywords - _unreserved_keywords: + return False + return valid_cql3_word_re.match(name) is not None + + +def maybe_escape_name(name): + if is_valid_name(name): + return name + return escape_name(name) + + +def escape_name(name): + return '"%s"' % (name.replace('"', '""'),) + + +class ColumnMetadata(object): + """ + A representation of a single column in a table. + """ + + table = None + """ The :class:`.TableMetadata` this column belongs to. """ + + name = None + """ The string name of this column. """ + + data_type = None + """ + The data type for the column in the form of an instance of one of + the type classes in :mod:`cassandra.cqltypes`. + """ + + index = None + """ + If an index exists on this column, this is an instance of + :class:`.IndexMetadata`, otherwise :const:`None`. + """ + + is_static = False + """ + If this column is static (available in Cassandra 2.1+), this will + be :const:`True`, otherwise :const:`False`. + """ + + def __init__(self, table_metadata, column_name, data_type, index_metadata=None, is_static=False): + self.table = table_metadata + self.name = column_name + self.data_type = data_type + self.index = index_metadata + self.is_static = is_static + + @property + def typestring(self): + """ + A string representation of the type for this column, such as "varchar" + or "map". + """ + if issubclass(self.data_type, types.ReversedType): + return self.data_type.subtypes[0].cql_parameterized_type() + else: + return self.data_type.cql_parameterized_type() + + def __str__(self): + return "%s %s" % (self.name, self.data_type) + + +class IndexMetadata(object): + """ + A representation of a secondary index on a column. + """ + + column = None + """ + The column (:class:`.ColumnMetadata`) this index is on. + """ + + name = None + """ A string name for the index. """ + + index_type = None + """ A string representing the type of index. """ + + index_options = {} + """ A dict of index options. """ + + def __init__(self, column_metadata, index_name=None, index_type=None, index_options={}): + self.column = column_metadata + self.name = index_name + self.index_type = index_type + self.index_options = index_options + + def as_cql_query(self): + """ + Returns a CQL query that can be used to recreate this index. + """ + table = self.column.table + if self.index_type != "CUSTOM": + index_target = protect_name(self.column.name) + if self.index_options is not None: + option_keys = self.index_options.keys() + if "index_keys" in option_keys: + index_target = 'keys(%s)' % (index_target,) + elif "index_values" in option_keys: + # don't use any "function" for collection values + pass + else: + # it might be a "full" index on a frozen collection, but + # we need to check the data type to verify that, because + # there is no special index option for full-collection + # indexes. + data_type = self.column.data_type + collection_types = ('map', 'set', 'list') + if data_type.typename == "frozen" and data_type.subtypes[0].typename in collection_types: + # no index option for full-collection index + index_target = 'full(%s)' % (index_target,) + + return "CREATE INDEX %s ON %s.%s (%s)" % ( + self.name, # Cassandra doesn't like quoted index names for some reason + protect_name(table.keyspace.name), + protect_name(table.name), + index_target) + else: + return "CREATE CUSTOM INDEX %s ON %s.%s (%s) USING '%s'" % ( + self.name, # Cassandra doesn't like quoted index names for some reason + protect_name(table.keyspace.name), + protect_name(table.name), + protect_name(self.column.name), + self.index_options["class_name"]) + + +class TokenMap(object): + """ + Information about the layout of the ring. + """ + + token_class = None + """ + A subclass of :class:`.Token`, depending on what partitioner the cluster uses. + """ + + token_to_host_owner = None + """ + A map of :class:`.Token` objects to the :class:`.Host` that owns that token. + """ + + tokens_to_hosts_by_ks = None + """ + A map of keyspace names to a nested map of :class:`.Token` objects to + sets of :class:`.Host` objects. + """ + + ring = None + """ + An ordered list of :class:`.Token` instances in the ring. + """ + + _metadata = None + + def __init__(self, token_class, token_to_host_owner, all_tokens, metadata): + self.token_class = token_class + self.ring = all_tokens + self.token_to_host_owner = token_to_host_owner + + self.tokens_to_hosts_by_ks = {} + self._metadata = metadata + self._rebuild_lock = RLock() + + def rebuild_keyspace(self, keyspace, build_if_absent=False): + with self._rebuild_lock: + current = self.tokens_to_hosts_by_ks.get(keyspace, None) + if (build_if_absent and current is None) or (not build_if_absent and current is not None): + replica_map = self.replica_map_for_keyspace(self._metadata.keyspaces[keyspace]) + self.tokens_to_hosts_by_ks[keyspace] = replica_map + + def replica_map_for_keyspace(self, ks_metadata): + strategy = ks_metadata.replication_strategy + if strategy: + return strategy.make_token_replica_map(self.token_to_host_owner, self.ring) + else: + return None + + def remove_keyspace(self, keyspace): + self.tokens_to_hosts_by_ks.pop(keyspace, None) + + def get_replicas(self, keyspace, token): + """ + Get a set of :class:`.Host` instances representing all of the + replica nodes for a given :class:`.Token`. + """ + tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None) + if tokens_to_hosts is None: + self.rebuild_keyspace(keyspace, build_if_absent=True) + tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None) + if not tokens_to_hosts: + return [] + + # token range ownership is exclusive on the LHS (the start token), so + # we use bisect_right, which, in the case of a tie/exact match, + # picks an insertion point to the right of the existing match + point = bisect_right(self.ring, token) + if point == len(self.ring): + return tokens_to_hosts[self.ring[0]] + else: + return tokens_to_hosts[self.ring[point]] + + +class Token(object): + """ + Abstract class representing a token. + """ + + @classmethod + def hash_fn(cls, key): + return key + + @classmethod + def from_key(cls, key): + return cls(cls.hash_fn(key)) + + def __cmp__(self, other): + if self.value < other.value: + return -1 + elif self.value == other.value: + return 0 + else: + return 1 + + def __eq__(self, other): + return self.value == other.value + + def __lt__(self, other): + return self.value < other.value + + def __hash__(self): + return hash(self.value) + + def __repr__(self): + return "<%s: %s>" % (self.__class__.__name__, self.value) + __str__ = __repr__ + +MIN_LONG = -(2 ** 63) +MAX_LONG = (2 ** 63) - 1 + + +class NoMurmur3(Exception): + pass + + +class Murmur3Token(Token): + """ + A token for ``Murmur3Partitioner``. + """ + + @classmethod + def hash_fn(cls, key): + if murmur3 is not None: + h = int(murmur3(key)) + return h if h != MIN_LONG else MAX_LONG + else: + raise NoMurmur3() + + def __init__(self, token): + """ `token` should be an int or string representing the token. """ + self.value = int(token) + + +class MD5Token(Token): + """ + A token for ``RandomPartitioner``. + """ + + @classmethod + def hash_fn(cls, key): + if isinstance(key, six.text_type): + key = key.encode('UTF-8') + return abs(varint_unpack(md5(key).digest())) + + def __init__(self, token): + """ `token` should be an int or string representing the token. """ + self.value = int(token) + + +class BytesToken(Token): + """ + A token for ``ByteOrderedPartitioner``. + """ + + def __init__(self, token_string): + """ `token_string` should be string representing the token. """ + if not isinstance(token_string, six.string_types): + raise TypeError( + "Tokens for ByteOrderedPartitioner should be strings (got %s)" + % (type(token_string),)) + self.value = token_string + + +class TriggerMetadata(object): + """ + A representation of a trigger for a table. + """ + + table = None + """ The :class:`.TableMetadata` this trigger belongs to. """ + + name = None + """ The string name of this trigger. """ + + options = None + """ + A dict mapping trigger option names to their specific settings for this + table. + """ + def __init__(self, table_metadata, trigger_name, options=None): + self.table = table_metadata + self.name = trigger_name + self.options = options + + def as_cql_query(self): + ret = "CREATE TRIGGER %s ON %s.%s USING %s" % ( + protect_name(self.name), + protect_name(self.table.keyspace.name), + protect_name(self.table.name), + protect_value(self.options['class']) + ) + return ret diff --git a/cassandra/metrics.py b/cassandra/metrics.py new file mode 100644 index 0000000..77ed896 --- /dev/null +++ b/cassandra/metrics.py @@ -0,0 +1,166 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import chain +import logging + +try: + from greplin import scales +except ImportError: + raise ImportError( + "The scales library is required for metrics support: " + "https://pypi.python.org/pypi/scales") + +log = logging.getLogger(__name__) + + +class Metrics(object): + """ + A collection of timers and counters for various performance metrics. + """ + + request_timer = None + """ + A :class:`greplin.scales.PmfStat` timer for requests. This is a dict-like + object with the following keys: + + * count - number of requests that have been timed + * min - min latency + * max - max latency + * mean - mean latency + * stdev - standard deviation for latencies + * median - median latency + * 75percentile - 75th percentile latencies + * 97percentile - 97th percentile latencies + * 98percentile - 98th percentile latencies + * 99percentile - 99th percentile latencies + * 999percentile - 99.9th percentile latencies + """ + + connection_errors = None + """ + A :class:`greplin.scales.IntStat` count of the number of times that a + request to a Cassandra node has failed due to a connection problem. + """ + + write_timeouts = None + """ + A :class:`greplin.scales.IntStat` count of write requests that resulted + in a timeout. + """ + + read_timeouts = None + """ + A :class:`greplin.scales.IntStat` count of read requests that resulted + in a timeout. + """ + + unavailables = None + """ + A :class:`greplin.scales.IntStat` count of write or read requests that + failed due to an insufficient number of replicas being alive to meet + the requested :class:`.ConsistencyLevel`. + """ + + other_errors = None + """ + A :class:`greplin.scales.IntStat` count of all other request failures, + including failures caused by invalid requests, bootstrapping nodes, + overloaded nodes, etc. + """ + + retries = None + """ + A :class:`greplin.scales.IntStat` count of the number of times a + request was retried based on the :class:`.RetryPolicy` decision. + """ + + ignores = None + """ + A :class:`greplin.scales.IntStat` count of the number of times a + failed request was ignored based on the :class:`.RetryPolicy` decision. + """ + + known_hosts = None + """ + A :class:`greplin.scales.IntStat` count of the number of nodes in + the cluster that the driver is aware of, regardless of whether any + connections are opened to those nodes. + """ + + connected_to = None + """ + A :class:`greplin.scales.IntStat` count of the number of nodes that + the driver currently has at least one connection open to. + """ + + open_connections = None + """ + A :class:`greplin.scales.IntStat` count of the number connections + the driver currently has open. + """ + + def __init__(self, cluster_proxy): + log.debug("Starting metric capture") + + self.stats = scales.collection('/cassandra', + scales.PmfStat('request_timer'), + scales.IntStat('connection_errors'), + scales.IntStat('write_timeouts'), + scales.IntStat('read_timeouts'), + scales.IntStat('unavailables'), + scales.IntStat('other_errors'), + scales.IntStat('retries'), + scales.IntStat('ignores'), + + # gauges + scales.Stat('known_hosts', + lambda: len(cluster_proxy.metadata.all_hosts())), + scales.Stat('connected_to', + lambda: len(set(chain.from_iterable(s._pools.keys() for s in cluster_proxy.sessions)))), + scales.Stat('open_connections', + lambda: sum(sum(p.open_count for p in s._pools.values()) for s in cluster_proxy.sessions))) + + self.request_timer = self.stats.request_timer + self.connection_errors = self.stats.connection_errors + self.write_timeouts = self.stats.write_timeouts + self.read_timeouts = self.stats.read_timeouts + self.unavailables = self.stats.unavailables + self.other_errors = self.stats.other_errors + self.retries = self.stats.retries + self.ignores = self.stats.ignores + self.known_hosts = self.stats.known_hosts + self.connected_to = self.stats.connected_to + self.open_connections = self.stats.open_connections + + def on_connection_error(self): + self.stats.connection_errors += 1 + + def on_write_timeout(self): + self.stats.write_timeouts += 1 + + def on_read_timeout(self): + self.stats.read_timeouts += 1 + + def on_unavailable(self): + self.stats.unavailables += 1 + + def on_other_error(self): + self.stats.other_errors += 1 + + def on_ignore(self): + self.stats.ignores += 1 + + def on_retry(self): + self.stats.retries += 1 diff --git a/cassandra/murmur3.c b/cassandra/murmur3.c new file mode 100644 index 0000000..bdcb972 --- /dev/null +++ b/cassandra/murmur3.c @@ -0,0 +1,268 @@ +/* + * The majority of this code was taken from the python-smhasher library, + * which can be found here: https://github.com/phensley/python-smhasher + * + * That library is under the MIT license with the following copyright: + * + * Copyright (c) 2011 Austin Appleby (Murmur3 routine) + * Copyright (c) 2011 Patrick Hensley (Python wrapper, packaging) + * Copyright 2013 DataStax (Minor modifications to match Cassandra's MM3 hashes) + * + */ + +#define PY_SSIZE_T_CLEAN 1 +#include +#include + +#if PY_VERSION_HEX < 0x02050000 +typedef int Py_ssize_t; +#define PY_SSIZE_T_MAX INT_MAX +#define PY_SSIZE_T_MIN INT_MIN +#endif + +//----------------------------------------------------------------------------- +// Platform-specific functions and macros + +// Microsoft Visual Studio + +#if defined(_MSC_VER) + +typedef unsigned char uint8_t; +typedef unsigned long uint32_t; +typedef unsigned __int64 uint64_t; + +typedef char int8_t; +typedef long int32_t; +typedef __int64 int64_t; + +#define FORCE_INLINE __forceinline + +#include + +#define ROTL32(x,y) _rotl(x,y) +#define ROTL64(x,y) _rotl64(x,y) + +#define BIG_CONSTANT(x) (x) + +// Other compilers + +#else // defined(_MSC_VER) + +#include + +#define FORCE_INLINE inline __attribute__((always_inline)) + +inline uint32_t rotl32 ( int32_t x, int8_t r ) +{ + // cast to unsigned for logical right bitshift (to match C* MM3 implementation) + return (x << r) | ((int32_t) (((uint32_t) x) >> (32 - r))); +} + +inline int64_t rotl64 ( int64_t x, int8_t r ) +{ + // cast to unsigned for logical right bitshift (to match C* MM3 implementation) + return (x << r) | ((int64_t) (((uint64_t) x) >> (64 - r))); +} + +#define ROTL32(x,y) rotl32(x,y) +#define ROTL64(x,y) rotl64(x,y) + +#define BIG_CONSTANT(x) (x##LL) + +#endif // !defined(_MSC_VER) + +//----------------------------------------------------------------------------- +// Block read - if your platform needs to do endian-swapping or can only +// handle aligned reads, do the conversion here + +// TODO 32bit? + +FORCE_INLINE int64_t getblock ( const int64_t * p, int i ) +{ + return p[i]; +} + +//----------------------------------------------------------------------------- +// Finalization mix - force all bits of a hash block to avalanche + +FORCE_INLINE int64_t fmix ( int64_t k ) +{ + // cast to unsigned for logical right bitshift (to match C* MM3 implementation) + k ^= ((uint64_t) k) >> 33; + k *= BIG_CONSTANT(0xff51afd7ed558ccd); + k ^= ((uint64_t) k) >> 33; + k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); + k ^= ((uint64_t) k) >> 33; + + return k; +} + +int64_t MurmurHash3_x64_128 (const void * key, const int len, + const uint32_t seed) +{ + const int8_t * data = (const int8_t*)key; + const int nblocks = len / 16; + + int64_t h1 = seed; + int64_t h2 = seed; + + int64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); + int64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); + int64_t k1 = 0; + int64_t k2 = 0; + + const int64_t * blocks = (const int64_t *)(data); + const int8_t * tail = (const int8_t*)(data + nblocks*16); + + //---------- + // body + + int i; + for(i = 0; i < nblocks; i++) + { + int64_t k1 = getblock(blocks,i*2+0); + int64_t k2 = getblock(blocks,i*2+1); + + k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; + + h1 = ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; + + k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + h2 = ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; + + } + + //---------- + // tail + switch(len & 15) + { + case 15: k2 ^= ((int64_t) (tail[14])) << 48; + case 14: k2 ^= ((int64_t) (tail[13])) << 40; + case 13: k2 ^= ((int64_t) (tail[12])) << 32; + case 12: k2 ^= ((int64_t) (tail[11])) << 24; + case 11: k2 ^= ((int64_t) (tail[10])) << 16; + case 10: k2 ^= ((int64_t) (tail[ 9])) << 8; + case 9: k2 ^= ((int64_t) (tail[ 8])) << 0; + k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + case 8: k1 ^= ((int64_t) (tail[ 7])) << 56; + case 7: k1 ^= ((int64_t) (tail[ 6])) << 48; + case 6: k1 ^= ((int64_t) (tail[ 5])) << 40; + case 5: k1 ^= ((int64_t) (tail[ 4])) << 32; + case 4: k1 ^= ((int64_t) (tail[ 3])) << 24; + case 3: k1 ^= ((int64_t) (tail[ 2])) << 16; + case 2: k1 ^= ((int64_t) (tail[ 1])) << 8; + case 1: k1 ^= ((int64_t) (tail[ 0])) << 0; + k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; h2 ^= len; + + h1 += h2; + h2 += h1; + + h1 = fmix(h1); + h2 = fmix(h2); + + h1 += h2; + h2 += h1; + + return h1; +} + + +struct module_state { + PyObject *error; +}; + +#if PY_MAJOR_VERSION >= 3 +#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m)) +#else +#define GETSTATE(m) (&_state) +static struct module_state _state; +#endif + +static PyObject * +murmur3(PyObject *self, PyObject *args) +{ + const char *key; + Py_ssize_t len; + uint32_t seed = 0; + int64_t result = 0; + + + if (!PyArg_ParseTuple(args, "s#|I", &key, &len, &seed)) { + return NULL; + } + + // TODO handle x86 version? + result = MurmurHash3_x64_128((void *)key, len, seed); + return (PyObject *) PyLong_FromLong((long int)result); +} + +static PyMethodDef murmur3_methods[] = { + {"murmur3", murmur3, METH_VARARGS, "Make an x64 murmur3 64-bit hash value"}, + {NULL, NULL, 0, NULL} +}; + +#if PY_MAJOR_VERSION >= 3 + +static int murmur3_traverse(PyObject *m, visitproc visit, void *arg) { + Py_VISIT(GETSTATE(m)->error); + return 0; +} + +static int murmur3_clear(PyObject *m) { + Py_CLEAR(GETSTATE(m)->error); + return 0; +} + +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "murmur3", + NULL, + sizeof(struct module_state), + murmur3_methods, + NULL, + murmur3_traverse, + murmur3_clear, + NULL +}; + +#define INITERROR return NULL + +PyObject * +PyInit_murmur3(void) + +#else +#define INITERROR return + +void +initmurmur3(void) +#endif +{ +#if PY_MAJOR_VERSION >= 3 + PyObject *module = PyModule_Create(&moduledef); +#else + PyObject *module = Py_InitModule("murmur3", murmur3_methods); +#endif + struct module_state *st = NULL; + + if (module == NULL) + INITERROR; + st = GETSTATE(module); + + st->error = PyErr_NewException("murmur3.Error", NULL, NULL); + if (st->error == NULL) { + Py_DECREF(module); + INITERROR; + } + +#if PY_MAJOR_VERSION >= 3 + return module; +#endif +} diff --git a/cassandra/policies.py b/cassandra/policies.py new file mode 100644 index 0000000..244df24 --- /dev/null +++ b/cassandra/policies.py @@ -0,0 +1,825 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import islice, cycle, groupby, repeat +import logging +from random import randint +from threading import Lock +import six + +from cassandra import ConsistencyLevel + +from six.moves import range + +log = logging.getLogger(__name__) + + +class HostDistance(object): + """ + A measure of how "distant" a node is from the client, which + may influence how the load balancer distributes requests + and how many connections are opened to the node. + """ + + IGNORED = -1 + """ + A node with this distance should never be queried or have + connections opened to it. + """ + + LOCAL = 0 + """ + Nodes with ``LOCAL`` distance will be preferred for operations + under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`) + and will have a greater number of connections opened against + them by default. + + This distance is typically used for nodes within the same + datacenter as the client. + """ + + REMOTE = 1 + """ + Nodes with ``REMOTE`` distance will be treated as a last resort + by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`) + and will have a smaller number of connections opened against + them by default. + + This distance is typically used for nodes outside of the + datacenter that the client is running in. + """ + + +class HostStateListener(object): + + def on_up(self, host): + """ Called when a node is marked up. """ + raise NotImplementedError() + + def on_down(self, host): + """ Called when a node is marked down. """ + raise NotImplementedError() + + def on_add(self, host): + """ + Called when a node is added to the cluster. The newly added node + should be considered up. + """ + raise NotImplementedError() + + def on_remove(self, host): + """ Called when a node is removed from the cluster. """ + raise NotImplementedError() + + +class LoadBalancingPolicy(HostStateListener): + """ + Load balancing policies are used to decide how to distribute + requests among all possible coordinator nodes in the cluster. + + In particular, they may focus on querying "near" nodes (those + in a local datacenter) or on querying nodes who happen to + be replicas for the requested data. + + You may also use subclasses of :class:`.LoadBalancingPolicy` for + custom behavior. + """ + + _hosts_lock = None + + def __init__(self): + self._hosts_lock = Lock() + + def distance(self, host): + """ + Returns a measure of how remote a :class:`~.pool.Host` is in + terms of the :class:`.HostDistance` enums. + """ + raise NotImplementedError() + + def populate(self, cluster, hosts): + """ + This method is called to initialize the load balancing + policy with a set of :class:`.Host` instances before its + first use. The `cluster` parameter is an instance of + :class:`.Cluster`. + """ + raise NotImplementedError() + + def make_query_plan(self, working_keyspace=None, query=None): + """ + Given a :class:`~.query.Statement` instance, return a iterable + of :class:`.Host` instances which should be queried in that + order. A generator may work well for custom implementations + of this method. + + Note that the `query` argument may be :const:`None` when preparing + statements. + + `working_keyspace` should be the string name of the current keyspace, + as set through :meth:`.Session.set_keyspace()` or with a ``USE`` + statement. + """ + raise NotImplementedError() + + def check_supported(self): + """ + This will be called after the cluster Metadata has been initialized. + If the load balancing policy implementation cannot be supported for + some reason (such as a missing C extension), this is the point at + which it should raise an exception. + """ + pass + + +class RoundRobinPolicy(LoadBalancingPolicy): + """ + A subclass of :class:`.LoadBalancingPolicy` which evenly + distributes queries across all nodes in the cluster, + regardless of what datacenter the nodes may be in. + + This load balancing policy is used by default. + """ + _live_hosts = frozenset(()) + + def populate(self, cluster, hosts): + self._live_hosts = frozenset(hosts) + if len(hosts) <= 1: + self._position = 0 + else: + self._position = randint(0, len(hosts) - 1) + + def distance(self, host): + return HostDistance.LOCAL + + def make_query_plan(self, working_keyspace=None, query=None): + # not thread-safe, but we don't care much about lost increments + # for the purposes of load balancing + pos = self._position + self._position += 1 + + hosts = self._live_hosts + length = len(hosts) + if length: + pos %= length + return list(islice(cycle(hosts), pos, pos + length)) + else: + return [] + + def on_up(self, host): + with self._hosts_lock: + self._live_hosts = self._live_hosts.union((host, )) + + def on_down(self, host): + with self._hosts_lock: + self._live_hosts = self._live_hosts.difference((host, )) + + def on_add(self, host): + with self._hosts_lock: + self._live_hosts = self._live_hosts.union((host, )) + + def on_remove(self, host): + with self._hosts_lock: + self._live_hosts = self._live_hosts.difference((host, )) + + +class DCAwareRoundRobinPolicy(LoadBalancingPolicy): + """ + Similar to :class:`.RoundRobinPolicy`, but prefers hosts + in the local datacenter and only uses nodes in remote + datacenters as a last resort. + """ + + local_dc = None + used_hosts_per_remote_dc = 0 + + def __init__(self, local_dc='', used_hosts_per_remote_dc=0): + """ + The `local_dc` parameter should be the name of the datacenter + (such as is reported by ``nodetool ring``) that should + be considered local. If not specified, the driver will choose + a local_dc based on the first host among :attr:`.Cluster.contact_points` + having a valid DC. If relying on this mechanism, all specified + contact points should be nodes in a single, local DC. + + `used_hosts_per_remote_dc` controls how many nodes in + each remote datacenter will have connections opened + against them. In other words, `used_hosts_per_remote_dc` hosts + will be considered :attr:`~.HostDistance.REMOTE` and the + rest will be considered :attr:`~.HostDistance.IGNORED`. + By default, all remote hosts are ignored. + """ + self.local_dc = local_dc + self.used_hosts_per_remote_dc = used_hosts_per_remote_dc + self._dc_live_hosts = {} + self._position = 0 + self._contact_points = [] + LoadBalancingPolicy.__init__(self) + + def _dc(self, host): + return host.datacenter or self.local_dc + + def populate(self, cluster, hosts): + for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): + self._dc_live_hosts[dc] = tuple(set(dc_hosts)) + + if not self.local_dc: + self._contact_points = cluster.contact_points + + self._position = randint(0, len(hosts) - 1) if hosts else 0 + + def distance(self, host): + dc = self._dc(host) + if dc == self.local_dc: + return HostDistance.LOCAL + + if not self.used_hosts_per_remote_dc: + return HostDistance.IGNORED + else: + dc_hosts = self._dc_live_hosts.get(dc) + if not dc_hosts: + return HostDistance.IGNORED + + if host in list(dc_hosts)[:self.used_hosts_per_remote_dc]: + return HostDistance.REMOTE + else: + return HostDistance.IGNORED + + def make_query_plan(self, working_keyspace=None, query=None): + # not thread-safe, but we don't care much about lost increments + # for the purposes of load balancing + pos = self._position + self._position += 1 + + local_live = self._dc_live_hosts.get(self.local_dc, ()) + pos = (pos % len(local_live)) if local_live else 0 + for host in islice(cycle(local_live), pos, pos + len(local_live)): + yield host + + # the dict can change, so get candidate DCs iterating over keys of a copy + other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc] + for dc in other_dcs: + remote_live = self._dc_live_hosts.get(dc, ()) + for host in remote_live[:self.used_hosts_per_remote_dc]: + yield host + + def on_up(self, host): + # not worrying about threads because this will happen during + # control connection startup/refresh + if not self.local_dc and host.datacenter: + if host.address in self._contact_points: + self.local_dc = host.datacenter + log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " + "if incorrect, please specify a local_dc to the constructor, " + "or limit contact points to local cluster nodes" % + (self.local_dc, host.address)) + del self._contact_points + + dc = self._dc(host) + with self._hosts_lock: + current_hosts = self._dc_live_hosts.get(dc, ()) + if host not in current_hosts: + self._dc_live_hosts[dc] = current_hosts + (host, ) + + def on_down(self, host): + dc = self._dc(host) + with self._hosts_lock: + current_hosts = self._dc_live_hosts.get(dc, ()) + if host in current_hosts: + hosts = tuple(h for h in current_hosts if h != host) + if hosts: + self._dc_live_hosts[dc] = hosts + else: + del self._dc_live_hosts[dc] + + def on_add(self, host): + self.on_up(host) + + def on_remove(self, host): + self.on_down(host) + + +class TokenAwarePolicy(LoadBalancingPolicy): + """ + A :class:`.LoadBalancingPolicy` wrapper that adds token awareness to + a child policy. + + This alters the child policy's behavior so that it first attempts to + send queries to :attr:`~.HostDistance.LOCAL` replicas (as determined + by the child policy) based on the :class:`.Statement`'s + :attr:`~.Statement.routing_key`. Once those hosts are exhausted, the + remaining hosts in the child policy's query plan will be used. + + If no :attr:`~.Statement.routing_key` is set on the query, the child + policy's query plan will be used as is. + """ + + _child_policy = None + _cluster_metadata = None + + def __init__(self, child_policy): + self._child_policy = child_policy + + def populate(self, cluster, hosts): + self._cluster_metadata = cluster.metadata + self._child_policy.populate(cluster, hosts) + + def check_supported(self): + if not self._cluster_metadata.can_support_partitioner(): + raise Exception( + '%s cannot be used with the cluster partitioner (%s) because ' + 'the relevant C extension for this driver was not compiled. ' + 'See the installation instructions for details on building ' + 'and installing the C extensions.' % + (self.__class__.__name__, self._cluster_metadata.partitioner)) + + def distance(self, *args, **kwargs): + return self._child_policy.distance(*args, **kwargs) + + def make_query_plan(self, working_keyspace=None, query=None): + if query and query.keyspace: + keyspace = query.keyspace + else: + keyspace = working_keyspace + + child = self._child_policy + if query is None: + for host in child.make_query_plan(keyspace, query): + yield host + else: + routing_key = query.routing_key + if routing_key is None or keyspace is None: + for host in child.make_query_plan(keyspace, query): + yield host + else: + replicas = self._cluster_metadata.get_replicas(keyspace, routing_key) + for replica in replicas: + if replica.is_up and \ + child.distance(replica) == HostDistance.LOCAL: + yield replica + + for host in child.make_query_plan(keyspace, query): + # skip if we've already listed this host + if host not in replicas or \ + child.distance(host) == HostDistance.REMOTE: + yield host + + def on_up(self, *args, **kwargs): + return self._child_policy.on_up(*args, **kwargs) + + def on_down(self, *args, **kwargs): + return self._child_policy.on_down(*args, **kwargs) + + def on_add(self, *args, **kwargs): + return self._child_policy.on_add(*args, **kwargs) + + def on_remove(self, *args, **kwargs): + return self._child_policy.on_remove(*args, **kwargs) + + +class WhiteListRoundRobinPolicy(RoundRobinPolicy): + """ + A subclass of :class:`.RoundRobinPolicy` which evenly + distributes queries across all nodes in the cluster, + regardless of what datacenter the nodes may be in, but + only if that node exists in the list of allowed nodes + + This policy is addresses the issue described in + https://datastax-oss.atlassian.net/browse/JAVA-145 + Where connection errors occur when connection + attempts are made to private IP addresses remotely + """ + def __init__(self, hosts): + """ + The `hosts` parameter should be a sequence of hosts to permit + connections to. + """ + self._allowed_hosts = hosts + RoundRobinPolicy.__init__(self) + + def populate(self, cluster, hosts): + self._live_hosts = frozenset(h for h in hosts if h.address in self._allowed_hosts) + + if len(hosts) <= 1: + self._position = 0 + else: + self._position = randint(0, len(hosts) - 1) + + def distance(self, host): + if host.address in self._allowed_hosts: + return HostDistance.LOCAL + else: + return HostDistance.IGNORED + + def on_up(self, host): + if host.address in self._allowed_hosts: + RoundRobinPolicy.on_up(self, host) + + def on_add(self, host): + if host.address in self._allowed_hosts: + RoundRobinPolicy.on_add(self, host) + + +class ConvictionPolicy(object): + """ + A policy which decides when hosts should be considered down + based on the types of failures and the number of failures. + + If custom behavior is needed, this class may be subclassed. + """ + + def __init__(self, host): + """ + `host` is an instance of :class:`.Host`. + """ + self.host = host + + def add_failure(self, connection_exc): + """ + Implementations should return :const:`True` if the host should be + convicted, :const:`False` otherwise. + """ + raise NotImplementedError() + + def reset(self): + """ + Implementations should clear out any convictions or state regarding + the host. + """ + raise NotImplementedError() + + +class SimpleConvictionPolicy(ConvictionPolicy): + """ + The default implementation of :class:`ConvictionPolicy`, + which simply marks a host as down after the first failure + of any kind. + """ + + def add_failure(self, connection_exc): + return True + + def reset(self): + pass + + +class ReconnectionPolicy(object): + """ + This class and its subclasses govern how frequently an attempt is made + to reconnect to nodes that are marked as dead. + + If custom behavior is needed, this class may be subclassed. + """ + + def new_schedule(self): + """ + This should return a finite or infinite iterable of delays (each as a + floating point number of seconds) inbetween each failed reconnection + attempt. Note that if the iterable is finite, reconnection attempts + will cease once the iterable is exhausted. + """ + raise NotImplementedError() + + +class ConstantReconnectionPolicy(ReconnectionPolicy): + """ + A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay + inbetween each reconnection attempt. + """ + + def __init__(self, delay, max_attempts=64): + """ + `delay` should be a floating point number of seconds to wait inbetween + each attempt. + + `max_attempts` should be a total number of attempts to be made before + giving up, or :const:`None` to continue reconnection attempts forever. + The default is 64. + """ + if delay < 0: + raise ValueError("delay must not be negative") + if max_attempts < 0: + raise ValueError("max_attempts must not be negative") + + self.delay = delay + self.max_attempts = max_attempts + + def new_schedule(self): + return repeat(self.delay, self.max_attempts) + + +class ExponentialReconnectionPolicy(ReconnectionPolicy): + """ + A :class:`.ReconnectionPolicy` subclass which exponentially increases + the length of the delay inbetween each reconnection attempt up to + a set maximum delay. + """ + + def __init__(self, base_delay, max_delay): + """ + `base_delay` and `max_delay` should be in floating point units of + seconds. + """ + if base_delay < 0 or max_delay < 0: + raise ValueError("Delays may not be negative") + + if max_delay < base_delay: + raise ValueError("Max delay must be greater than base delay") + + self.base_delay = base_delay + self.max_delay = max_delay + + def new_schedule(self): + return (min(self.base_delay * (2 ** i), self.max_delay) for i in range(64)) + + +class WriteType(object): + """ + For usage with :class:`.RetryPolicy`, this describe a type + of write operation. + """ + + SIMPLE = 0 + """ + A write to a single partition key. Such writes are guaranteed to be atomic + and isolated. + """ + + BATCH = 1 + """ + A write to multiple partition keys that used the distributed batch log to + ensure atomicity. + """ + + UNLOGGED_BATCH = 2 + """ + A write to multiple partition keys that did not use the distributed batch + log. Atomicity for such writes is not guaranteed. + """ + + COUNTER = 3 + """ + A counter write (for one or multiple partition keys). Such writes should + not be replayed in order to avoid overcount. + """ + + BATCH_LOG = 4 + """ + The initial write to the distributed batch log that Cassandra performs + internally before a BATCH write. + """ + + CAS = 5 + """ + A lighweight-transaction write, such as "DELETE ... IF EXISTS". + """ + +WriteType.name_to_value = { + 'SIMPLE': WriteType.SIMPLE, + 'BATCH': WriteType.BATCH, + 'UNLOGGED_BATCH': WriteType.UNLOGGED_BATCH, + 'COUNTER': WriteType.COUNTER, + 'BATCH_LOG': WriteType.BATCH_LOG, + 'CAS': WriteType.CAS +} + + +class RetryPolicy(object): + """ + A policy that describes whether to retry, rethrow, or ignore timeout + and unavailable failures. + + To specify a default retry policy, set the + :attr:`.Cluster.default_retry_policy` attribute to an instance of this + class or one of its subclasses. + + To specify a retry policy per query, set the :attr:`.Statement.retry_policy` + attribute to an instance of this class or one of its subclasses. + + If custom behavior is needed for retrying certain operations, + this class may be subclassed. + """ + + RETRY = 0 + """ + This should be returned from the below methods if the operation + should be retried on the same connection. + """ + + RETHROW = 1 + """ + This should be returned from the below methods if the failure + should be propagated and no more retries attempted. + """ + + IGNORE = 2 + """ + This should be returned from the below methods if the failure + should be ignored but no more retries should be attempted. + """ + + def on_read_timeout(self, query, consistency, required_responses, + received_responses, data_retrieved, retry_num): + """ + This is called when a read operation times out from the coordinator's + perspective (i.e. a replica did not respond to the coordinator in time). + It should return a tuple with two items: one of the class enums (such + as :attr:`.RETRY`) and a :class:`.ConsistencyLevel` to retry the + operation at or :const:`None` to keep the same consistency level. + + `query` is the :class:`.Statement` that timed out. + + `consistency` is the :class:`.ConsistencyLevel` that the operation was + attempted at. + + The `required_responses` and `received_responses` parameters describe + how many replicas needed to respond to meet the requested consistency + level and how many actually did respond before the coordinator timed + out the request. `data_retrieved` is a boolean indicating whether + any of those responses contained data (as opposed to just a digest). + + `retry_num` counts how many times the operation has been retried, so + the first time this method is called, `retry_num` will be 0. + + By default, operations will be retried at most once, and only if + a sufficient number of replicas responded (with data digests). + """ + if retry_num != 0: + return (self.RETHROW, None) + elif received_responses >= required_responses and not data_retrieved: + return (self.RETRY, consistency) + else: + return (self.RETHROW, None) + + def on_write_timeout(self, query, consistency, write_type, + required_responses, received_responses, retry_num): + """ + This is called when a write operation times out from the coordinator's + perspective (i.e. a replica did not respond to the coordinator in time). + + `query` is the :class:`.Statement` that timed out. + + `consistency` is the :class:`.ConsistencyLevel` that the operation was + attempted at. + + `write_type` is one of the :class:`.WriteType` enums describing the + type of write operation. + + The `required_responses` and `received_responses` parameters describe + how many replicas needed to acknowledge the write to meet the requested + consistency level and how many replicas actually did acknowledge the + write before the coordinator timed out the request. + + `retry_num` counts how many times the operation has been retried, so + the first time this method is called, `retry_num` will be 0. + + By default, failed write operations will retried at most once, and + they will only be retried if the `write_type` was + :attr:`~.WriteType.BATCH_LOG`. + """ + if retry_num != 0: + return (self.RETHROW, None) + elif write_type == WriteType.BATCH_LOG: + return (self.RETRY, consistency) + else: + return (self.RETHROW, None) + + def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): + """ + This is called when the coordinator node determines that a read or + write operation cannot be successful because the number of live + replicas are too low to meet the requested :class:`.ConsistencyLevel`. + This means that the read or write operation was never forwared to + any replicas. + + `query` is the :class:`.Statement` that failed. + + `consistency` is the :class:`.ConsistencyLevel` that the operation was + attempted at. + + `required_replicas` is the number of replicas that would have needed to + acknowledge the operation to meet the requested consistency level. + `alive_replicas` is the number of replicas that the coordinator + considered alive at the time of the request. + + `retry_num` counts how many times the operation has been retried, so + the first time this method is called, `retry_num` will be 0. + + By default, no retries will be attempted and the error will be re-raised. + """ + return (self.RETHROW, None) + + +class FallthroughRetryPolicy(RetryPolicy): + """ + A retry policy that never retries and always propagates failures to + the application. + """ + + def on_read_timeout(self, *args, **kwargs): + return (self.RETHROW, None) + + def on_write_timeout(self, *args, **kwargs): + return (self.RETHROW, None) + + def on_unavailable(self, *args, **kwargs): + return (self.RETHROW, None) + + +class DowngradingConsistencyRetryPolicy(RetryPolicy): + """ + A retry policy that sometimes retries with a lower consistency level than + the one initially requested. + + **BEWARE**: This policy may retry queries using a lower consistency + level than the one initially requested. By doing so, it may break + consistency guarantees. In other words, if you use this retry policy, + there are cases (documented below) where a read at :attr:`~.QUORUM` + *may not* see a preceding write at :attr:`~.QUORUM`. Do not use this + policy unless you have understood the cases where this can happen and + are ok with that. It is also recommended to subclass this class so + that queries that required a consistency level downgrade can be + recorded (so that repairs can be made later, etc). + + This policy implements the same retries as :class:`.RetryPolicy`, + but on top of that, it also retries in the following cases: + + * On a read timeout: if the number of replicas that responded is + greater than one but lower than is required by the requested + consistency level, the operation is retried at a lower consistency + level. + * On a write timeout: if the operation is an :attr:`~.UNLOGGED_BATCH` + and at least one replica acknowledged the write, the operation is + retried at a lower consistency level. Furthermore, for other + write types, if at least one replica acknowledged the write, the + timeout is ignored. + * On an unavailable exception: if at least one replica is alive, the + operation is retried at a lower consistency level. + + The reasoning behind this retry policy is as follows: if, based + on the information the Cassandra coordinator node returns, retrying the + operation with the initially requested consistency has a chance to + succeed, do it. Otherwise, if based on that information we know the + initially requested consistency level cannot be achieved currently, then: + + * For writes, ignore the exception (thus silently failing the + consistency requirement) if we know the write has been persisted on at + least one replica. + * For reads, try reading at a lower consistency level (thus silently + failing the consistency requirement). + + In other words, this policy implements the idea that if the requested + consistency level cannot be achieved, the next best thing for writes is + to make sure the data is persisted, and that reading something is better + than reading nothing, even if there is a risk of reading stale data. + """ + def _pick_consistency(self, num_responses): + if num_responses >= 3: + return (self.RETRY, ConsistencyLevel.THREE) + elif num_responses >= 2: + return (self.RETRY, ConsistencyLevel.TWO) + elif num_responses >= 1: + return (self.RETRY, ConsistencyLevel.ONE) + else: + return (self.RETHROW, None) + + def on_read_timeout(self, query, consistency, required_responses, + received_responses, data_retrieved, retry_num): + if retry_num != 0: + return (self.RETHROW, None) + elif received_responses < required_responses: + return self._pick_consistency(received_responses) + elif not data_retrieved: + return (self.RETRY, consistency) + else: + return (self.RETHROW, None) + + def on_write_timeout(self, query, consistency, write_type, + required_responses, received_responses, retry_num): + if retry_num != 0: + return (self.RETHROW, None) + elif write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): + return (self.IGNORE, None) + elif write_type == WriteType.UNLOGGED_BATCH: + return self._pick_consistency(received_responses) + elif write_type == WriteType.BATCH_LOG: + return (self.RETRY, consistency) + else: + return (self.RETHROW, None) + + def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): + if retry_num != 0: + return (self.RETHROW, None) + else: + return self._pick_consistency(alive_replicas) diff --git a/cassandra/pool.py b/cassandra/pool.py new file mode 100644 index 0000000..3057fcd --- /dev/null +++ b/cassandra/pool.py @@ -0,0 +1,712 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Connection pooling and host management. +""" + +import logging +import socket +import time +from threading import Lock, RLock, Condition +import weakref +try: + from weakref import WeakSet +except ImportError: + from cassandra.util import WeakSet # NOQA + +from cassandra import AuthenticationFailed +from cassandra.connection import ConnectionException +from cassandra.policies import HostDistance + +log = logging.getLogger(__name__) + + +class NoConnectionsAvailable(Exception): + """ + All existing connections to a given host are busy, or there are + no open connections. + """ + pass + + +class Host(object): + """ + Represents a single Cassandra node. + """ + + address = None + """ + The IP address or hostname of the node. + """ + + conviction_policy = None + """ + A :class:`~.ConvictionPolicy` instance for determining when this node should + be marked up or down. + """ + + is_up = None + """ + :const:`True` if the node is considered up, :const:`False` if it is + considered down, and :const:`None` if it is not known if the node is + up or down. + """ + + _datacenter = None + _rack = None + _reconnection_handler = None + lock = None + + _currently_handling_node_up = False + + def __init__(self, inet_address, conviction_policy_factory, datacenter=None, rack=None): + if inet_address is None: + raise ValueError("inet_address may not be None") + if conviction_policy_factory is None: + raise ValueError("conviction_policy_factory may not be None") + + self.address = inet_address + self.conviction_policy = conviction_policy_factory(self) + self.set_location_info(datacenter, rack) + self.lock = RLock() + + @property + def datacenter(self): + """ The datacenter the node is in. """ + return self._datacenter + + @property + def rack(self): + """ The rack the node is in. """ + return self._rack + + def set_location_info(self, datacenter, rack): + """ + Sets the datacenter and rack for this node. Intended for internal + use (by the control connection, which periodically checks the + ring topology) only. + """ + self._datacenter = datacenter + self._rack = rack + + def set_up(self): + if not self.is_up: + log.debug("Host %s is now marked up", self.address) + self.conviction_policy.reset() + self.is_up = True + + def set_down(self): + self.is_up = False + + def signal_connection_failure(self, connection_exc): + return self.conviction_policy.add_failure(connection_exc) + + def is_currently_reconnecting(self): + return self._reconnection_handler is not None + + def get_and_set_reconnection_handler(self, new_handler): + """ + Atomically replaces the reconnection handler for this + host. Intended for internal use only. + """ + with self.lock: + old = self._reconnection_handler + self._reconnection_handler = new_handler + return old + + def __eq__(self, other): + return self.address == other.address + + def __hash__(self): + return hash(self.address) + + def __lt__(self, other): + return self.address < other.address + + def __str__(self): + return str(self.address) + + def __repr__(self): + dc = (" %s" % (self._datacenter,)) if self._datacenter else "" + return "<%s: %s%s>" % (self.__class__.__name__, self.address, dc) + + +class _ReconnectionHandler(object): + """ + Abstract class for attempting reconnections with a given + schedule and scheduler. + """ + + _cancelled = False + + def __init__(self, scheduler, schedule, callback, *callback_args, **callback_kwargs): + self.scheduler = scheduler + self.schedule = schedule + self.callback = callback + self.callback_args = callback_args + self.callback_kwargs = callback_kwargs + + def start(self): + if self._cancelled: + log.debug("Reconnection handler was cancelled before starting") + return + + first_delay = next(self.schedule) + self.scheduler.schedule(first_delay, self.run) + + def run(self): + if self._cancelled: + return + + conn = None + try: + conn = self.try_reconnect() + except Exception as exc: + try: + next_delay = next(self.schedule) + except StopIteration: + # the schedule has been exhausted + next_delay = None + + # call on_exception for logging purposes even if next_delay is None + if self.on_exception(exc, next_delay): + if next_delay is None: + log.warning( + "Will not continue to retry reconnection attempts " + "due to an exhausted retry schedule") + else: + self.scheduler.schedule(next_delay, self.run) + else: + if not self._cancelled: + self.on_reconnection(conn) + self.callback(*(self.callback_args), **(self.callback_kwargs)) + finally: + if conn: + conn.close() + + def cancel(self): + self._cancelled = True + + def try_reconnect(self): + """ + Subclasses must implement this method. It should attempt to + open a new Connection and return it; if a failure occurs, an + Exception should be raised. + """ + raise NotImplementedError() + + def on_reconnection(self, connection): + """ + Called when a new Connection is successfully opened. Nothing is + done by default. + """ + pass + + def on_exception(self, exc, next_delay): + """ + Called when an Exception is raised when trying to connect. + `exc` is the Exception that was raised and `next_delay` is the + number of seconds (as a float) that the handler will wait before + attempting to connect again. + + Subclasses should return :const:`False` if no more attempts to + connection should be made, :const:`True` otherwise. The default + behavior is to always retry unless the error is an + :exc:`.AuthenticationFailed` instance. + """ + if isinstance(exc, AuthenticationFailed): + return False + else: + return True + + +class _HostReconnectionHandler(_ReconnectionHandler): + + def __init__(self, host, connection_factory, is_host_addition, on_add, on_up, *args, **kwargs): + _ReconnectionHandler.__init__(self, *args, **kwargs) + self.is_host_addition = is_host_addition + self.on_add = on_add + self.on_up = on_up + self.host = host + self.connection_factory = connection_factory + + def try_reconnect(self): + return self.connection_factory() + + def on_reconnection(self, connection): + log.info("Successful reconnection to %s, marking node up if it isn't already", self.host) + if self.is_host_addition: + self.on_add(self.host) + else: + self.on_up(self.host) + + def on_exception(self, exc, next_delay): + if isinstance(exc, AuthenticationFailed): + return False + else: + log.warning("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s", + self.host, next_delay, exc) + log.debug("Reconnection error details", exc_info=True) + return True + + +class HostConnection(object): + """ + When using v3 of the native protocol, this is used instead of a connection + pool per host (HostConnectionPool) due to the increased in-flight capacity + of individual connections. + """ + + host = None + host_distance = None + is_shutdown = False + + _session = None + _connection = None + _lock = None + + def __init__(self, host, host_distance, session): + self.host = host + self.host_distance = host_distance + self._session = weakref.proxy(session) + self._lock = Lock() + + if host_distance == HostDistance.IGNORED: + log.debug("Not opening connection to ignored host %s", self.host) + return + elif host_distance == HostDistance.REMOTE and not session.cluster.connect_to_remote_hosts: + log.debug("Not opening connection to remote host %s", self.host) + return + + log.debug("Initializing connection for host %s", self.host) + self._connection = session.cluster.connection_factory(host.address) + if session.keyspace: + self._connection.set_keyspace_blocking(session.keyspace) + log.debug("Finished initializing connection for host %s", self.host) + + def borrow_connection(self, timeout): + if self.is_shutdown: + raise ConnectionException( + "Pool for %s is shutdown" % (self.host,), self.host) + + conn = self._connection + if not conn: + raise NoConnectionsAvailable() + + with conn.lock: + if conn.in_flight < conn.max_request_id: + conn.in_flight += 1 + return conn, conn.get_request_id() + + raise NoConnectionsAvailable("All request IDs are currently in use") + + def return_connection(self, connection): + with connection.lock: + connection.in_flight -= 1 + + if connection.is_defunct or connection.is_closed: + log.debug("Defunct or closed connection (%s) returned to pool, potentially " + "marking host %s as down", id(connection), self.host) + is_down = self._session.cluster.signal_connection_failure( + self.host, connection.last_error, is_host_addition=False) + if is_down: + self.shutdown() + else: + self._connection = None + with self._lock: + if self._is_replacing: + return + self._is_replacing = True + self._session.submit(self._replace, connection) + + def _replace(self, connection): + log.debug("Replacing connection (%s) to %s", id(connection), self.host) + conn = self._session.cluster.connection_factory(self.host.address) + if self._session.keyspace: + conn.set_keyspace_blocking(self._session.keyspace) + self._connection = conn + with self._lock: + self._is_replacing = False + + def shutdown(self): + with self._lock: + if self.is_shutdown: + return + else: + self.is_shutdown = True + + if self._connection: + self._connection.close() + + def _set_keyspace_for_all_conns(self, keyspace, callback): + if self.is_shutdown or not self._connection: + return + + def connection_finished_setting_keyspace(conn, error): + self.return_connection(conn) + errors = [] if not error else [error] + callback(self, errors) + + self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace) + + def get_connections(self): + c = self._connection + return [c] if c else [] + + def get_state(self): + connection = self._connection + open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0 + in_flights = [connection.in_flight] if connection else [] + return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights} + + @property + def open_count(self): + connection = self._connection + return 1 if connection and not (connection.is_closed or connection.is_defunct) else 0 + +_MAX_SIMULTANEOUS_CREATION = 1 +_MIN_TRASH_INTERVAL = 10 + + +class HostConnectionPool(object): + """ + Used to pool connections to a host for v1 and v2 native protocol. + """ + + host = None + host_distance = None + + is_shutdown = False + open_count = 0 + _scheduled_for_creation = 0 + _next_trash_allowed_at = 0 + + def __init__(self, host, host_distance, session): + self.host = host + self.host_distance = host_distance + + self._session = weakref.proxy(session) + self._lock = RLock() + self._conn_available_condition = Condition() + + log.debug("Initializing new connection pool for host %s", self.host) + core_conns = session.cluster.get_core_connections_per_host(host_distance) + self._connections = [session.cluster.connection_factory(host.address) + for i in range(core_conns)] + + if session.keyspace: + for conn in self._connections: + conn.set_keyspace_blocking(session.keyspace) + + self._trash = set() + self._next_trash_allowed_at = time.time() + self.open_count = core_conns + log.debug("Finished initializing new connection pool for host %s", self.host) + + def borrow_connection(self, timeout): + if self.is_shutdown: + raise ConnectionException( + "Pool for %s is shutdown" % (self.host,), self.host) + + conns = self._connections + if not conns: + # handled specially just for simpler code + log.debug("Detected empty pool, opening core conns to %s", self.host) + core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) + with self._lock: + # we check the length of self._connections again + # along with self._scheduled_for_creation while holding the lock + # in case multiple threads hit this condition at the same time + to_create = core_conns - (len(self._connections) + self._scheduled_for_creation) + for i in range(to_create): + self._scheduled_for_creation += 1 + self._session.submit(self._create_new_connection) + + # in_flight is incremented by wait_for_conn + conn = self._wait_for_conn(timeout) + return conn + else: + # note: it would be nice to push changes to these config settings + # to pools instead of doing a new lookup on every + # borrow_connection() call + max_reqs = self._session.cluster.get_max_requests_per_connection(self.host_distance) + max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) + + least_busy = min(conns, key=lambda c: c.in_flight) + request_id = None + # to avoid another thread closing this connection while + # trashing it (through the return_connection process), hold + # the connection lock from this point until we've incremented + # its in_flight count + need_to_wait = False + with least_busy.lock: + if least_busy.in_flight < least_busy.max_request_id: + least_busy.in_flight += 1 + request_id = least_busy.get_request_id() + else: + # once we release the lock, wait for another connection + need_to_wait = True + + if need_to_wait: + # wait_for_conn will increment in_flight on the conn + least_busy, request_id = self._wait_for_conn(timeout) + + # if we have too many requests on this connection but we still + # have space to open a new connection against this host, go ahead + # and schedule the creation of a new connection + if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns: + self._maybe_spawn_new_connection() + + return least_busy, request_id + + def _maybe_spawn_new_connection(self): + with self._lock: + if self._scheduled_for_creation >= _MAX_SIMULTANEOUS_CREATION: + return + if self.open_count >= self._session.cluster.get_max_connections_per_host(self.host_distance): + return + self._scheduled_for_creation += 1 + + log.debug("Submitting task for creation of new Connection to %s", self.host) + self._session.submit(self._create_new_connection) + + def _create_new_connection(self): + try: + self._add_conn_if_under_max() + except (ConnectionException, socket.error) as exc: + log.warning("Failed to create new connection to %s: %s", self.host, exc) + except Exception: + log.exception("Unexpectedly failed to create new connection") + finally: + with self._lock: + self._scheduled_for_creation -= 1 + + def _add_conn_if_under_max(self): + max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) + with self._lock: + if self.is_shutdown: + return False + + if self.open_count >= max_conns: + return False + + self.open_count += 1 + + log.debug("Going to open new connection to host %s", self.host) + try: + conn = self._session.cluster.connection_factory(self.host.address) + if self._session.keyspace: + conn.set_keyspace_blocking(self._session.keyspace) + self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL + with self._lock: + new_connections = self._connections[:] + [conn] + self._connections = new_connections + log.debug("Added new connection (%s) to pool for host %s, signaling availablility", + id(conn), self.host) + self._signal_available_conn() + return True + except (ConnectionException, socket.error) as exc: + log.warning("Failed to add new connection to pool for host %s: %s", self.host, exc) + with self._lock: + self.open_count -= 1 + if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False): + self.shutdown() + return False + except AuthenticationFailed: + with self._lock: + self.open_count -= 1 + return False + + def _await_available_conn(self, timeout): + with self._conn_available_condition: + self._conn_available_condition.wait(timeout) + + def _signal_available_conn(self): + with self._conn_available_condition: + self._conn_available_condition.notify() + + def _signal_all_available_conn(self): + with self._conn_available_condition: + self._conn_available_condition.notify_all() + + def _wait_for_conn(self, timeout): + start = time.time() + remaining = timeout + + while remaining > 0: + # wait on our condition for the possibility that a connection + # is useable + self._await_available_conn(remaining) + + # self.shutdown() may trigger the above Condition + if self.is_shutdown: + raise ConnectionException("Pool is shutdown") + + conns = self._connections + if conns: + least_busy = min(conns, key=lambda c: c.in_flight) + with least_busy.lock: + if least_busy.in_flight < least_busy.max_request_id: + least_busy.in_flight += 1 + return least_busy, least_busy.get_request_id() + + remaining = timeout - (time.time() - start) + + raise NoConnectionsAvailable() + + def return_connection(self, connection): + with connection.lock: + connection.in_flight -= 1 + in_flight = connection.in_flight + + if connection.is_defunct or connection.is_closed: + log.debug("Defunct or closed connection (%s) returned to pool, potentially " + "marking host %s as down", id(connection), self.host) + is_down = self._session.cluster.signal_connection_failure( + self.host, connection.last_error, is_host_addition=False) + if is_down: + self.shutdown() + else: + self._replace(connection) + else: + if connection in self._trash: + with connection.lock: + if connection.in_flight == 0: + with self._lock: + if connection in self._trash: + self._trash.remove(connection) + log.debug("Closing trashed connection (%s) to %s", id(connection), self.host) + connection.close() + return + + core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) + min_reqs = self._session.cluster.get_min_requests_per_connection(self.host_distance) + # we can use in_flight here without holding the connection lock + # because the fact that in_flight dipped below the min at some + # point is enough to start the trashing procedure + if len(self._connections) > core_conns and in_flight <= min_reqs and \ + time.time() >= self._next_trash_allowed_at: + self._maybe_trash_connection(connection) + else: + self._signal_available_conn() + + def _maybe_trash_connection(self, connection): + core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) + did_trash = False + with self._lock: + if connection not in self._connections: + return + + if self.open_count > core_conns: + did_trash = True + self.open_count -= 1 + new_connections = self._connections[:] + new_connections.remove(connection) + self._connections = new_connections + + with connection.lock: + if connection.in_flight == 0: + log.debug("Skipping trash and closing unused connection (%s) to %s", id(connection), self.host) + connection.close() + + # skip adding it to the trash if we're already closing it + return + + self._trash.add(connection) + + if did_trash: + self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL + log.debug("Trashed connection (%s) to %s", id(connection), self.host) + + def _replace(self, connection): + should_replace = False + with self._lock: + if connection in self._connections: + new_connections = self._connections[:] + new_connections.remove(connection) + self._connections = new_connections + self.open_count -= 1 + should_replace = True + + if should_replace: + log.debug("Replacing connection (%s) to %s", id(connection), self.host) + + def close_and_replace(): + connection.close() + self._add_conn_if_under_max() + + self._session.submit(close_and_replace) + else: + # just close it + log.debug("Closing connection (%s) to %s", id(connection), self.host) + connection.close() + + def shutdown(self): + with self._lock: + if self.is_shutdown: + return + else: + self.is_shutdown = True + + self._signal_all_available_conn() + for conn in self._connections: + conn.close() + self.open_count -= 1 + + for conn in self._trash: + conn.close() + + def ensure_core_connections(self): + if self.is_shutdown: + return + + core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) + with self._lock: + to_create = core_conns - (len(self._connections) + self._scheduled_for_creation) + for i in range(to_create): + self._scheduled_for_creation += 1 + self._session.submit(self._create_new_connection) + + def _set_keyspace_for_all_conns(self, keyspace, callback): + """ + Asynchronously sets the keyspace for all connections. When all + connections have been set, `callback` will be called with two + arguments: this pool, and a list of any errors that occurred. + """ + remaining_callbacks = set(self._connections) + errors = [] + + if not remaining_callbacks: + callback(self, errors) + return + + def connection_finished_setting_keyspace(conn, error): + self.return_connection(conn) + remaining_callbacks.remove(conn) + if error: + errors.append(error) + + if not remaining_callbacks: + callback(self, errors) + + for conn in self._connections: + conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace) + + def get_connections(self): + return self._connections + + def get_state(self): + in_flights = [c.in_flight for c in self._connections] + return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights} diff --git a/cassandra/protocol.py b/cassandra/protocol.py new file mode 100644 index 0000000..bd929da --- /dev/null +++ b/cassandra/protocol.py @@ -0,0 +1,998 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import # to enable import io from stdlib +import logging +import socket +from uuid import UUID + +import six +from six.moves import range +import io + +from cassandra import (Unavailable, WriteTimeout, ReadTimeout, + AlreadyExists, InvalidRequest, Unauthorized, + UnsupportedOperation) +from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, + int8_pack, int8_unpack, uint64_pack, header_pack, + v3_header_pack) +from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, + CounterColumnType, DateType, DecimalType, + DoubleType, FloatType, Int32Type, + InetAddressType, IntegerType, ListType, + LongType, MapType, SetType, TimeUUIDType, + UTF8Type, UUIDType, UserType, + TupleType, lookup_casstype, SimpleDateType, + TimeType) +from cassandra.policies import WriteType + +log = logging.getLogger(__name__) + + +class NotSupportedError(Exception): + pass + + +class InternalError(Exception): + pass + + +HEADER_DIRECTION_FROM_CLIENT = 0x00 +HEADER_DIRECTION_TO_CLIENT = 0x80 +HEADER_DIRECTION_MASK = 0x80 + +COMPRESSED_FLAG = 0x01 +TRACING_FLAG = 0x02 + +_message_types_by_name = {} +_message_types_by_opcode = {} + + +class _RegisterMessageType(type): + def __init__(cls, name, bases, dct): + if not name.startswith('_'): + _message_types_by_name[cls.name] = cls + _message_types_by_opcode[cls.opcode] = cls + + +@six.add_metaclass(_RegisterMessageType) +class _MessageType(object): + + tracing = False + + def to_binary(self, stream_id, protocol_version, compression=None): + body = io.BytesIO() + self.send_body(body, protocol_version) + body = body.getvalue() + + flags = 0 + if compression and len(body) > 0: + body = compression(body) + flags |= COMPRESSED_FLAG + if self.tracing: + flags |= TRACING_FLAG + + msg = io.BytesIO() + write_header(msg, protocol_version, flags, stream_id, self.opcode, len(body)) + msg.write(body) + + return msg.getvalue() + + def __repr__(self): + return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self))) + + +def _get_params(message_obj): + base_attrs = dir(_MessageType) + return ( + (n, a) for n, a in message_obj.__dict__.items() + if n not in base_attrs and not n.startswith('_') and not callable(a) + ) + + +def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, body, + decompressor=None): + if flags & COMPRESSED_FLAG: + if decompressor is None: + raise Exception("No de-compressor available for compressed frame!") + body = decompressor(body) + flags ^= COMPRESSED_FLAG + + body = io.BytesIO(body) + if flags & TRACING_FLAG: + trace_id = UUID(bytes=body.read(16)) + flags ^= TRACING_FLAG + else: + trace_id = None + + if flags: + log.warning("Unknown protocol flags set: %02x. May cause problems.", flags) + + msg_class = _message_types_by_opcode[opcode] + msg = msg_class.recv_body(body, protocol_version, user_type_map) + msg.stream_id = stream_id + msg.trace_id = trace_id + return msg + + +error_classes = {} + + +class ErrorMessage(_MessageType, Exception): + opcode = 0x00 + name = 'ERROR' + summary = 'Unknown' + + def __init__(self, code, message, info): + self.code = code + self.message = message + self.info = info + + @classmethod + def recv_body(cls, f, protocol_version, user_type_map): + code = read_int(f) + msg = read_string(f) + subcls = error_classes.get(code, cls) + extra_info = subcls.recv_error_info(f) + return subcls(code=code, message=msg, info=extra_info) + + def summary_msg(self): + msg = 'code=%04x [%s] message="%s"' \ + % (self.code, self.summary, self.message) + return msg + + def __str__(self): + return '' % self.summary_msg() + __repr__ = __str__ + + @staticmethod + def recv_error_info(f): + pass + + def to_exception(self): + return self + + +class ErrorMessageSubclass(_RegisterMessageType): + def __init__(cls, name, bases, dct): + if cls.error_code is not None: # Server has an error code of 0. + error_classes[cls.error_code] = cls + + +@six.add_metaclass(ErrorMessageSubclass) +class ErrorMessageSub(ErrorMessage): + error_code = None + + +class RequestExecutionException(ErrorMessageSub): + pass + + +class RequestValidationException(ErrorMessageSub): + pass + + +class ServerError(ErrorMessageSub): + summary = 'Server error' + error_code = 0x0000 + + +class ProtocolException(ErrorMessageSub): + summary = 'Protocol error' + error_code = 0x000A + + +class BadCredentials(ErrorMessageSub): + summary = 'Bad credentials' + error_code = 0x0100 + + +class UnavailableErrorMessage(RequestExecutionException): + summary = 'Unavailable exception' + error_code = 0x1000 + + @staticmethod + def recv_error_info(f): + return { + 'consistency': read_consistency_level(f), + 'required_replicas': read_int(f), + 'alive_replicas': read_int(f), + } + + def to_exception(self): + return Unavailable(self.summary_msg(), **self.info) + + +class OverloadedErrorMessage(RequestExecutionException): + summary = 'Coordinator node overloaded' + error_code = 0x1001 + + +class IsBootstrappingErrorMessage(RequestExecutionException): + summary = 'Coordinator node is bootstrapping' + error_code = 0x1002 + + +class TruncateError(RequestExecutionException): + summary = 'Error during truncate' + error_code = 0x1003 + + +class WriteTimeoutErrorMessage(RequestExecutionException): + summary = "Coordinator node timed out waiting for replica nodes' responses" + error_code = 0x1100 + + @staticmethod + def recv_error_info(f): + return { + 'consistency': read_consistency_level(f), + 'received_responses': read_int(f), + 'required_responses': read_int(f), + 'write_type': WriteType.name_to_value[read_string(f)], + } + + def to_exception(self): + return WriteTimeout(self.summary_msg(), **self.info) + + +class ReadTimeoutErrorMessage(RequestExecutionException): + summary = "Coordinator node timed out waiting for replica nodes' responses" + error_code = 0x1200 + + @staticmethod + def recv_error_info(f): + return { + 'consistency': read_consistency_level(f), + 'received_responses': read_int(f), + 'required_responses': read_int(f), + 'data_retrieved': bool(read_byte(f)), + } + + def to_exception(self): + return ReadTimeout(self.summary_msg(), **self.info) + + +class SyntaxException(RequestValidationException): + summary = 'Syntax error in CQL query' + error_code = 0x2000 + + +class UnauthorizedErrorMessage(RequestValidationException): + summary = 'Unauthorized' + error_code = 0x2100 + + def to_exception(self): + return Unauthorized(self.summary_msg()) + + +class InvalidRequestException(RequestValidationException): + summary = 'Invalid query' + error_code = 0x2200 + + def to_exception(self): + return InvalidRequest(self.summary_msg()) + + +class ConfigurationException(RequestValidationException): + summary = 'Query invalid because of configuration issue' + error_code = 0x2300 + + +class PreparedQueryNotFound(RequestValidationException): + summary = 'Matching prepared statement not found on this node' + error_code = 0x2500 + + @staticmethod + def recv_error_info(f): + # return the query ID + return read_binary_string(f) + + +class AlreadyExistsException(ConfigurationException): + summary = 'Item already exists' + error_code = 0x2400 + + @staticmethod + def recv_error_info(f): + return { + 'keyspace': read_string(f), + 'table': read_string(f), + } + + def to_exception(self): + return AlreadyExists(**self.info) + + +class StartupMessage(_MessageType): + opcode = 0x01 + name = 'STARTUP' + + KNOWN_OPTION_KEYS = set(( + 'CQL_VERSION', + 'COMPRESSION', + )) + + def __init__(self, cqlversion, options): + self.cqlversion = cqlversion + self.options = options + + def send_body(self, f, protocol_version): + optmap = self.options.copy() + optmap['CQL_VERSION'] = self.cqlversion + write_stringmap(f, optmap) + + +class ReadyMessage(_MessageType): + opcode = 0x02 + name = 'READY' + + @classmethod + def recv_body(cls, f, protocol_version, user_type_map): + return cls() + + +class AuthenticateMessage(_MessageType): + opcode = 0x03 + name = 'AUTHENTICATE' + + def __init__(self, authenticator): + self.authenticator = authenticator + + @classmethod + def recv_body(cls, f, protocol_version, user_type_map): + authname = read_string(f) + return cls(authenticator=authname) + + +class CredentialsMessage(_MessageType): + opcode = 0x04 + name = 'CREDENTIALS' + + def __init__(self, creds): + self.creds = creds + + def send_body(self, f, protocol_version): + if protocol_version > 1: + raise UnsupportedOperation( + "Credentials-based authentication is not supported with " + "protocol version 2 or higher. Use the SASL authentication " + "mechanism instead.") + write_short(f, len(self.creds)) + for credkey, credval in self.creds.items(): + write_string(f, credkey) + write_string(f, credval) + + +class AuthChallengeMessage(_MessageType): + opcode = 0x0E + name = 'AUTH_CHALLENGE' + + def __init__(self, challenge): + self.challenge = challenge + + @classmethod + def recv_body(cls, f, protocol_version, user_type_map): + return cls(read_binary_longstring(f)) + + +class AuthResponseMessage(_MessageType): + opcode = 0x0F + name = 'AUTH_RESPONSE' + + def __init__(self, response): + self.response = response + + def send_body(self, f, protocol_version): + write_longstring(f, self.response) + + +class AuthSuccessMessage(_MessageType): + opcode = 0x10 + name = 'AUTH_SUCCESS' + + def __init__(self, token): + self.token = token + + @classmethod + def recv_body(cls, f, protocol_version, user_type_map): + return cls(read_longstring(f)) + + +class OptionsMessage(_MessageType): + opcode = 0x05 + name = 'OPTIONS' + + def send_body(self, f, protocol_version): + pass + + +class SupportedMessage(_MessageType): + opcode = 0x06 + name = 'SUPPORTED' + + def __init__(self, cql_versions, options): + self.cql_versions = cql_versions + self.options = options + + @classmethod + def recv_body(cls, f, protocol_version, user_type_map): + options = read_stringmultimap(f) + cql_versions = options.pop('CQL_VERSION') + return cls(cql_versions=cql_versions, options=options) + + +# used for QueryMessage and ExecuteMessage +_VALUES_FLAG = 0x01 +_SKIP_METADATA_FLAG = 0x01 +_PAGE_SIZE_FLAG = 0x04 +_WITH_PAGING_STATE_FLAG = 0x08 +_WITH_SERIAL_CONSISTENCY_FLAG = 0x10 +_PROTOCOL_TIMESTAMP = 0x20 + + +class QueryMessage(_MessageType): + opcode = 0x07 + name = 'QUERY' + + def __init__(self, query, consistency_level, serial_consistency_level=None, + fetch_size=None, paging_state=None, timestamp=None): + self.query = query + self.consistency_level = consistency_level + self.serial_consistency_level = serial_consistency_level + self.fetch_size = fetch_size + self.paging_state = paging_state + self.timestamp = timestamp + + def send_body(self, f, protocol_version): + write_longstring(f, self.query) + write_consistency_level(f, self.consistency_level) + flags = 0x00 + if self.serial_consistency_level: + if protocol_version >= 2: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + else: + raise UnsupportedOperation( + "Serial consistency levels require the use of protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2 " + "to support serial consistency levels.") + + if self.fetch_size: + if protocol_version >= 2: + flags |= _PAGE_SIZE_FLAG + else: + raise UnsupportedOperation( + "Automatic query paging may only be used with protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2.") + + if self.paging_state: + if protocol_version >= 2: + flags |= _WITH_PAGING_STATE_FLAG + else: + raise UnsupportedOperation( + "Automatic query paging may only be used with protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2.") + + if self.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP + + write_byte(f, flags) + if self.fetch_size: + write_int(f, self.fetch_size) + if self.paging_state: + write_longstring(f, self.paging_state) + if self.serial_consistency_level: + write_consistency_level(f, self.serial_consistency_level) + if self.timestamp is not None: + write_long(f, self.timestamp) + +CUSTOM_TYPE = object() + +RESULT_KIND_VOID = 0x0001 +RESULT_KIND_ROWS = 0x0002 +RESULT_KIND_SET_KEYSPACE = 0x0003 +RESULT_KIND_PREPARED = 0x0004 +RESULT_KIND_SCHEMA_CHANGE = 0x0005 + + +class ResultMessage(_MessageType): + opcode = 0x08 + name = 'RESULT' + + kind = None + results = None + paging_state = None + + _type_codes = { + 0x0000: CUSTOM_TYPE, + 0x0001: AsciiType, + 0x0002: LongType, + 0x0003: BytesType, + 0x0004: BooleanType, + 0x0005: CounterColumnType, + 0x0006: DecimalType, + 0x0007: DoubleType, + 0x0008: FloatType, + 0x0009: Int32Type, + 0x000A: UTF8Type, + 0x000B: DateType, + 0x000C: UUIDType, + 0x000D: UTF8Type, + 0x000E: IntegerType, + 0x000F: TimeUUIDType, + 0x0010: InetAddressType, + 0x0011: SimpleDateType, + 0x0012: TimeType, + 0x0020: ListType, + 0x0021: MapType, + 0x0022: SetType, + 0x0030: UserType, + 0x0031: TupleType, + } + + _FLAGS_GLOBAL_TABLES_SPEC = 0x0001 + _HAS_MORE_PAGES_FLAG = 0x0002 + _NO_METADATA_FLAG = 0x0004 + + def __init__(self, kind, results, paging_state=None): + self.kind = kind + self.results = results + self.paging_state = paging_state + + @classmethod + def recv_body(cls, f, protocol_version, user_type_map): + kind = read_int(f) + paging_state = None + if kind == RESULT_KIND_VOID: + results = None + elif kind == RESULT_KIND_ROWS: + paging_state, results = cls.recv_results_rows( + f, protocol_version, user_type_map) + elif kind == RESULT_KIND_SET_KEYSPACE: + ksname = read_string(f) + results = ksname + elif kind == RESULT_KIND_PREPARED: + results = cls.recv_results_prepared(f, user_type_map) + elif kind == RESULT_KIND_SCHEMA_CHANGE: + results = cls.recv_results_schema_change(f, protocol_version) + return cls(kind, results, paging_state) + + @classmethod + def recv_results_rows(cls, f, protocol_version, user_type_map): + paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) + rowcount = read_int(f) + rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] + colnames = [c[2] for c in column_metadata] + coltypes = [c[3] for c in column_metadata] + parsed_rows = [ + tuple(ctype.from_binary(val, protocol_version) + for ctype, val in zip(coltypes, row)) + for row in rows] + return (paging_state, (colnames, parsed_rows)) + + @classmethod + def recv_results_prepared(cls, f, user_type_map): + query_id = read_binary_string(f) + _, column_metadata = cls.recv_results_metadata(f, user_type_map) + return (query_id, column_metadata) + + @classmethod + def recv_results_metadata(cls, f, user_type_map): + flags = read_int(f) + glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) + colcount = read_int(f) + if flags & cls._HAS_MORE_PAGES_FLAG: + paging_state = read_binary_longstring(f) + else: + paging_state = None + if glob_tblspec: + ksname = read_string(f) + cfname = read_string(f) + column_metadata = [] + for _ in range(colcount): + if glob_tblspec: + colksname = ksname + colcfname = cfname + else: + colksname = read_string(f) + colcfname = read_string(f) + colname = read_string(f) + coltype = cls.read_type(f, user_type_map) + column_metadata.append((colksname, colcfname, colname, coltype)) + return paging_state, column_metadata + + @classmethod + def recv_results_schema_change(cls, f, protocol_version): + return EventMessage.recv_schema_change(f, protocol_version) + + @classmethod + def read_type(cls, f, user_type_map): + optid = read_short(f) + try: + typeclass = cls._type_codes[optid] + except KeyError: + raise NotSupportedError("Unknown data type code 0x%04x. Have to skip" + " entire result set." % (optid,)) + if typeclass in (ListType, SetType): + subtype = cls.read_type(f, user_type_map) + typeclass = typeclass.apply_parameters((subtype,)) + elif typeclass == MapType: + keysubtype = cls.read_type(f, user_type_map) + valsubtype = cls.read_type(f, user_type_map) + typeclass = typeclass.apply_parameters((keysubtype, valsubtype)) + elif typeclass == TupleType: + num_items = read_short(f) + types = tuple(cls.read_type(f, user_type_map) for _ in range(num_items)) + typeclass = typeclass.apply_parameters(types) + elif typeclass == UserType: + ks = read_string(f) + udt_name = read_string(f) + num_fields = read_short(f) + names_and_types = tuple((read_string(f), cls.read_type(f, user_type_map)) + for _ in range(num_fields)) + mapped_class = user_type_map.get(ks, {}).get(udt_name) + typeclass = typeclass.make_udt_class( + ks, udt_name, names_and_types, mapped_class) + elif typeclass == CUSTOM_TYPE: + classname = read_string(f) + typeclass = lookup_casstype(classname) + + return typeclass + + @staticmethod + def recv_row(f, colcount): + return [read_value(f) for _ in range(colcount)] + + +class PrepareMessage(_MessageType): + opcode = 0x09 + name = 'PREPARE' + + def __init__(self, query): + self.query = query + + def send_body(self, f, protocol_version): + write_longstring(f, self.query) + + +class ExecuteMessage(_MessageType): + opcode = 0x0A + name = 'EXECUTE' + + def __init__(self, query_id, query_params, consistency_level, + serial_consistency_level=None, fetch_size=None, + paging_state=None, timestamp=None): + self.query_id = query_id + self.query_params = query_params + self.consistency_level = consistency_level + self.serial_consistency_level = serial_consistency_level + self.fetch_size = fetch_size + self.paging_state = paging_state + self.timestamp = timestamp + + def send_body(self, f, protocol_version): + write_string(f, self.query_id) + if protocol_version == 1: + if self.serial_consistency_level: + raise UnsupportedOperation( + "Serial consistency levels require the use of protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2 " + "to support serial consistency levels.") + if self.fetch_size or self.paging_state: + raise UnsupportedOperation( + "Automatic query paging may only be used with protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2.") + write_short(f, len(self.query_params)) + for param in self.query_params: + write_value(f, param) + write_consistency_level(f, self.consistency_level) + else: + write_consistency_level(f, self.consistency_level) + flags = _VALUES_FLAG + if self.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if self.fetch_size: + flags |= _PAGE_SIZE_FLAG + if self.paging_state: + flags |= _WITH_PAGING_STATE_FLAG + if self.timestamp is not None: + if protocol_version >= 3: + flags |= _PROTOCOL_TIMESTAMP + else: + raise UnsupportedOperation( + "Protocol-level timestamps may only be used with protocol version " + "3 or higher. Consider setting Cluster.protocol_version to 3.") + write_byte(f, flags) + write_short(f, len(self.query_params)) + for param in self.query_params: + write_value(f, param) + if self.fetch_size: + write_int(f, self.fetch_size) + if self.paging_state: + write_longstring(f, self.paging_state) + if self.serial_consistency_level: + write_consistency_level(f, self.serial_consistency_level) + if self.timestamp is not None: + write_long(f, self.timestamp) + + +class BatchMessage(_MessageType): + opcode = 0x0D + name = 'BATCH' + + def __init__(self, batch_type, queries, consistency_level, + serial_consistency_level=None, timestamp=None): + self.batch_type = batch_type + self.queries = queries + self.consistency_level = consistency_level + self.serial_consistency_level = serial_consistency_level + self.timestamp = timestamp + + def send_body(self, f, protocol_version): + write_byte(f, self.batch_type.value) + write_short(f, len(self.queries)) + for prepared, string_or_query_id, params in self.queries: + if not prepared: + write_byte(f, 0) + write_longstring(f, string_or_query_id) + else: + write_byte(f, 1) + write_short(f, len(string_or_query_id)) + f.write(string_or_query_id) + write_short(f, len(params)) + for param in params: + write_value(f, param) + + write_consistency_level(f, self.consistency_level) + if protocol_version >= 3: + flags = 0 + if self.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if self.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP + write_byte(f, flags) + + if self.serial_consistency_level: + write_consistency_level(f, self.serial_consistency_level) + if self.timestamp is not None: + write_long(f, self.timestamp) + + +known_event_types = frozenset(( + 'TOPOLOGY_CHANGE', + 'STATUS_CHANGE', + 'SCHEMA_CHANGE' +)) + + +class RegisterMessage(_MessageType): + opcode = 0x0B + name = 'REGISTER' + + def __init__(self, event_list): + self.event_list = event_list + + def send_body(self, f, protocol_version): + write_stringlist(f, self.event_list) + + +class EventMessage(_MessageType): + opcode = 0x0C + name = 'EVENT' + + def __init__(self, event_type, event_args): + self.event_type = event_type + self.event_args = event_args + + @classmethod + def recv_body(cls, f, protocol_version, user_type_map): + event_type = read_string(f).upper() + if event_type in known_event_types: + read_method = getattr(cls, 'recv_' + event_type.lower()) + return cls(event_type=event_type, event_args=read_method(f, protocol_version)) + raise NotSupportedError('Unknown event type %r' % event_type) + + @classmethod + def recv_topology_change(cls, f, protocol_version): + # "NEW_NODE" or "REMOVED_NODE" + change_type = read_string(f) + address = read_inet(f) + return dict(change_type=change_type, address=address) + + @classmethod + def recv_status_change(cls, f, protocol_version): + # "UP" or "DOWN" + change_type = read_string(f) + address = read_inet(f) + return dict(change_type=change_type, address=address) + + @classmethod + def recv_schema_change(cls, f, protocol_version): + # "CREATED", "DROPPED", or "UPDATED" + change_type = read_string(f) + if protocol_version >= 3: + target = read_string(f) + keyspace = read_string(f) + if target != "KEYSPACE": + table_or_type = read_string(f) + return {'change_type': change_type, 'keyspace': keyspace, target.lower(): table_or_type} + else: + return {'change_type': change_type, 'keyspace': keyspace} + else: + keyspace = read_string(f) + table = read_string(f) + return {'change_type': change_type, 'keyspace': keyspace, 'table': table} + + +def write_header(f, version, flags, stream_id, opcode, length): + """ + Write a CQL protocol frame header. + """ + pack = v3_header_pack if version >= 3 else header_pack + f.write(pack(version | HEADER_DIRECTION_FROM_CLIENT, flags, stream_id, opcode)) + write_int(f, length) + + +def read_byte(f): + return int8_unpack(f.read(1)) + + +def write_byte(f, b): + f.write(int8_pack(b)) + + +def read_int(f): + return int32_unpack(f.read(4)) + + +def write_int(f, i): + f.write(int32_pack(i)) + + +def write_long(f, i): + f.write(uint64_pack(i)) + + +def read_short(f): + return uint16_unpack(f.read(2)) + + +def write_short(f, s): + f.write(uint16_pack(s)) + + +def read_consistency_level(f): + return read_short(f) + + +def write_consistency_level(f, cl): + write_short(f, cl) + + +def read_string(f): + size = read_short(f) + contents = f.read(size) + return contents.decode('utf8') + + +def read_binary_string(f): + size = read_short(f) + contents = f.read(size) + return contents + + +def write_string(f, s): + if isinstance(s, six.text_type): + s = s.encode('utf8') + write_short(f, len(s)) + f.write(s) + + +def read_binary_longstring(f): + size = read_int(f) + contents = f.read(size) + return contents + + +def read_longstring(f): + return read_binary_longstring(f).decode('utf8') + + +def write_longstring(f, s): + if isinstance(s, six.text_type): + s = s.encode('utf8') + write_int(f, len(s)) + f.write(s) + + +def read_stringlist(f): + numstrs = read_short(f) + return [read_string(f) for _ in range(numstrs)] + + +def write_stringlist(f, stringlist): + write_short(f, len(stringlist)) + for s in stringlist: + write_string(f, s) + + +def read_stringmap(f): + numpairs = read_short(f) + strmap = {} + for _ in range(numpairs): + k = read_string(f) + strmap[k] = read_string(f) + return strmap + + +def write_stringmap(f, strmap): + write_short(f, len(strmap)) + for k, v in strmap.items(): + write_string(f, k) + write_string(f, v) + + +def read_stringmultimap(f): + numkeys = read_short(f) + strmmap = {} + for _ in range(numkeys): + k = read_string(f) + strmmap[k] = read_stringlist(f) + return strmmap + + +def write_stringmultimap(f, strmmap): + write_short(f, len(strmmap)) + for k, v in strmmap.items(): + write_string(f, k) + write_stringlist(f, v) + + +def read_value(f): + size = read_int(f) + if size < 0: + return None + return f.read(size) + + +def write_value(f, v): + if v is None: + write_int(f, -1) + else: + write_int(f, len(v)) + f.write(v) + + +def read_inet(f): + size = read_byte(f) + addrbytes = f.read(size) + port = read_int(f) + if size == 4: + addrfam = socket.AF_INET + elif size == 16: + addrfam = socket.AF_INET6 + else: + raise InternalError("bad inet address: %r" % (addrbytes,)) + return (socket.inet_ntop(addrfam, addrbytes), port) + + +def write_inet(f, addrtuple): + addr, port = addrtuple + if ':' in addr: + addrfam = socket.AF_INET6 + else: + addrfam = socket.AF_INET + addrbytes = socket.inet_pton(addrfam, addr) + write_byte(f, len(addrbytes)) + f.write(addrbytes) + write_int(f, port) diff --git a/cassandra/query.py b/cassandra/query.py new file mode 100644 index 0000000..8bc156f --- /dev/null +++ b/cassandra/query.py @@ -0,0 +1,901 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module holds classes for working with prepared statements and +specifying consistency levels and retry policies for individual +queries. +""" + +from collections import namedtuple +from datetime import datetime, timedelta +import re +import struct +import time +import six + +from cassandra import ConsistencyLevel, OperationTimedOut +from cassandra.cqltypes import unix_time_from_uuid1 +from cassandra.encoder import Encoder +import cassandra.encoder +from cassandra.util import OrderedDict + +import logging +log = logging.getLogger(__name__) + + +NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') +START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') +END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') + +_clean_name_cache = {} + + +def _clean_column_name(name): + try: + return _clean_name_cache[name] + except KeyError: + clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))) + _clean_name_cache[name] = clean + return clean + + +def tuple_factory(colnames, rows): + """ + Returns each row as a tuple + + Example:: + + >>> from cassandra.query import tuple_factory + >>> session = cluster.connect('mykeyspace') + >>> session.row_factory = tuple_factory + >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") + >>> print rows[0] + ('Bob', 42) + + .. versionchanged:: 2.0.0 + moved from ``cassandra.decoder`` to ``cassandra.query`` + """ + return rows + + +def named_tuple_factory(colnames, rows): + """ + Returns each row as a `namedtuple `_. + This is the default row factory. + + Example:: + + >>> from cassandra.query import named_tuple_factory + >>> session = cluster.connect('mykeyspace') + >>> session.row_factory = named_tuple_factory + >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") + >>> user = rows[0] + + >>> # you can access field by their name: + >>> print "name: %s, age: %d" % (user.name, user.age) + name: Bob, age: 42 + + >>> # or you can access fields by their position (like a tuple) + >>> name, age = user + >>> print "name: %s, age: %d" % (name, age) + name: Bob, age: 42 + >>> name = user[0] + >>> age = user[1] + >>> print "name: %s, age: %d" % (name, age) + name: Bob, age: 42 + + .. versionchanged:: 2.0.0 + moved from ``cassandra.decoder`` to ``cassandra.query`` + """ + clean_column_names = map(_clean_column_name, colnames) + try: + Row = namedtuple('Row', clean_column_names) + except Exception: + log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " + "(see Python 'namedtuple' documentation for details on name rules). " + "Results will be returned with positional names. " + "Avoid this by choosing different names, using SELECT \"\" AS aliases, " + "or specifying a different row_factory on your Session" % + (colnames, clean_column_names)) + Row = namedtuple('Row', clean_column_names, rename=True) + + return [Row(*row) for row in rows] + + +def dict_factory(colnames, rows): + """ + Returns each row as a dict. + + Example:: + + >>> from cassandra.query import dict_factory + >>> session = cluster.connect('mykeyspace') + >>> session.row_factory = dict_factory + >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") + >>> print rows[0] + {u'age': 42, u'name': u'Bob'} + + .. versionchanged:: 2.0.0 + moved from ``cassandra.decoder`` to ``cassandra.query`` + """ + return [dict(zip(colnames, row)) for row in rows] + + +def ordered_dict_factory(colnames, rows): + """ + Like :meth:`~cassandra.query.dict_factory`, but returns each row as an OrderedDict, + so the order of the columns is preserved. + + .. versionchanged:: 2.0.0 + moved from ``cassandra.decoder`` to ``cassandra.query`` + """ + return [OrderedDict(zip(colnames, row)) for row in rows] + + +FETCH_SIZE_UNSET = object() + + +class Statement(object): + """ + An abstract class representing a single query. There are three subclasses: + :class:`.SimpleStatement`, :class:`.BoundStatement`, and :class:`.BatchStatement`. + These can be passed to :meth:`.Session.execute()`. + """ + + retry_policy = None + """ + An instance of a :class:`cassandra.policies.RetryPolicy` or one of its + subclasses. This controls when a query will be retried and how it + will be retried. + """ + + trace = None + """ + If :meth:`.Session.execute()` is run with `trace` set to :const:`True`, + this will be set to a :class:`.QueryTrace` instance. + """ + + consistency_level = None + """ + The :class:`.ConsistencyLevel` to be used for this operation. Defaults + to :const:`None`, which means that the default consistency level for + the Session this is executed in will be used. + """ + + fetch_size = FETCH_SIZE_UNSET + """ + How many rows will be fetched at a time. This overrides the default + of :attr:`.Session.default_fetch_size` + + This only takes effect when protocol version 2 or higher is used. + See :attr:`.Cluster.protocol_version` for details. + + .. versionadded:: 2.0.0 + """ + + keyspace = None + """ + The string name of the keyspace this query acts on. This is used when + :class:`~.TokenAwarePolicy` is configured for + :attr:`.Cluster.load_balancing_policy` + + It is set implicitly on :class:`.BoundStatement`, and :class:`.BatchStatement`, + but must be set explicitly on :class:`.SimpleStatement`. + + .. versionadded:: 2.1.3 + """ + + _serial_consistency_level = None + _routing_key = None + + def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None): + self.retry_policy = retry_policy + if consistency_level is not None: + self.consistency_level = consistency_level + self._routing_key = routing_key + if serial_consistency_level is not None: + self.serial_consistency_level = serial_consistency_level + if fetch_size is not FETCH_SIZE_UNSET: + self.fetch_size = fetch_size + if keyspace is not None: + self.keyspace = keyspace + + def _get_routing_key(self): + return self._routing_key + + def _set_routing_key(self, key): + if isinstance(key, (list, tuple)): + self._routing_key = b"".join(struct.pack("HsB", len(component), component, 0) + for component in key) + else: + self._routing_key = key + + def _del_routing_key(self): + self._routing_key = None + + routing_key = property( + _get_routing_key, + _set_routing_key, + _del_routing_key, + """ + The :attr:`~.TableMetadata.partition_key` portion of the primary key, + which can be used to determine which nodes are replicas for the query. + + If the partition key is a composite, a list or tuple must be passed in. + Each key component should be in its packed (binary) format, so all + components should be strings. + """) + + def _get_serial_consistency_level(self): + return self._serial_consistency_level + + def _set_serial_consistency_level(self, serial_consistency_level): + acceptable = (None, ConsistencyLevel.SERIAL, ConsistencyLevel.LOCAL_SERIAL) + if serial_consistency_level not in acceptable: + raise ValueError( + "serial_consistency_level must be either ConsistencyLevel.SERIAL " + "or ConsistencyLevel.LOCAL_SERIAL") + self._serial_consistency_level = serial_consistency_level + + def _del_serial_consistency_level(self): + self._serial_consistency_level = None + + serial_consistency_level = property( + _get_serial_consistency_level, + _set_serial_consistency_level, + _del_serial_consistency_level, + """ + The serial consistency level is only used by conditional updates + (``INSERT``, ``UPDATE`` and ``DELETE`` with an ``IF`` condition). For + those, the ``serial_consistency_level`` defines the consistency level of + the serial phase (or "paxos" phase) while the normal + :attr:`~.consistency_level` defines the consistency for the "learn" phase, + i.e. what type of reads will be guaranteed to see the update right away. + For example, if a conditional write has a :attr:`~.consistency_level` of + :attr:`~.ConsistencyLevel.QUORUM` (and is successful), then a + :attr:`~.ConsistencyLevel.QUORUM` read is guaranteed to see that write. + But if the regular :attr:`~.consistency_level` of that write is + :attr:`~.ConsistencyLevel.ANY`, then only a read with a + :attr:`~.consistency_level` of :attr:`~.ConsistencyLevel.SERIAL` is + guaranteed to see it (even a read with consistency + :attr:`~.ConsistencyLevel.ALL` is not guaranteed to be enough). + + The serial consistency can only be one of :attr:`~.ConsistencyLevel.SERIAL` + or :attr:`~.ConsistencyLevel.LOCAL_SERIAL`. While ``SERIAL`` guarantees full + linearizability (with other ``SERIAL`` updates), ``LOCAL_SERIAL`` only + guarantees it in the local data center. + + The serial consistency level is ignored for any query that is not a + conditional update. Serial reads should use the regular + :attr:`consistency_level`. + + Serial consistency levels may only be used against Cassandra 2.0+ + and the :attr:`~.Cluster.protocol_version` must be set to 2 or higher. + + .. versionadded:: 2.0.0 + """) + + +class SimpleStatement(Statement): + """ + A simple, un-prepared query. All attributes of :class:`Statement` apply + to this class as well. + """ + + def __init__(self, query_string, *args, **kwargs): + """ + `query_string` should be a literal CQL statement with the exception + of parameter placeholders that will be filled through the + `parameters` argument of :meth:`.Session.execute()`. + """ + Statement.__init__(self, *args, **kwargs) + self._query_string = query_string + + @property + def query_string(self): + return self._query_string + + def __str__(self): + consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') + return (u'' % + (self.query_string, consistency)) + __repr__ = __str__ + + +class PreparedStatement(object): + """ + A statement that has been prepared against at least one Cassandra node. + Instances of this class should not be created directly, but through + :meth:`.Session.prepare()`. + + A :class:`.PreparedStatement` should be prepared only once. Re-preparing a statement + may affect performance (as the operation requires a network roundtrip). + """ + + column_metadata = None + query_id = None + query_string = None + keyspace = None # change to prepared_keyspace in major release + + routing_key_indexes = None + + consistency_level = None + serial_consistency_level = None + + protocol_version = None + + fetch_size = FETCH_SIZE_UNSET + + def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace, + protocol_version, consistency_level=None, serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET): + self.column_metadata = column_metadata + self.query_id = query_id + self.routing_key_indexes = routing_key_indexes + self.query_string = query + self.keyspace = keyspace + self.protocol_version = protocol_version + self.consistency_level = consistency_level + self.serial_consistency_level = serial_consistency_level + if fetch_size is not FETCH_SIZE_UNSET: + self.fetch_size = fetch_size + + @classmethod + def from_message(cls, query_id, column_metadata, cluster_metadata, query, prepared_keyspace, protocol_version): + if not column_metadata: + return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version) + + partition_key_columns = None + routing_key_indexes = None + + ks_name, table_name, _, _ = column_metadata[0] + ks_meta = cluster_metadata.keyspaces.get(ks_name) + if ks_meta: + table_meta = ks_meta.tables.get(table_name) + if table_meta: + partition_key_columns = table_meta.partition_key + + # make a map of {column_name: index} for each column in the statement + statement_indexes = dict((c[2], i) for i, c in enumerate(column_metadata)) + + # a list of which indexes in the statement correspond to partition key items + try: + routing_key_indexes = [statement_indexes[c.name] + for c in partition_key_columns] + except KeyError: # we're missing a partition key component in the prepared + pass # statement; just leave routing_key_indexes as None + + return PreparedStatement(column_metadata, query_id, routing_key_indexes, + query, prepared_keyspace, protocol_version) + + def bind(self, values): + """ + Creates and returns a :class:`BoundStatement` instance using `values`. + The `values` parameter **must** be a sequence, such as a tuple or list, + even if there is only one value to bind. + """ + return BoundStatement(self).bind(values) + + def __str__(self): + consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') + return (u'' % + (self.query_string, consistency)) + __repr__ = __str__ + + +class BoundStatement(Statement): + """ + A prepared statement that has been bound to a particular set of values. + These may be created directly or through :meth:`.PreparedStatement.bind()`. + + All attributes of :class:`Statement` apply to this class as well. + """ + + prepared_statement = None + """ + The :class:`PreparedStatement` instance that this was created from. + """ + + values = None + """ + The sequence of values that were bound to the prepared statement. + """ + + def __init__(self, prepared_statement, *args, **kwargs): + """ + `prepared_statement` should be an instance of :class:`PreparedStatement`. + All other ``*args`` and ``**kwargs`` will be passed to :class:`.Statement`. + """ + self.prepared_statement = prepared_statement + + self.consistency_level = prepared_statement.consistency_level + self.serial_consistency_level = prepared_statement.serial_consistency_level + self.fetch_size = prepared_statement.fetch_size + self.values = [] + + meta = prepared_statement.column_metadata + if meta: + self.keyspace = meta[0][0] + + Statement.__init__(self, *args, **kwargs) + + def bind(self, values): + """ + Binds a sequence of values for the prepared statement parameters + and returns this instance. Note that `values` *must* be: + * a sequence, even if you are only binding one value, or + * a dict that relates 1-to-1 between dict keys and columns + """ + if values is None: + values = () + col_meta = self.prepared_statement.column_metadata + + proto_version = self.prepared_statement.protocol_version + + # special case for binding dicts + if isinstance(values, dict): + dict_values = values + values = [] + + # sort values accordingly + for col in col_meta: + try: + values.append(dict_values[col[2]]) + except KeyError: + raise KeyError( + 'Column name `%s` not found in bound dict.' % + (col[2])) + + # ensure a 1-to-1 dict keys to columns relationship + if len(dict_values) != len(col_meta): + # find expected columns + columns = set() + for col in col_meta: + columns.add(col[2]) + + # generate error message + if len(dict_values) > len(col_meta): + difference = set(dict_values.keys()).difference(columns) + msg = "Too many arguments provided to bind() (got %d, expected %d). " + \ + "Unexpected keys %s." + else: + difference = set(columns).difference(dict_values.keys()) + msg = "Too few arguments provided to bind() (got %d, expected %d). " + \ + "Expected keys %s." + + # exit with error message + msg = msg % (len(values), len(col_meta), difference) + raise ValueError(msg) + + if len(values) > len(col_meta): + raise ValueError( + "Too many arguments provided to bind() (got %d, expected %d)" % + (len(values), len(col_meta))) + + if self.prepared_statement.routing_key_indexes and \ + len(values) < len(self.prepared_statement.routing_key_indexes): + raise ValueError( + "Too few arguments provided to bind() (got %d, required %d for routing key)" % + (len(values), len(self.prepared_statement.routing_key_indexes))) + + self.raw_values = values + self.values = [] + for value, col_spec in zip(values, col_meta): + if value is None: + self.values.append(None) + else: + col_type = col_spec[-1] + + try: + self.values.append(col_type.serialize(value, proto_version)) + except (TypeError, struct.error) as exc: + col_name = col_spec[2] + expected_type = col_type + actual_type = type(value) + + message = ('Received an argument of invalid type for column "%s". ' + 'Expected: %s, Got: %s; (%s)' % (col_name, expected_type, actual_type, exc)) + raise TypeError(message) + + return self + + @property + def routing_key(self): + if not self.prepared_statement.routing_key_indexes: + return None + + if self._routing_key is not None: + return self._routing_key + + routing_indexes = self.prepared_statement.routing_key_indexes + if len(routing_indexes) == 1: + self._routing_key = self.values[routing_indexes[0]] + else: + components = [] + for statement_index in routing_indexes: + val = self.values[statement_index] + l = len(val) + components.append(struct.pack(">H%dsB" % l, l, val, 0)) + + self._routing_key = b"".join(components) + + return self._routing_key + + def __str__(self): + consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') + return (u'' % + (self.prepared_statement.query_string, self.raw_values, consistency)) + __repr__ = __str__ + + +class BatchType(object): + """ + A BatchType is used with :class:`.BatchStatement` instances to control + the atomicity of the batch operation. + + .. versionadded:: 2.0.0 + """ + + LOGGED = None + """ + Atomic batch operation. + """ + + UNLOGGED = None + """ + Non-atomic batch operation. + """ + + COUNTER = None + """ + Batches of counter operations. + """ + + def __init__(self, name, value): + self.name = name + self.value = value + + def __str__(self): + return self.name + + def __repr__(self): + return "BatchType.%s" % (self.name, ) + + +BatchType.LOGGED = BatchType("LOGGED", 0) +BatchType.UNLOGGED = BatchType("UNLOGGED", 1) +BatchType.COUNTER = BatchType("COUNTER", 2) + + +class BatchStatement(Statement): + """ + A protocol-level batch of operations which are applied atomically + by default. + + .. versionadded:: 2.0.0 + """ + + batch_type = None + """ + The :class:`.BatchType` for the batch operation. Defaults to + :attr:`.BatchType.LOGGED`. + """ + + serial_consistency_level = None + """ + The same as :attr:`.Statement.serial_consistency_level`, but is only + supported when using protocol version 3 or higher. + """ + + _statements_and_parameters = None + _session = None + + def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, + consistency_level=None, serial_consistency_level=None, session=None): + """ + `batch_type` specifies The :class:`.BatchType` for the batch operation. + Defaults to :attr:`.BatchType.LOGGED`. + + `retry_policy` should be a :class:`~.RetryPolicy` instance for + controlling retries on the operation. + + `consistency_level` should be a :class:`~.ConsistencyLevel` value + to be used for all operations in the batch. + + Example usage: + + .. code-block:: python + + insert_user = session.prepare("INSERT INTO users (name, age) VALUES (?, ?)") + batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM) + + for (name, age) in users_to_insert: + batch.add(insert_user, (name, age)) + + session.execute(batch) + + You can also mix different types of operations within a batch: + + .. code-block:: python + + batch = BatchStatement() + batch.add(SimpleStatement("INSERT INTO users (name, age) VALUES (%s, %s)"), (name, age)) + batch.add(SimpleStatement("DELETE FROM pending_users WHERE name=%s"), (name,)) + session.execute(batch) + + .. versionadded:: 2.0.0 + + .. versionchanged:: 2.1.0 + Added `serial_consistency_level` as a parameter + """ + self.batch_type = batch_type + self._statements_and_parameters = [] + self._session = session + Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, + serial_consistency_level=serial_consistency_level) + + def add(self, statement, parameters=None): + """ + Adds a :class:`.Statement` and optional sequence of parameters + to be used with the statement to the batch. + + Like with other statements, parameters must be a sequence, even + if there is only one item. + """ + if isinstance(statement, six.string_types): + if parameters: + encoder = Encoder() if self._session is None else self._session.encoder + statement = bind_params(statement, parameters, encoder) + self._statements_and_parameters.append((False, statement, ())) + elif isinstance(statement, PreparedStatement): + query_id = statement.query_id + bound_statement = statement.bind(() if parameters is None else parameters) + self._maybe_set_routing_attributes(bound_statement) + self._statements_and_parameters.append( + (True, query_id, bound_statement.values)) + elif isinstance(statement, BoundStatement): + if parameters: + raise ValueError( + "Parameters cannot be passed with a BoundStatement " + "to BatchStatement.add()") + self._maybe_set_routing_attributes(statement) + self._statements_and_parameters.append( + (True, statement.prepared_statement.query_id, statement.values)) + else: + # it must be a SimpleStatement + query_string = statement.query_string + if parameters: + encoder = Encoder() if self._session is None else self._session.encoder + query_string = bind_params(query_string, parameters, encoder) + self._maybe_set_routing_attributes(statement) + self._statements_and_parameters.append((False, query_string, ())) + return self + + def add_all(self, statements, parameters): + """ + Adds a sequence of :class:`.Statement` objects and a matching sequence + of parameters to the batch. :const:`None` can be used in place of + parameters when no parameters are needed. + """ + for statement, value in zip(statements, parameters): + self.add(statement, parameters) + + def _maybe_set_routing_attributes(self, statement): + if self.routing_key is None: + if statement.keyspace and statement.routing_key: + self.routing_key = statement.routing_key + self.keyspace = statement.keyspace + + def __str__(self): + consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') + return (u'' % + (self.batch_type, len(self._statements_and_parameters), consistency)) + __repr__ = __str__ + + +ValueSequence = cassandra.encoder.ValueSequence +""" +A wrapper class that is used to specify that a sequence of values should +be treated as a CQL list of values instead of a single column collection when used +as part of the `parameters` argument for :meth:`.Session.execute()`. + +This is typically needed when supplying a list of keys to select. +For example:: + + >>> my_user_ids = ('alice', 'bob', 'charles') + >>> query = "SELECT * FROM users WHERE user_id IN %s" + >>> session.execute(query, parameters=[ValueSequence(my_user_ids)]) + +""" + + +def bind_params(query, params, encoder): + if isinstance(params, dict): + return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params)) + else: + return query % tuple(encoder.cql_encode_all_types(v) for v in params) + + +class TraceUnavailable(Exception): + """ + Raised when complete trace details cannot be fetched from Cassandra. + """ + pass + + +class QueryTrace(object): + """ + A trace of the duration and events that occurred when executing + an operation. + """ + + trace_id = None + """ + :class:`uuid.UUID` unique identifier for this tracing session. Matches + the ``session_id`` column in ``system_traces.sessions`` and + ``system_traces.events``. + """ + + request_type = None + """ + A string that very generally describes the traced operation. + """ + + duration = None + """ + A :class:`datetime.timedelta` measure of the duration of the query. + """ + + coordinator = None + """ + The IP address of the host that acted as coordinator for this request. + """ + + parameters = None + """ + A :class:`dict` of parameters for the traced operation, such as the + specific query string. + """ + + started_at = None + """ + A UTC :class:`datetime.datetime` object describing when the operation + was started. + """ + + events = None + """ + A chronologically sorted list of :class:`.TraceEvent` instances + representing the steps the traced operation went through. This + corresponds to the rows in ``system_traces.events`` for this tracing + session. + """ + + _session = None + + _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s" + _SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s" + _BASE_RETRY_SLEEP = 0.003 + + def __init__(self, trace_id, session): + self.trace_id = trace_id + self._session = session + + def populate(self, max_wait=2.0): + """ + Retrieves the actual tracing details from Cassandra and populates the + attributes of this instance. Because tracing details are stored + asynchronously by Cassandra, this may need to retry the session + detail fetch. If the trace is still not available after `max_wait` + seconds, :exc:`.TraceUnavailable` will be raised; if `max_wait` is + :const:`None`, this will retry forever. + """ + attempt = 0 + start = time.time() + while True: + time_spent = time.time() - start + if max_wait is not None and time_spent >= max_wait: + raise TraceUnavailable( + "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,)) + + log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) + session_results = self._execute( + self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait) + + if not session_results or session_results[0].duration is None: + time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) + attempt += 1 + continue + log.debug("Fetched trace info for trace ID: %s", self.trace_id) + + session_row = session_results[0] + self.request_type = session_row.request + self.duration = timedelta(microseconds=session_row.duration) + self.started_at = session_row.started_at + self.coordinator = session_row.coordinator + self.parameters = session_row.parameters + + log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) + time_spent = time.time() - start + event_results = self._execute( + self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait) + log.debug("Fetched trace events for trace ID: %s", self.trace_id) + self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) + for r in event_results) + break + + def _execute(self, query, parameters, time_spent, max_wait): + # in case the user switched the row factory, set it to namedtuple for this query + future = self._session._create_response_future(query, parameters, trace=False) + future.row_factory = named_tuple_factory + future.send_request() + + timeout = (max_wait - time_spent) if max_wait is not None else None + try: + return future.result(timeout=timeout) + except OperationTimedOut: + raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) + + def __str__(self): + return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ + % (self.request_type, self.trace_id, self.coordinator, self.started_at, + self.duration, self.parameters) + + +class TraceEvent(object): + """ + Representation of a single event within a query trace. + """ + + description = None + """ + A brief description of the event. + """ + + datetime = None + """ + A UTC :class:`datetime.datetime` marking when the event occurred. + """ + + source = None + """ + The IP address of the node this event occurred on. + """ + + source_elapsed = None + """ + A :class:`datetime.timedelta` measuring the amount of time until + this event occurred starting from when :attr:`.source` first + received the query. + """ + + thread_name = None + """ + The name of the thread that this event occurred on. + """ + + def __init__(self, description, timeuuid, source, source_elapsed, thread_name): + self.description = description + self.datetime = datetime.utcfromtimestamp(unix_time_from_uuid1(timeuuid)) + self.source = source + if source_elapsed is not None: + self.source_elapsed = timedelta(microseconds=source_elapsed) + else: + self.source_elapsed = None + self.thread_name = thread_name + + def __str__(self): + return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) diff --git a/cassandra/util.py b/cassandra/util.py new file mode 100644 index 0000000..35e1e5e --- /dev/null +++ b/cassandra/util.py @@ -0,0 +1,1001 @@ +from __future__ import with_statement +import calendar +import datetime +import random +import six +import uuid + +DATETIME_EPOC = datetime.datetime(1970, 1, 1) + + +def datetime_from_timestamp(timestamp): + """ + Creates a timezone-agnostic datetime from timestamp (in seconds) in a consistent manner. + Works around a Windows issue with large negative timestamps (PYTHON-119), + and rounding differences in Python 3.4 (PYTHON-340). + + :param timestamp: a unix timestamp, in seconds + + :rtype: datetime + """ + dt = DATETIME_EPOC + datetime.timedelta(seconds=timestamp) + return dt + + +def unix_time_from_uuid1(uuid_arg): + """ + Converts a version 1 :class:`uuid.UUID` to a timestamp with the same precision + as :meth:`time.time()` returns. This is useful for examining the + results of queries returning a v1 :class:`~uuid.UUID`. + + :param uuid_arg: a version 1 :class:`~uuid.UUID` + + :rtype: timestamp + + """ + return (uuid_arg.time - 0x01B21DD213814000) / 1e7 + + +def datetime_from_uuid1(uuid_arg): + """ + Creates a timezone-agnostic datetime from the timestamp in the + specified type-1 UUID. + + :param uuid_arg: a version 1 :class:`~uuid.UUID` + + :rtype: timestamp + + """ + return datetime_from_timestamp(unix_time_from_uuid1(uuid_arg)) + + +def min_uuid_from_time(timestamp): + """ + Generates the minimum TimeUUID (type 1) for a given timestamp, as compared by Cassandra. + + See :func:`uuid_from_time` for argument and return types. + """ + return uuid_from_time(timestamp, 0x808080808080, 0x80) # Cassandra does byte-wise comparison; fill with min signed bytes (0x80 = -128) + + +def max_uuid_from_time(timestamp): + """ + Generates the maximum TimeUUID (type 1) for a given timestamp, as compared by Cassandra. + + See :func:`uuid_from_time` for argument and return types. + """ + return uuid_from_time(timestamp, 0x7f7f7f7f7f7f, 0x3f7f) # Max signed bytes (0x7f = 127) + + +def uuid_from_time(time_arg, node=None, clock_seq=None): + """ + Converts a datetime or timestamp to a type 1 :class:`uuid.UUID`. + + :param time_arg: + The time to use for the timestamp portion of the UUID. + This can either be a :class:`datetime` object or a timestamp + in seconds (as returned from :meth:`time.time()`). + :type datetime: :class:`datetime` or timestamp + + :param node: + None integer for the UUID (up to 48 bits). If not specified, this + field is randomized. + :type node: long + + :param clock_seq: + Clock sequence field for the UUID (up to 14 bits). If not specified, + a random sequence is generated. + :type clock_seq: int + + :rtype: :class:`uuid.UUID` + + """ + if hasattr(time_arg, 'utctimetuple'): + seconds = int(calendar.timegm(time_arg.utctimetuple())) + microseconds = (seconds * 1e6) + time_arg.time().microsecond + else: + microseconds = int(time_arg * 1e6) + + # 0x01b21dd213814000 is the number of 100-ns intervals between the + # UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00. + intervals = int(microseconds * 10) + 0x01b21dd213814000 + + time_low = intervals & 0xffffffff + time_mid = (intervals >> 32) & 0xffff + time_hi_version = (intervals >> 48) & 0x0fff + + if clock_seq is None: + clock_seq = random.getrandbits(14) + else: + if clock_seq > 0x3fff: + raise ValueError('clock_seq is out of range (need a 14-bit value)') + + clock_seq_low = clock_seq & 0xff + clock_seq_hi_variant = 0x80 | ((clock_seq >> 8) & 0x3f) + + if node is None: + node = random.getrandbits(48) + + return uuid.UUID(fields=(time_low, time_mid, time_hi_version, + clock_seq_hi_variant, clock_seq_low, node), version=1) + +LOWEST_TIME_UUID = uuid.UUID('00000000-0000-1000-8080-808080808080') +""" The lowest possible TimeUUID, as sorted by Cassandra. """ + +HIGHEST_TIME_UUID = uuid.UUID('ffffffff-ffff-1fff-bf7f-7f7f7f7f7f7f') +""" The highest possible TimeUUID, as sorted by Cassandra. """ + + +try: + from collections import OrderedDict +except ImportError: + # OrderedDict from Python 2.7+ + + # Copyright (c) 2009 Raymond Hettinger + # + # Permission is hereby granted, free of charge, to any person + # obtaining a copy of this software and associated documentation files + # (the "Software"), to deal in the Software without restriction, + # including without limitation the rights to use, copy, modify, merge, + # publish, distribute, sublicense, and/or sell copies of the Software, + # and to permit persons to whom the Software is furnished to do so, + # subject to the following conditions: + # + # The above copyright notice and this permission notice shall be + # included in all copies or substantial portions of the Software. + # + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + # OTHER DEALINGS IN THE SOFTWARE. + from UserDict import DictMixin + + class OrderedDict(dict, DictMixin): # noqa + """ A dictionary which maintains the insertion order of keys. """ + + def __init__(self, *args, **kwds): + """ A dictionary which maintains the insertion order of keys. """ + + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + try: + self.__end + except AttributeError: + self.clear() + self.update(*args, **kwds) + + def clear(self): + self.__end = end = [] + end += [None, end, end] # sentinel node for doubly linked list + self.__map = {} # key --> [key, prev, next] + dict.clear(self) + + def __setitem__(self, key, value): + if key not in self: + end = self.__end + curr = end[1] + curr[2] = end[1] = self.__map[key] = [key, curr, end] + dict.__setitem__(self, key, value) + + def __delitem__(self, key): + dict.__delitem__(self, key) + key, prev, next = self.__map.pop(key) + prev[2] = next + next[1] = prev + + def __iter__(self): + end = self.__end + curr = end[2] + while curr is not end: + yield curr[0] + curr = curr[2] + + def __reversed__(self): + end = self.__end + curr = end[1] + while curr is not end: + yield curr[0] + curr = curr[1] + + def popitem(self, last=True): + if not self: + raise KeyError('dictionary is empty') + if last: + key = next(reversed(self)) + else: + key = next(iter(self)) + value = self.pop(key) + return key, value + + def __reduce__(self): + items = [[k, self[k]] for k in self] + tmp = self.__map, self.__end + del self.__map, self.__end + inst_dict = vars(self).copy() + self.__map, self.__end = tmp + if inst_dict: + return (self.__class__, (items,), inst_dict) + return self.__class__, (items,) + + def keys(self): + return list(self) + + setdefault = DictMixin.setdefault + update = DictMixin.update + pop = DictMixin.pop + values = DictMixin.values + items = DictMixin.items + iterkeys = DictMixin.iterkeys + itervalues = DictMixin.itervalues + iteritems = DictMixin.iteritems + + def __repr__(self): + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, self.items()) + + def copy(self): + return self.__class__(self) + + @classmethod + def fromkeys(cls, iterable, value=None): + d = cls() + for key in iterable: + d[key] = value + return d + + def __eq__(self, other): + if isinstance(other, OrderedDict): + if len(self) != len(other): + return False + for p, q in zip(self.items(), other.items()): + if p != q: + return False + return True + return dict.__eq__(self, other) + + def __ne__(self, other): + return not self == other + + +# WeakSet from Python 2.7+ (https://code.google.com/p/weakrefset) + +from _weakref import ref + + +class _IterationGuard(object): + # This context manager registers itself in the current iterators of the + # weak container, such as to delay all removals until the context manager + # exits. + # This technique should be relatively thread-safe (since sets are). + + def __init__(self, weakcontainer): + # Don't create cycles + self.weakcontainer = ref(weakcontainer) + + def __enter__(self): + w = self.weakcontainer() + if w is not None: + w._iterating.add(self) + return self + + def __exit__(self, e, t, b): + w = self.weakcontainer() + if w is not None: + s = w._iterating + s.remove(self) + if not s: + w._commit_removals() + + +class WeakSet(object): + def __init__(self, data=None): + self.data = set() + + def _remove(item, selfref=ref(self)): + self = selfref() + if self is not None: + if self._iterating: + self._pending_removals.append(item) + else: + self.data.discard(item) + + self._remove = _remove + # A list of keys to be removed + self._pending_removals = [] + self._iterating = set() + if data is not None: + self.update(data) + + def _commit_removals(self): + l = self._pending_removals + discard = self.data.discard + while l: + discard(l.pop()) + + def __iter__(self): + with _IterationGuard(self): + for itemref in self.data: + item = itemref() + if item is not None: + yield item + + def __len__(self): + return sum(x() is not None for x in self.data) + + def __contains__(self, item): + return ref(item) in self.data + + def __reduce__(self): + return (self.__class__, (list(self),), + getattr(self, '__dict__', None)) + + __hash__ = None + + def add(self, item): + if self._pending_removals: + self._commit_removals() + self.data.add(ref(item, self._remove)) + + def clear(self): + if self._pending_removals: + self._commit_removals() + self.data.clear() + + def copy(self): + return self.__class__(self) + + def pop(self): + if self._pending_removals: + self._commit_removals() + while True: + try: + itemref = self.data.pop() + except KeyError: + raise KeyError('pop from empty WeakSet') + item = itemref() + if item is not None: + return item + + def remove(self, item): + if self._pending_removals: + self._commit_removals() + self.data.remove(ref(item)) + + def discard(self, item): + if self._pending_removals: + self._commit_removals() + self.data.discard(ref(item)) + + def update(self, other): + if self._pending_removals: + self._commit_removals() + if isinstance(other, self.__class__): + self.data.update(other.data) + else: + for element in other: + self.add(element) + + def __ior__(self, other): + self.update(other) + return self + + # Helper functions for simple delegating methods. + def _apply(self, other, method): + if not isinstance(other, self.__class__): + other = self.__class__(other) + newdata = method(other.data) + newset = self.__class__() + newset.data = newdata + return newset + + def difference(self, other): + return self._apply(other, self.data.difference) + __sub__ = difference + + def difference_update(self, other): + if self._pending_removals: + self._commit_removals() + if self is other: + self.data.clear() + else: + self.data.difference_update(ref(item) for item in other) + + def __isub__(self, other): + if self._pending_removals: + self._commit_removals() + if self is other: + self.data.clear() + else: + self.data.difference_update(ref(item) for item in other) + return self + + def intersection(self, other): + return self._apply(other, self.data.intersection) + __and__ = intersection + + def intersection_update(self, other): + if self._pending_removals: + self._commit_removals() + self.data.intersection_update(ref(item) for item in other) + + def __iand__(self, other): + if self._pending_removals: + self._commit_removals() + self.data.intersection_update(ref(item) for item in other) + return self + + def issubset(self, other): + return self.data.issubset(ref(item) for item in other) + __lt__ = issubset + + def __le__(self, other): + return self.data <= set(ref(item) for item in other) + + def issuperset(self, other): + return self.data.issuperset(ref(item) for item in other) + __gt__ = issuperset + + def __ge__(self, other): + return self.data >= set(ref(item) for item in other) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.data == set(ref(item) for item in other) + + def symmetric_difference(self, other): + return self._apply(other, self.data.symmetric_difference) + __xor__ = symmetric_difference + + def symmetric_difference_update(self, other): + if self._pending_removals: + self._commit_removals() + if self is other: + self.data.clear() + else: + self.data.symmetric_difference_update(ref(item) for item in other) + + def __ixor__(self, other): + if self._pending_removals: + self._commit_removals() + if self is other: + self.data.clear() + else: + self.data.symmetric_difference_update(ref(item) for item in other) + return self + + def union(self, other): + return self._apply(other, self.data.union) + __or__ = union + + def isdisjoint(self, other): + return len(self.intersection(other)) == 0 + +try: + from blist import sortedset +except ImportError: + + import warnings + + warnings.warn( + "The blist library is not available, so a pure python list-based set will " + "be used in place of blist.sortedset for set collection values. " + "You can find the blist library here: https://pypi.python.org/pypi/blist/") + + from bisect import bisect_left + + class sortedset(object): + ''' + A sorted set based on sorted list + + This set is used in place of blist.sortedset in Python environments + where blist module/extension is not available. + + A sorted set implementation is used in this case because it does not + require its elements to be immutable/hashable. + + #Not implemented: update functions, inplace operators + + ''' + + def __init__(self, iterable=()): + self._items = [] + for i in iterable: + self.add(i) + + def __len__(self): + return len(self._items) + + def __iter__(self): + return iter(self._items) + + def __reversed__(self): + return reversed(self._items) + + def __repr__(self): + return '%s(%r)' % ( + self.__class__.__name__, + self._items) + + def __reduce__(self): + return self.__class__, (self._items,) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._items == other._items + else: + if not isinstance(other, set): + return False + + return len(other) == len(self._items) and all(item in other for item in self._items) + + def __ne__(self, other): + if isinstance(other, self.__class__): + return self._items != other._items + else: + if not isinstance(other, set): + return True + + return len(other) != len(self._items) or any(item not in other for item in self._items) + + def __le__(self, other): + return self.issubset(other) + + def __lt__(self, other): + return len(other) > len(self._items) and self.issubset(other) + + def __ge__(self, other): + return self.issuperset(other) + + def __gt__(self, other): + return len(self._items) > len(other) and self.issuperset(other) + + def __and__(self, other): + return self._intersect(other) + + def __or__(self, other): + return self.union(other) + + def __sub__(self, other): + return self._diff(other) + + def __xor__(self, other): + return self.symmetric_difference(other) + + def __contains__(self, item): + i = bisect_left(self._items, item) + return i < len(self._items) and self._items[i] == item + + def add(self, item): + i = bisect_left(self._items, item) + if i < len(self._items): + if self._items[i] != item: + self._items.insert(i, item) + else: + self._items.append(item) + + def clear(self): + del self._items[:] + + def copy(self): + new = sortedset() + new._items = list(self._items) + return new + + def isdisjoint(self, other): + return len(self._intersect(other)) == 0 + + def issubset(self, other): + return len(self._intersect(other)) == len(self._items) + + def issuperset(self, other): + return len(self._intersect(other)) == len(other) + + def pop(self): + if not self._items: + raise KeyError("pop from empty set") + return self._items.pop() + + def remove(self, item): + i = bisect_left(self._items, item) + if i < len(self._items): + if self._items[i] == item: + self._items.pop(i) + return + raise KeyError('%r' % item) + + def union(self, *others): + union = sortedset() + union._items = list(self._items) + for other in others: + if isinstance(other, self.__class__): + i = 0 + for item in other._items: + i = bisect_left(union._items, item, i) + if i < len(union._items): + if item != union._items[i]: + union._items.insert(i, item) + else: + union._items.append(item) + else: + for item in other: + union.add(item) + return union + + def intersection(self, *others): + isect = self.copy() + for other in others: + isect = isect._intersect(other) + if not isect: + break + return isect + + def difference(self, *others): + diff = self.copy() + for other in others: + diff = diff._diff(other) + if not diff: + break + return diff + + def symmetric_difference(self, other): + diff_self_other = self._diff(other) + diff_other_self = other.difference(self) + return diff_self_other.union(diff_other_self) + + def _diff(self, other): + diff = sortedset() + if isinstance(other, self.__class__): + i = 0 + for item in self._items: + i = bisect_left(other._items, item, i) + if i < len(other._items): + if item != other._items[i]: + diff._items.append(item) + else: + diff._items.append(item) + else: + for item in self._items: + if item not in other: + diff.add(item) + return diff + + def _intersect(self, other): + isect = sortedset() + if isinstance(other, self.__class__): + i = 0 + for item in self._items: + i = bisect_left(other._items, item, i) + if i < len(other._items): + if item == other._items[i]: + isect._items.append(item) + else: + break + else: + for item in self._items: + if item in other: + isect.add(item) + return isect + + +from collections import Mapping +from six.moves import cPickle + + +class OrderedMap(Mapping): + ''' + An ordered map that accepts non-hashable types for keys. It also maintains the + insertion order of items, behaving as OrderedDict in that regard. These maps + are constructed and read just as normal mapping types, exept that they may + contain arbitrary collections and other non-hashable items as keys:: + + >>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'), + ... ({'three': 3, 'four': 4}, 'value2')]) + >>> list(od.keys()) + [{'two': 2, 'one': 1}, {'three': 3, 'four': 4}] + >>> list(od.values()) + ['value', 'value2'] + + These constructs are needed to support nested collections in Cassandra 2.1.3+, + where frozen collections can be specified as parameters to others\*:: + + CREATE TABLE example ( + ... + value map>, double> + ... + ) + + This class dervies from the (immutable) Mapping API. Objects in these maps + are not intended be modified. + + \* Note: Because of the way Cassandra encodes nested types, when using the + driver with nested collections, :attr:`~.Cluster.protocol_version` must be 3 + or higher. + + ''' + + def __init__(self, *args, **kwargs): + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + + self._items = [] + self._index = {} + if args: + e = args[0] + if callable(getattr(e, 'keys', None)): + for k in e.keys(): + self._insert(k, e[k]) + else: + for k, v in e: + self._insert(k, v) + + for k, v in six.iteritems(kwargs): + self._insert(k, v) + + def _insert(self, key, value): + flat_key = self._serialize_key(key) + i = self._index.get(flat_key, -1) + if i >= 0: + self._items[i] = (key, value) + else: + self._items.append((key, value)) + self._index[flat_key] = len(self._items) - 1 + + def __getitem__(self, key): + try: + index = self._index[self._serialize_key(key)] + return self._items[index][1] + except KeyError: + raise KeyError(str(key)) + + def __iter__(self): + for i in self._items: + yield i[0] + + def __len__(self): + return len(self._items) + + def __eq__(self, other): + if isinstance(other, OrderedMap): + return self._items == other._items + try: + d = dict(other) + return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items) + except KeyError: + return False + except TypeError: + pass + return NotImplemented + + def __repr__(self): + return '%s([%s])' % ( + self.__class__.__name__, + ', '.join("(%r, %r)" % (k, v) for k, v in self._items)) + + def __str__(self): + return '{%s}' % ', '.join("%r: %r" % (k, v) for k, v in self._items) + + def _serialize_key(self, key): + return cPickle.dumps(key) + + +class OrderedMapSerializedKey(OrderedMap): + + def __init__(self, cass_type, protocol_version): + super(OrderedMapSerializedKey, self).__init__() + self.cass_key_type = cass_type + self.protocol_version = protocol_version + + def _insert_unchecked(self, key, flat_key, value): + self._items.append((key, value)) + self._index[flat_key] = len(self._items) - 1 + + def _serialize_key(self, key): + return self.cass_key_type.serialize(key, self.protocol_version) + + +import datetime +import time + +if six.PY3: + long = int + + +class Time(object): + ''' + Idealized time, independent of day. + + Up to nanosecond resolution + ''' + + MICRO = 1000 + MILLI = 1000 * MICRO + SECOND = 1000 * MILLI + MINUTE = 60 * SECOND + HOUR = 60 * MINUTE + DAY = 24 * HOUR + + nanosecond_time = 0 + + def __init__(self, value): + """ + Initializer value can be: + + - integer_type: absolute nanoseconds in the day + - datetime.time: built-in time + - string_type: a string time of the form "HH:MM:SS[.mmmuuunnn]" + """ + if isinstance(value, six.integer_types): + self._from_timestamp(value) + elif isinstance(value, datetime.time): + self._from_time(value) + elif isinstance(value, six.string_types): + self._from_timestring(value) + else: + raise TypeError('Time arguments must be a whole number, datetime.time, or string') + + @property + def hour(self): + """ + The hour component of this time (0-23) + """ + return self.nanosecond_time // Time.HOUR + + @property + def minute(self): + """ + The minute component of this time (0-59) + """ + minutes = self.nanosecond_time // Time.MINUTE + return minutes % 60 + + @property + def second(self): + """ + The second component of this time (0-59) + """ + seconds = self.nanosecond_time // Time.SECOND + return seconds % 60 + + @property + def nanosecond(self): + """ + The fractional seconds component of the time, in nanoseconds + """ + return self.nanosecond_time % Time.SECOND + + def _from_timestamp(self, t): + if t >= Time.DAY: + raise ValueError("value must be less than number of nanoseconds in a day (%d)" % Time.DAY) + self.nanosecond_time = t + + def _from_timestring(self, s): + try: + parts = s.split('.') + base_time = time.strptime(parts[0], "%H:%M:%S") + self.nanosecond_time = (base_time.tm_hour * Time.HOUR + + base_time.tm_min * Time.MINUTE + + base_time.tm_sec * Time.SECOND) + + if len(parts) > 1: + # right pad to 9 digits + nano_time_str = parts[1] + "0" * (9 - len(parts[1])) + self.nanosecond_time += int(nano_time_str) + + except ValueError: + raise ValueError("can't interpret %r as a time" % (s,)) + + def _from_time(self, t): + self.nanosecond_time = (t.hour * Time.HOUR + + t.minute * Time.MINUTE + + t.second * Time.SECOND + + t.microsecond * Time.MICRO) + + def __eq__(self, other): + if isinstance(other, Time): + return self.nanosecond_time == other.nanosecond_time + + if isinstance(other, six.integer_types): + return self.nanosecond_time == other + + return self.nanosecond_time % Time.MICRO == 0 and \ + datetime.time(hour=self.hour, minute=self.minute, second=self.second, + microsecond=self.nanosecond // Time.MICRO) == other + + def __repr__(self): + return "Time(%s)" % self.nanosecond_time + + def __str__(self): + return "%02d:%02d:%02d.%09d" % (self.hour, self.minute, + self.second, self.nanosecond) + + +class Date(object): + ''' + Idealized naive date: year, month, day + + Offers wider year range than datetime.date. For Dates that cannot be represented + as a datetime.date (because datetime.MINYEAR, datetime.MAXYEAR), this type falls back + to printing days_from_epoch offset. + ''' + + MINUTE = 60 + HOUR = 60 * MINUTE + DAY = 24 * HOUR + + date_format = "%Y-%m-%d" + + days_from_epoch = 0 + + def __init__(self, value): + """ + Initializer value can be: + + - integer_type: absolute days from epoch (1970, 1, 1). Can be negative. + - datetime.date: built-in date + - string_type: a string time of the form "yyyy-mm-dd" + """ + if isinstance(value, six.integer_types): + self.days_from_epoch = value + elif isinstance(value, (datetime.date, datetime.datetime)): + self._from_timetuple(value.timetuple()) + elif isinstance(value, six.string_types): + self._from_datestring(value) + else: + raise TypeError('Date arguments must be a whole number, datetime.date, or string') + + @property + def seconds(self): + """ + Absolute seconds from epoch (can be negative) + """ + return self.days_from_epoch * Date.DAY + + def date(self): + """ + Return a built-in datetime.date for Dates falling in the years [datetime.MINYEAR, datetime.MAXYEAR] + + ValueError is raised for Dates outside this range. + """ + try: + dt = datetime_from_timestamp(self.seconds) + return datetime.date(dt.year, dt.month, dt.day) + except Exception: + raise ValueError("%r exceeds ranges for built-in datetime.date" % self) + + def _from_timetuple(self, t): + self.days_from_epoch = calendar.timegm(t) // Date.DAY + + def _from_datestring(self, s): + if s[0] == '+': + s = s[1:] + dt = datetime.datetime.strptime(s, self.date_format) + self._from_timetuple(dt.timetuple()) + + def __eq__(self, other): + if isinstance(other, Date): + return self.days_from_epoch == other.days_from_epoch + + if isinstance(other, six.integer_types): + return self.days_from_epoch == other + + try: + return self.date() == other + except Exception: + return False + + def __repr__(self): + return "Date(%s)" % self.days_from_epoch + + def __str__(self): + try: + dt = datetime_from_timestamp(self.seconds) + return "%04d-%02d-%02d" % (dt.year, dt.month, dt.day) + except: + # If we overflow datetime.[MIN|M + return str(self.days_from_epoch) diff --git a/cassandra_driver.egg-info/PKG-INFO b/cassandra_driver.egg-info/PKG-INFO new file mode 100644 index 0000000..6571faa --- /dev/null +++ b/cassandra_driver.egg-info/PKG-INFO @@ -0,0 +1,99 @@ +Metadata-Version: 1.1 +Name: cassandra-driver +Version: 2.5.1 +Summary: Python driver for Cassandra +Home-page: http://github.com/datastax/python-driver +Author: Tyler Hobbs +Author-email: tyler@datastax.com +License: UNKNOWN +Description: DataStax Python Driver for Apache Cassandra + =========================================== + + .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master + :target: https://travis-ci.org/datastax/python-driver + + A Python client driver for Apache Cassandra. This driver works exclusively + with the Cassandra Query Language v3 (CQL3) and Cassandra's native + protocol. Cassandra versions 1.2 through 2.1 are supported. + + The driver supports Python 2.6, 2.7, 3.3, and 3.4*. + + * cqlengine component presently supports Python 2.7+ + + Installation + ------------ + Installation through pip is recommended:: + + $ pip install cassandra-driver + + For more complete installation instructions, see the + `installation guide `_. + + Documentation + ------------- + The documentation can be found online `here `_. + + A couple of links for getting up to speed: + + * `Installation `_ + * `Getting started guide `_ + * `API docs `_ + * `Performance tips `_ + + Object Mapper + ------------- + cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the + community) is now maintained as an integral part of this package. Refer to + `documentation here `_. + + Reporting Problems + ------------------ + Please report any bugs and make any feature requests on the + `JIRA `_ issue tracker. + + If you would like to contribute, please feel free to open a pull request. + + Getting Help + ------------ + Your two best options for getting help with the driver are the + `mailing list `_ + and the IRC channel. + + For IRC, use the #datastax-drivers channel on irc.freenode.net. If you don't have an IRC client, + you can use `freenode's web-based client `_. + + Features to be Added + -------------------- + * C extension for encoding/decoding messages + + License + ------- + Copyright 2013-2015 DataStax + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +Keywords: cassandra,cql,orm +Platform: UNKNOWN +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Natural Language :: English +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 2.6 +Classifier: Programming Language :: Python :: 2.7 +Classifier: Programming Language :: Python :: 3.3 +Classifier: Programming Language :: Python :: 3.4 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Topic :: Software Development :: Libraries :: Python Modules diff --git a/cassandra_driver.egg-info/SOURCES.txt b/cassandra_driver.egg-info/SOURCES.txt new file mode 100644 index 0000000..4c0d83c --- /dev/null +++ b/cassandra_driver.egg-info/SOURCES.txt @@ -0,0 +1,45 @@ +LICENSE +MANIFEST.in +README.rst +ez_setup.py +setup.py +cassandra/__init__.py +cassandra/auth.py +cassandra/cluster.py +cassandra/concurrent.py +cassandra/connection.py +cassandra/cqltypes.py +cassandra/decoder.py +cassandra/encoder.py +cassandra/marshal.py +cassandra/metadata.py +cassandra/metrics.py +cassandra/murmur3.c +cassandra/policies.py +cassandra/pool.py +cassandra/protocol.py +cassandra/query.py +cassandra/util.py +cassandra/cqlengine/__init__.py +cassandra/cqlengine/columns.py +cassandra/cqlengine/connection.py +cassandra/cqlengine/functions.py +cassandra/cqlengine/management.py +cassandra/cqlengine/models.py +cassandra/cqlengine/named.py +cassandra/cqlengine/operators.py +cassandra/cqlengine/query.py +cassandra/cqlengine/statements.py +cassandra/cqlengine/usertype.py +cassandra/io/__init__.py +cassandra/io/asyncorereactor.py +cassandra/io/eventletreactor.py +cassandra/io/geventreactor.py +cassandra/io/libevreactor.py +cassandra/io/libevwrapper.c +cassandra/io/twistedreactor.py +cassandra_driver.egg-info/PKG-INFO +cassandra_driver.egg-info/SOURCES.txt +cassandra_driver.egg-info/dependency_links.txt +cassandra_driver.egg-info/requires.txt +cassandra_driver.egg-info/top_level.txt \ No newline at end of file diff --git a/cassandra_driver.egg-info/dependency_links.txt b/cassandra_driver.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/cassandra_driver.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/cassandra_driver.egg-info/requires.txt b/cassandra_driver.egg-info/requires.txt new file mode 100644 index 0000000..c067e65 --- /dev/null +++ b/cassandra_driver.egg-info/requires.txt @@ -0,0 +1,2 @@ +futures +six >=1.6 diff --git a/cassandra_driver.egg-info/top_level.txt b/cassandra_driver.egg-info/top_level.txt new file mode 100644 index 0000000..6a6d6b4 --- /dev/null +++ b/cassandra_driver.egg-info/top_level.txt @@ -0,0 +1 @@ +cassandra diff --git a/ez_setup.py b/ez_setup.py new file mode 100644 index 0000000..2535472 --- /dev/null +++ b/ez_setup.py @@ -0,0 +1,258 @@ +#!python +"""Bootstrap setuptools installation + +If you want to use setuptools in your package's setup.py, just include this +file in the same directory with it, and add this to the top of your setup.py:: + + from ez_setup import use_setuptools + use_setuptools() + +If you want to require a specific version of setuptools, set a download +mirror, or use an alternate download directory, you can do so by supplying +the appropriate options to ``use_setuptools()``. + +This file can also be run as a script to install or upgrade setuptools. +""" +import os +import shutil +import sys +import tempfile +import tarfile +import optparse +import subprocess + +from distutils import log + +try: + from site import USER_SITE +except ImportError: + USER_SITE = None + +DEFAULT_VERSION = "0.9.6" +DEFAULT_URL = "https://pypi.python.org/packages/source/s/setuptools/" + +def _python_cmd(*args): + args = (sys.executable,) + args + return subprocess.call(args) == 0 + +def _install(tarball, install_args=()): + # extracting the tarball + tmpdir = tempfile.mkdtemp() + log.warn('Extracting in %s', tmpdir) + old_wd = os.getcwd() + try: + os.chdir(tmpdir) + tar = tarfile.open(tarball) + _extractall(tar) + tar.close() + + # going in the directory + subdir = os.path.join(tmpdir, os.listdir(tmpdir)[0]) + os.chdir(subdir) + log.warn('Now working in %s', subdir) + + # installing + log.warn('Installing Setuptools') + if not _python_cmd('setup.py', 'install', *install_args): + log.warn('Something went wrong during the installation.') + log.warn('See the error message above.') + # exitcode will be 2 + return 2 + finally: + os.chdir(old_wd) + shutil.rmtree(tmpdir) + + +def _build_egg(egg, tarball, to_dir): + # extracting the tarball + tmpdir = tempfile.mkdtemp() + log.warn('Extracting in %s', tmpdir) + old_wd = os.getcwd() + try: + os.chdir(tmpdir) + tar = tarfile.open(tarball) + _extractall(tar) + tar.close() + + # going in the directory + subdir = os.path.join(tmpdir, os.listdir(tmpdir)[0]) + os.chdir(subdir) + log.warn('Now working in %s', subdir) + + # building an egg + log.warn('Building a Setuptools egg in %s', to_dir) + _python_cmd('setup.py', '-q', 'bdist_egg', '--dist-dir', to_dir) + + finally: + os.chdir(old_wd) + shutil.rmtree(tmpdir) + # returning the result + log.warn(egg) + if not os.path.exists(egg): + raise IOError('Could not build the egg.') + + +def _do_download(version, download_base, to_dir, download_delay): + egg = os.path.join(to_dir, 'setuptools-%s-py%d.%d.egg' + % (version, sys.version_info[0], sys.version_info[1])) + if not os.path.exists(egg): + tarball = download_setuptools(version, download_base, + to_dir, download_delay) + _build_egg(egg, tarball, to_dir) + sys.path.insert(0, egg) + import setuptools + setuptools.bootstrap_install_from = egg + + +def use_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL, + to_dir=os.curdir, download_delay=15): + # making sure we use the absolute path + to_dir = os.path.abspath(to_dir) + was_imported = 'pkg_resources' in sys.modules or \ + 'setuptools' in sys.modules + try: + import pkg_resources + except ImportError: + return _do_download(version, download_base, to_dir, download_delay) + try: + pkg_resources.require("setuptools>=" + version) + return + except pkg_resources.VersionConflict: + e = sys.exc_info()[1] + if was_imported: + sys.stderr.write( + "The required version of setuptools (>=%s) is not available,\n" + "and can't be installed while this script is running. Please\n" + "install a more recent version first, using\n" + "'easy_install -U setuptools'." + "\n\n(Currently using %r)\n" % (version, e.args[0])) + sys.exit(2) + else: + del pkg_resources, sys.modules['pkg_resources'] # reload ok + return _do_download(version, download_base, to_dir, + download_delay) + except pkg_resources.DistributionNotFound: + return _do_download(version, download_base, to_dir, + download_delay) + + +def download_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL, + to_dir=os.curdir, delay=15): + """Download setuptools from a specified location and return its filename + + `version` should be a valid setuptools version number that is available + as an egg for download under the `download_base` URL (which should end + with a '/'). `to_dir` is the directory where the egg will be downloaded. + `delay` is the number of seconds to pause before an actual download + attempt. + """ + # making sure we use the absolute path + to_dir = os.path.abspath(to_dir) + try: + from urllib.request import urlopen + except ImportError: + from urllib2 import urlopen + tgz_name = "setuptools-%s.tar.gz" % version + url = download_base + tgz_name + saveto = os.path.join(to_dir, tgz_name) + src = dst = None + if not os.path.exists(saveto): # Avoid repeated downloads + try: + log.warn("Downloading %s", url) + src = urlopen(url) + # Read/write all in one block, so we don't create a corrupt file + # if the download is interrupted. + data = src.read() + dst = open(saveto, "wb") + dst.write(data) + finally: + if src: + src.close() + if dst: + dst.close() + return os.path.realpath(saveto) + + +def _extractall(self, path=".", members=None): + """Extract all members from the archive to the current working + directory and set owner, modification time and permissions on + directories afterwards. `path' specifies a different directory + to extract to. `members' is optional and must be a subset of the + list returned by getmembers(). + """ + import copy + import operator + from tarfile import ExtractError + directories = [] + + if members is None: + members = self + + for tarinfo in members: + if tarinfo.isdir(): + # Extract directories with a safe mode. + directories.append(tarinfo) + tarinfo = copy.copy(tarinfo) + tarinfo.mode = 448 # decimal for oct 0700 + self.extract(tarinfo, path) + + # Reverse sort directories. + if sys.version_info < (2, 4): + def sorter(dir1, dir2): + return cmp(dir1.name, dir2.name) + directories.sort(sorter) + directories.reverse() + else: + directories.sort(key=operator.attrgetter('name'), reverse=True) + + # Set correct owner, mtime and filemode on directories. + for tarinfo in directories: + dirpath = os.path.join(path, tarinfo.name) + try: + self.chown(tarinfo, dirpath) + self.utime(tarinfo, dirpath) + self.chmod(tarinfo, dirpath) + except ExtractError: + e = sys.exc_info()[1] + if self.errorlevel > 1: + raise + else: + self._dbg(1, "tarfile: %s" % e) + + +def _build_install_args(options): + """ + Build the arguments to 'python setup.py install' on the setuptools package + """ + install_args = [] + if options.user_install: + if sys.version_info < (2, 6): + log.warn("--user requires Python 2.6 or later") + raise SystemExit(1) + install_args.append('--user') + return install_args + +def _parse_args(): + """ + Parse the command line for options + """ + parser = optparse.OptionParser() + parser.add_option( + '--user', dest='user_install', action='store_true', default=False, + help='install in user site package (requires Python 2.6 or later)') + parser.add_option( + '--download-base', dest='download_base', metavar="URL", + default=DEFAULT_URL, + help='alternative URL from where to download the setuptools package') + options, args = parser.parse_args() + # positional arguments are ignored + return options + +def main(version=DEFAULT_VERSION): + """Install or upgrade setuptools and EasyInstall""" + options = _parse_args() + tarball = download_setuptools(download_base=options.download_base) + return _install(tarball, _build_install_args(options)) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..861a9f5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +[egg_info] +tag_build = +tag_date = 0 +tag_svn_revision = 0 + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..df3aef2 --- /dev/null +++ b/setup.py @@ -0,0 +1,287 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +import warnings + +if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests": + print("Running gevent tests") + from gevent.monkey import patch_all + patch_all() + +if __name__ == '__main__' and sys.argv[1] == "eventlet_nosetests": + print("Running eventlet tests") + from eventlet import monkey_patch + monkey_patch() + +import ez_setup +ez_setup.use_setuptools() + +from setuptools import setup +from distutils.command.build_ext import build_ext +from distutils.core import Extension +from distutils.errors import (CCompilerError, DistutilsPlatformError, + DistutilsExecError) +from distutils.cmd import Command + + +import os +import warnings + +try: + import subprocess + has_subprocess = True +except ImportError: + has_subprocess = False + +from cassandra import __version__ + +long_description = "" +with open("README.rst") as f: + long_description = f.read() + + +try: + from nose.commands import nosetests +except ImportError: + gevent_nosetests = None + eventlet_nosetests = None +else: + class gevent_nosetests(nosetests): + description = "run nosetests with gevent monkey patching" + + class eventlet_nosetests(nosetests): + description = "run nosetests with eventlet monkey patching" + +has_cqlengine = False +if __name__ == '__main__' and sys.argv[1] == "install": + try: + import cqlengine + has_cqlengine = True + except ImportError: + pass + + +class DocCommand(Command): + + description = "generate or test documentation" + + user_options = [("test", "t", + "run doctests instead of generating documentation")] + + boolean_options = ["test"] + + def initialize_options(self): + self.test = False + + def finalize_options(self): + pass + + def run(self): + if self.test: + path = "docs/_build/doctest" + mode = "doctest" + else: + path = "docs/_build/%s" % __version__ + mode = "html" + + try: + os.makedirs(path) + except: + pass + + if has_subprocess: + try: + output = subprocess.check_output( + ["sphinx-build", "-b", mode, "docs", path], + stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as exc: + raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output)) + else: + print(output) + + print("") + print("Documentation step '%s' performed, results here:" % mode) + print(" file://%s/%s/index.html" % (os.path.dirname(os.path.realpath(__file__)), path)) + + +class BuildFailed(Exception): + + def __init__(self, ext): + self.ext = ext + + +murmur3_ext = Extension('cassandra.murmur3', + sources=['cassandra/murmur3.c']) + +libev_ext = Extension('cassandra.io.libevwrapper', + sources=['cassandra/io/libevwrapper.c'], + include_dirs=['/usr/include/libev', '/usr/local/include', '/opt/local/include'], + libraries=['ev'], + library_dirs=['/usr/local/lib', '/opt/local/lib']) + + +class build_extensions(build_ext): + + error_message = """ +=============================================================================== +WARNING: could not compile %s. + +The C extensions are not required for the driver to run, but they add support +for libev and token-aware routing with the Murmur3Partitioner. + +Linux users should ensure that GCC and the Python headers are available. + +On Ubuntu and Debian, this can be accomplished by running: + + $ sudo apt-get install build-essential python-dev + +On RedHat and RedHat-based systems like CentOS and Fedora: + + $ sudo yum install gcc python-devel + +On OSX, homebrew installations of Python should provide the necessary headers. + +libev Support +------------- +For libev support, you will also need to install libev and its headers. + +On Debian/Ubuntu: + + $ sudo apt-get install libev4 libev-dev + +On RHEL/CentOS/Fedora: + + $ sudo yum install libev libev-devel + +On OSX, via homebrew: + + $ brew install libev + +=============================================================================== + """ + + def run(self): + try: + build_ext.run(self) + except DistutilsPlatformError as exc: + sys.stderr.write('%s\n' % str(exc)) + warnings.warn(self.error_message % "C extensions.") + + def build_extension(self, ext): + try: + build_ext.build_extension(self, ext) + except (CCompilerError, DistutilsExecError, + DistutilsPlatformError, IOError) as exc: + sys.stderr.write('%s\n' % str(exc)) + name = "The %s extension" % (ext.name,) + warnings.warn(self.error_message % (name,)) + raise BuildFailed(ext) + + +def run_setup(extensions): + + kw = {'cmdclass': {'doc': DocCommand}} + if gevent_nosetests is not None: + kw['cmdclass']['gevent_nosetests'] = gevent_nosetests + + if eventlet_nosetests is not None: + kw['cmdclass']['eventlet_nosetests'] = eventlet_nosetests + + if extensions: + kw['cmdclass']['build_ext'] = build_extensions + kw['ext_modules'] = extensions + + dependencies = ['futures', 'six >=1.6'] + + setup( + name='cassandra-driver', + version=__version__, + description='Python driver for Cassandra', + long_description=long_description, + url='http://github.com/datastax/python-driver', + author='Tyler Hobbs', + author_email='tyler@datastax.com', + packages=['cassandra', 'cassandra.io', 'cassandra.cqlengine'], + keywords='cassandra,cql,orm', + include_package_data=True, + install_requires=dependencies, + tests_require=['nose', 'mock', 'PyYAML', 'pytz', 'sure'], + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: Apache Software License', + 'Natural Language :: English', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Programming Language :: Python :: 2.6', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: Implementation :: CPython', + 'Programming Language :: Python :: Implementation :: PyPy', + 'Topic :: Software Development :: Libraries :: Python Modules' + ], + **kw) + +extensions = [murmur3_ext, libev_ext] +if "--no-extensions" in sys.argv: + sys.argv = [a for a in sys.argv if a != "--no-extensions"] + extensions = [] +elif "--no-murmur3" in sys.argv: + sys.argv = [a for a in sys.argv if a != "--no-murmur3"] + extensions.remove(murmur3_ext) +elif "--no-libev" in sys.argv: + sys.argv = [a for a in sys.argv if a != "--no-libev"] + extensions.remove(libev_ext) + + +platform_unsupported_msg = \ +""" +=============================================================================== +The optional C extensions are not supported on this platform. +=============================================================================== +""" + +arch_unsupported_msg = \ +""" +=============================================================================== +The optional C extensions are not supported on big-endian systems. +=============================================================================== +""" + +if extensions: + if (sys.platform.startswith("java") + or sys.platform == "cli" + or "PyPy" in sys.version): + sys.stderr.write(platform_unsupported_msg) + extensions = () + elif sys.byteorder == "big": + sys.stderr.write(arch_unsupported_msg) + extensions = () + +while True: + # try to build as many of the extensions as we can + try: + run_setup(extensions) + except BuildFailed as failure: + extensions.remove(failure.ext) + else: + break + +if has_cqlengine: + warnings.warn("\n#######\n'cqlengine' package is present on path: %s\n" + "cqlengine is now an integrated sub-package of this driver.\n" + "It is recommended to remove this package to reduce the chance for conflicting usage" % cqlengine.__file__)