diff --git a/.travis.yml b/.travis.yml
index f1fff4b..b485e21 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,32 +1,33 @@
dist: xenial
sudo: false
language: python
python:
- "2.7"
- "3.5"
- "3.6"
- "3.7"
- "pypy2.7-6.0"
- "pypy3.5"
env:
- CASS_DRIVER_NO_CYTHON=1
addons:
apt:
packages:
- build-essential
- python-dev
- pypy-dev
- libc-ares-dev
- libev4
- libev-dev
install:
- - pip install tox-travis lz4
+ - pip install tox-travis
+ - if [[ $TRAVIS_PYTHON_VERSION != pypy3.5 ]]; then pip install lz4; fi
script:
- tox
- tox -e gevent_loop
- tox -e eventlet_loop
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 0ac2aeb..d2d577c 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -1,1494 +1,1678 @@
+3.25.0
+======
+March 18, 2021
+
+Features
+--------
+* Ensure the driver can connect when invalid peer hosts are in system.peers (PYTHON-1260)
+* Implement protocol v5 checksumming (PYTHON-1258)
+* Fix the default cqlengine connection mechanism to work with Astra (PYTHON-1265)
+
+Bug Fixes
+---------
+* Asyncore race condition cause logging exception on shutdown (PYTHON-1266)
+* Update list of reserved keywords (PYTHON-1269)
+
+Others
+------
+* Drop Python 3.4 support (PYTHON-1220)
+* Update security documentation and examples to use PROTOCOL_TLS (PYTHON-1264)
+
+3.24.0
+======
+June 18, 2020
+
+Features
+--------
+* Make geomet an optional dependency at runtime (PYTHON-1237)
+* Add use_default_tempdir cloud config options (PYTHON-1245)
+* Tcp flow control for libevreactor (PYTHON-1248)
+
+Bug Fixes
+---------
+* Unable to connect to a cloud cluster using Ubuntu 20.04 (PYTHON-1238)
+* PlainTextAuthProvider fails with unicode chars and Python3 (PYTHON-1241)
+* [GRAPH] Graph execution profiles consistency level are not set to LOCAL_QUORUM with a cloud cluster (PYTHON-1240)
+* [GRAPH] Can't write data in a Boolean field using the Fluent API (PYTHON-1239)
+* [GRAPH] Fix elementMap() result deserialization (PYTHON-1233)
+
+Others
+------
+* Bump geomet dependency version to 0.2 (PYTHON-1243)
+* Bump gremlinpython dependency version to 3.4.6 (PYTHON-1212)
+* Improve fluent graph documentation for core graphs (PYTHON-1244)
+
+3.23.0
+======
+April 6, 2020
+
+Features
+--------
+* Transient Replication Support (PYTHON-1207)
+* Support system.peers_v2 and port discovery for C* 4.0 (PYTHON-700)
+
+Bug Fixes
+---------
+* Asyncore logging exception on shutdown (PYTHON-1228)
+
+3.22.0
+======
+February 26, 2020
+
+Features
+--------
+
+* Add all() function to the ResultSet API (PYTHON-1203)
+* Parse new schema metadata in NGDG and generate table edges CQL syntax (PYTHON-996)
+* Add GraphSON3 support (PYTHON-788)
+* Use GraphSON3 as default for Native graphs (PYTHON-1004)
+* Add Tuple and UDT types for native graph (PYTHON-1005)
+* Add Duration type for native graph (PYTHON-1000)
+* Add gx:ByteBuffer graphson type support for Blob field (PYTHON-1027)
+* Enable Paging Through DSE Driver for Gremlin Traversals (PYTHON-1045)
+* Provide numerical wrappers to ensure proper graphson schema definition (PYTHON-1051)
+* Resolve the row_factory automatically for native graphs (PYTHON-1056)
+* Add g:TraversalMetrics/g:Metrics graph deserializers (PYTHON-1057)
+* Add g:BulkSet graph deserializers (PYTHON-1060)
+* Update Graph Engine names and the way to create a Classic/Native Graph (PYTHON-1090)
+* Update Native to Core Graph Engine
+* Add graphson3 and native graph support (PYTHON-1039)
+* Enable Paging Through DSE Driver for Gremlin Traversals (PYTHON-1045)
+* Expose filter predicates for cql collections (PYTHON-1019)
+* Add g:TraversalMetrics/Metrics deserializers (PYTHON-1057)
+* Make graph metadata handling more robust (PYTHON-1204)
+
+Bug Fixes
+---------
+* Make sure to only query the native_transport_address column with DSE (PYTHON-1205)
+
+3.21.0
+======
+January 15, 2020
+
+Features
+--------
+* Unified driver: merge core and DSE drivers into a single package (PYTHON-1130)
+* Add Python 3.8 support (PYTHON-1189)
+* Allow passing ssl context for Twisted (PYTHON-1161)
+* Ssl context and cloud support for Eventlet (PYTHON-1162)
+* Cloud Twisted support (PYTHON-1163)
+* Add additional_write_policy and read_repair to system schema parsing (PYTHON-1048)
+* Flexible version parsing (PYTHON-1174)
+* Support NULL in collection deserializer (PYTHON-1123)
+* [GRAPH] Ability to execute Fluent Graph queries asynchronously (PYTHON-1129)
+
+Bug Fixes
+---------
+* Handle prepared id mismatch when repreparing on the fly (PYTHON-1124)
+* re-raising the CQLEngineException will fail on Python 3 (PYTHON-1166)
+* asyncio message chunks can be processed discontinuously (PYTHON-1185)
+* Reconnect attempts persist after downed node removed from peers (PYTHON-1181)
+* Connection fails to validate ssl certificate hostname when SSLContext.check_hostname is set (PYTHON-1186)
+* ResponseFuture._set_result crashes on connection error when used with PrepareMessage (PYTHON-1187)
+* Insights fail to serialize the startup message when the SSL Context is from PyOpenSSL (PYTHON-1192)
+
+Others
+------
+* The driver has a new dependency: geomet. It comes from the dse-driver unification and
+ is used to support DSE geo types.
+* Remove *read_repair_chance table options (PYTHON-1140)
+* Avoid warnings about unspecified load balancing policy when connecting to a cloud cluster (PYTHON-1177)
+* Add new DSE CQL keywords (PYTHON-1122)
+* Publish binary wheel distributions (PYTHON-1013)
+
+Deprecations
+------------
+
+* DSELoadBalancingPolicy will be removed in the next major, consider using
+ the DefaultLoadBalancingPolicy.
+
+Merged from dse-driver:
+
+Features
+--------
+
+* Insights integration (PYTHON-1047)
+* Graph execution profiles should preserve their graph_source when graph_options is overridden (PYTHON-1021)
+* Add NodeSync metadata (PYTHON-799)
+* Add new NodeSync failure values (PYTHON-934)
+* DETERMINISTIC and MONOTONIC Clauses for Functions and Aggregates (PYTHON-955)
+* GraphOptions should show a warning for unknown parameters (PYTHON-819)
+* DSE protocol version 2 and continous paging backpressure (PYTHON-798)
+* GraphSON2 Serialization/Deserialization Support (PYTHON-775)
+* Add graph-results payload option for GraphSON format (PYTHON-773)
+* Create an AuthProvider for the DSE transitional mode (PYTHON-831)
+* Implement serializers for the Graph String API (PYTHON-778)
+* Provide deserializers for GraphSON types (PYTHON-782)
+* Add Graph DurationType support (PYTHON-607)
+* Support DSE DateRange type (PYTHON-668)
+* RLAC CQL output for materialized views (PYTHON-682)
+* Add Geom Types wkt deserializer
+* DSE Graph Client timeouts in custom payload (PYTHON-589)
+* Make DSEGSSAPIAuthProvider accept principal name (PYTHON-574)
+* Add config profiles to DSE graph execution (PYTHON-570)
+* DSE Driver version checking (PYTHON-568)
+* Distinct default timeout for graph queries (PYTHON-477)
+* Graph result parsing for known types (PYTHON-479,487)
+* Distinct read/write CL for graph execution (PYTHON-509)
+* Target graph analytics query to spark master when available (PYTHON-510)
+
+Bug Fixes
+---------
+
+* Continuous paging sessions raise RuntimeError when results are not entirely consumed (PYTHON-1054)
+* GraphSON Property deserializer should return a dict instead of a set (PYTHON-1033)
+* ResponseFuture.has_more_pages may hold the wrong value (PYTHON-946)
+* DETERMINISTIC clause in AGGREGATE misplaced in CQL generation (PYTHON-963)
+* graph module import cause a DLL issue on Windows due to its cythonizing failure (PYTHON-900)
+* Update date serialization to isoformat in graph (PYTHON-805)
+* DateRange Parse Error (PYTHON-729)
+* MontonicTimestampGenerator.__init__ ignores class defaults (PYTHON-728)
+* metadata.get_host returning None unexpectedly (PYTHON-709)
+* Sockets associated with sessions not getting cleaned up on session.shutdown() (PYTHON-673)
+* Resolve FQDN from ip address and use that as host passed to SASLClient (PYTHON-566)
+* Geospatial type implementations don't handle 'EMPTY' values. (PYTHON-481)
+* Correctly handle other types in geo type equality (PYTHON-508)
+
+Other
+-----
+* Add tests around cqlengine and continuous paging (PYTHON-872)
+* Add an abstract GraphStatement to handle different graph statements (PYTHON-789)
+* Write documentation examples for DSE 2.0 features (PYTHON-732)
+* DSE_V1 protocol should not include all of protocol v5 (PYTHON-694)
+
3.20.2
======
November 19, 2019
Bug Fixes
---------
* Fix import error for old python installation without SSLContext (PYTHON-1183)
3.20.1
======
November 6, 2019
Bug Fixes
---------
* ValueError: too many values to unpack (expected 2)" when there are two dashes in server version number (PYTHON-1172)
3.20.0
======
October 28, 2019
Features
--------
-* DataStax Apollo Support (PYTHON-1074)
+* DataStax Astra Support (PYTHON-1074)
* Use 4.0 schema parser in 4 alpha and snapshot builds (PYTHON-1158)
Bug Fixes
---------
* Connection setup methods prevent using ExecutionProfile in cqlengine (PYTHON-1009)
* Driver deadlock if all connections dropped by heartbeat whilst request in flight and request times out (PYTHON-1044)
* Exception when use pk__token__gt filter In python 3.7 (PYTHON-1121)
3.19.0
======
August 26, 2019
Features
--------
* Add Python 3.7 support (PYTHON-1016)
* Future-proof Mapping imports (PYTHON-1023)
* Include param values in cqlengine logging (PYTHON-1105)
* NTS Token Replica Map Generation is slow (PYTHON-622)
Bug Fixes
---------
* as_cql_query UDF/UDA parameters incorrectly includes "frozen" if arguments are collections (PYTHON-1031)
* cqlengine does not currently support combining TTL and TIMESTAMP on INSERT (PYTHON-1093)
* Fix incorrect metadata for compact counter tables (PYTHON-1100)
* Call ConnectionException with correct kwargs (PYTHON-1117)
* Can't connect to clusters built from source because version parsing doesn't handle 'x.y-SNAPSHOT' (PYTHON-1118)
* Discovered node doesn´t honor the configured Cluster port on connection (PYTHON-1127)
+* Exception when use pk__token__gt filter In python 3.7 (PYTHON-1121)
Other
-----
* Remove invalid warning in set_session when we initialize a default connection (PYTHON-1104)
* Set the proper default ExecutionProfile.row_factory value (PYTHON-1119)
3.18.0
======
May 27, 2019
Features
--------
* Abstract Host Connection information (PYTHON-1079)
* Improve version parsing to support a non-integer 4th component (PYTHON-1091)
* Expose on_request_error method in the RetryPolicy (PYTHON-1064)
* Add jitter to ExponentialReconnectionPolicy (PYTHON-1065)
Bug Fixes
---------
* Fix error when preparing queries with beta protocol v5 (PYTHON-1081)
* Accept legacy empty strings as column names (PYTHON-1082)
* Let util.SortedSet handle uncomparable elements (PYTHON-1087)
3.17.1
======
May 2, 2019
Bug Fixes
---------
* Socket errors EAGAIN/EWOULDBLOCK are not handled properly and cause timeouts (PYTHON-1089)
3.17.0
======
February 19, 2019
Features
--------
* Send driver name and version in startup message (PYTHON-1068)
* Add Cluster ssl_context option to enable SSL (PYTHON-995)
* Allow encrypted private keys for 2-way SSL cluster connections (PYTHON-995)
* Introduce new method ConsistencyLevel.is_serial (PYTHON-1067)
* Add Session.get_execution_profile (PYTHON-932)
* Add host kwarg to Session.execute/execute_async APIs to send a query to a specific node (PYTHON-993)
Bug Fixes
---------
* NoHostAvailable when all hosts are up and connectable (PYTHON-891)
* Serial consistency level is not used (PYTHON-1007)
Other
-----
* Fail faster on incorrect lz4 import (PYTHON-1042)
* Bump Cython dependency version to 0.29 (PYTHON-1036)
* Expand Driver SSL Documentation (PYTHON-740)
Deprecations
------------
* Using Cluster.ssl_options to enable SSL is deprecated and will be removed in
the next major release, use ssl_context.
* DowngradingConsistencyRetryPolicy is deprecated and will be
removed in the next major release. (PYTHON-937)
3.16.0
======
November 12, 2018
Bug Fixes
---------
* Improve and fix socket error-catching code in nonblocking-socket reactors (PYTHON-1024)
* Non-ASCII characters in schema break CQL string generation (PYTHON-1008)
* Fix OSS driver's virtual table support against DSE 6.0.X and future server releases (PYTHON-1020)
* ResultSet.one() fails if the row_factory is using a generator (PYTHON-1026)
* Log profile name on attempt to create existing profile (PYTHON-944)
* Cluster instantiation fails if any contact points' hostname resolution fails (PYTHON-895)
Other
-----
* Fix tests when RF is not maintained if we decomission a node (PYTHON-1017)
* Fix wrong use of ResultSet indexing (PYTHON-1015)
3.15.1
======
September 6, 2018
Bug Fixes
---------
* C* 4.0 schema-parsing logic breaks running against DSE 6.0.X (PYTHON-1018)
3.15.0
======
August 30, 2018
Features
--------
* Parse Virtual Keyspace Metadata (PYTHON-992)
Bug Fixes
---------
* Tokenmap.get_replicas returns the wrong value if token coincides with the end of the range (PYTHON-978)
* Python Driver fails with "more than 255 arguments" python exception when > 255 columns specified in query response (PYTHON-893)
* Hang in integration.standard.test_cluster.ClusterTests.test_set_keyspace_twice (PYTHON-998)
* Asyncore reactors should use a global variable instead of a class variable for the event loop (PYTHON-697)
Other
-----
* Use global variable for libev loops so it can be subclassed (PYTHON-973)
* Update SchemaParser for V4 (PYTHON-1006)
* Bump Cython dependency version to 0.28 (PYTHON-1012)
3.14.0
======
April 17, 2018
Features
--------
* Add one() function to the ResultSet API (PYTHON-947)
* Create an utility function to fetch concurrently many keys from the same replica (PYTHON-647)
* Allow filter queries with fields that have an index managed outside of cqlengine (PYTHON-966)
* Twisted SSL Support (PYTHON-343)
* Support IS NOT NULL operator in cqlengine (PYTHON-968)
Other
-----
* Fix Broken Links in Docs (PYTHON-916)
* Reevaluate MONKEY_PATCH_LOOP in test codebase (PYTHON-903)
* Remove CASS_SERVER_VERSION and replace it for CASSANDRA_VERSION in tests (PYTHON-910)
* Refactor CASSANDRA_VERSION to a some kind of version object (PYTHON-915)
* Log warning when driver configures an authenticator, but server does not request authentication (PYTHON-940)
* Warn users when using the deprecated Session.default_consistency_level (PYTHON-953)
* Add DSE smoke test to OSS driver tests (PYTHON-894)
* Document long compilation times and workarounds (PYTHON-868)
* Improve error for batch WriteTimeouts (PYTHON-941)
* Deprecate ResultSet indexing (PYTHON-945)
3.13.0
======
January 30, 2018
Features
--------
* cqlengine: LIKE filter operator (PYTHON-512)
* Support cassandra.query.BatchType with cqlengine BatchQuery (PYTHON-888)
Bug Fixes
---------
* AttributeError: 'NoneType' object has no attribute 'add_timer' (PYTHON-862)
* Support retry_policy in PreparedStatement (PYTHON-861)
* __del__ method in Session is throwing an exception (PYTHON-813)
* LZ4 import issue with recent versions (PYTHON-897)
* ResponseFuture._connection can be None when returning request_id (PYTHON-853)
* ResultSet.was_applied doesn't support batch with LWT statements (PYTHON-848)
Other
-----
* cqlengine: avoid warning when unregistering connection on shutdown (PYTHON-865)
* Fix DeprecationWarning of log.warn (PYTHON-846)
* Fix example_mapper.py for python3 (PYTHON-860)
* Possible deadlock on cassandra.concurrent.execute_concurrent (PYTHON-768)
* Add some known deprecated warnings for 4.x (PYTHON-877)
* Remove copyright dates from copyright notices (PYTHON-863)
* Remove "Experimental" tag from execution profiles documentation (PYTHON-840)
* request_timer metrics descriptions are slightly incorrect (PYTHON-885)
* Remove "Experimental" tag from cqlengine connections documentation (PYTHON-892)
* Set in documentation default consistency for operations is LOCAL_ONE (PYTHON-901)
3.12.0
======
November 6, 2017
Features
--------
* Send keyspace in QUERY, PREPARE, and BATCH messages (PYTHON-678)
* Add IPv4Address/IPv6Address support for inet types (PYTHON-751)
* WriteType.CDC and VIEW missing (PYTHON-794)
* Warn on Cluster init if contact points are specified but LBP isn't (legacy mode) (PYTHON-812)
* Warn on Cluster init if contact points are specified but LBP isn't (exection profile mode) (PYTHON-838)
* Include hash of result set metadata in prepared stmt id (PYTHON-808)
* Add NO_COMPACT startup option (PYTHON-839)
* Add new exception type for CDC (PYTHON-837)
* Allow 0ms in ConstantSpeculativeExecutionPolicy (PYTHON-836)
* Add asyncio reactor (PYTHON-507)
Bug Fixes
---------
* Both _set_final_exception/result called for the same ResponseFuture (PYTHON-630)
* Use of DCAwareRoundRobinPolicy raises NoHostAvailable exception (PYTHON-781)
* Not create two sessions by default in CQLEngine (PYTHON-814)
* Bug when subclassing AyncoreConnection (PYTHON-827)
* Error at cleanup when closing the asyncore connections (PYTHON-829)
* Fix sites where `sessions` can change during iteration (PYTHON-793)
* cqlengine: allow min_length=0 for Ascii and Text column types (PYTHON-735)
* Rare exception when "sys.exit(0)" after query timeouts (PYTHON-752)
* Dont set the session keyspace when preparing statements (PYTHON-843)
* Use of DCAwareRoundRobinPolicy raises NoHostAvailable exception (PYTHON-781)
Other
------
* Remove DeprecationWarning when using WhiteListRoundRobinPolicy (PYTHON-810)
* Bump Cython dependency version to 0.27 (PYTHON-833)
3.11.0
======
July 24, 2017
Features
--------
* Add idle_heartbeat_timeout cluster option to tune how long to wait for heartbeat responses. (PYTHON-762)
* Add HostFilterPolicy (PYTHON-761)
Bug Fixes
---------
* is_idempotent flag is not propagated from PreparedStatement to BoundStatement (PYTHON-736)
* Fix asyncore hang on exit (PYTHON-767)
* Driver takes several minutes to remove a bad host from session (PYTHON-762)
* Installation doesn't always fall back to no cython in Windows (PYTHON-763)
* Avoid to replace a connection that is supposed to shutdown (PYTHON-772)
* request_ids may not be returned to the pool (PYTHON-739)
* Fix murmur3 on big-endian systems (PYTHON-653)
* Ensure unused connections are closed if a Session is deleted by the GC (PYTHON-774)
* Fix .values_list by using db names internally (cqlengine) (PYTHON-785)
Other
-----
* Bump Cython dependency version to 0.25.2 (PYTHON-754)
* Fix DeprecationWarning when using lz4 (PYTHON-769)
* Deprecate WhiteListRoundRobinPolicy (PYTHON-759)
* Improve upgrade guide for materializing pages (PYTHON-464)
* Documentation for time/date specifies timestamp inupt as microseconds (PYTHON-717)
* Point to DSA Slack, not IRC, in docs index
3.10.0
======
May 24, 2017
Features
--------
* Add Duration type to cqlengine (PYTHON-750)
* Community PR review: Raise error on primary key update only if its value changed (PYTHON-705)
* get_query_trace() contract is ambiguous (PYTHON-196)
Bug Fixes
---------
* Queries using speculative execution policy timeout prematurely (PYTHON-755)
* Fix `map` where results are not consumed (PYTHON-749)
* Driver fails to encode Duration's with large values (PYTHON-747)
* UDT values are not updated correctly in CQLEngine (PYTHON-743)
* UDT types are not validated in CQLEngine (PYTHON-742)
* to_python is not implemented for types columns.Type and columns.Date in CQLEngine (PYTHON-741)
* Clients spin infinitely trying to connect to a host that is drained (PYTHON-734)
* Resulset.get_query_trace returns empty trace sometimes (PYTHON-730)
* Memory grows and doesn't get removed (PYTHON-720)
* Fix RuntimeError caused by change dict size during iteration (PYTHON-708)
* fix ExponentialReconnectionPolicy may throw OverflowError problem (PYTHON-707)
* Avoid using nonexistent prepared statement in ResponseFuture (PYTHON-706)
Other
-----
* Update README (PYTHON-746)
* Test python versions 3.5 and 3.6 (PYTHON-737)
* Docs Warning About Prepare "select *" (PYTHON-626)
* Increase Coverage in CqlEngine Test Suite (PYTHON-505)
* Example SSL connection code does not verify server certificates (PYTHON-469)
3.9.0
=====
Features
--------
* cqlengine: remove elements by key from a map (PYTHON-688)
Bug Fixes
---------
* improve error handling when connecting to non-existent keyspace (PYTHON-665)
* Sockets associated with sessions not getting cleaned up on session.shutdown() (PYTHON-673)
* rare flake on integration.standard.test_cluster.ClusterTests.test_clone_shared_lbp (PYTHON-727)
* MontonicTimestampGenerator.__init__ ignores class defaults (PYTHON-728)
* race where callback or errback for request may not be called (PYTHON-733)
* cqlengine: model.update() should not update columns with a default value that hasn't changed (PYTHON-657)
* cqlengine: field value manager's explicit flag is True when queried back from cassandra (PYTHON-719)
Other
-----
* Connection not closed in example_mapper (PYTHON-723)
* Remove mention of pre-2.0 C* versions from OSS 3.0+ docs (PYTHON-710)
3.8.1
=====
March 16, 2017
Bug Fixes
---------
* implement __le__/__ge__/__ne__ on some custom types (PYTHON-714)
* Fix bug in eventlet and gevent reactors that could cause hangs (PYTHON-721)
* Fix DecimalType regression (PYTHON-724)
3.8.0
=====
Features
--------
* Quote index names in metadata CQL generation (PYTHON-616)
* On column deserialization failure, keep error message consistent between python and cython (PYTHON-631)
* TokenAwarePolicy always sends requests to the same replica for a given key (PYTHON-643)
* Added cql types to result set (PYTHON-648)
* Add __len__ to BatchStatement (PYTHON-650)
* Duration Type for Cassandra (PYTHON-655)
* Send flags with PREPARE message in v5 (PYTHON-684)
Bug Fixes
---------
* Potential Timing issue if application exits prior to session pool initialization (PYTHON-636)
* "Host X.X.X.X has been marked down" without any exceptions (PYTHON-640)
* NoHostAvailable or OperationTimedOut when using execute_concurrent with a generator that inserts into more than one table (PYTHON-642)
* ResponseFuture creates Timers and don't cancel them even when result is received which leads to memory leaks (PYTHON-644)
* Driver cannot connect to Cassandra version > 3 (PYTHON-646)
* Unable to import model using UserType without setuping connection since 3.7 (PYTHON-649)
* Don't prepare queries on ignored hosts on_up (PYTHON-669)
* Sockets associated with sessions not getting cleaned up on session.shutdown() (PYTHON-673)
* Make client timestamps strictly monotonic (PYTHON-676)
* cassandra.cqlengine.connection.register_connection broken when hosts=None (PYTHON-692)
Other
-----
* Create a cqlengine doc section explaining None semantics (PYTHON-623)
* Resolve warnings in documentation generation (PYTHON-645)
* Cython dependency (PYTHON-686)
* Drop Support for Python 2.6 (PYTHON-690)
3.7.1
=====
October 26, 2016
Bug Fixes
---------
* Cython upgrade has broken stable version of cassandra-driver (PYTHON-656)
3.7.0
=====
September 13, 2016
Features
--------
* Add v5 protocol failure map (PYTHON-619)
* Don't return from initial connect on first error (PYTHON-617)
* Indicate failed column when deserialization fails (PYTHON-361)
* Let Cluster.refresh_nodes force a token map rebuild (PYTHON-349)
* Refresh UDTs after "keyspace updated" event with v1/v2 protocol (PYTHON-106)
* EC2 Address Resolver (PYTHON-198)
* Speculative query retries (PYTHON-218)
* Expose paging state in API (PYTHON-200)
* Don't mark host down while one connection is active (PYTHON-498)
* Query request size information (PYTHON-284)
* Avoid quadratic ring processing with invalid replication factors (PYTHON-379)
* Improve Connection/Pool creation concurrency on startup (PYTHON-82)
* Add beta version native protocol flag (PYTHON-614)
* cqlengine: Connections: support of multiple keyspaces and sessions (PYTHON-613)
Bug Fixes
---------
* Race when adding a pool while setting keyspace (PYTHON-628)
* Update results_metadata when prepared statement is reprepared (PYTHON-621)
* CQL Export for Thrift Tables (PYTHON-213)
* cqlengine: default value not applied to UserDefinedType (PYTHON-606)
* cqlengine: columns are no longer hashable (PYTHON-618)
* cqlengine: remove clustering keys from where clause when deleting only static columns (PYTHON-608)
3.6.0
=====
August 1, 2016
Features
--------
* Handle null values in NumpyProtocolHandler (PYTHON-553)
* Collect greplin scales stats per cluster (PYTHON-561)
* Update mock unit test dependency requirement (PYTHON-591)
* Handle Missing CompositeType metadata following C* upgrade (PYTHON-562)
* Improve Host.is_up state for HostDistance.IGNORED hosts (PYTHON-551)
* Utilize v2 protocol's ability to skip result set metadata for prepared statement execution (PYTHON-71)
* Return from Cluster.connect() when first contact point connection(pool) is opened (PYTHON-105)
* cqlengine: Add ContextQuery to allow cqlengine models to switch the keyspace context easily (PYTHON-598)
* Standardize Validation between Ascii and Text types in Cqlengine (PYTHON-609)
Bug Fixes
---------
* Fix geventreactor with SSL support (PYTHON-600)
* Don't downgrade protocol version if explicitly set (PYTHON-537)
* Nonexistent contact point tries to connect indefinitely (PYTHON-549)
* Execute_concurrent can exceed max recursion depth in failure mode (PYTHON-585)
* Libev loop shutdown race (PYTHON-578)
* Include aliases in DCT type string (PYTHON-579)
* cqlengine: Comparison operators for Columns (PYTHON-595)
* cqlengine: disentangle default_time_to_live table option from model query default TTL (PYTHON-538)
* cqlengine: pk__token column name issue with the equality operator (PYTHON-584)
* cqlengine: Fix "__in" filtering operator converts True to string "True" automatically (PYTHON-596)
* cqlengine: Avoid LWTExceptions when updating columns that are part of the condition (PYTHON-580)
* cqlengine: Cannot execute a query when the filter contains all columns (PYTHON-599)
* cqlengine: routing key computation issue when a primary key column is overriden by model inheritance (PYTHON-576)
3.5.0
=====
June 27, 2016
Features
--------
* Optional Execution Profiles for the core driver (PYTHON-569)
* API to get the host metadata associated with the control connection node (PYTHON-583)
* Expose CDC option in table metadata CQL (PYTHON-593)
Bug Fixes
---------
* Clean up Asyncore socket map when fork is detected (PYTHON-577)
* cqlengine: QuerySet only() is not respected when there are deferred fields (PYTHON-560)
3.4.1
=====
May 26, 2016
Bug Fixes
---------
* Gevent connection closes on IO timeout (PYTHON-573)
* "dictionary changed size during iteration" with Python 3 (PYTHON-572)
3.4.0
=====
May 24, 2016
Features
--------
* Include DSE version and workload in Host data (PYTHON-555)
* Add a context manager to Cluster and Session (PYTHON-521)
* Better Error Message for Unsupported Protocol Version (PYTHON-157)
* Make the error message explicitly state when an error comes from the server (PYTHON-412)
* Short Circuit meta refresh on topo change if NEW_NODE already exists (PYTHON-557)
* Show warning when the wrong config is passed to SimpleStatement (PYTHON-219)
* Return namedtuple result pairs from execute_concurrent (PYTHON-362)
* BatchStatement should enforce batch size limit in a better way (PYTHON-151)
* Validate min/max request thresholds for connection pool scaling (PYTHON-220)
* Handle or warn about multiple hosts with the same rpc_address (PYTHON-365)
* Write docs around working with datetime and timezones (PYTHON-394)
Bug Fixes
---------
* High CPU utilization when using asyncore event loop (PYTHON-239)
* Fix CQL Export for non-ASCII Identifiers (PYTHON-447)
* Make stress scripts Python 2.6 compatible (PYTHON-434)
* UnicodeDecodeError when unicode characters in key in BOP (PYTHON-559)
* WhiteListRoundRobinPolicy should resolve hosts (PYTHON-565)
* Cluster and Session do not GC after leaving scope (PYTHON-135)
* Don't wait for schema agreement on ignored nodes (PYTHON-531)
* Reprepare on_up with many clients causes node overload (PYTHON-556)
* None inserted into host map when control connection node is decommissioned (PYTHON-548)
* weakref.ref does not accept keyword arguments (github #585)
3.3.0
=====
May 2, 2016
Features
--------
* Add an AddressTranslator interface (PYTHON-69)
* New Retry Policy Decision - try next host (PYTHON-285)
* Don't mark host down on timeout (PYTHON-286)
* SSL hostname verification (PYTHON-296)
* Add C* version to metadata or cluster objects (PYTHON-301)
* Options to Disable Schema, Token Metadata Processing (PYTHON-327)
* Expose listen_address of node we get ring information from (PYTHON-332)
* Use A-record with multiple IPs for contact points (PYTHON-415)
* Custom consistency level for populating query traces (PYTHON-435)
* Normalize Server Exception Types (PYTHON-443)
* Propagate exception message when DDL schema agreement fails (PYTHON-444)
* Specialized exceptions for metadata refresh methods failure (PYTHON-527)
Bug Fixes
---------
* Resolve contact point hostnames to avoid duplicate hosts (PYTHON-103)
* GeventConnection stalls requests when read is a multiple of the input buffer size (PYTHON-429)
* named_tuple_factory breaks with duplicate "cleaned" col names (PYTHON-467)
* Connection leak if Cluster.shutdown() happens during reconnection (PYTHON-482)
* HostConnection.borrow_connection does not block when all request ids are used (PYTHON-514)
* Empty field not being handled by the NumpyProtocolHandler (PYTHON-550)
3.2.2
=====
April 19, 2016
* Fix counter save-after-no-update (PYTHON-547)
3.2.1
=====
April 13, 2016
* Introduced an update to allow deserializer compilation with recently released Cython 0.24 (PYTHON-542)
3.2.0
=====
April 12, 2016
Features
--------
* cqlengine: Warn on sync_schema type mismatch (PYTHON-260)
* cqlengine: Automatically defer fields with the '=' operator (and immutable values) in select queries (PYTHON-520)
* cqlengine: support non-equal conditions for LWT (PYTHON-528)
* cqlengine: sync_table should validate the primary key composition (PYTHON-532)
* cqlengine: token-aware routing for mapper statements (PYTHON-535)
Bug Fixes
---------
* Deleting a column in a lightweight transaction raises a SyntaxException #325 (PYTHON-249)
* cqlengine: make Token function works with named tables/columns #86 (PYTHON-272)
* comparing models with datetime fields fail #79 (PYTHON-273)
* cython date deserializer integer math should be aligned with CPython (PYTHON-480)
* db_field is not always respected with UpdateStatement (PYTHON-530)
* Sync_table fails on column.Set with secondary index (PYTHON-533)
3.1.1
=====
March 14, 2016
Bug Fixes
---------
* cqlengine: Fix performance issue related to additional "COUNT" queries (PYTHON-522)
3.1.0
=====
March 10, 2016
Features
--------
* Pass name of server auth class to AuthProvider (PYTHON-454)
* Surface schema agreed flag for DDL statements (PYTHON-458)
* Automatically convert float and int to Decimal on serialization (PYTHON-468)
* Eventlet Reactor IO improvement (PYTHON-495)
* Make pure Python ProtocolHandler available even when Cython is present (PYTHON-501)
* Optional Cython deserializer for bytes as bytearray (PYTHON-503)
* Add Session.default_serial_consistency_level (github #510)
* cqlengine: Expose prior state information via cqlengine LWTException (github #343, PYTHON-336)
* cqlengine: Collection datatype "contains" operators support (Cassandra 2.1) #278 (PYTHON-258)
* cqlengine: Add DISTINCT query operator (PYTHON-266)
* cqlengine: Tuple cqlengine api (PYTHON-306)
* cqlengine: Add support for UPDATE/DELETE ... IF EXISTS statements (PYTHON-432)
* cqlengine: Allow nested container types (PYTHON-478)
* cqlengine: Add ability to set query's fetch_size and limit (PYTHON-323)
* cqlengine: Internalize default keyspace from successive set_session (PYTHON-486)
* cqlengine: Warn when Model.create() on Counters (to be deprecated) (PYTHON-333)
Bug Fixes
---------
* Bus error (alignment issues) when running cython on some ARM platforms (PYTHON-450)
* Overflow when decoding large collections (cython) (PYTHON-459)
* Timer heap comparison issue with Python 3 (github #466)
* Cython deserializer date overflow at 2^31 - 1 (PYTHON-452)
* Decode error encountered when cython deserializing large map results (PYTHON-459)
* Don't require Cython for build if compiler or Python header not present (PYTHON-471)
* Unorderable types in task scheduling with Python 3 (h(PYTHON-473)
* cqlengine: Fix crash when updating a UDT column with a None value (github #467)
* cqlengine: Race condition in ..connection.execute with lazy_connect (PYTHON-310)
* cqlengine: doesn't support case sensitive column family names (PYTHON-337)
* cqlengine: UserDefinedType mandatory in create or update (PYTHON-344)
* cqlengine: db_field breaks UserType (PYTHON-346)
* cqlengine: UDT badly quoted (PYTHON-347)
* cqlengine: Use of db_field on primary key prevents querying except while tracing. (PYTHON-351)
* cqlengine: DateType.deserialize being called with one argument vs two (PYTHON-354)
* cqlengine: Querying without setting up connection now throws AttributeError and not CQLEngineException (PYTHON-395)
* cqlengine: BatchQuery multiple time executing execute statements. (PYTHON-445)
* cqlengine: Better error for management functions when no connection set (PYTHON-451)
* cqlengine: Handle None values for UDT attributes in cqlengine (PYTHON-470)
* cqlengine: Fix inserting None for model save (PYTHON-475)
* cqlengine: EQ doesn't map to a QueryOperator (setup race condition) (PYTHON-476)
* cqlengine: class.MultipleObjectsReturned has DoesNotExist as base class (PYTHON-489)
* cqlengine: Typo in cqlengine UserType __len__ breaks attribute assignment (PYTHON-502)
Other
-----
* cqlengine: a major improvement on queryset has been introduced. It
is a lot more efficient to iterate large datasets: the rows are
now fetched on demand using the driver pagination.
* cqlengine: the queryset len() and count() behaviors have changed. It
now executes a "SELECT COUNT(*)" of the query rather than returning
the size of the internal result_cache (loaded rows). On large
queryset, you might want to avoid using them due to the performance
cost. Note that trying to access objects using list index/slicing
with negative indices also requires a count to be
executed.
3.0.0
=====
November 24, 2015
Features
--------
* Support datetime.date objects as a DateType (PYTHON-212)
* Add Cluster.update_view_metadata (PYTHON-407)
* QueryTrace option to populate partial trace sessions (PYTHON-438)
* Attach column names to ResultSet (PYTHON-439)
* Change default consistency level to LOCAL_ONE
Bug Fixes
---------
* Properly SerDes nested collections when protocol_version < 3 (PYTHON-215)
* Evict UDTs from UserType cache on change (PYTHON-226)
* Make sure query strings are always encoded UTF-8 (PYTHON-334)
* Track previous value of columns at instantiation in CQLengine (PYTHON-348)
* UDT CQL encoding does not work for unicode values (PYTHON-353)
* NetworkTopologyStrategy#make_token_replica_map does not account for multiple racks in a DC (PYTHON-378)
* Cython integer overflow on decimal type deserialization (PYTHON-433)
* Query trace: if session hasn't been logged, query trace can throw exception (PYTHON-442)
3.0.0rc1
========
November 9, 2015
Features
--------
* Process Modernized Schema Tables for Cassandra 3.0 (PYTHON-276, PYTHON-408, PYTHON-400, PYTHON-422)
* Remove deprecated features (PYTHON-292)
* Don't assign trace data to Statements (PYTHON-318)
* Normalize results return (PYTHON-368)
* Process Materialized View Metadata/Events (PYTHON-371)
* Remove blist as soft dependency (PYTHON-385)
* Change default consistency level to LOCAL_QUORUM (PYTHON-416)
* Normalize CQL query/export in metadata model (PYTHON-405)
Bug Fixes
---------
* Implementation of named arguments bind is non-pythonic (PYTHON-178)
* CQL encoding is incorrect for NaN and Infinity floats (PYTHON-282)
* Protocol downgrade issue with C* 2.0.x, 2.1.x, and python3, with non-default logging (PYTHON-409)
* ValueError when accessing usertype with non-alphanumeric field names (PYTHON-413)
* NumpyProtocolHandler does not play well with PagedResult (PYTHON-430)
2.7.2
=====
September 14, 2015
Bug Fixes
---------
* Resolve CQL export error for UDF with zero parameters (PYTHON-392)
* Remove futures dep. for Python 3 (PYTHON-393)
* Avoid Python closure in cdef (supports earlier Cython compiler) (PYTHON-396)
* Unit test runtime issues (PYTHON-397,398)
2.7.1
=====
August 25, 2015
Bug Fixes
---------
* Explicitly include extension source files in Manifest
2.7.0
=====
August 25, 2015
Cython is introduced, providing compiled extensions for core modules, and
extensions for optimized results deserialization.
Features
--------
* General Performance Improvements for Throughput (PYTHON-283)
* Improve synchronous request performance with Timers (PYTHON-108)
* Enable C Extensions for PyPy Runtime (PYTHON-357)
* Refactor SerDes functionality for pluggable interface (PYTHON-313)
* Cython SerDes Extension (PYTHON-377)
* Accept iterators/generators for execute_concurrent() (PYTHON-123)
* cythonize existing modules (PYTHON-342)
* Pure Python murmur3 implementation (PYTHON-363)
* Make driver tolerant of inconsistent metadata (PYTHON-370)
Bug Fixes
---------
* Drop Events out-of-order Cause KeyError on Processing (PYTHON-358)
* DowngradingConsistencyRetryPolicy doesn't check response count on write timeouts (PYTHON-338)
* Blocking connect does not use connect_timeout (PYTHON-381)
* Properly protect partition key in CQL export (PYTHON-375)
* Trigger error callbacks on timeout (PYTHON-294)
2.6.0
=====
July 20, 2015
Bug Fixes
---------
* Output proper CQL for compact tables with no clustering columns (PYTHON-360)
2.6.0c2
=======
June 24, 2015
Features
--------
* Automatic Protocol Version Downgrade (PYTHON-240)
* cqlengine Python 2.6 compatibility (PYTHON-288)
* Double-dollar string quote UDF body (PYTHON-345)
* Set models.DEFAULT_KEYSPACE when calling set_session (github #352)
Bug Fixes
---------
* Avoid stall while connecting to mixed version cluster (PYTHON-303)
* Make SSL work with AsyncoreConnection in python 2.6.9 (PYTHON-322)
* Fix Murmur3Token.from_key() on Windows (PYTHON-331)
* Fix cqlengine TimeUUID rounding error for Windows (PYTHON-341)
* Avoid invalid compaction options in CQL export for non-SizeTiered (PYTHON-352)
2.6.0c1
=======
June 4, 2015
This release adds support for Cassandra 2.2 features, including version
4 of the native protocol.
Features
--------
* Default load balancing policy to TokenAware(DCAware) (PYTHON-160)
* Configuration option for connection timeout (PYTHON-206)
* Support User Defined Function and Aggregate metadata in C* 2.2 (PYTHON-211)
* Surface request client in QueryTrace for C* 2.2+ (PYTHON-235)
* Implement new request failure messages in protocol v4+ (PYTHON-238)
* Metadata model now maps index meta by index name (PYTHON-241)
* Support new types in C* 2.2: date, time, smallint, tinyint (PYTHON-245, 295)
* cqle: add Double column type and remove Float overload (PYTHON-246)
* Use partition key column information in prepared response for protocol v4+ (PYTHON-277)
* Support message custom payloads in protocol v4+ (PYTHON-280, PYTHON-329)
* Deprecate refresh_schema and replace with functions for specific entities (PYTHON-291)
* Save trace id even when trace complete times out (PYTHON-302)
* Warn when registering client UDT class for protocol < v3 (PYTHON-305)
* Support client warnings returned with messages in protocol v4+ (PYTHON-315)
* Ability to distinguish between NULL and UNSET values in protocol v4+ (PYTHON-317)
* Expose CQL keywords in API (PYTHON-324)
Bug Fixes
---------
* IPv6 address support on Windows (PYTHON-20)
* Convert exceptions during automatic re-preparation to nice exceptions (PYTHON-207)
* cqle: Quote keywords properly in table management functions (PYTHON-244)
* Don't default to GeventConnection when gevent is loaded, but not monkey-patched (PYTHON-289)
* Pass dynamic host from SaslAuthProvider to SaslAuthenticator (PYTHON-300)
* Make protocol read_inet work for Windows (PYTHON-309)
* cqle: Correct encoding for nested types (PYTHON-311)
* Update list of CQL keywords used quoting identifiers (PYTHON-319)
* Make ConstantReconnectionPolicy work with infinite retries (github #327, PYTHON-325)
* Accept UUIDs with uppercase hex as valid in cqlengine (github #335)
2.5.1
=====
April 23, 2015
Bug Fixes
---------
* Fix thread safety in DC-aware load balancing policy (PYTHON-297)
* Fix race condition in node/token rebuild (PYTHON-298)
* Set and send serial consistency parameter (PYTHON-299)
2.5.0
=====
March 30, 2015
Features
--------
* Integrated cqlengine object mapping package
* Utility functions for converting timeuuids and datetime (PYTHON-99)
* Schema metadata fetch window randomized, config options added (PYTHON-202)
* Support for new Date and Time Cassandra types (PYTHON-190)
Bug Fixes
---------
* Fix index target for collection indexes (full(), keys()) (PYTHON-222)
* Thread exception during GIL cleanup (PYTHON-229)
* Workaround for rounding anomaly in datetime.utcfromtime (Python 3.4) (PYTHON-230)
* Normalize text serialization for lookup in OrderedMap (PYTHON-231)
* Support reading CompositeType data (PYTHON-234)
* Preserve float precision in CQL encoding (PYTHON-243)
2.1.4
=====
January 26, 2015
Features
--------
* SaslAuthenticator for Kerberos support (PYTHON-109)
* Heartbeat for network device keepalive and detecting failures on idle connections (PYTHON-197)
* Support nested, frozen collections for Cassandra 2.1.3+ (PYTHON-186)
* Schema agreement wait bypass config, new call for synchronous schema refresh (PYTHON-205)
* Add eventlet connection support (PYTHON-194)
Bug Fixes
---------
* Schema meta fix for complex thrift tables (PYTHON-191)
* Support for 'unknown' replica placement strategies in schema meta (PYTHON-192)
* Resolve stream ID leak on set_keyspace (PYTHON-195)
* Remove implicit timestamp scaling on serialization of numeric timestamps (PYTHON-204)
* Resolve stream id collision when using SASL auth (PYTHON-210)
* Correct unhexlify usage for user defined type meta in Python3 (PYTHON-208)
2.1.3
=====
December 16, 2014
Features
--------
* INFO-level log confirmation that a connection was opened to a node that was marked up (PYTHON-116)
* Avoid connecting to peer with incomplete metadata (PYTHON-163)
* Add SSL support to gevent reactor (PYTHON-174)
* Use control connection timeout in wait for schema agreement (PYTHON-175)
* Better consistency level representation in unavailable+timeout exceptions (PYTHON-180)
* Update schema metadata processing to accommodate coming schema modernization (PYTHON-185)
Bug Fixes
---------
* Support large negative timestamps on Windows (PYTHON-119)
* Fix schema agreement for clusters with peer rpc_addres 0.0.0.0 (PYTHON-166)
* Retain table metadata following keyspace meta refresh (PYTHON-173)
* Use a timeout when preparing a statement for all nodes (PYTHON-179)
* Make TokenAware routing tolerant of statements with no keyspace (PYTHON-181)
* Update add_collback to store/invoke multiple callbacks (PYTHON-182)
* Correct routing key encoding for composite keys (PYTHON-184)
* Include compression option in schema export string when disabled (PYTHON-187)
2.1.2
=====
October 16, 2014
Features
--------
* Allow DCAwareRoundRobinPolicy to be constructed without a local_dc, defaulting
instead to the DC of a contact_point (PYTHON-126)
* Set routing key in BatchStatement.add() if none specified in batch (PYTHON-148)
* Improved feedback on ValueError using named_tuple_factory with invalid column names (PYTHON-122)
Bug Fixes
---------
* Make execute_concurrent compatible with Python 2.6 (PYTHON-159)
* Handle Unauthorized message on schema_triggers query (PYTHON-155)
* Pure Python sorted set in support of UDTs nested in collections (PYTON-167)
* Support CUSTOM index metadata and string export (PYTHON-165)
2.1.1
=====
September 11, 2014
Features
--------
* Detect triggers and include them in CQL queries generated to recreate
the schema (github-189)
* Support IPv6 addresses (PYTHON-144) (note: basic functionality added; Windows
platform not addressed (PYTHON-20))
Bug Fixes
---------
* Fix NetworkTopologyStrategy.export_for_schema (PYTHON-120)
* Keep timeout for paged results (PYTHON-150)
Other
-----
* Add frozen<> type modifier to UDTs and tuples to handle CASSANDRA-7857
2.1.0
=====
August 7, 2014
Bug Fixes
---------
* Correctly serialize and deserialize null values in tuples and
user-defined types (PYTHON-110)
* Include additional header and lib dirs, allowing libevwrapper to build
against Homebrew and Mac Ports installs of libev (PYTHON-112 and 804dea3)
2.1.0c1
=======
July 25, 2014
Bug Fixes
---------
* Properly specify UDTs for columns in CREATE TABLE statements
* Avoid moving retries to a new host when using request ID zero (PYTHON-88)
* Don't ignore fetch_size arguments to Statement constructors (github-151)
* Allow disabling automatic paging on a per-statement basis when it's
enabled by default for the session (PYTHON-93)
* Raise ValueError when tuple query parameters for prepared statements
have extra items (PYTHON-98)
* Correctly encode nested tuples and UDTs for non-prepared statements (PYTHON-100)
* Raise TypeError when a string is used for contact_points (github #164)
* Include User Defined Types in KeyspaceMetadata.export_as_string() (PYTHON-96)
Other
-----
* Return list collection columns as python lists instead of tuples
now that tuples are a specific Cassandra type
2.1.0b1
=======
July 11, 2014
This release adds support for Cassandra 2.1 features, including version
3 of the native protocol.
Features
--------
* When using the v3 protocol, only one connection is opened per-host, and
throughput is improved due to reduced pooling overhead and lock contention.
* Support for user-defined types (Cassandra 2.1+)
* Support for tuple type in (limited usage Cassandra 2.0.9, full usage
in Cassandra 2.1)
* Protocol-level client-side timestamps (see Session.use_client_timestamp)
* Overridable type encoding for non-prepared statements (see Session.encoders)
* Configurable serial consistency levels for batch statements
* Use io.BytesIO for reduced CPU consumption (github #143)
* Support Twisted as a reactor. Note that a Twisted-compatible
API is not exposed (so no Deferreds), this is just a reactor
implementation. (github #135, PYTHON-8)
Bug Fixes
---------
* Fix references to xrange that do not go through "six" in libevreactor and
geventreactor (github #138)
* Make BoundStatements inherit fetch_size from their parent
PreparedStatement (PYTHON-80)
* Clear reactor state in child process after forking to prevent errors with
multiprocessing when the parent process has connected a Cluster before
forking (github #141)
* Don't share prepared statement lock across Cluster instances
* Format CompositeType and DynamicCompositeType columns correctly in
CREATE TABLE statements.
* Fix cassandra.concurrent behavior when dealing with automatic paging
(PYTHON-81)
* Properly defunct connections after protocol errors
* Avoid UnicodeDecodeError when query string is unicode (PYTHON-76)
* Correctly capture dclocal_read_repair_chance for tables and
use it when generating CREATE TABLE statements (PYTHON-84)
* Avoid race condition with AsyncoreConnection that may cause messages
to fail to be written until a new message is pushed
* Make sure cluster.metadata.partitioner and cluster.metadata.token_map
are populated when all nodes in the cluster are included in the
contact points (PYTHON-90)
* Make Murmur3 hash match Cassandra's hash for all values (PYTHON-89,
github #147)
* Don't attempt to reconnect to hosts that should be ignored (according
to the load balancing policy) when a notification is received that the
host is down.
* Add CAS WriteType, avoiding KeyError on CAS write timeout (PYTHON-91)
2.0.2
=====
June 10, 2014
Bug Fixes
---------
* Add six to requirements.txt
* Avoid KeyError during schema refresh when a keyspace is dropped
and TokenAwarePolicy is not in use
* Avoid registering multiple atexit cleanup functions when the
asyncore event loop is restarted multiple times
* Delay initialization of reactors in order to avoid problems
with shared state when using multiprocessing (PYTHON-60)
* Add python-six to debian dependencies, move python-blist to recommends
* Fix memory leak when libev connections are created and
destroyed (github #93)
* Ensure token map is rebuilt when hosts are removed from the cluster
2.0.1
=====
May 28, 2014
Bug Fixes
---------
* Fix check for Cluster.is_shutdown in in @run_in_executor
decorator
2.0.0
=====
May 28, 2014
Features
--------
* Make libev C extension Python3-compatible (PYTHON-70)
* Support v2 protocol authentication (PYTHON-73, github #125)
Bug Fixes
---------
* Fix murmur3 C extension compilation under Python3.4 (github #124)
Merged From 1.x
---------------
Features
^^^^^^^^
* Add Session.default_consistency_level (PYTHON-14)
Bug Fixes
^^^^^^^^^
* Don't strip trailing underscores from column names when using the
named_tuple_factory (PYTHON-56)
* Ensure replication factors are ints for NetworkTopologyStrategy
to avoid TypeErrors (github #120)
* Pass WriteType instance to RetryPolicy.on_write_timeout() instead
of the string name of the write type. This caused write timeout
errors to always be rethrown instead of retrying. (github #123)
* Avoid submitting tasks to the ThreadPoolExecutor after shutdown. With
retries enabled, this could cause Cluster.shutdown() to hang under
some circumstances.
* Fix unintended rebuild of token replica map when keyspaces are
discovered (on startup), added, or updated and TokenAwarePolicy is not
in use.
* Avoid rebuilding token metadata when cluster topology has not
actually changed
* Avoid preparing queries for hosts that should be ignored (such as
remote hosts when using the DCAwareRoundRobinPolicy) (PYTHON-75)
Other
^^^^^
* Add 1 second timeout to join() call on event loop thread during
interpreter shutdown. This can help to prevent the process from
hanging during shutdown.
2.0.0b1
=======
May 6, 2014
Upgrading from 1.x
------------------
Cluster.shutdown() should always be called when you are done with a
Cluster instance. If it is not called, there are no guarantees that the
driver will not hang. However, if you *do* have a reproduceable case
where Cluster.shutdown() is not called and the driver hangs, please
report it so that we can attempt to fix it.
If you're using the 2.0 driver against Cassandra 1.2, you will need
to set your protocol version to 1. For example:
cluster = Cluster(..., protocol_version=1)
Features
--------
* Support v2 of Cassandra's native protocol, which includes the following
new features: automatic query paging support, protocol-level batch statements,
and lightweight transactions
* Support for Python 3.3 and 3.4
* Allow a default query timeout to be set per-Session
Bug Fixes
---------
* Avoid errors during interpreter shutdown (the driver attempts to cleanup
daemonized worker threads before interpreter shutdown)
Deprecations
------------
The following functions have moved from cassandra.decoder to cassandra.query.
The original functions have been left in place with a DeprecationWarning for
now:
* cassandra.decoder.tuple_factory has moved to cassandra.query.tuple_factory
* cassandra.decoder.named_tuple_factory has moved to cassandra.query.named_tuple_factory
* cassandra.decoder.dict_factory has moved to cassandra.query.dict_factory
* cassandra.decoder.ordered_dict_factory has moved to cassandra.query.ordered_dict_factory
Exceptions that were in cassandra.decoder have been moved to cassandra.protocol. If
you handle any of these exceptions, you must adjust the code accordingly.
1.1.2
=====
May 8, 2014
Features
--------
* Allow a specific compression type to be requested for communications with
Cassandra and prefer lz4 if available
Bug Fixes
---------
* Update token metadata (for TokenAware calculations) when a node is removed
from the ring
* Fix file handle leak with gevent reactor due to blocking Greenlet kills when
closing excess connections
* Avoid handling a node coming up multiple times due to a reconnection attempt
succeeding close to the same time that an UP notification is pushed
* Fix duplicate node-up handling, which could result in multiple reconnectors
being started as well as the executor threads becoming deadlocked, preventing
future node up or node down handling from being executed.
* Handle exhausted ReconnectionPolicy schedule correctly
Other
-----
* Don't log at ERROR when a connection is closed during the startup
communications
* Mke scales, blist optional dependencies
1.1.1
=====
April 16, 2014
Bug Fixes
---------
* Fix unconditional import of nose in setup.py (github #111)
1.1.0
=====
April 16, 2014
Features
--------
* Gevent is now supported through monkey-patching the stdlib (PYTHON-7,
github issue #46)
* Support static columns in schemas, which are available starting in
Cassandra 2.1. (github issue #91)
* Add debian packaging (github issue #101)
* Add utility methods for easy concurrent execution of statements. See
the new cassandra.concurrent module. (github issue #7)
Bug Fixes
---------
* Correctly supply compaction and compression parameters in CREATE statements
for tables when working with Cassandra 2.0+
* Lowercase boolean literals when generating schemas
* Ignore SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE socket errors. Previously,
these resulted in the connection being defuncted, but they can safely be
ignored by the driver.
* Don't reconnect the control connection every time Cluster.connect() is
called
* Avoid race condition that could leave ResponseFuture callbacks uncalled
if the callback was added outside of the event loop thread (github issue #95)
* Properly escape keyspace name in Session.set_keyspace(). Previously, the
keyspace name was quoted, but any quotes in the string were not escaped.
* Avoid adding hosts to the load balancing policy before their datacenter
and rack information has been set, if possible.
* Avoid KeyError when updating metadata after droping a table (github issues
#97, #98)
* Use tuples instead of sets for DCAwareLoadBalancingPolicy to ensure equal
distribution of requests
Other
-----
* Don't ignore column names when parsing typestrings. This is needed for
user-defined type support. (github issue #90)
* Better error message when libevwrapper is not found
* Only try to import scales when metrics are enabled (github issue #92)
* Cut down on the number of queries executing when a new Cluster
connects and when the control connection has to reconnect (github issue #104,
PYTHON-59)
* Issue warning log when schema versions do not match
1.0.2
=====
March 4, 2014
Bug Fixes
---------
* With asyncorereactor, correctly handle EAGAIN/EWOULDBLOCK when the message from
Cassandra is a multiple of the read buffer size. Previously, if no more data
became available to read on the socket, the message would never be processed,
resulting in an OperationTimedOut error.
* Double quote keyspace, table and column names that require them (those using
uppercase characters or keywords) when generating CREATE statements through
KeyspaceMetadata and TableMetadata.
* Decode TimestampType as DateType. (Cassandra replaced DateType with
TimestampType to fix sorting of pre-unix epoch dates in CASSANDRA-5723.)
* Handle latest table options when parsing the schema and generating
CREATE statements.
* Avoid 'Set changed size during iteration' during query plan generation
when hosts go up or down
Other
-----
* Remove ignored ``tracing_enabled`` parameter for ``SimpleStatement``. The
correct way to trace a query is by setting the ``trace`` argument to ``True``
in ``Session.execute()`` and ``Session.execute_async()``.
* Raise TypeError instead of cassandra.query.InvalidParameterTypeError when
a parameter for a prepared statement has the wrong type; remove
cassandra.query.InvalidParameterTypeError.
* More consistent type checking for query parameters
* Add option to a return special object for empty string values for non-string
columns
1.0.1
=====
Feb 19, 2014
Bug Fixes
---------
* Include table indexes in ``KeyspaceMetadata.export_as_string()``
* Fix broken token awareness on ByteOrderedPartitioner
* Always close socket when defuncting error'ed connections to avoid a potential
file descriptor leak
* Handle "custom" types (such as the replaced DateType) correctly
* With libevreactor, correctly handle EAGAIN/EWOULDBLOCK when the message from
Cassandra is a multiple of the read buffer size. Previously, if no more data
became available to read on the socket, the message would never be processed,
resulting in an OperationTimedOut error.
* Don't break tracing when a Session's row_factory is not the default
namedtuple_factory.
* Handle data that is already utf8-encoded for UTF8Type values
* Fix token-aware routing for tokens that fall before the first node token in
the ring and tokens that exactly match a node's token
* Tolerate null source_elapsed values for Trace events. These may not be
set when events complete after the main operation has already completed.
Other
-----
* Skip sending OPTIONS message on connection creation if compression is
disabled or not available and a CQL version has not been explicitly
set
* Add details about errors and the last queried host to ``OperationTimedOut``
1.0.0 Final
===========
Jan 29, 2014
Bug Fixes
---------
* Prevent leak of Scheduler thread (even with proper shutdown)
* Correctly handle ignored hosts, which are common with the
DCAwareRoundRobinPolicy
* Hold strong reference to prepared statement while executing it to avoid
garbage collection
* Add NullHandler logging handler to the cassandra package to avoid
warnings about there being no configured logger
* Fix bad handling of nodes that have been removed from the cluster
* Properly escape string types within cql collections
* Handle setting the same keyspace twice in a row
* Avoid race condition during schema agreement checks that could result
in schema update queries returning before all nodes had seen the change
* Preserve millisecond-level precision in datetimes when performing inserts
with simple (non-prepared) statements
* Properly defunct connections when libev reports an error by setting
errno instead of simply logging the error
* Fix endless hanging of some requests when using the libev reactor
* Always start a reconnection process when we fail to connect to
a newly bootstrapped node
* Generators map to CQL lists, not key sequences
* Always defunct connections when an internal operation fails
* Correctly break from handle_write() if nothing was sent (asyncore
reactor only)
* Avoid potential double-erroring of callbacks when a connection
becomes defunct
Features
--------
* Add default query timeout to ``Session``
* Add timeout parameter to ``Session.execute()``
* Add ``WhiteListRoundRobinPolicy`` as a load balancing policy option
* Support for consistency level ``LOCAL_ONE``
* Make the backoff for fetching traces exponentially increasing and
configurable
Other
-----
* Raise Exception if ``TokenAwarePolicy`` is used against a cluster using the
``Murmur3Partitioner`` if the murmur3 C extension has not been compiled
* Add encoder mapping for ``OrderedDict``
* Use timeouts on all control connection queries
* Benchmark improvements, including command line options and eay
multithreading support
* Reduced lock contention when using the asyncore reactor
* Warn when non-datetimes are used for 'timestamp' column values in
prepared statements
* Add requirements.txt and test-requirements.txt
* TravisCI integration for running unit tests against Python 2.6,
Python 2.7, and PyPy
1.0.0b7
=======
Nov 12, 2013
This release makes many stability improvements, especially around
prepared statements and node failure handling. In particular,
several cases where a request would never be completed (and as a
result, leave the application hanging) have been resolved.
Features
--------
* Add `timeout` kwarg to ``ResponseFuture.result()``
* Create connection pools to all hosts in parallel when initializing
new Sesssions.
Bug Fixes
---------
* Properly set exception on ResponseFuture when a query fails
against all hosts
* Improved cleanup and reconnection efforts when reconnection fails
on a node that has recently come up
* Use correct consistency level when retrying failed operations
against a different host. (An invalid consistency level was being
used, causing the retry to fail.)
* Better error messages for failed ``Session.prepare()`` opertaions
* Prepare new statements against all hosts in parallel (formerly
sequential)
* Fix failure to save the new current keyspace on connections. (This
could cause problems for prepared statements and lead to extra
operations to continuously re-set the keyspace.)
* Avoid sharing ``LoadBalancingPolicies`` across ``Cluster`` instances. (When
a second ``Cluster`` was connected, it effectively mark nodes down for the
first ``Cluster``.)
* Better handling of failures during the re-preparation sequence for
unrecognized prepared statements
* Throttle trashing of underutilized connections to avoid trashing newly
created connections
* Fix race condition which could result in trashed connections being closed
before the last operations had completed
* Avoid preparing statements on the event loop thread (which could lead to
deadlock)
* Correctly mark up non-contact point nodes discovered by the control
connection. (This lead to prepared statements not being prepared
against those hosts, generating extra traffic later when the
statements were executed and unrecognized.)
* Correctly handle large messages through libev
* Add timeout to schema agreement check queries
* More complete (and less contended) locking around manipulation of the
pending message deque for libev connections
Other
-----
* Prepare statements in batches of 10. (When many prepared statements
are in use, this allows the driver to start utilizing nodes that
were restarted more quickly.)
* Better debug logging around connection management
* Don't retain unreferenced prepared statements in the local cache.
(If many different prepared statements were created, this would
increase memory usage and greatly increase the amount of time
required to begin utilizing a node that was added or marked
up.)
1.0.0b6
=======
Oct 22, 2013
Bug Fixes
---------
* Use lazy string formatting when logging
* Avoid several deadlock scenarios, especially when nodes go down
* Avoid trashing newly created connections due to insufficient traffic
* Gracefully handle un-handled Exceptions when erroring callbacks
Other
-----
* Node state listeners (which are called when a node is added, removed,
goes down, or comes up) should now be registered through
Cluster.register_listener() instead of through a host's HealthMonitor
(which has been removed)
1.0.0b5
========
Oct 10, 2013
Features
--------
* SSL support
Bug Fixes
---------
* Avoid KeyError when building replica map for NetworkTopologyStrategy
* Work around python bug which causes deadlock when a thread imports
the utf8 module
* Handle no blist library, which is not compatible with pypy
* Avoid deadlock triggered by a keyspace being set on a connection (which
may happen automatically for new connections)
Other
-----
* Switch packaging from Distribute to setuptools, improved C extension
support
* Use PEP 386 compliant beta and post-release versions
1.0.0-beta4
===========
Sep 24, 2013
Features
--------
* Handle new blob syntax in Cassandra 2.0 by accepting bytearray
objects for blob values
* Add cql_version kwarg to Cluster.__init__
Bug Fixes
---------
* Fix KeyError when building token map with NetworkTopologyStrategy
keyspaces (this prevented a Cluster from successfully connecting
at all).
* Don't lose default consitency level from parent PreparedStatement
when creating BoundStatements
1.0.0-beta3
===========
Sep 20, 2013
Features
--------
* Support for LZ4 compression (Cassandra 2.0+)
* Token-aware routing will now utilize all replicas for a query instead
of just the first replica
Bug Fixes
---------
* Fix libev include path for CentOS
* Fix varint packing of the value 0
* Correctly pack unicode values
* Don't attempt to return failed connections to the pool when a final result
is set
* Fix bad iteration of connection credentials
* Use blist's orderedset for set collections and OrderedDict for map
collections so that Cassandra's ordering is preserved
* Fix connection failure on Windows due to unavailability of inet_pton
and inet_ntop. (Note that IPv6 inet_address values are still not
supported on Windows.)
* Boolean constants shouldn't be surrounded by single quotes
* Avoid a potential loss of precision on float constants due to string
formatting
* Actually utilize non-standard ports set on Cluster objects
* Fix export of schema as a set of CQL queries
Other
-----
* Use cStringIO for connection buffer for better performance
* Add __repr__ method for Statement classes
* Raise InvalidTypeParameterError when parameters of the wrong
type are used with statements
* Make all tests compatible with Python 2.6
* Add 1s timeout for opening new connections
1.0.0-beta2
===========
Aug 19, 2013
Bug Fixes
---------
* Fix pip packaging
1.0.0-beta
==========
Aug 16, 2013
Initial release
diff --git a/Jenkinsfile b/Jenkinsfile
new file mode 100644
index 0000000..abb6092
--- /dev/null
+++ b/Jenkinsfile
@@ -0,0 +1,675 @@
+#!groovy
+/*
+
+There are multiple combinations to test the python driver.
+
+Test Profiles:
+
+ Full: Execute all unit and integration tests, including long tests.
+ Standard: Execute unit and integration tests.
+ Smoke Tests: Execute a small subset of tests.
+ EVENT_LOOP: Execute a small subset of tests selected to test EVENT_LOOPs.
+
+Matrix Types:
+
+ Full: All server versions, python runtimes tested with and without Cython.
+ Develop: Smaller matrix for dev purpose.
+ Cassandra: All cassandra server versions.
+ Dse: All dse server versions.
+
+Parameters:
+
+ EVENT_LOOP: 'LIBEV' (Default), 'GEVENT', 'EVENTLET', 'ASYNCIO', 'ASYNCORE', 'TWISTED'
+ CYTHON: Default, 'True', 'False'
+
+*/
+
+@Library('dsdrivers-pipeline-lib@develop')
+import com.datastax.jenkins.drivers.python.Slack
+
+slack = new Slack()
+
+// Define our predefined matrices
+matrices = [
+ "FULL": [
+ "SERVER": ['2.1', '2.2', '3.0', '3.11', '4.0', 'dse-5.0', 'dse-5.1', 'dse-6.0', 'dse-6.7', 'dse-6.8'],
+ "RUNTIME": ['2.7.18', '3.5.9', '3.6.10', '3.7.7', '3.8.3'],
+ "CYTHON": ["True", "False"]
+ ],
+ "DEVELOP": [
+ "SERVER": ['2.1', '3.11', 'dse-6.8'],
+ "RUNTIME": ['2.7.18', '3.6.10'],
+ "CYTHON": ["True", "False"]
+ ],
+ "CASSANDRA": [
+ "SERVER": ['2.1', '2.2', '3.0', '3.11', '4.0'],
+ "RUNTIME": ['2.7.18', '3.5.9', '3.6.10', '3.7.7', '3.8.3'],
+ "CYTHON": ["True", "False"]
+ ],
+ "DSE": [
+ "SERVER": ['dse-5.0', 'dse-5.1', 'dse-6.0', 'dse-6.7', 'dse-6.8'],
+ "RUNTIME": ['2.7.18', '3.5.9', '3.6.10', '3.7.7', '3.8.3'],
+ "CYTHON": ["True", "False"]
+ ]
+]
+
+def getBuildContext() {
+ /*
+ Based on schedule, parameters and branch name, configure the build context and env vars.
+ */
+
+ def driver_display_name = 'Cassandra Python Driver'
+ if (env.GIT_URL.contains('riptano/python-driver')) {
+ driver_display_name = 'private ' + driver_display_name
+ } else if (env.GIT_URL.contains('python-dse-driver')) {
+ driver_display_name = 'DSE Python Driver'
+ }
+
+ def git_sha = "${env.GIT_COMMIT.take(7)}"
+ def github_project_url = "https://${GIT_URL.replaceFirst(/(git@|http:\/\/|https:\/\/)/, '').replace(':', '/').replace('.git', '')}"
+ def github_branch_url = "${github_project_url}/tree/${env.BRANCH_NAME}"
+ def github_commit_url = "${github_project_url}/commit/${env.GIT_COMMIT}"
+
+ def profile = "${params.PROFILE}"
+ def EVENT_LOOP = "${params.EVENT_LOOP.toLowerCase()}"
+ matrixType = "FULL"
+ developBranchPattern = ~"((dev|long)-)?python-.*"
+
+ if (developBranchPattern.matcher(env.BRANCH_NAME).matches()) {
+ matrixType = "DEVELOP"
+ if (env.BRANCH_NAME.contains("long")) {
+ profile = "FULL"
+ }
+ }
+
+ // Check if parameters were set explicitly
+ if (params.MATRIX != "DEFAULT") {
+ matrixType = params.MATRIX
+ }
+
+ matrix = matrices[matrixType].clone()
+ if (params.CYTHON != "DEFAULT") {
+ matrix["CYTHON"] = [params.CYTHON]
+ }
+
+ if (params.SERVER_VERSION != "DEFAULT") {
+ matrix["SERVER"] = [params.SERVER_VERSION]
+ }
+
+ if (params.PYTHON_VERSION != "DEFAULT") {
+ matrix["RUNTIME"] = [params.PYTHON_VERSION]
+ }
+
+ if (params.CI_SCHEDULE == "WEEKNIGHTS") {
+ matrix["SERVER"] = params.CI_SCHEDULE_SERVER_VERSION.split(' ')
+ matrix["RUNTIME"] = params.CI_SCHEDULE_PYTHON_VERSION.split(' ')
+ }
+
+ context = [
+ vars: [
+ "PROFILE=${profile}",
+ "EVENT_LOOP=${EVENT_LOOP}",
+ "DRIVER_DISPLAY_NAME=${driver_display_name}", "GIT_SHA=${git_sha}", "GITHUB_PROJECT_URL=${github_project_url}",
+ "GITHUB_BRANCH_URL=${github_branch_url}", "GITHUB_COMMIT_URL=${github_commit_url}"
+ ],
+ matrix: matrix
+ ]
+
+ return context
+}
+
+def buildAndTest(context) {
+ initializeEnvironment()
+ installDriverAndCompileExtensions()
+
+ try {
+ executeTests()
+ } finally {
+ junit testResults: '*_results.xml'
+ }
+}
+
+def getMatrixBuilds(buildContext) {
+ def tasks = [:]
+ matrix = buildContext.matrix
+
+ matrix["SERVER"].each { serverVersion ->
+ matrix["RUNTIME"].each { runtimeVersion ->
+ matrix["CYTHON"].each { cythonFlag ->
+ def taskVars = [
+ "CASSANDRA_VERSION=${serverVersion}",
+ "PYTHON_VERSION=${runtimeVersion}",
+ "CYTHON_ENABLED=${cythonFlag}"
+ ]
+ def cythonDesc = cythonFlag == "True" ? ", Cython": ""
+ tasks["${serverVersion}, py${runtimeVersion}${cythonDesc}"] = {
+ node("${OS_VERSION}") {
+ checkout scm
+
+ withEnv(taskVars) {
+ buildAndTest(context)
+ }
+ }
+ }
+ }
+ }
+ }
+ return tasks
+}
+
+def initializeEnvironment() {
+ sh label: 'Initialize the environment', script: '''#!/bin/bash -lex
+ pyenv global ${PYTHON_VERSION}
+ sudo apt-get install socat
+ pip install --upgrade pip
+ pip install -U setuptools
+ pip install ${HOME}/ccm
+ '''
+
+ // Determine if server version is Apache CassandraⓇ or DataStax Enterprise
+ if (env.CASSANDRA_VERSION.split('-')[0] == 'dse') {
+ sh label: 'Install DataStax Enterprise requirements', script: '''#!/bin/bash -lex
+ pip install -r test-datastax-requirements.txt
+ '''
+ } else {
+ sh label: 'Install Apache CassandraⓇ requirements', script: '''#!/bin/bash -lex
+ pip install -r test-requirements.txt
+ '''
+
+ sh label: 'Uninstall the geomet dependency since it is not required for Cassandra', script: '''#!/bin/bash -lex
+ pip uninstall -y geomet
+ '''
+ }
+
+ sh label: 'Install unit test modules', script: '''#!/bin/bash -lex
+ pip install nose-ignore-docstring nose-exclude service_identity
+ '''
+
+ if (env.CYTHON_ENABLED == 'True') {
+ sh label: 'Install cython modules', script: '''#!/bin/bash -lex
+ pip install cython numpy
+ '''
+ }
+
+ sh label: 'Download Apache CassandraⓇ or DataStax Enterprise', script: '''#!/bin/bash -lex
+ . ${CCM_ENVIRONMENT_SHELL} ${CASSANDRA_VERSION}
+ '''
+
+ sh label: 'Display Python and environment information', script: '''#!/bin/bash -le
+ # Load CCM environment variables
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ python --version
+ pip --version
+ pip freeze
+ printenv | sort
+ '''
+}
+
+def installDriverAndCompileExtensions() {
+ if (env.CYTHON_ENABLED == 'True') {
+ sh label: 'Install the driver and compile with C extensions with Cython', script: '''#!/bin/bash -lex
+ python setup.py build_ext --inplace
+ '''
+ } else {
+ sh label: 'Install the driver and compile with C extensions without Cython', script: '''#!/bin/bash -lex
+ python setup.py build_ext --inplace --no-cython
+ '''
+ }
+}
+
+def executeStandardTests() {
+
+ sh label: 'Execute unit tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variables
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_results.xml tests/unit/ || true
+ EVENT_LOOP=eventlet VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_eventlet_results.xml tests/unit/io/test_eventletreactor.py || true
+ EVENT_LOOP=gevent VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_gevent_results.xml tests/unit/io/test_geventreactor.py || true
+ '''
+
+ sh label: 'Execute Simulacron integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variables
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ SIMULACRON_JAR="${HOME}/simulacron.jar"
+ SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --exclude test_backpressure.py --xunit-file=simulacron_results.xml tests/integration/simulacron/ || true
+
+ # Run backpressure tests separately to avoid memory issue
+ SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --exclude test_backpressure.py --xunit-file=simulacron_backpressure_1_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_paused_connections || true
+ SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --exclude test_backpressure.py --xunit-file=simulacron_backpressure_2_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_queued_requests_timeout || true
+ SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --exclude test_backpressure.py --xunit-file=simulacron_backpressure_3_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_cluster_busy || true
+ SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --exclude test_backpressure.py --xunit-file=simulacron_backpressure_4_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_node_busy || true
+ '''
+
+ sh label: 'Execute CQL engine integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variables
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=cqle_results.xml tests/integration/cqlengine/ || true
+ '''
+
+ sh label: 'Execute Apache CassandraⓇ integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variables
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/ || true
+ '''
+
+ if (env.CASSANDRA_VERSION.split('-')[0] == 'dse' && env.CASSANDRA_VERSION.split('-')[1] != '4.8') {
+ sh label: 'Execute DataStax Enterprise integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variable
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} DSE_VERSION=${DSE_VERSION} ADS_HOME="${HOME}/" VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=dse_results.xml tests/integration/advanced/ || true
+ '''
+ }
+
+ sh label: 'Execute DataStax Constellation integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variable
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CLOUD_PROXY_PATH="${HOME}/proxy/" CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=advanced_results.xml tests/integration/cloud/ || true
+ '''
+
+ if (env.PROFILE == 'FULL') {
+ sh label: 'Execute long running integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variable
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --exclude-dir=tests/integration/long/upgrade --with-ignore-docstrings --with-xunit --xunit-file=long_results.xml tests/integration/long/ || true
+ '''
+ }
+}
+
+def executeDseSmokeTests() {
+ sh label: 'Execute profile DataStax Enterprise smoke test integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variable
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} DSE_VERSION=${DSE_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/test_dse.py || true
+ '''
+}
+
+def executeEventLoopTests() {
+ sh label: 'Execute profile event loop manager integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variable
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP_TESTS=(
+ "tests/integration/standard/test_cluster.py"
+ "tests/integration/standard/test_concurrent.py"
+ "tests/integration/standard/test_connection.py"
+ "tests/integration/standard/test_control_connection.py"
+ "tests/integration/standard/test_metrics.py"
+ "tests/integration/standard/test_query.py"
+ "tests/integration/simulacron/test_endpoint.py"
+ "tests/integration/long/test_ssl.py"
+ )
+ EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml ${EVENT_LOOP_TESTS[@]} || true
+ '''
+}
+
+def executeTests() {
+ switch(env.PROFILE) {
+ case 'DSE-SMOKE-TEST':
+ executeDseSmokeTests()
+ break
+ case 'EVENT_LOOP':
+ executeEventLoopTests()
+ break
+ default:
+ executeStandardTests()
+ break
+ }
+}
+
+
+// TODO move this in the shared lib
+def getDriverMetricType() {
+ metric_type = 'oss'
+ if (env.GIT_URL.contains('riptano/python-driver')) {
+ metric_type = 'oss-private'
+ } else if (env.GIT_URL.contains('python-dse-driver')) {
+ metric_type = 'dse'
+ }
+ return metric_type
+}
+
+def submitCIMetrics(buildType) {
+ long durationMs = currentBuild.duration
+ long durationSec = durationMs / 1000
+ long nowSec = (currentBuild.startTimeInMillis + durationMs) / 1000
+ def branchNameNoPeriods = env.BRANCH_NAME.replaceAll('\\.', '_')
+ metric_type = getDriverMetricType()
+ def durationMetric = "okr.ci.python.${metric_type}.${buildType}.${branchNameNoPeriods} ${durationSec} ${nowSec}"
+
+ timeout(time: 1, unit: 'MINUTES') {
+ withCredentials([string(credentialsId: 'lab-grafana-address', variable: 'LAB_GRAFANA_ADDRESS'),
+ string(credentialsId: 'lab-grafana-port', variable: 'LAB_GRAFANA_PORT')]) {
+ withEnv(["DURATION_METRIC=${durationMetric}"]) {
+ sh label: 'Send runtime metrics to labgrafana', script: '''#!/bin/bash -lex
+ echo "${DURATION_METRIC}" | nc -q 5 ${LAB_GRAFANA_ADDRESS} ${LAB_GRAFANA_PORT}
+ '''
+ }
+ }
+ }
+}
+
+def describeBuild(buildContext) {
+ script {
+ def runtimes = buildContext.matrix["RUNTIME"]
+ def serverVersions = buildContext.matrix["SERVER"]
+ def numBuilds = runtimes.size() * serverVersions.size() * buildContext.matrix["CYTHON"].size()
+ currentBuild.displayName = "${env.PROFILE} (${env.EVENT_LOOP} | ${numBuilds} builds)"
+ currentBuild.description = "${env.PROFILE} build testing servers (${serverVersions.join(', ')}) against Python (${runtimes.join(', ')}) using ${env.EVENT_LOOP} event loop manager"
+ }
+}
+
+def scheduleTriggerJobName = "drivers/python/oss/master/disabled"
+
+pipeline {
+ agent none
+
+ // Global pipeline timeout
+ options {
+ timeout(time: 10, unit: 'HOURS') // TODO timeout should be per build
+ buildDiscarder(logRotator(artifactNumToKeepStr: '10', // Keep only the last 10 artifacts
+ numToKeepStr: '50')) // Keep only the last 50 build records
+ }
+
+ parameters {
+ choice(
+ name: 'ADHOC_BUILD_TYPE',
+ choices: ['BUILD', 'BUILD-AND-EXECUTE-TESTS'],
+ description: '''
Perform a adhoc build operation
+
+
+
+
+
Choice
+
Description
+
+
+
BUILD
+
Performs a Per-Commit build
+
+
+
BUILD-AND-EXECUTE-TESTS
+
Performs a build and executes the integration and unit tests
Performs a build and executes the integration and unit tests
+
+
''')
+ choice(
+ name: 'ADHOC_BUILD_AND_EXECUTE_TESTS_PYTHON_VERSION',
+ choices: ['2.7.18', '3.4.10', '3.5.9', '3.6.10', '3.7.7', '3.8.3'],
+ description: 'Python version to use for adhoc BUILD-AND-EXECUTE-TESTSONLY!')
+ choice(
+ name: 'ADHOC_BUILD_AND_EXECUTE_TESTS_SERVER_VERSION',
+ choices: ['2.1', // Legacy Apache CassandraⓇ
+ '2.2', // Legacy Apache CassandraⓇ
+ '3.0', // Previous Apache CassandraⓇ
+ '3.11', // Current Apache CassandraⓇ
+ '4.0', // Development Apache CassandraⓇ
+ 'dse-5.0', // Long Term Support DataStax Enterprise
+ 'dse-5.1', // Legacy DataStax Enterprise
+ 'dse-6.0', // Previous DataStax Enterprise
+ 'dse-6.7', // Previous DataStax Enterprise
+ 'dse-6.8', // Current DataStax Enterprise
+ 'ALL'],
+ description: '''Apache CassandraⓇ and DataStax Enterprise server version to use for adhoc BUILD-AND-EXECUTE-TESTSONLY!
+
+
+
+
+
Choice
+
Description
+
+
+
2.1
+
Apache CassandaraⓇ; v2.1.x
+
+
+
2.2
+
Apache CassandarⓇ; v2.2.x
+
+
+
3.0
+
Apache CassandaraⓇ v3.0.x
+
+
+
3.11
+
Apache CassandaraⓇ v3.11.x
+
+
+
4.0
+
Apache CassandaraⓇ v4.x (CURRENTLY UNDER DEVELOPMENT)
+
+
+
dse-5.0
+
DataStax Enterprise v5.0.x (Long Term Support)
+
+
+
dse-5.1
+
DataStax Enterprise v5.1.x
+
+
+
dse-6.0
+
DataStax Enterprise v6.0.x
+
+
+
dse-6.7
+
DataStax Enterprise v6.7.x
+
+
+
dse-6.8
+
DataStax Enterprise v6.8.x (CURRENTLY UNDER DEVELOPMENT)
+
+
''')
+ booleanParam(
+ name: 'CYTHON',
+ defaultValue: false,
+ description: 'Flag to determine if Cython should be enabled for scheduled or adhoc builds')
+ booleanParam(
+ name: 'EXECUTE_LONG_TESTS',
+ defaultValue: false,
+ description: 'Flag to determine if long integration tests should be executed for scheduled or adhoc builds')
+ choice(
+ name: 'EVENT_LOOP_MANAGER',
+ choices: ['LIBEV', 'GEVENT', 'EVENTLET', 'ASYNCIO', 'ASYNCORE', 'TWISTED'],
+ description: '''
Event loop manager to utilize for scheduled or adhoc builds
+
+
+
+
+
Choice
+
Description
+
+
+
LIBEV
+
A full-featured and high-performance event loop that is loosely modeled after libevent, but without its limitations and bugs
+
+
+
GEVENT
+
A co-routine -based Python networking library that uses greenlet to provide a high-level synchronous API on top of the libev or libuv event loop
+
+
+
EVENTLET
+
A concurrent networking library for Python that allows you to change how you run your code, not how you write it
+
+
+
ASYNCIO
+
A library to write concurrent code using the async/await syntax
+
+
+
ASYNCORE
+
A module provides the basic infrastructure for writing asynchronous socket service clients and servers
+
+
+
TWISTED
+
An event-driven networking engine written in Python and licensed under the open source MIT license
Execute only the event loop tests for the specified event loop manager (see: EVENT_LOOP_MANAGER)
+
+
+
UPGRADE
+
Execute only the upgrade tests
+
+
''')
+ choice(
+ name: 'CI_SCHEDULE',
+ choices: ['DO-NOT-CHANGE-THIS-SELECTION', 'WEEKNIGHTS', 'WEEKENDS'],
+ description: 'CI testing schedule to execute periodically scheduled builds and tests of the driver (DO NOT CHANGE THIS SELECTION)')
+ string(
+ name: 'CI_SCHEDULE_PYTHON_VERSION',
+ defaultValue: 'DO-NOT-CHANGE-THIS-SELECTION',
+ description: 'CI testing python version to utilize for scheduled test runs of the driver (DO NOT CHANGE THIS SELECTION)')
+ string(
+ name: 'CI_SCHEDULE_SERVER_VERSION',
+ defaultValue: 'DO-NOT-CHANGE-THIS-SELECTION',
+ description: 'CI testing server version to utilize for scheduled test runs of the driver (DO NOT CHANGE THIS SELECTION)')
+ }
+
+ triggers {
+ parameterizedCron((branchPatternCron.matcher(env.BRANCH_NAME).matches() && !riptanoPatternCron.matcher(GIT_URL).find()) ? """
+ # Every weeknight (Monday - Friday) around 4:00 AM
+ # These schedules will run with and without Cython enabled for Python v2.7.18 and v3.5.9
+ H 4 * * 1-5 %CI_SCHEDULE=WEEKNIGHTS;EVENT_LOOP_MANAGER=LIBEV;CI_SCHEDULE_PYTHON_VERSION=2.7.18;CI_SCHEDULE_SERVER_VERSION=2.2 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 4 * * 1-5 %CI_SCHEDULE=WEEKNIGHTS;EVENT_LOOP_MANAGER=LIBEV;CI_SCHEDULE_PYTHON_VERSION=3.5.9;CI_SCHEDULE_SERVER_VERSION=2.2 3.11 dse-5.1 dse-6.0 dse-6.7
+
+ # Every Saturday around 12:00, 4:00 and 8:00 PM
+ # These schedules are for weekly libev event manager runs with and without Cython for most of the Python versions (excludes v3.5.9.x)
+ H 12 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=LIBEV;CI_SCHEDULE_PYTHON_VERSION=2.7.18;CI_SCHEDULE_SERVER_VERSION=2.1 3.0 dse-5.1 dse-6.0 dse-6.7
+ H 12 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=LIBEV;CI_SCHEDULE_PYTHON_VERSION=3.4.10;CI_SCHEDULE_SERVER_VERSION=2.1 3.0 dse-5.1 dse-6.0 dse-6.7
+ H 12 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=LIBEV;CI_SCHEDULE_PYTHON_VERSION=3.6.10;CI_SCHEDULE_SERVER_VERSION=2.1 3.0 dse-5.1 dse-6.0 dse-6.7
+ H 12 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=LIBEV;CI_SCHEDULE_PYTHON_VERSION=3.7.7;CI_SCHEDULE_SERVER_VERSION=2.1 3.0 dse-5.1 dse-6.0 dse-6.7
+ H 12 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=LIBEV;CI_SCHEDULE_PYTHON_VERSION=3.8.3;CI_SCHEDULE_SERVER_VERSION=2.1 3.0 dse-5.1 dse-6.0 dse-6.7
+ # These schedules are for weekly gevent event manager event loop only runs with and without Cython for most of the Python versions (excludes v3.4.10.x)
+ H 16 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=GEVENT;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=2.7.18;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 16 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=GEVENT;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.5.9;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 16 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=GEVENT;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.6.10;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 16 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=GEVENT;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.7.7;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 16 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=GEVENT;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.8.3;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ # These schedules are for weekly eventlet event manager event loop only runs with and without Cython for most of the Python versions (excludes v3.4.10.x)
+ H 20 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=EVENTLET;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=2.7.18;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 20 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=EVENTLET;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.5.9;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 20 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=EVENTLET;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.6.10;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 20 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=EVENTLET;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.7.7;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 20 * * 6 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=EVENTLET;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.8.3;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+
+ # Every Sunday around 12:00 and 4:00 AM
+ # These schedules are for weekly asyncore event manager event loop only runs with and without Cython for most of the Python versions (excludes v3.4.10.x)
+ H 0 * * 7 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=ASYNCORE;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=2.7.18;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 0 * * 7 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=ASYNCORE;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.5.9;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 0 * * 7 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=ASYNCORE;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.6.10;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 0 * * 7 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=ASYNCORE;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.7.7;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 0 * * 7 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=ASYNCORE;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.8.3;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ # These schedules are for weekly twisted event manager event loop only runs with and without Cython for most of the Python versions (excludes v3.4.10.x)
+ H 4 * * 7 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=TWISTED;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=2.7.18;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 4 * * 7 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=TWISTED;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.5.9;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 4 * * 7 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=TWISTED;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.6.10;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 4 * * 7 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=TWISTED;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.7.7;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ H 4 * * 7 %CI_SCHEDULE=WEEKENDS;EVENT_LOOP_MANAGER=TWISTED;PROFILE=EVENT-LOOP;CI_SCHEDULE_PYTHON_VERSION=3.8.3;CI_SCHEDULE_SERVER_VERSION=2.1 2.2 3.0 3.11 dse-5.1 dse-6.0 dse-6.7
+ """ : "")
+ }
+
+ environment {
+ OS_VERSION = 'ubuntu/bionic64/python-driver'
+ CYTHON_ENABLED = "${params.CYTHON ? 'True' : 'False'}"
+ EVENT_LOOP_MANAGER = "${params.EVENT_LOOP_MANAGER.toLowerCase()}"
+ EXECUTE_LONG_TESTS = "${params.EXECUTE_LONG_TESTS ? 'True' : 'False'}"
+ CCM_ENVIRONMENT_SHELL = '/usr/local/bin/ccm_environment.sh'
+ CCM_MAX_HEAP_SIZE = '1536M'
+ }
+
+ stages {
+ stage ('Per-Commit') {
+ options {
+ timeout(time: 2, unit: 'HOURS')
+ }
+ when {
+ beforeAgent true
+ branch pattern: '((dev|long)-)?python-.*', comparator: 'REGEXP'
+ allOf {
+ expression { params.ADHOC_BUILD_TYPE == 'BUILD' }
+ expression { params.CI_SCHEDULE == 'DO-NOT-CHANGE-THIS-SELECTION' }
+ not { buildingTag() }
+ }
+ }
+
+ matrix {
+ axes {
+ axis {
+ name 'CASSANDRA_VERSION'
+ values '3.11', // Current Apache Cassandra
+ 'dse-6.8' // Current DataStax Enterprise
+ }
+ axis {
+ name 'PYTHON_VERSION'
+ values '2.7.18', '3.5.9'
+ }
+ axis {
+ name 'CYTHON_ENABLED'
+ values 'False'
+ }
+ }
+
+ agent {
+ label "${OS_VERSION}"
+ }
+
+ stages {
+ stage('Initialize-Environment') {
+ steps {
+ initializeEnvironment()
+ script {
+ if (env.BUILD_STATED_SLACK_NOTIFIED != 'true') {
+ notifySlack()
+ }
+ }
+ }
+ }
+ stage('Describe-Build') {
+ steps {
+ describePerCommitStage()
+ }
+ }
+ stage('Install-Driver-And-Compile-Extensions') {
+ steps {
+ installDriverAndCompileExtensions()
+ }
+ }
+ stage('Execute-Tests') {
+ steps {
+
+ script {
+ if (env.BRANCH_NAME ==~ /long-python.*/) {
+ withEnv(["EXECUTE_LONG_TESTS=True"]) {
+ executeTests()
+ }
+ }
+ else {
+ executeTests()
+ }
+ }
+ }
+ post {
+ always {
+ junit testResults: '*_results.xml'
+ }
+ }
+ }
+ }
+ }
+ post {
+ always {
+ node('master') {
+ submitCIMetrics('commit')
+ }
+ }
+ aborted {
+ notifySlack('aborted')
+ }
+ success {
+ notifySlack('completed')
+ }
+ unstable {
+ notifySlack('unstable')
+ }
+ failure {
+ notifySlack('FAILED')
+ }
+ }
+ }
+
+ stage ('Scheduled-Testing') {
+ when {
+ beforeAgent true
+ allOf {
+ expression { params.ADHOC_BUILD_TYPE == 'BUILD' }
+ expression { params.CI_SCHEDULE != 'DO-NOT-CHANGE-THIS-SELECTION' }
+ not { buildingTag() }
+ }
+ }
+ matrix {
+ axes {
+ axis {
+ name 'CASSANDRA_VERSION'
+ values '2.1', // Legacy Apache Cassandra
+ '2.2', // Legacy Apache Cassandra
+ '3.0', // Previous Apache Cassandra
+ '3.11', // Current Apache Cassandra
+ 'dse-5.1', // Legacy DataStax Enterprise
+ 'dse-6.0', // Previous DataStax Enterprise
+ 'dse-6.7' // Current DataStax Enterprise
+ }
+ axis {
+ name 'CYTHON_ENABLED'
+ values 'True', 'False'
+ }
+ }
+ when {
+ beforeAgent true
+ allOf {
+ expression { return params.CI_SCHEDULE_SERVER_VERSION.split(' ').any { it =~ /(ALL|${env.CASSANDRA_VERSION})/ } }
+ }
+ }
+
+ environment {
+ PYTHON_VERSION = "${params.CI_SCHEDULE_PYTHON_VERSION}"
+ }
+ agent {
+ label "${OS_VERSION}"
+ }
+
+ stages {
+ stage('Initialize-Environment') {
+ steps {
+ initializeEnvironment()
+ script {
+ if (env.BUILD_STATED_SLACK_NOTIFIED != 'true') {
+ notifySlack()
+ }
+ }
+ }
+ }
+ stage('Describe-Build') {
+ steps {
+ describeScheduledTestingStage()
+ }
+ }
+ stage('Install-Driver-And-Compile-Extensions') {
+ steps {
+ installDriverAndCompileExtensions()
+ }
+ }
+ stage('Execute-Tests') {
+ steps {
+ executeTests()
+ }
+ post {
+ always {
+ junit testResults: '*_results.xml'
+ }
+ }
+ }
+ }
+ }
+ post {
+ aborted {
+ notifySlack('aborted')
+ }
+ success {
+ notifySlack('completed')
+ }
+ unstable {
+ notifySlack('unstable')
+ }
+ failure {
+ notifySlack('FAILED')
+ }
+ }
+ }
+
+
+ stage('Adhoc-Testing') {
+ when {
+ beforeAgent true
+ allOf {
+ expression { params.ADHOC_BUILD_TYPE == 'BUILD-AND-EXECUTE-TESTS' }
+ not { buildingTag() }
+ }
+ }
+
+ environment {
+ CYTHON_ENABLED = "${params.CYTHON ? 'True' : 'False'}"
+ PYTHON_VERSION = "${params.ADHOC_BUILD_AND_EXECUTE_TESTS_PYTHON_VERSION}"
+ }
+
+ matrix {
+ axes {
+ axis {
+ name 'CASSANDRA_VERSION'
+ values '2.1', // Legacy Apache Cassandra
+ '2.2', // Legacy Apache Cassandra
+ '3.0', // Previous Apache Cassandra
+ '3.11', // Current Apache Cassandra
+ '4.0', // Development Apache Cassandra
+ 'dse-5.0', // Long Term Support DataStax Enterprise
+ 'dse-5.1', // Legacy DataStax Enterprise
+ 'dse-6.0', // Previous DataStax Enterprise
+ 'dse-6.7', // Current DataStax Enterprise
+ 'dse-6.8' // Development DataStax Enterprise
+ }
+ }
+ when {
+ beforeAgent true
+ allOf {
+ expression { params.ADHOC_BUILD_AND_EXECUTE_TESTS_SERVER_VERSION ==~ /(ALL|${env.CASSANDRA_VERSION})/ }
+ }
+ }
+
+ agent {
+ label "${OS_VERSION}"
+ }
+
+ stages {
+ stage('Describe-Build') {
+ steps {
+ describeAdhocTestingStage()
+ }
+ }
+ stage('Initialize-Environment') {
+ steps {
+ initializeEnvironment()
+ }
+ }
+ stage('Install-Driver-And-Compile-Extensions') {
+ steps {
+ installDriverAndCompileExtensions()
+ }
+ }
+ stage('Execute-Tests') {
+ steps {
+ executeTests()
+ }
+ post {
+ always {
+ junit testResults: '*_results.xml'
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/README-dev.rst b/README-dev.rst
index c10aed2..85a722c 100644
--- a/README-dev.rst
+++ b/README-dev.rst
@@ -1,171 +1,251 @@
Releasing
=========
* Run the tests and ensure they all pass
* Update CHANGELOG.rst
-
* Check for any missing entries
* Add today's date to the release section
* Update the version in ``cassandra/__init__.py``
-
* For beta releases, use a version like ``(2, 1, '0b1')``
* For release candidates, use a version like ``(2, 1, '0rc1')``
* When in doubt, follow PEP 440 versioning
* Add the new version in ``docs.yaml``
-
* Commit the changelog and version changes, e.g. ``git commit -m'version 1.0.0'``
* Tag the release. For example: ``git tag -a 1.0.0 -m 'version 1.0.0'``
* Push the tag and new ``master``: ``git push origin 1.0.0 ; git push origin master``
-* Upload the package to pypi::
+* Update the `python-driver` submodule of `python-driver-wheels`,
+ commit then push. This will trigger TravisCI and the wheels building.
+* For a GA release, upload the package to pypi::
+
+ # Clean the working directory
+ python setup.py clean
+ rm dist/*
- python setup.py register
- python setup.py sdist upload
+ # Build the source distribution
+ python setup.py sdist
+
+ # Download all wheels from the jfrog repository and copy them in
+ # the dist/ directory
+ cp /path/to/wheels/*.whl dist/
+
+ # Upload all files
+ twine upload dist/*
* On pypi, make the latest GA the only visible version
* Update the docs (see below)
* Append a 'postN' string to the version tuple in ``cassandra/__init__.py``
so that it looks like ``(x, y, z, 'postN')``
* After a beta or rc release, this should look like ``(2, 1, '0b1', 'post0')``
+* After the release has been tagged, add a section to docs.yaml with the new tag ref::
+
+ versions:
+ - name:
+ ref:
+
* Commit and push
* Update 'cassandra-test' branch to reflect new release
* this is typically a matter of merging or rebasing onto master
* test and push updated branch to origin
* Update the JIRA versions: https://datastax-oss.atlassian.net/plugins/servlet/project-config/PYTHON/versions
* add release dates and set version as "released"
* Make an announcement on the mailing list
Building the Docs
=================
Sphinx is required to build the docs. You probably want to install through apt,
if possible::
sudo apt-get install python-sphinx
pip may also work::
sudo pip install -U Sphinx
To build the docs, run::
python setup.py doc
Upload the Docs
=================
This is deprecated. The docs is now only published on https://docs.datastax.com.
To upload the docs, checkout the ``gh-pages`` branch and copy the entire
contents all of ``docs/_build/X.Y.Z/*`` into the root of the ``gh-pages`` branch
and then push that branch to github.
For example::
git checkout 1.0.0
python setup.py doc
git checkout gh-pages
cp -R docs/_build/1.0.0/* .
git add --update # add modified files
# Also make sure to add any new documentation files!
git commit -m 'Update docs (version 1.0.0)'
git push origin gh-pages
If docs build includes errors, those errors may not show up in the next build unless
you have changed the files with errors. It's good to occassionally clear the build
directory and build from scratch::
rm -rf docs/_build/*
-Running the Tests
-=================
-In order for the extensions to be built and used in the test, run::
+Documentor
+==========
+We now also use another tool called Documentor with Sphinx source to build docs.
+This gives us versioned docs with nice integrated search. This is a private tool
+of DataStax.
- nosetests
+Dependencies
+------------
+Sphinx
+~~~~~~
+Installed as described above
-You can run a specific test module or package like so::
+Documentor
+~~~~~~~~~~
+Clone and setup Documentor as specified in `the project `_.
+This tool assumes Ruby, bundler, and npm are present.
- nosetests -w tests/unit/
+Building
+--------
+The setup script expects documentor to be in the system path. You can either add it permanently or run with something
+like this::
-You can run a specific test method like so::
+ PATH=$PATH:/bin python setup.py doc
- nosetests -w tests/unit/test_connection.py:ConnectionTest.test_bad_protocol_version
+The docs will not display properly just browsing the filesystem in a browser. To view the docs as they would be in most
+web servers, use the SimpleHTTPServer module::
-Seeing Test Logs in Real Time
------------------------------
-Sometimes it's useful to output logs for the tests as they run::
+ cd docs/_build/
+ python -m SimpleHTTPServer
- nosetests -w tests/unit/ --nocapture --nologcapture
+Then, browse to `localhost:8000 `_.
-Use tee to capture logs and see them on your terminal::
+Tests
+=====
- nosetests -w tests/unit/ --nocapture --nologcapture 2>&1 | tee test.log
+Running Unit Tests
+------------------
+Unit tests can be run like so::
-Specifying a Cassandra Version for Integration Tests
-----------------------------------------------------
-You can specify a cassandra version with the ``CASSANDRA_VERSION`` environment variable::
+ nosetests -w tests/unit/
+
+You can run a specific test method like so::
+
+ nosetests -w tests/unit/test_connection.py:ConnectionTest.test_bad_protocol_version
+
+Running Integration Tests
+-------------------------
+In order to run integration tests, you must specify a version to run using the ``CASSANDRA_VERSION`` or ``DSE_VERSION`` environment variable::
CASSANDRA_VERSION=2.0.9 nosetests -w tests/integration/standard
-You can also specify a cassandra directory (to test unreleased versions)::
+Or you can specify a cassandra directory (to test unreleased versions)::
- CASSANDRA_DIR=/home/thobbs/cassandra nosetests -w tests/integration/standard
+ CASSANDRA_DIR=/home/thobbs/cassandra nosetests -w tests/integration/standard/
Specifying the usage of an already running Cassandra cluster
-----------------------------------------------------
-The test will start the appropriate Cassandra clusters when necessary but if you don't want this to happen because a Cassandra cluster is already running the flag ``USE_CASS_EXTERNAL`` can be used, for example:
+------------------------------------------------------------
+The test will start the appropriate Cassandra clusters when necessary but if you don't want this to happen because a Cassandra cluster is already running the flag ``USE_CASS_EXTERNAL`` can be used, for example::
- USE_CASS_EXTERNAL=1 python setup.py nosetests -w tests/integration/standard
+ USE_CASS_EXTERNAL=1 CASSANDRA_VERSION=2.0.9 nosetests -w tests/integration/standard
Specify a Protocol Version for Tests
------------------------------------
The protocol version defaults to 1 for cassandra 1.2 and 2 otherwise. You can explicitly set
it with the ``PROTOCOL_VERSION`` environment variable::
PROTOCOL_VERSION=3 nosetests -w tests/integration/standard
+Seeing Test Logs in Real Time
+-----------------------------
+Sometimes it's useful to output logs for the tests as they run::
+
+ nosetests -w tests/unit/ --nocapture --nologcapture
+
+Use tee to capture logs and see them on your terminal::
+
+ nosetests -w tests/unit/ --nocapture --nologcapture 2>&1 | tee test.log
+
Testing Multiple Python Versions
--------------------------------
-If you want to test all of python 2.7, 3.4, 3.5, 3.6 and pypy, use tox (this is what
+If you want to test all of python 2.7, 3.5, 3.6, 3.7, and pypy, use tox (this is what
TravisCI runs)::
tox
-By default, tox only runs the unit tests because I haven't put in the effort
-to get the integration tests to run on TravicCI. However, the integration
-tests should work locally. To run them, edit the following line in tox.ini::
-
- commands = {envpython} setup.py build_ext --inplace nosetests --verbosity=2 tests/unit/
-
-and change ``tests/unit/`` to ``tests/``.
+By default, tox only runs the unit tests.
Running the Benchmarks
======================
There needs to be a version of cassandra running locally so before running the benchmarks, if ccm is installed:
ccm create benchmark_cluster -v 3.0.1 -n 1 -s
To run the benchmarks, pick one of the files under the ``benchmarks/`` dir and run it::
python benchmarks/future_batches.py
There are a few options. Use ``--help`` to see them all::
python benchmarks/future_batches.py --help
Packaging for Cassandra
=======================
A source distribution is included in Cassandra, which uses the driver internally for ``cqlsh``.
To package a released version, checkout the tag and build a source zip archive::
python setup.py sdist --formats=zip
If packaging a pre-release (untagged) version, it is useful to include a commit hash in the archive
name to specify the built version::
python setup.py egg_info -b-`git rev-parse --short HEAD` sdist --formats=zip
The file (``dist/cassandra-driver-.zip``) is packaged with Cassandra in ``cassandra/lib/cassandra-driver-internal-only*zip``.
+
+Releasing an EAP
+================
+
+An EAP release is only uploaded on a private server and it is not published on pypi.
+
+* Clean the environment::
+
+ python setup.py clean
+
+* Package the source distribution::
+
+ python setup.py sdist
+
+* Test the source distribution::
+
+ pip install dist/cassandra-driver-.tar.gz
+
+* Upload the package on the EAP download server.
+* Build the documentation::
+
+ python setup.py doc
+
+* Upload the docs on the EAP download server.
+
+Adding a New Python Runtime Support
+===================================
+
+* Add the new python version to our jenkins image:
+ https://github.com/riptano/openstack-jenkins-drivers/
+
+* Add the new python version in job-creator:
+ https://github.com/riptano/job-creator/
+
+* Run the tests and ensure they all pass
+ * also test all event loops
+
+* Update the wheels building repo to support that version:
+ https://github.com/riptano/python-dse-driver-wheels
diff --git a/README.rst b/README.rst
index b98463c..7c5bf1e 100644
--- a/README.rst
+++ b/README.rst
@@ -1,89 +1,87 @@
-DataStax Python Driver for Apache Cassandra
-===========================================
+DataStax Driver for Apache Cassandra
+====================================
-.. image:: https://travis-ci.org/datastax/python-driver.png?branch=master
- :target: https://travis-ci.org/datastax/python-driver
+.. image:: https://travis-ci.com/datastax/python-driver.png?branch=master
+ :target: https://travis-ci.com/github/datastax/python-driver
-A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3.
+A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) and
+DataStax Enterprise (4.7+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3.
-The driver supports Python 2.7, 3.4, 3.5, 3.6 and 3.7.
-
-If you require compatibility with DataStax Enterprise, use the `DataStax Enterprise Python Driver `_.
+The driver supports Python 2.7, 3.5, 3.6, 3.7 and 3.8.
**Note:** DataStax products do not support big-endian systems.
-Feedback Requested
-------------------
-**Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short).
-
Features
--------
* `Synchronous `_ and `Asynchronous `_ APIs
* `Simple, Prepared, and Batch statements `_
* Asynchronous IO, parallel execution, request pipelining
* `Connection pooling `_
* Automatic node discovery
* `Automatic reconnection `_
* Configurable `load balancing `_ and `retry policies `_
* `Concurrent execution utilities `_
* `Object mapper `_
-* `Connecting to DataStax Apollo database (cloud) `_
+* `Connecting to DataStax Astra database (cloud) `_
+* DSE Graph execution API
+* DSE Geometric type serialization
+* DSE PlainText and GSSAPI authentication
Installation
------------
Installation through pip is recommended::
$ pip install cassandra-driver
For more complete installation instructions, see the
`installation guide `_.
Documentation
-------------
The documentation can be found online `here `_.
A couple of links for getting up to speed:
* `Installation `_
* `Getting started guide `_
* `API docs `_
* `Performance tips `_
Object Mapper
-------------
cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the
community) is now maintained as an integral part of this package. Refer to
`documentation here `_.
Contributing
------------
See `CONTRIBUTING.md `_.
Reporting Problems
------------------
Please report any bugs and make any feature requests on the
`JIRA `_ issue tracker.
If you would like to contribute, please feel free to open a pull request.
Getting Help
------------
Your best options for getting help with the driver are the
`mailing list `_
-and the ``#datastax-drivers`` channel in the `DataStax Academy Slack `_.
+and the `DataStax Community `_.
License
-------
Copyright DataStax, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
diff --git a/appveyor/appveyor.ps1 b/appveyor/appveyor.ps1
index cc1e6aa..5f6840e 100644
--- a/appveyor/appveyor.ps1
+++ b/appveyor/appveyor.ps1
@@ -1,80 +1,80 @@
$env:JAVA_HOME="C:\Program Files\Java\jdk1.8.0"
$env:PATH="$($env:JAVA_HOME)\bin;$($env:PATH)"
$env:CCM_PATH="C:\Users\appveyor\ccm"
$env:CASSANDRA_VERSION=$env:cassandra_version
$env:EVENT_LOOP_MANAGER="asyncore"
$env:SIMULACRON_JAR="C:\Users\appveyor\simulacron-standalone-0.7.0.jar"
python --version
python -c "import platform; print(platform.architecture())"
# Install Ant
Start-Process cinst -ArgumentList @("-y","ant") -Wait -NoNewWindow
# Workaround for ccm, link ant.exe -> ant.bat
If (!(Test-Path C:\ProgramData\chocolatey\bin\ant.bat)) {
cmd /c mklink C:\ProgramData\chocolatey\bin\ant.bat C:\ProgramData\chocolatey\bin\ant.exe
}
$jce_indicator = "$target\README.txt"
# Install Java Cryptographic Extensions, needed for SSL.
If (!(Test-Path $jce_indicator)) {
$zip = "C:\Users\appveyor\jce_policy-$($env:java_version).zip"
$target = "$($env:JAVA_HOME)\jre\lib\security"
# If this file doesn't exist we know JCE hasn't been installed.
$url = "https://www.dropbox.com/s/po4308hlwulpvep/UnlimitedJCEPolicyJDK7.zip?dl=1"
$extract_folder = "UnlimitedJCEPolicy"
If ($env:java_version -eq "1.8.0") {
$url = "https://www.dropbox.com/s/al1e6e92cjdv7m7/jce_policy-8.zip?dl=1"
$extract_folder = "UnlimitedJCEPolicyJDK8"
}
# Download zip to staging area if it doesn't exist, we do this because
# we extract it to the directory based on the platform and we want to cache
# this file so it can apply to all platforms.
if(!(Test-Path $zip)) {
(new-object System.Net.WebClient).DownloadFile($url, $zip)
}
Add-Type -AssemblyName System.IO.Compression.FileSystem
[System.IO.Compression.ZipFile]::ExtractToDirectory($zip, $target)
$jcePolicyDir = "$target\$extract_folder"
Move-Item $jcePolicyDir\* $target\ -force
Remove-Item $jcePolicyDir
}
# Download simulacron
$simulacron_url = "https://github.com/datastax/simulacron/releases/download/0.7.0/simulacron-standalone-0.7.0.jar"
$simulacron_jar = $env:SIMULACRON_JAR
if(!(Test-Path $simulacron_jar)) {
(new-object System.Net.WebClient).DownloadFile($simulacron_url, $simulacron_jar)
}
# Install Python Dependencies for CCM.
Start-Process python -ArgumentList "-m pip install psutil pyYaml six numpy" -Wait -NoNewWindow
# Clone ccm from git and use master.
If (!(Test-Path $env:CCM_PATH)) {
- Start-Process git -ArgumentList "clone https://github.com/pcmanus/ccm.git $($env:CCM_PATH)" -Wait -NoNewWindow
+ Start-Process git -ArgumentList "clone -b cassandra-test https://github.com/pcmanus/ccm.git $($env:CCM_PATH)" -Wait -NoNewWindow
}
# Copy ccm -> ccm.py so windows knows to run it.
If (!(Test-Path $env:CCM_PATH\ccm.py)) {
Copy-Item "$env:CCM_PATH\ccm" "$env:CCM_PATH\ccm.py"
}
$env:PYTHONPATH="$($env:CCM_PATH);$($env:PYTHONPATH)"
$env:PATH="$($env:CCM_PATH);$($env:PATH)"
# Predownload cassandra version for CCM if it isn't already downloaded.
# This is necessary because otherwise ccm fails
If (!(Test-Path C:\Users\appveyor\.ccm\repository\$env:cassandra_version)) {
Start-Process python -ArgumentList "$($env:CCM_PATH)\ccm.py create -v $($env:cassandra_version) -n 1 predownload" -Wait -NoNewWindow
echo "Checking status of download"
python $env:CCM_PATH\ccm.py status
Start-Process python -ArgumentList "$($env:CCM_PATH)\ccm.py remove predownload" -Wait -NoNewWindow
echo "Downloaded version $env:cassandra_version"
}
Start-Process python -ArgumentList "-m pip install -r test-requirements.txt" -Wait -NoNewWindow
Start-Process python -ArgumentList "-m pip install nose-ignore-docstring" -Wait -NoNewWindow
diff --git a/build.yaml b/build.yaml.bak
similarity index 68%
rename from build.yaml
rename to build.yaml.bak
index 335de1e..100c865 100644
--- a/build.yaml
+++ b/build.yaml.bak
@@ -1,255 +1,264 @@
schedules:
nightly_master:
schedule: nightly
+ disable_pull_requests: true
branches:
include: [master]
env_vars: |
EVENT_LOOP_MANAGER='libev'
matrix:
exclude:
- - python: [3.4, 3.6, 3.7]
- - cassandra: ['2.1', '3.0', 'test-dse']
+ - python: [3.6, 3.7, 3.8]
+ - cassandra: ['2.1', '3.0', '4.0', 'test-dse']
commit_long_test:
schedule: per_commit
disable_pull_requests: true
branches:
include: [/long-python.*/]
env_vars: |
EVENT_LOOP_MANAGER='libev'
matrix:
exclude:
- - python: [3.4, 3.6, 3.7]
+ - python: [3.6, 3.7, 3.8]
- cassandra: ['2.1', '3.0', 'test-dse']
commit_branches:
schedule: per_commit
disable_pull_requests: true
branches:
include: [/python.*/]
env_vars: |
EVENT_LOOP_MANAGER='libev'
EXCLUDE_LONG=1
matrix:
exclude:
- - python: [3.4, 3.6, 3.7]
+ - python: [3.6, 3.7, 3.8]
- cassandra: ['2.1', '3.0', 'test-dse']
commit_branches_dev:
schedule: per_commit
disable_pull_requests: true
branches:
include: [/dev-python.*/]
env_vars: |
EVENT_LOOP_MANAGER='libev'
EXCLUDE_LONG=1
matrix:
exclude:
- - python: [2.7, 3.4, 3.6, 3.7]
- - cassandra: ['2.0', '2.1', '2.2', '3.0', 'test-dse']
+ - python: [2.7, 3.7, 3.6, 3.8]
+ - cassandra: ['2.0', '2.1', '2.2', '3.0', '4.0', 'test-dse', 'dse-4.8', 'dse-5.0', 'dse-6.0', 'dse-6.8']
release_test:
schedule: per_commit
disable_pull_requests: true
branches:
include: [/release-.+/]
env_vars: |
EVENT_LOOP_MANAGER='libev'
weekly_master:
schedule: 0 10 * * 6
disable_pull_requests: true
branches:
include: [master]
env_vars: |
EVENT_LOOP_MANAGER='libev'
matrix:
exclude:
- python: [3.5]
- cassandra: ['2.2', '3.1']
weekly_gevent:
schedule: 0 14 * * 6
disable_pull_requests: true
branches:
include: [master]
env_vars: |
EVENT_LOOP_MANAGER='gevent'
JUST_EVENT_LOOP=1
- matrix:
- exclude:
- - python: [3.4]
weekly_eventlet:
schedule: 0 18 * * 6
disable_pull_requests: true
branches:
include: [master]
env_vars: |
EVENT_LOOP_MANAGER='eventlet'
JUST_EVENT_LOOP=1
- matrix:
- exclude:
- - python: [3.4]
weekly_asyncio:
schedule: 0 22 * * 6
disable_pull_requests: true
branches:
include: [master]
env_vars: |
EVENT_LOOP_MANAGER='asyncio'
JUST_EVENT_LOOP=1
matrix:
exclude:
- python: [2.7]
weekly_async:
schedule: 0 10 * * 7
disable_pull_requests: true
branches:
include: [master]
env_vars: |
EVENT_LOOP_MANAGER='asyncore'
JUST_EVENT_LOOP=1
- matrix:
- exclude:
- - python: [3.4]
weekly_twister:
schedule: 0 14 * * 7
disable_pull_requests: true
branches:
include: [master]
env_vars: |
EVENT_LOOP_MANAGER='twisted'
JUST_EVENT_LOOP=1
- matrix:
- exclude:
- - python: [3.4]
upgrade_tests:
schedule: adhoc
branches:
include: [master, python-546]
env_vars: |
EVENT_LOOP_MANAGER='libev'
JUST_UPGRADE=True
matrix:
exclude:
- - python: [3.4, 3.6, 3.7]
- - cassandra: ['2.0', '2.1', '2.2', '3.0', 'test-dse']
+ - python: [3.6, 3.7, 3.8]
+ - cassandra: ['2.0', '2.1', '2.2', '3.0', '4.0', 'test-dse']
python:
- 2.7
- - 3.4
- 3.5
- 3.6
- 3.7
+ - 3.8
os:
- ubuntu/bionic64/python-driver
cassandra:
- '2.1'
- '2.2'
- '3.0'
- '3.11'
- - 'test-dse'
+ - '4.0'
+ - 'dse-4.8'
+ - 'dse-5.0'
+ - 'dse-5.1'
+ - 'dse-6.0'
+ - 'dse-6.7'
+ - 'dse-6.8.0'
env:
CYTHON:
- CYTHON
- NO_CYTHON
build:
- script: |
export JAVA_HOME=$CCM_JAVA_HOME
export PATH=$JAVA_HOME/bin:$PATH
export PYTHONPATH=""
+ export CCM_MAX_HEAP_SIZE=1024M
# Required for unix socket tests
sudo apt-get install socat
# Install latest setuptools
pip install --upgrade pip
pip install -U setuptools
- pip install git+ssh://git@github.com/riptano/ccm-private.git
+ pip install git+ssh://git@github.com/riptano/ccm-private.git@cassandra-7544-native-ports-with-dse-fix
+
+ #pip install $HOME/ccm
+
+ if [ -n "$CCM_IS_DSE" ]; then
+ pip install -r test-datastax-requirements.txt
+ else
+ pip install -r test-requirements.txt
+ fi
- pip install -r test-requirements.txt
pip install nose-ignore-docstring
pip install nose-exclude
pip install service_identity
FORCE_CYTHON=False
if [[ $CYTHON == 'CYTHON' ]]; then
FORCE_CYTHON=True
pip install cython
pip install numpy
# Install the driver & compile C extensions
python setup.py build_ext --inplace
else
# Install the driver & compile C extensions with no cython
python setup.py build_ext --inplace --no-cython
fi
echo "JUST_UPGRADE: $JUST_UPGRADE"
if [[ $JUST_UPGRADE == 'True' ]]; then
EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=upgrade_results.xml tests/integration/upgrade || true
exit 0
fi
- if [[ $CCM_IS_DSE == 'true' ]]; then
- # We only use a DSE version for unreleased DSE versions, so we only need to run the smoke tests here
- echo "CCM_IS_DSE: $CCM_IS_DSE"
+ if [[ $JUST_SMOKE == 'true' ]]; then
+ # When we ONLY want to run the smoke tests
+ echo "JUST_SMOKE: $JUST_SMOKE"
echo "==========RUNNING SMOKE TESTS==========="
EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION DSE_VERSION='6.7.0' MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/test_dse.py || true
exit 0
fi
# Run the unit tests, this is not done in travis because
# it takes too much time for the whole matrix to build with cython
if [[ $CYTHON == 'CYTHON' ]]; then
EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER VERIFY_CYTHON=1 nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_results.xml tests/unit/ || true
EVENT_LOOP_MANAGER=eventlet VERIFY_CYTHON=1 nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_eventlet_results.xml tests/unit/io/test_eventletreactor.py || true
EVENT_LOOP_MANAGER=gevent VERIFY_CYTHON=1 nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_gevent_results.xml tests/unit/io/test_geventreactor.py || true
fi
if [ -n "$JUST_EVENT_LOOP" ]; then
echo "Running integration event loop subset with $EVENT_LOOP_MANAGER"
EVENT_LOOP_TESTS=(
"tests/integration/standard/test_cluster.py"
"tests/integration/standard/test_concurrent.py"
"tests/integration/standard/test_connection.py"
"tests/integration/standard/test_control_connection.py"
"tests/integration/standard/test_metrics.py"
"tests/integration/standard/test_query.py"
"tests/integration/simulacron/test_endpoint.py"
+ "tests/integration/long/test_ssl.py"
)
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml ${EVENT_LOOP_TESTS[@]} || true
+ EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" DSE_VERSION=$DSE_VERSION CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml ${EVENT_LOOP_TESTS[@]} || true
exit 0
fi
echo "Running with event loop manager: $EVENT_LOOP_MANAGER"
echo "==========RUNNING SIMULACRON TESTS=========="
SIMULACRON_JAR="$HOME/simulacron.jar"
- SIMULACRON_JAR=$SIMULACRON_JAR EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CASSANDRA_DIR=$CCM_INSTALL_DIR CCM_ARGS="$CCM_ARGS" DSE_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=simulacron_results.xml tests/integration/simulacron/ || true
+ SIMULACRON_JAR=$SIMULACRON_JAR EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CASSANDRA_DIR=$CCM_INSTALL_DIR CCM_ARGS="$CCM_ARGS" DSE_VERSION=$DSE_VERSION CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=simulacron_results.xml tests/integration/simulacron/ || true
echo "Running with event loop manager: $EVENT_LOOP_MANAGER"
echo "==========RUNNING CQLENGINE TESTS=========="
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=cqle_results.xml tests/integration/cqlengine/ || true
+ EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" DSE_VERSION=$DSE_VERSION CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=cqle_results.xml tests/integration/cqlengine/ || true
echo "==========RUNNING INTEGRATION TESTS=========="
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/ || true
+ EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" DSE_VERSION=$DSE_VERSION CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/ || true
+
+ if [ -n "$DSE_VERSION" ] && ! [[ $DSE_VERSION == "4.8"* ]]; then
+ echo "==========RUNNING DSE INTEGRATION TESTS=========="
+ EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CASSANDRA_DIR=$CCM_INSTALL_DIR DSE_VERSION=$DSE_VERSION ADS_HOME=$HOME/ VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=dse_results.xml tests/integration/advanced/ || true
+ fi
- echo "==========RUNNING ADVANCED AND CLOUD TESTS=========="
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CLOUD_PROXY_PATH="$HOME/proxy/" CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=advanced_results.xml tests/integration/advanced/ || true
+ echo "==========RUNNING CLOUD TESTS=========="
+ EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CLOUD_PROXY_PATH="$HOME/proxy/" CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=advanced_results.xml tests/integration/cloud/ || true
if [ -z "$EXCLUDE_LONG" ]; then
echo "==========RUNNING LONG INTEGRATION TESTS=========="
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --exclude-dir=tests/integration/long/upgrade --with-ignore-docstrings --with-xunit --xunit-file=long_results.xml tests/integration/long/ || true
+ EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" DSE_VERSION=$DSE_VERSION CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --exclude-dir=tests/integration/long/upgrade --with-ignore-docstrings --with-xunit --xunit-file=long_results.xml tests/integration/long/ || true
fi
- xunit:
- "*_results.xml"
diff --git a/cassandra/__init__.py b/cassandra/__init__.py
index 38aef2f..5739d5d 100644
--- a/cassandra/__init__.py
+++ b/cassandra/__init__.py
@@ -1,703 +1,730 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
class NullHandler(logging.Handler):
def emit(self, record):
pass
logging.getLogger('cassandra').addHandler(NullHandler())
-__version_info__ = (3, 20, 2)
+__version_info__ = (3, 25, 0)
__version__ = '.'.join(map(str, __version_info__))
class ConsistencyLevel(object):
"""
Spcifies how many replicas must respond for an operation to be considered
a success. By default, ``ONE`` is used for all operations.
"""
ANY = 0
"""
Only requires that one replica receives the write *or* the coordinator
stores a hint to replay later. Valid only for writes.
"""
ONE = 1
"""
Only one replica needs to respond to consider the operation a success
"""
TWO = 2
"""
Two replicas must respond to consider the operation a success
"""
THREE = 3
"""
Three replicas must respond to consider the operation a success
"""
QUORUM = 4
"""
``ceil(RF/2)`` replicas must respond to consider the operation a success
"""
ALL = 5
"""
All replicas must respond to consider the operation a success
"""
LOCAL_QUORUM = 6
"""
Requires a quorum of replicas in the local datacenter
"""
EACH_QUORUM = 7
"""
Requires a quorum of replicas in each datacenter
"""
SERIAL = 8
"""
For conditional inserts/updates that utilize Cassandra's lightweight
transactions, this requires consensus among all replicas for the
modified data.
"""
LOCAL_SERIAL = 9
"""
Like :attr:`~ConsistencyLevel.SERIAL`, but only requires consensus
among replicas in the local datacenter.
"""
LOCAL_ONE = 10
"""
Sends a request only to replicas in the local datacenter and waits for
one response.
"""
@staticmethod
def is_serial(cl):
return cl == ConsistencyLevel.SERIAL or cl == ConsistencyLevel.LOCAL_SERIAL
ConsistencyLevel.value_to_name = {
ConsistencyLevel.ANY: 'ANY',
ConsistencyLevel.ONE: 'ONE',
ConsistencyLevel.TWO: 'TWO',
ConsistencyLevel.THREE: 'THREE',
ConsistencyLevel.QUORUM: 'QUORUM',
ConsistencyLevel.ALL: 'ALL',
ConsistencyLevel.LOCAL_QUORUM: 'LOCAL_QUORUM',
ConsistencyLevel.EACH_QUORUM: 'EACH_QUORUM',
ConsistencyLevel.SERIAL: 'SERIAL',
ConsistencyLevel.LOCAL_SERIAL: 'LOCAL_SERIAL',
ConsistencyLevel.LOCAL_ONE: 'LOCAL_ONE'
}
ConsistencyLevel.name_to_value = {
'ANY': ConsistencyLevel.ANY,
'ONE': ConsistencyLevel.ONE,
'TWO': ConsistencyLevel.TWO,
'THREE': ConsistencyLevel.THREE,
'QUORUM': ConsistencyLevel.QUORUM,
'ALL': ConsistencyLevel.ALL,
'LOCAL_QUORUM': ConsistencyLevel.LOCAL_QUORUM,
'EACH_QUORUM': ConsistencyLevel.EACH_QUORUM,
'SERIAL': ConsistencyLevel.SERIAL,
'LOCAL_SERIAL': ConsistencyLevel.LOCAL_SERIAL,
'LOCAL_ONE': ConsistencyLevel.LOCAL_ONE
}
def consistency_value_to_name(value):
return ConsistencyLevel.value_to_name[value] if value is not None else "Not Set"
class ProtocolVersion(object):
"""
Defines native protocol versions supported by this driver.
"""
V1 = 1
"""
v1, supported in Cassandra 1.2-->2.2
"""
V2 = 2
"""
v2, supported in Cassandra 2.0-->2.2;
added support for lightweight transactions, batch operations, and automatic query paging.
"""
V3 = 3
"""
v3, supported in Cassandra 2.1-->3.x+;
added support for protocol-level client-side timestamps (see :attr:`.Session.use_client_timestamp`),
serial consistency levels for :class:`~.BatchStatement`, and an improved connection pool.
"""
V4 = 4
"""
v4, supported in Cassandra 2.2-->3.x+;
added a number of new types, server warnings, new failure messages, and custom payloads. Details in the
`project docs `_
"""
V5 = 5
"""
- v5, in beta from 3.x+
+ v5, in beta from 3.x+. Finalised in 4.0-beta5
"""
- SUPPORTED_VERSIONS = (V5, V4, V3, V2, V1)
+ V6 = 6
+ """
+ v6, in beta from 4.0-beta5
+ """
+
+ DSE_V1 = 0x41
+ """
+ DSE private protocol v1, supported in DSE 5.1+
+ """
+
+ DSE_V2 = 0x42
+ """
+ DSE private protocol v2, supported in DSE 6.0+
+ """
+
+ SUPPORTED_VERSIONS = (DSE_V2, DSE_V1, V6, V5, V4, V3, V2, V1)
"""
A tuple of all supported protocol versions
"""
- BETA_VERSIONS = (V5,)
+ BETA_VERSIONS = (V6,)
"""
A tuple of all beta protocol versions
"""
MIN_SUPPORTED = min(SUPPORTED_VERSIONS)
"""
Minimum protocol version supported by this driver.
"""
MAX_SUPPORTED = max(SUPPORTED_VERSIONS)
"""
- Maximum protocol versioni supported by this driver.
+ Maximum protocol version supported by this driver.
"""
@classmethod
def get_lower_supported(cls, previous_version):
"""
Return the lower supported protocol version. Beta versions are omitted.
"""
try:
version = next(v for v in sorted(ProtocolVersion.SUPPORTED_VERSIONS, reverse=True) if
v not in ProtocolVersion.BETA_VERSIONS and v < previous_version)
except StopIteration:
version = 0
return version
@classmethod
def uses_int_query_flags(cls, version):
return version >= cls.V5
@classmethod
def uses_prepare_flags(cls, version):
- return version >= cls.V5
+ return version >= cls.V5 and version != cls.DSE_V1
@classmethod
def uses_prepared_metadata(cls, version):
- return version >= cls.V5
+ return version >= cls.V5 and version != cls.DSE_V1
@classmethod
def uses_error_code_map(cls, version):
return version >= cls.V5
@classmethod
def uses_keyspace_flag(cls, version):
- return version >= cls.V5
+ return version >= cls.V5 and version != cls.DSE_V1
+
+ @classmethod
+ def has_continuous_paging_support(cls, version):
+ return version >= cls.DSE_V1
+
+ @classmethod
+ def has_continuous_paging_next_pages(cls, version):
+ return version >= cls.DSE_V2
+
+ @classmethod
+ def has_checksumming_support(cls, version):
+ return cls.V5 <= version < cls.DSE_V1
class WriteType(object):
"""
For usage with :class:`.RetryPolicy`, this describe a type
of write operation.
"""
SIMPLE = 0
"""
A write to a single partition key. Such writes are guaranteed to be atomic
and isolated.
"""
BATCH = 1
"""
A write to multiple partition keys that used the distributed batch log to
ensure atomicity.
"""
UNLOGGED_BATCH = 2
"""
A write to multiple partition keys that did not use the distributed batch
log. Atomicity for such writes is not guaranteed.
"""
COUNTER = 3
"""
A counter write (for one or multiple partition keys). Such writes should
not be replayed in order to avoid overcount.
"""
BATCH_LOG = 4
"""
The initial write to the distributed batch log that Cassandra performs
internally before a BATCH write.
"""
CAS = 5
"""
A lighweight-transaction write, such as "DELETE ... IF EXISTS".
"""
VIEW = 6
"""
This WriteType is only seen in results for requests that were unable to
complete MV operations.
"""
CDC = 7
"""
This WriteType is only seen in results for requests that were unable to
complete CDC operations.
"""
WriteType.name_to_value = {
'SIMPLE': WriteType.SIMPLE,
'BATCH': WriteType.BATCH,
'UNLOGGED_BATCH': WriteType.UNLOGGED_BATCH,
'COUNTER': WriteType.COUNTER,
'BATCH_LOG': WriteType.BATCH_LOG,
'CAS': WriteType.CAS,
'VIEW': WriteType.VIEW,
'CDC': WriteType.CDC
}
WriteType.value_to_name = {v: k for k, v in WriteType.name_to_value.items()}
class SchemaChangeType(object):
DROPPED = 'DROPPED'
CREATED = 'CREATED'
UPDATED = 'UPDATED'
class SchemaTargetType(object):
KEYSPACE = 'KEYSPACE'
TABLE = 'TABLE'
TYPE = 'TYPE'
FUNCTION = 'FUNCTION'
AGGREGATE = 'AGGREGATE'
class SignatureDescriptor(object):
def __init__(self, name, argument_types):
self.name = name
self.argument_types = argument_types
@property
def signature(self):
"""
function signature string in the form 'name([type0[,type1[...]]])'
can be used to uniquely identify overloaded function names within a keyspace
"""
return self.format_signature(self.name, self.argument_types)
@staticmethod
def format_signature(name, argument_types):
return "%s(%s)" % (name, ','.join(t for t in argument_types))
def __repr__(self):
return "%s(%s, %s)" % (self.__class__.__name__, self.name, self.argument_types)
class UserFunctionDescriptor(SignatureDescriptor):
"""
Describes a User function by name and argument signature
"""
name = None
"""
name of the function
"""
argument_types = None
"""
Ordered list of CQL argument type names comprising the type signature
"""
class UserAggregateDescriptor(SignatureDescriptor):
"""
Describes a User aggregate function by name and argument signature
"""
name = None
"""
name of the aggregate
"""
argument_types = None
"""
Ordered list of CQL argument type names comprising the type signature
"""
class DriverException(Exception):
"""
Base for all exceptions explicitly raised by the driver.
"""
pass
class RequestExecutionException(DriverException):
"""
Base for request execution exceptions returned from the server.
"""
pass
class Unavailable(RequestExecutionException):
"""
There were not enough live replicas to satisfy the requested consistency
level, so the coordinator node immediately failed the request without
forwarding it to any replicas.
"""
consistency = None
""" The requested :class:`ConsistencyLevel` """
required_replicas = None
""" The number of replicas that needed to be live to complete the operation """
alive_replicas = None
""" The number of replicas that were actually alive """
def __init__(self, summary_message, consistency=None, required_replicas=None, alive_replicas=None):
self.consistency = consistency
self.required_replicas = required_replicas
self.alive_replicas = alive_replicas
Exception.__init__(self, summary_message + ' info=' +
repr({'consistency': consistency_value_to_name(consistency),
'required_replicas': required_replicas,
'alive_replicas': alive_replicas}))
class Timeout(RequestExecutionException):
"""
Replicas failed to respond to the coordinator node before timing out.
"""
consistency = None
""" The requested :class:`ConsistencyLevel` """
required_responses = None
""" The number of required replica responses """
received_responses = None
"""
The number of replicas that responded before the coordinator timed out
the operation
"""
def __init__(self, summary_message, consistency=None, required_responses=None,
received_responses=None, **kwargs):
self.consistency = consistency
self.required_responses = required_responses
self.received_responses = received_responses
if "write_type" in kwargs:
kwargs["write_type"] = WriteType.value_to_name[kwargs["write_type"]]
info = {'consistency': consistency_value_to_name(consistency),
'required_responses': required_responses,
'received_responses': received_responses}
info.update(kwargs)
Exception.__init__(self, summary_message + ' info=' + repr(info))
class ReadTimeout(Timeout):
"""
A subclass of :exc:`Timeout` for read operations.
This indicates that the replicas failed to respond to the coordinator
node before the configured timeout. This timeout is configured in
``cassandra.yaml`` with the ``read_request_timeout_in_ms``
and ``range_request_timeout_in_ms`` options.
"""
data_retrieved = None
"""
A boolean indicating whether the requested data was retrieved
by the coordinator from any replicas before it timed out the
operation
"""
def __init__(self, message, data_retrieved=None, **kwargs):
Timeout.__init__(self, message, **kwargs)
self.data_retrieved = data_retrieved
class WriteTimeout(Timeout):
"""
A subclass of :exc:`Timeout` for write operations.
This indicates that the replicas failed to respond to the coordinator
node before the configured timeout. This timeout is configured in
``cassandra.yaml`` with the ``write_request_timeout_in_ms``
option.
"""
write_type = None
"""
The type of write operation, enum on :class:`~cassandra.policies.WriteType`
"""
def __init__(self, message, write_type=None, **kwargs):
kwargs["write_type"] = write_type
Timeout.__init__(self, message, **kwargs)
self.write_type = write_type
class CDCWriteFailure(RequestExecutionException):
"""
Hit limit on data in CDC folder, writes are rejected
"""
def __init__(self, message):
Exception.__init__(self, message)
class CoordinationFailure(RequestExecutionException):
"""
Replicas sent a failure to the coordinator.
"""
consistency = None
""" The requested :class:`ConsistencyLevel` """
required_responses = None
""" The number of required replica responses """
received_responses = None
"""
The number of replicas that responded before the coordinator timed out
the operation
"""
failures = None
"""
The number of replicas that sent a failure message
"""
error_code_map = None
"""
A map of inet addresses to error codes representing replicas that sent
a failure message. Only set when `protocol_version` is 5 or higher.
"""
def __init__(self, summary_message, consistency=None, required_responses=None,
received_responses=None, failures=None, error_code_map=None):
self.consistency = consistency
self.required_responses = required_responses
self.received_responses = received_responses
self.failures = failures
self.error_code_map = error_code_map
info_dict = {
'consistency': consistency_value_to_name(consistency),
'required_responses': required_responses,
'received_responses': received_responses,
'failures': failures
}
if error_code_map is not None:
# make error codes look like "0x002a"
formatted_map = dict((addr, '0x%04x' % err_code)
for (addr, err_code) in error_code_map.items())
info_dict['error_code_map'] = formatted_map
Exception.__init__(self, summary_message + ' info=' + repr(info_dict))
class ReadFailure(CoordinationFailure):
"""
A subclass of :exc:`CoordinationFailure` for read operations.
This indicates that the replicas sent a failure message to the coordinator.
"""
data_retrieved = None
"""
A boolean indicating whether the requested data was retrieved
by the coordinator from any replicas before it timed out the
operation
"""
def __init__(self, message, data_retrieved=None, **kwargs):
CoordinationFailure.__init__(self, message, **kwargs)
self.data_retrieved = data_retrieved
class WriteFailure(CoordinationFailure):
"""
A subclass of :exc:`CoordinationFailure` for write operations.
This indicates that the replicas sent a failure message to the coordinator.
"""
write_type = None
"""
The type of write operation, enum on :class:`~cassandra.policies.WriteType`
"""
def __init__(self, message, write_type=None, **kwargs):
CoordinationFailure.__init__(self, message, **kwargs)
self.write_type = write_type
class FunctionFailure(RequestExecutionException):
"""
User Defined Function failed during execution
"""
keyspace = None
"""
Keyspace of the function
"""
function = None
"""
Name of the function
"""
arg_types = None
"""
List of argument type names of the function
"""
def __init__(self, summary_message, keyspace, function, arg_types):
self.keyspace = keyspace
self.function = function
self.arg_types = arg_types
Exception.__init__(self, summary_message)
class RequestValidationException(DriverException):
"""
Server request validation failed
"""
pass
class ConfigurationException(RequestValidationException):
"""
Server indicated request errro due to current configuration
"""
pass
class AlreadyExists(ConfigurationException):
"""
An attempt was made to create a keyspace or table that already exists.
"""
keyspace = None
"""
The name of the keyspace that already exists, or, if an attempt was
made to create a new table, the keyspace that the table is in.
"""
table = None
"""
The name of the table that already exists, or, if an attempt was
make to create a keyspace, :const:`None`.
"""
def __init__(self, keyspace=None, table=None):
if table:
message = "Table '%s.%s' already exists" % (keyspace, table)
else:
message = "Keyspace '%s' already exists" % (keyspace,)
Exception.__init__(self, message)
self.keyspace = keyspace
self.table = table
class InvalidRequest(RequestValidationException):
"""
A query was made that was invalid for some reason, such as trying to set
the keyspace for a connection to a nonexistent keyspace.
"""
pass
class Unauthorized(RequestValidationException):
"""
The current user is not authorized to perform the requested operation.
"""
pass
class AuthenticationFailed(DriverException):
"""
Failed to authenticate.
"""
pass
class OperationTimedOut(DriverException):
"""
The operation took longer than the specified (client-side) timeout
to complete. This is not an error generated by Cassandra, only
the driver.
"""
errors = None
"""
A dict of errors keyed by the :class:`~.Host` against which they occurred.
"""
last_host = None
"""
The last :class:`~.Host` this operation was attempted against.
"""
def __init__(self, errors=None, last_host=None):
self.errors = errors
self.last_host = last_host
message = "errors=%s, last_host=%s" % (self.errors, self.last_host)
Exception.__init__(self, message)
class UnsupportedOperation(DriverException):
"""
An attempt was made to use a feature that is not supported by the
selected protocol version. See :attr:`Cluster.protocol_version`
for more details.
"""
pass
class UnresolvableContactPoints(DriverException):
"""
The driver was unable to resolve any provided hostnames.
Note that this is *not* raised when a :class:`.Cluster` is created with no
contact points, only when lookup fails for all hosts
"""
pass
diff --git a/cassandra/auth.py b/cassandra/auth.py
index 1d94817..3d2f751 100644
--- a/cassandra/auth.py
+++ b/cassandra/auth.py
@@ -1,182 +1,309 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import socket
+import logging
+
+try:
+ import kerberos
+ _have_kerberos = True
+except ImportError:
+ _have_kerberos = False
+
+try:
+ from puresasl.client import SASLClient
+ _have_puresasl = True
+except ImportError:
+ _have_puresasl = False
+
try:
from puresasl.client import SASLClient
except ImportError:
SASLClient = None
+import six
+
+log = logging.getLogger(__name__)
+
+# Custom payload keys related to DSE Unified Auth
+_proxy_execute_key = 'ProxyExecute'
+
+
class AuthProvider(object):
"""
An abstract class that defines the interface that will be used for
creating :class:`~.Authenticator` instances when opening new
connections to Cassandra.
.. versionadded:: 2.0.0
"""
def new_authenticator(self, host):
"""
Implementations of this class should return a new instance
of :class:`~.Authenticator` or one of its subclasses.
"""
raise NotImplementedError()
class Authenticator(object):
"""
An abstract class that handles SASL authentication with Cassandra servers.
Each time a new connection is created and the server requires authentication,
a new instance of this class will be created by the corresponding
:class:`~.AuthProvider` to handler that authentication. The lifecycle of the
new :class:`~.Authenticator` will the be:
1) The :meth:`~.initial_response()` method will be called. The return
value will be sent to the server to initiate the handshake.
2) The server will respond to each client response by either issuing a
challenge or indicating that the authentication is complete (successful or not).
If a new challenge is issued, :meth:`~.evaluate_challenge()`
will be called to produce a response that will be sent to the
server. This challenge/response negotiation will continue until the server
responds that authentication is successful (or an :exc:`~.AuthenticationFailed`
is raised).
3) When the server indicates that authentication is successful,
:meth:`~.on_authentication_success` will be called a token string that
that the server may optionally have sent.
The exact nature of the negotiation between the client and server is specific
to the authentication mechanism configured server-side.
.. versionadded:: 2.0.0
"""
server_authenticator_class = None
""" Set during the connection AUTHENTICATE phase """
def initial_response(self):
"""
Returns an message to send to the server to initiate the SASL handshake.
:const:`None` may be returned to send an empty message.
"""
return None
def evaluate_challenge(self, challenge):
"""
Called when the server sends a challenge message. Generally, this method
should return :const:`None` when authentication is complete from a
client perspective. Otherwise, a string should be returned.
"""
raise NotImplementedError()
def on_authentication_success(self, token):
"""
Called when the server indicates that authentication was successful.
Depending on the authentication mechanism, `token` may be :const:`None`
or a string.
"""
pass
class PlainTextAuthProvider(AuthProvider):
"""
An :class:`~.AuthProvider` that works with Cassandra's PasswordAuthenticator.
Example usage::
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
auth_provider = PlainTextAuthProvider(
username='cassandra', password='cassandra')
cluster = Cluster(auth_provider=auth_provider)
.. versionadded:: 2.0.0
"""
def __init__(self, username, password):
self.username = username
self.password = password
def new_authenticator(self, host):
return PlainTextAuthenticator(self.username, self.password)
-class PlainTextAuthenticator(Authenticator):
+class TransitionalModePlainTextAuthProvider(object):
"""
- An :class:`~.Authenticator` that works with Cassandra's PasswordAuthenticator.
+ An :class:`~.AuthProvider` that works with DSE TransitionalModePlainTextAuthenticator.
- .. versionadded:: 2.0.0
- """
+ Example usage::
- def __init__(self, username, password):
- self.username = username
- self.password = password
+ from cassandra.cluster import Cluster
+ from cassandra.auth import TransitionalModePlainTextAuthProvider
- def initial_response(self):
- return "\x00%s\x00%s" % (self.username, self.password)
+ auth_provider = TransitionalModePlainTextAuthProvider()
+ cluster = Cluster(auth_provider=auth_provider)
- def evaluate_challenge(self, challenge):
- return None
+ .. warning:: TransitionalModePlainTextAuthProvider will be removed in cassandra-driver
+ 4.0. The transitional mode will be handled internally without the need
+ of any auth provider.
+ """
+
+ def __init__(self):
+ # TODO remove next major
+ log.warning("TransitionalModePlainTextAuthProvider will be removed in cassandra-driver "
+ "4.0. The transitional mode will be handled internally without the need "
+ "of any auth provider.")
+
+ def new_authenticator(self, host):
+ return TransitionalModePlainTextAuthenticator()
class SaslAuthProvider(AuthProvider):
"""
An :class:`~.AuthProvider` supporting general SASL auth mechanisms
Suitable for GSSAPI or other SASL mechanisms
Example usage::
from cassandra.cluster import Cluster
from cassandra.auth import SaslAuthProvider
sasl_kwargs = {'service': 'something',
'mechanism': 'GSSAPI',
'qops': 'auth'.split(',')}
auth_provider = SaslAuthProvider(**sasl_kwargs)
cluster = Cluster(auth_provider=auth_provider)
.. versionadded:: 2.1.4
"""
def __init__(self, **sasl_kwargs):
if SASLClient is None:
raise ImportError('The puresasl library has not been installed')
if 'host' in sasl_kwargs:
raise ValueError("kwargs should not contain 'host' since it is passed dynamically to new_authenticator")
self.sasl_kwargs = sasl_kwargs
def new_authenticator(self, host):
return SaslAuthenticator(host, **self.sasl_kwargs)
class SaslAuthenticator(Authenticator):
"""
A pass-through :class:`~.Authenticator` using the third party package
'pure-sasl' for authentication
.. versionadded:: 2.1.4
"""
def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs):
if SASLClient is None:
raise ImportError('The puresasl library has not been installed')
self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
def initial_response(self):
return self.sasl.process()
def evaluate_challenge(self, challenge):
return self.sasl.process(challenge)
+
+# TODO remove me next major
+DSEPlainTextAuthProvider = PlainTextAuthProvider
+
+
+class DSEGSSAPIAuthProvider(AuthProvider):
+ """
+ Auth provider for GSS API authentication. Works with legacy `KerberosAuthenticator`
+ or `DseAuthenticator` if `kerberos` scheme is enabled.
+ """
+ def __init__(self, service='dse', qops=('auth',), resolve_host_name=True, **properties):
+ """
+ :param service: name of the service
+ :param qops: iterable of "Quality of Protection" allowed; see ``puresasl.QOP``
+ :param resolve_host_name: boolean flag indicating whether the authenticator should reverse-lookup an FQDN when
+ creating a new authenticator. Default is ``True``, which will resolve, or return the numeric address if there is no PTR
+ record. Setting ``False`` creates the authenticator with the numeric address known by Cassandra
+ :param properties: additional keyword properties to pass for the ``puresasl.mechanisms.GSSAPIMechanism`` class.
+ Presently, 'principal' (user) is the only one referenced in the ``pure-sasl`` implementation
+ """
+ if not _have_puresasl:
+ raise ImportError('The puresasl library has not been installed')
+ if not _have_kerberos:
+ raise ImportError('The kerberos library has not been installed')
+ self.service = service
+ self.qops = qops
+ self.resolve_host_name = resolve_host_name
+ self.properties = properties
+
+ def new_authenticator(self, host):
+ if self.resolve_host_name:
+ host = socket.getnameinfo((host, 0), 0)[0]
+ return GSSAPIAuthenticator(host, self.service, self.qops, self.properties)
+
+
+class BaseDSEAuthenticator(Authenticator):
+ def get_mechanism(self):
+ raise NotImplementedError("get_mechanism not implemented")
+
+ def get_initial_challenge(self):
+ raise NotImplementedError("get_initial_challenge not implemented")
+
+ def initial_response(self):
+ if self.server_authenticator_class == "com.datastax.bdp.cassandra.auth.DseAuthenticator":
+ return self.get_mechanism()
+ else:
+ return self.evaluate_challenge(self.get_initial_challenge())
+
+
+class PlainTextAuthenticator(BaseDSEAuthenticator):
+
+ def __init__(self, username, password):
+ self.username = username
+ self.password = password
+
+ def get_mechanism(self):
+ return six.b("PLAIN")
+
+ def get_initial_challenge(self):
+ return six.b("PLAIN-START")
+
+ def evaluate_challenge(self, challenge):
+ if challenge == six.b('PLAIN-START'):
+ data = "\x00%s\x00%s" % (self.username, self.password)
+ return data if six.PY2 else data.encode()
+ raise Exception('Did not receive a valid challenge response from server')
+
+
+class TransitionalModePlainTextAuthenticator(PlainTextAuthenticator):
+ """
+ Authenticator that accounts for DSE authentication is configured with transitional mode.
+ """
+
+ def __init__(self):
+ super(TransitionalModePlainTextAuthenticator, self).__init__('', '')
+
+
+class GSSAPIAuthenticator(BaseDSEAuthenticator):
+ def __init__(self, host, service, qops, properties):
+ properties = properties or {}
+ self.sasl = SASLClient(host, service, 'GSSAPI', qops=qops, **properties)
+
+ def get_mechanism(self):
+ return six.b("GSSAPI")
+
+ def get_initial_challenge(self):
+ return six.b("GSSAPI-START")
+
+ def evaluate_challenge(self, challenge):
+ if challenge == six.b('GSSAPI-START'):
+ return self.sasl.process()
+ else:
+ return self.sasl.process(challenge)
diff --git a/cassandra/cluster.py b/cassandra/cluster.py
index 8fcbe33..7e101af 100644
--- a/cassandra/cluster.py
+++ b/cassandra/cluster.py
@@ -1,4651 +1,5257 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module houses the main classes you will interact with,
:class:`.Cluster` and :class:`.Session`.
"""
from __future__ import absolute_import
import atexit
+from binascii import hexlify
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures
from copy import copy
from functools import partial, wraps
-from itertools import groupby, count
+from itertools import groupby, count, chain
+import json
import logging
from warnings import warn
from random import random
import six
from six.moves import filter, range, queue as Queue
import socket
import sys
import time
from threading import Lock, RLock, Thread, Event
+import uuid
import weakref
from weakref import WeakValueDictionary
-try:
- from weakref import WeakSet
-except ImportError:
- from cassandra.util import WeakSet # NOQA
from cassandra import (ConsistencyLevel, AuthenticationFailed,
OperationTimedOut, UnsupportedOperation,
SchemaTargetType, DriverException, ProtocolVersion,
UnresolvableContactPoints)
-from cassandra.auth import PlainTextAuthProvider
+from cassandra.auth import _proxy_execute_key, PlainTextAuthProvider
from cassandra.connection import (ConnectionException, ConnectionShutdown,
ConnectionHeartbeat, ProtocolVersionUnsupported,
EndPoint, DefaultEndPoint, DefaultEndPointFactory,
- SniEndPointFactory)
+ ContinuousPagingState, SniEndPointFactory, ConnectionBusy)
from cassandra.cqltypes import UserType
from cassandra.encoder import Encoder
from cassandra.protocol import (QueryMessage, ResultMessage,
ErrorMessage, ReadTimeoutErrorMessage,
WriteTimeoutErrorMessage,
UnavailableErrorMessage,
OverloadedErrorMessage,
PrepareMessage, ExecuteMessage,
PreparedQueryNotFound,
IsBootstrappingErrorMessage,
TruncateError, ServerError,
BatchMessage, RESULT_KIND_PREPARED,
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
- RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler)
-from cassandra.metadata import Metadata, protect_name, murmur3
+ RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler,
+ RESULT_KIND_VOID, ProtocolException)
+from cassandra.metadata import Metadata, protect_name, murmur3, _NodeInfo
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance,
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
- NoSpeculativeExecutionPolicy)
+ NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy,
+ NeverRetryPolicy)
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
HostConnectionPool, HostConnection,
NoConnectionsAvailable)
from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
BatchStatement, bind_params, QueryTrace, TraceUnavailable,
- named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET)
+ named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET,
+ HostTargetingStatement)
+from cassandra.marshal import int64_pack
from cassandra.timestamps import MonotonicTimestampGenerator
from cassandra.compat import Mapping
+from cassandra.util import _resolve_contact_points_to_string_map, Version
+
+from cassandra.datastax.insights.reporter import MonitorReporter
+from cassandra.datastax.insights.util import version_supports_insights
+
+from cassandra.datastax.graph import (graph_object_row_factory, GraphOptions, GraphSON1Serializer,
+ GraphProtocol, GraphSON2Serializer, GraphStatement, SimpleGraphStatement,
+ graph_graphson2_row_factory, graph_graphson3_row_factory,
+ GraphSON3Serializer)
+from cassandra.datastax.graph.query import _request_timeout_key, _GraphSONContextRowFactory
from cassandra.datastax import cloud as dscloud
+try:
+ from cassandra.io.twistedreactor import TwistedConnection
+except ImportError:
+ TwistedConnection = None
+
+try:
+ from cassandra.io.eventletreactor import EventletConnection
+except ImportError:
+ EventletConnection = None
+
+try:
+ from weakref import WeakSet
+except ImportError:
+ from cassandra.util import WeakSet # NOQA
+
+if six.PY3:
+ long = int
def _is_eventlet_monkey_patched():
if 'eventlet.patcher' not in sys.modules:
return False
import eventlet.patcher
return eventlet.patcher.is_monkey_patched('socket')
def _is_gevent_monkey_patched():
if 'gevent.monkey' not in sys.modules:
return False
import gevent.socket
return socket.socket is gevent.socket.socket
# default to gevent when we are monkey patched with gevent, eventlet when
# monkey patched with eventlet, otherwise if libev is available, use that as
# the default because it's fastest. Otherwise, use asyncore.
if _is_gevent_monkey_patched():
from cassandra.io.geventreactor import GeventConnection as DefaultConnection
elif _is_eventlet_monkey_patched():
from cassandra.io.eventletreactor import EventletConnection as DefaultConnection
else:
try:
from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA
except ImportError:
from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA
# Forces load of utf8 encoding module to avoid deadlock that occurs
# if code that is being imported tries to import the module in a seperate
# thread.
# See http://bugs.python.org/issue10923
"".encode('utf8')
log = logging.getLogger(__name__)
DEFAULT_MIN_REQUESTS = 5
DEFAULT_MAX_REQUESTS = 100
DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST = 2
DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST = 8
DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1
DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2
+_GRAPH_PAGING_MIN_DSE_VERSION = Version('6.8.0')
_NOT_SET = object()
class NoHostAvailable(Exception):
"""
Raised when an operation is attempted but all connections are
busy, defunct, closed, or resulted in errors when used.
"""
errors = None
"""
A map of the form ``{ip: exception}`` which details the particular
Exception that was caught for each host the operation was attempted
against.
"""
def __init__(self, message, errors):
Exception.__init__(self, message, errors)
self.errors = errors
def _future_completed(future):
""" Helper for run_in_executor() """
exc = future.exception()
if exc:
log.debug("Failed to run task on executor", exc_info=exc)
def run_in_executor(f):
"""
A decorator to run the given method in the ThreadPoolExecutor.
"""
@wraps(f)
def new_f(self, *args, **kwargs):
if self.is_shutdown:
return
try:
future = self.executor.submit(f, self, *args, **kwargs)
future.add_done_callback(_future_completed)
except Exception:
log.exception("Failed to submit task to executor")
return new_f
_clusters_for_shutdown = set()
def _register_cluster_shutdown(cluster):
_clusters_for_shutdown.add(cluster)
def _discard_cluster_shutdown(cluster):
_clusters_for_shutdown.discard(cluster)
def _shutdown_clusters():
clusters = _clusters_for_shutdown.copy() # copy because shutdown modifies the global set "discard"
for cluster in clusters:
cluster.shutdown()
atexit.register(_shutdown_clusters)
def default_lbp_factory():
if murmur3 is not None:
return TokenAwarePolicy(DCAwareRoundRobinPolicy())
return DCAwareRoundRobinPolicy()
+class ContinuousPagingOptions(object):
+
+ class PagingUnit(object):
+ BYTES = 1
+ ROWS = 2
+
+ page_unit = None
+ """
+ Value of PagingUnit. Default is PagingUnit.ROWS.
+
+ Units refer to the :attr:`~.Statement.fetch_size` or :attr:`~.Session.default_fetch_size`.
+ """
+
+ max_pages = None
+ """
+ Max number of pages to send
+ """
+
+ max_pages_per_second = None
+ """
+ Max rate at which to send pages
+ """
+
+ max_queue_size = None
+ """
+ The maximum queue size for caching pages, only honored for protocol version DSE_V2 and higher,
+ by default it is 4 and it must be at least 2.
+ """
+
+ def __init__(self, page_unit=PagingUnit.ROWS, max_pages=0, max_pages_per_second=0, max_queue_size=4):
+ self.page_unit = page_unit
+ self.max_pages = max_pages
+ self.max_pages_per_second = max_pages_per_second
+ if max_queue_size < 2:
+ raise ValueError('ContinuousPagingOptions.max_queue_size must be 2 or greater')
+ self.max_queue_size = max_queue_size
+
+ def page_unit_bytes(self):
+ return self.page_unit == ContinuousPagingOptions.PagingUnit.BYTES
+
+
def _addrinfo_or_none(contact_point, port):
"""
A helper function that wraps socket.getaddrinfo and returns None
when it fails to, e.g. resolve one of the hostnames. Used to address
PYTHON-895.
"""
try:
return socket.getaddrinfo(contact_point, port,
socket.AF_UNSPEC, socket.SOCK_STREAM)
except socket.gaierror:
log.debug('Could not resolve hostname "{}" '
'with port {}'.format(contact_point, port))
return None
-def _resolve_contact_points(contact_points, port):
- resolved = tuple(_addrinfo_or_none(p, port)
- for p in contact_points)
-
- if resolved and all((x is None for x in resolved)):
- raise UnresolvableContactPoints(contact_points, port)
-
- resolved = tuple(r for r in resolved if r is not None)
-
- return [endpoint[4][0]
- for addrinfo in resolved
- for endpoint in addrinfo]
+def _execution_profile_to_string(name):
+ default_profiles = {
+ EXEC_PROFILE_DEFAULT: 'EXEC_PROFILE_DEFAULT',
+ EXEC_PROFILE_GRAPH_DEFAULT: 'EXEC_PROFILE_GRAPH_DEFAULT',
+ EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT: 'EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT',
+ EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT: 'EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT',
+ }
+ if name in default_profiles:
+ return default_profiles[name]
-def _execution_profile_to_string(name):
- if name is EXEC_PROFILE_DEFAULT:
- return 'EXEC_PROFILE_DEFAULT'
return '"%s"' % (name,)
class ExecutionProfile(object):
load_balancing_policy = None
"""
An instance of :class:`.policies.LoadBalancingPolicy` or one of its subclasses.
Used in determining host distance for establishing connections, and routing requests.
Defaults to ``TokenAwarePolicy(DCAwareRoundRobinPolicy())`` if not specified
"""
retry_policy = None
"""
An instance of :class:`.policies.RetryPolicy` instance used when :class:`.Statement` objects do not have a
:attr:`~.Statement.retry_policy` explicitly set.
Defaults to :class:`.RetryPolicy` if not specified
"""
consistency_level = ConsistencyLevel.LOCAL_ONE
"""
:class:`.ConsistencyLevel` used when not specified on a :class:`.Statement`.
"""
serial_consistency_level = None
"""
Serial :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement` (for LWT conditional statements).
"""
request_timeout = 10.0
"""
Request timeout used when not overridden in :meth:`.Session.execute`
"""
row_factory = staticmethod(named_tuple_factory)
"""
A callable to format results, accepting ``(colnames, rows)`` where ``colnames`` is a list of column names, and
``rows`` is a list of tuples, with each tuple representing a row of parsed values.
Some example implementations:
- :func:`cassandra.query.tuple_factory` - return a result row as a tuple
- :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple
- :func:`cassandra.query.dict_factory` - return a result row as a dict
- :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict
"""
speculative_execution_policy = None
"""
An instance of :class:`.policies.SpeculativeExecutionPolicy`
Defaults to :class:`.NoSpeculativeExecutionPolicy` if not specified
"""
- # indicates if set explicitly or uses default values
+ continuous_paging_options = None
+ """
+ *Note:* This feature is implemented to facilitate server integration testing. It is not intended for general use in the Python driver.
+ See :attr:`.Statement.fetch_size` or :attr:`Session.default_fetch_size` for configuring normal paging.
+
+ When set, requests will use DSE's continuous paging, which streams multiple pages without
+ intermediate requests.
+
+ This has the potential to materialize all results in memory at once if the consumer cannot keep up. Use options
+ to constrain page size and rate.
+
+ This is only available for DSE clusters.
+ """
+
+ # indicates if lbp was set explicitly or uses default values
_load_balancing_policy_explicit = False
_consistency_level_explicit = False
def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None,
consistency_level=_NOT_SET, serial_consistency_level=None,
- request_timeout=10.0, row_factory=named_tuple_factory, speculative_execution_policy=None):
+ request_timeout=10.0, row_factory=named_tuple_factory, speculative_execution_policy=None,
+ continuous_paging_options=None):
if load_balancing_policy is _NOT_SET:
self._load_balancing_policy_explicit = False
self.load_balancing_policy = default_lbp_factory()
else:
self._load_balancing_policy_explicit = True
self.load_balancing_policy = load_balancing_policy
if consistency_level is _NOT_SET:
self._consistency_level_explicit = False
self.consistency_level = ConsistencyLevel.LOCAL_ONE
else:
self._consistency_level_explicit = True
self.consistency_level = consistency_level
self.retry_policy = retry_policy or RetryPolicy()
if (serial_consistency_level is not None and
not ConsistencyLevel.is_serial(serial_consistency_level)):
raise ValueError("serial_consistency_level must be either "
"ConsistencyLevel.SERIAL "
"or ConsistencyLevel.LOCAL_SERIAL.")
self.serial_consistency_level = serial_consistency_level
self.request_timeout = request_timeout
self.row_factory = row_factory
self.speculative_execution_policy = speculative_execution_policy or NoSpeculativeExecutionPolicy()
+ self.continuous_paging_options = continuous_paging_options
+
+
+class GraphExecutionProfile(ExecutionProfile):
+ graph_options = None
+ """
+ :class:`.GraphOptions` to use with this execution
+
+ Default options for graph queries, initialized as follows by default::
+
+ GraphOptions(graph_language=b'gremlin-groovy')
+
+ See cassandra.graph.GraphOptions
+ """
+
+ def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None,
+ consistency_level=_NOT_SET, serial_consistency_level=None,
+ request_timeout=30.0, row_factory=None,
+ graph_options=None, continuous_paging_options=_NOT_SET):
+ """
+ Default execution profile for graph execution.
+
+ See :class:`.ExecutionProfile` for base attributes. Note that if not explicitly set,
+ the row_factory and graph_options.graph_protocol are resolved during the query execution.
+ These options will resolve to graph_graphson3_row_factory and GraphProtocol.GRAPHSON_3_0
+ for the core graph engine (DSE 6.8+), otherwise graph_object_row_factory and GraphProtocol.GRAPHSON_1_0
+
+ In addition to default parameters shown in the signature, this profile also defaults ``retry_policy`` to
+ :class:`cassandra.policies.NeverRetryPolicy`.
+ """
+ retry_policy = retry_policy or NeverRetryPolicy()
+ super(GraphExecutionProfile, self).__init__(load_balancing_policy, retry_policy, consistency_level,
+ serial_consistency_level, request_timeout, row_factory,
+ continuous_paging_options=continuous_paging_options)
+ self.graph_options = graph_options or GraphOptions(graph_source=b'g',
+ graph_language=b'gremlin-groovy')
+
+
+class GraphAnalyticsExecutionProfile(GraphExecutionProfile):
+
+ def __init__(self, load_balancing_policy=None, retry_policy=None,
+ consistency_level=_NOT_SET, serial_consistency_level=None,
+ request_timeout=3600. * 24. * 7., row_factory=None,
+ graph_options=None):
+ """
+ Execution profile with timeout and load balancing appropriate for graph analytics queries.
+
+ See also :class:`~.GraphExecutionPolicy`.
+
+ In addition to default parameters shown in the signature, this profile also defaults ``retry_policy`` to
+ :class:`cassandra.policies.NeverRetryPolicy`, and ``load_balancing_policy`` to one that targets the current Spark
+ master.
+
+ Note: The graph_options.graph_source is set automatically to b'a' (analytics)
+ when using GraphAnalyticsExecutionProfile. This is mandatory to target analytics nodes.
+ """
+ load_balancing_policy = load_balancing_policy or DefaultLoadBalancingPolicy(default_lbp_factory())
+ graph_options = graph_options or GraphOptions(graph_language=b'gremlin-groovy')
+ super(GraphAnalyticsExecutionProfile, self).__init__(load_balancing_policy, retry_policy, consistency_level,
+ serial_consistency_level, request_timeout, row_factory,
+ graph_options)
+ # ensure the graph_source is analytics, since this is the purpose of the GraphAnalyticsExecutionProfile
+ self.graph_options.set_source_analytics()
class ProfileManager(object):
def __init__(self):
self.profiles = dict()
def _profiles_without_explicit_lbps(self):
names = (profile_name for
profile_name, profile in self.profiles.items()
if not profile._load_balancing_policy_explicit)
return tuple(
'EXEC_PROFILE_DEFAULT' if n is EXEC_PROFILE_DEFAULT else n
for n in names
)
def distance(self, host):
distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values())
return HostDistance.LOCAL if HostDistance.LOCAL in distances else \
HostDistance.REMOTE if HostDistance.REMOTE in distances else \
HostDistance.IGNORED
def populate(self, cluster, hosts):
for p in self.profiles.values():
p.load_balancing_policy.populate(cluster, hosts)
def check_supported(self):
for p in self.profiles.values():
p.load_balancing_policy.check_supported()
def on_up(self, host):
for p in self.profiles.values():
p.load_balancing_policy.on_up(host)
def on_down(self, host):
for p in self.profiles.values():
p.load_balancing_policy.on_down(host)
def on_add(self, host):
for p in self.profiles.values():
p.load_balancing_policy.on_add(host)
def on_remove(self, host):
for p in self.profiles.values():
p.load_balancing_policy.on_remove(host)
@property
def default(self):
"""
internal-only; no checks are done because this entry is populated on cluster init
"""
return self.profiles[EXEC_PROFILE_DEFAULT]
EXEC_PROFILE_DEFAULT = object()
"""
Key for the ``Cluster`` default execution profile, used when no other profile is selected in
``Session.execute(execution_profile)``.
Use this as the key in ``Cluster(execution_profiles)`` to override the default profile.
"""
+EXEC_PROFILE_GRAPH_DEFAULT = object()
+"""
+Key for the default graph execution profile, used when no other profile is selected in
+``Session.execute_graph(execution_profile)``.
+
+Use this as the key in :doc:`Cluster(execution_profiles) `
+to override the default graph profile.
+"""
+
+EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT = object()
+"""
+Key for the default graph system execution profile. This can be used for graph statements using the DSE graph
+system API.
+
+Selected using ``Session.execute_graph(execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT)``.
+"""
+
+EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT = object()
+"""
+Key for the default graph analytics execution profile. This can be used for graph statements intended to
+use Spark/analytics as the traversal source.
+
+Selected using ``Session.execute_graph(execution_profile=EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT)``.
+"""
+
class _ConfigMode(object):
UNCOMMITTED = 0
LEGACY = 1
PROFILES = 2
class Cluster(object):
"""
The main class to use when interacting with a Cassandra cluster.
Typically, one instance of this class will be created for each
separate Cassandra cluster that your application interacts with.
Example usage::
>>> from cassandra.cluster import Cluster
>>> cluster = Cluster(['192.168.1.1', '192.168.1.2'])
>>> session = cluster.connect()
>>> session.execute("CREATE KEYSPACE ...")
>>> ...
>>> cluster.shutdown()
``Cluster`` and ``Session`` also provide context management functions
which implicitly handle shutdown when leaving scope.
"""
contact_points = ['127.0.0.1']
"""
The list of contact points to try connecting for cluster discovery. A
- contact point can be a string (ip, hostname) or a
+ contact point can be a string (ip or hostname), a tuple (ip/hostname, port) or a
:class:`.connection.EndPoint` instance.
Defaults to loopback interface.
Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit
local_dc set (as is the default), the DC is chosen from an arbitrary
host in contact_points. In this case, contact_points should contain
only nodes from a single, local DC.
Note: In the next major version, if you specify contact points, you will
also be required to also explicitly specify a load-balancing policy. This
change will help prevent cases where users had hard-to-debug issues
surrounding unintuitive default load-balancing policy behavior.
"""
# tracks if contact_points was set explicitly or with default values
_contact_points_explicit = None
port = 9042
"""
The server-side port to open connections to. Defaults to 9042.
"""
cql_version = None
"""
If a specific version of CQL should be used, this may be set to that
string version. Otherwise, the highest CQL version supported by the
server will be automatically used.
"""
- protocol_version = ProtocolVersion.V4
+ protocol_version = ProtocolVersion.DSE_V2
"""
The maximum version of the native protocol to use.
See :class:`.ProtocolVersion` for more information about versions.
If not set in the constructor, the driver will automatically downgrade
version based on a negotiation with the server, but it is most efficient
to set this to the maximum supported by your version of Cassandra.
Setting this will also prevent conflicting versions negotiated if your
cluster is upgraded.
"""
allow_beta_protocol_version = False
no_compact = False
"""
Setting true injects a flag in all messages that makes the server accept and use "beta" protocol version.
Used for testing new protocol features incrementally before the new version is complete.
"""
compression = True
"""
Controls compression for communications between the driver and Cassandra.
If left as the default of :const:`True`, either lz4 or snappy compression
may be used, depending on what is supported by both the driver
and Cassandra. If both are fully supported, lz4 will be preferred.
You may also set this to 'snappy' or 'lz4' to request that specific
compression type.
Setting this to :const:`False` disables compression.
"""
_auth_provider = None
_auth_provider_callable = None
@property
def auth_provider(self):
"""
When :attr:`~.Cluster.protocol_version` is 2 or higher, this should
be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`,
such as :class:`~.PlainTextAuthProvider`.
When :attr:`~.Cluster.protocol_version` is 1, this should be
a function that accepts one argument, the IP address of a node,
and returns a dict of credentials for that node.
When not using authentication, this should be left as :const:`None`.
"""
return self._auth_provider
@auth_provider.setter # noqa
def auth_provider(self, value):
if not value:
self._auth_provider = value
return
try:
self._auth_provider_callable = value.new_authenticator
except AttributeError:
if self.protocol_version > 1:
raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider "
"interface when protocol_version >= 2")
elif not callable(value):
raise TypeError("auth_provider must be callable when protocol_version == 1")
self._auth_provider_callable = value
self._auth_provider = value
_load_balancing_policy = None
@property
def load_balancing_policy(self):
"""
An instance of :class:`.policies.LoadBalancingPolicy` or
one of its subclasses.
.. versionchanged:: 2.6.0
Defaults to :class:`~.TokenAwarePolicy` (:class:`~.DCAwareRoundRobinPolicy`).
when using CPython (where the murmur3 extension is available). :class:`~.DCAwareRoundRobinPolicy`
otherwise. Default local DC will be chosen from contact points.
**Please see** :class:`~.DCAwareRoundRobinPolicy` **for a discussion on default behavior with respect to
DC locality and remote nodes.**
"""
return self._load_balancing_policy
@load_balancing_policy.setter
def load_balancing_policy(self, lbp):
if self._config_mode == _ConfigMode.PROFILES:
raise ValueError("Cannot set Cluster.load_balancing_policy while using Configuration Profiles. Set this in a profile instead.")
self._load_balancing_policy = lbp
self._config_mode = _ConfigMode.LEGACY
@property
def _default_load_balancing_policy(self):
return self.profile_manager.default.load_balancing_policy
reconnection_policy = ExponentialReconnectionPolicy(1.0, 600.0)
"""
An instance of :class:`.policies.ReconnectionPolicy`. Defaults to an instance
of :class:`.ExponentialReconnectionPolicy` with a base delay of one second and
a max delay of ten minutes.
"""
_default_retry_policy = RetryPolicy()
@property
def default_retry_policy(self):
"""
A default :class:`.policies.RetryPolicy` instance to use for all
:class:`.Statement` objects which do not have a :attr:`~.Statement.retry_policy`
explicitly set.
"""
return self._default_retry_policy
@default_retry_policy.setter
def default_retry_policy(self, policy):
if self._config_mode == _ConfigMode.PROFILES:
raise ValueError("Cannot set Cluster.default_retry_policy while using Configuration Profiles. Set this in a profile instead.")
self._default_retry_policy = policy
self._config_mode = _ConfigMode.LEGACY
conviction_policy_factory = SimpleConvictionPolicy
"""
A factory function which creates instances of
:class:`.policies.ConvictionPolicy`. Defaults to
:class:`.policies.SimpleConvictionPolicy`.
"""
address_translator = IdentityTranslator()
"""
:class:`.policies.AddressTranslator` instance to be used in translating server node addresses
to driver connection addresses.
"""
connect_to_remote_hosts = True
"""
If left as :const:`True`, hosts that are considered :attr:`~.HostDistance.REMOTE`
by the :attr:`~.Cluster.load_balancing_policy` will have a connection
opened to them. Otherwise, they will not have a connection opened to them.
Note that the default load balancing policy ignores remote hosts by default.
.. versionadded:: 2.1.0
"""
metrics_enabled = False
"""
Whether or not metric collection is enabled. If enabled, :attr:`.metrics`
will be an instance of :class:`~cassandra.metrics.Metrics`.
"""
metrics = None
"""
An instance of :class:`cassandra.metrics.Metrics` if :attr:`.metrics_enabled` is
:const:`True`, else :const:`None`.
"""
ssl_options = None
"""
Using ssl_options without ssl_context is deprecated and will be removed in the
next major release.
An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket`` (or
``ssl.wrap_socket()`` if used without ssl_context) when new sockets are created.
This should be used when client encryption is enabled in Cassandra.
The following documentation only applies when ssl_options is used without ssl_context.
By default, a ``ca_certs`` value should be supplied (the value should be
a string pointing to the location of the CA certs file), and you probably
- want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLSv1`` to match
+ want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLS`` to match
Cassandra's default protocol.
.. versionchanged:: 3.3.0
In addition to ``wrap_socket`` kwargs, clients may also specify ``'check_hostname': True`` to verify the cert hostname
as outlined in RFC 2818 and RFC 6125. Note that this requires the certificate to be transferred, so
should almost always require the option ``'cert_reqs': ssl.CERT_REQUIRED``. Note also that this functionality was not built into
Python standard library until (2.7.9, 3.2). To enable this mechanism in earlier versions, patch ``ssl.match_hostname``
with a custom or `back-ported function `_.
"""
ssl_context = None
"""
An optional ``ssl.SSLContext`` instance which will be used when new sockets are created.
This should be used when client encryption is enabled in Cassandra.
``wrap_socket`` options can be set using :attr:`~Cluster.ssl_options`. ssl_options will
be used as kwargs for ``ssl.SSLContext.wrap_socket``.
.. versionadded:: 3.17.0
"""
sockopts = None
"""
An optional list of tuples which will be used as arguments to
``socket.setsockopt()`` for all created sockets.
Note: some drivers find setting TCPNODELAY beneficial in the context of
their execution model. It was not found generally beneficial for this driver.
To try with your own workload, set ``sockopts = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]``
"""
max_schema_agreement_wait = 10
"""
The maximum duration (in seconds) that the driver will wait for schema
agreement across the cluster. Defaults to ten seconds.
If set <= 0, the driver will bypass schema agreement waits altogether.
"""
metadata = None
"""
An instance of :class:`cassandra.metadata.Metadata`.
"""
connection_class = DefaultConnection
"""
This determines what event loop system will be used for managing
I/O with Cassandra. These are the current options:
* :class:`cassandra.io.asyncorereactor.AsyncoreConnection`
* :class:`cassandra.io.libevreactor.LibevConnection`
* :class:`cassandra.io.eventletreactor.EventletConnection` (requires monkey-patching - see doc for details)
* :class:`cassandra.io.geventreactor.GeventConnection` (requires monkey-patching - see doc for details)
* :class:`cassandra.io.twistedreactor.TwistedConnection`
* EXPERIMENTAL: :class:`cassandra.io.asyncioreactor.AsyncioConnection`
By default, ``AsyncoreConnection`` will be used, which uses
the ``asyncore`` module in the Python standard library.
If ``libev`` is installed, ``LibevConnection`` will be used instead.
If ``gevent`` or ``eventlet`` monkey-patching is detected, the corresponding
connection class will be used automatically.
``AsyncioConnection``, which uses the ``asyncio`` module in the Python
standard library, is also available, but currently experimental. Note that
it requires ``asyncio`` features that were only introduced in the 3.4 line
in 3.4.6, and in the 3.5 line in 3.5.1.
"""
control_connection_timeout = 2.0
"""
A timeout, in seconds, for queries made by the control connection, such
as querying the current schema and information about nodes in the cluster.
If set to :const:`None`, there will be no timeout for these queries.
"""
idle_heartbeat_interval = 30
"""
Interval, in seconds, on which to heartbeat idle connections. This helps
keep connections open through network devices that expire idle connections.
It also helps discover bad connections early in low-traffic scenarios.
Setting to zero disables heartbeats.
"""
idle_heartbeat_timeout = 30
"""
Timeout, in seconds, on which the heartbeat wait for idle connection responses.
Lowering this value can help to discover bad connections earlier.
"""
schema_event_refresh_window = 2
"""
Window, in seconds, within which a schema component will be refreshed after
receiving a schema_change event.
The driver delays a random amount of time in the range [0.0, window)
before executing the refresh. This serves two purposes:
1.) Spread the refresh for deployments with large fanout from C* to client tier,
preventing a 'thundering herd' problem with many clients refreshing simultaneously.
2.) Remove redundant refreshes. Redundant events arriving within the delay period
are discarded, and only one refresh is executed.
Setting this to zero will execute refreshes immediately.
Setting this negative will disable schema refreshes in response to push events
(refreshes will still occur in response to schema change responses to DDL statements
executed by Sessions of this Cluster).
"""
topology_event_refresh_window = 10
"""
Window, in seconds, within which the node and token list will be refreshed after
receiving a topology_change event.
Setting this to zero will execute refreshes immediately.
Setting this negative will disable node refreshes in response to push events.
See :attr:`.schema_event_refresh_window` for discussion of rationale
"""
status_event_refresh_window = 2
"""
Window, in seconds, within which the driver will start the reconnect after
receiving a status_change event.
Setting this to zero will connect immediately.
This is primarily used to avoid 'thundering herd' in deployments with large fanout from cluster to clients.
When nodes come up, clients attempt to reprepare prepared statements (depending on :attr:`.reprepare_on_up`), and
establish connection pools. This can cause a rush of connections and queries if not mitigated with this factor.
"""
prepare_on_all_hosts = True
"""
Specifies whether statements should be prepared on all hosts, or just one.
This can reasonably be disabled on long-running applications with numerous clients preparing statements on startup,
where a randomized initial condition of the load balancing policy can be expected to distribute prepares from
different clients across the cluster.
"""
reprepare_on_up = True
"""
Specifies whether all known prepared statements should be prepared on a node when it comes up.
May be used to avoid overwhelming a node on return, or if it is supposed that the node was only marked down due to
network. If statements are not reprepared, they are prepared on the first execution, causing
an extra roundtrip for one or more client requests.
"""
connect_timeout = 5
"""
Timeout, in seconds, for creating new connections.
This timeout covers the entire connection negotiation, including TCP
establishment, options passing, and authentication.
"""
timestamp_generator = None
"""
An object, shared between all sessions created by this cluster instance,
that generates timestamps when client-side timestamp generation is enabled.
By default, each :class:`Cluster` uses a new
:class:`~.MonotonicTimestampGenerator`.
Applications can set this value for custom timestamp behavior. See the
documentation for :meth:`Session.timestamp_generator`.
"""
+ monitor_reporting_enabled = True
+ """
+ A boolean indicating if monitor reporting, which sends gathered data to
+ Insights when running against DSE 6.8 and higher.
+ """
+
+ monitor_reporting_interval = 30
+ """
+ A boolean indicating if monitor reporting, which sends gathered data to
+ Insights when running against DSE 6.8 and higher.
+ """
+
+ client_id = None
+ """
+ A UUID that uniquely identifies this Cluster object to Insights. This will
+ be generated automatically unless the user provides one.
+ """
+
+ application_name = ''
+ """
+ A string identifying this application to Insights.
+ """
+
+ application_version = ''
+ """
+ A string identifiying this application's version to Insights
+ """
+
cloud = None
"""
A dict of the cloud configuration. Example::
{
# path to the secure connect bundle
- 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip'
+ 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip',
+
+ # optional config options
+ 'use_default_tempdir': True # use the system temp dir for the zip extraction
}
The zip file will be temporarily extracted in the same directory to
load the configuration and certificates.
"""
@property
def schema_metadata_enabled(self):
"""
Flag indicating whether internal schema metadata is updated.
When disabled, the driver does not populate Cluster.metadata.keyspaces on connect, or on schema change events. This
can be used to speed initial connection, and reduce load on client and server during operation. Turning this off
gives away token aware request routing, and programmatic inspection of the metadata model.
"""
return self.control_connection._schema_meta_enabled
@schema_metadata_enabled.setter
def schema_metadata_enabled(self, enabled):
self.control_connection._schema_meta_enabled = bool(enabled)
@property
def token_metadata_enabled(self):
"""
Flag indicating whether internal token metadata is updated.
When disabled, the driver does not query node token information on connect, or on topology change events. This
can be used to speed initial connection, and reduce load on client and server during operation. It is most useful
in large clusters using vnodes, where the token map can be expensive to compute. Turning this off
gives away token aware request routing, and programmatic inspection of the token ring.
"""
return self.control_connection._token_meta_enabled
@token_metadata_enabled.setter
def token_metadata_enabled(self, enabled):
self.control_connection._token_meta_enabled = bool(enabled)
endpoint_factory = None
"""
An :class:`~.connection.EndPointFactory` instance to use internally when creating
a socket connection to a node. You can ignore this unless you need a special
connection mechanism.
"""
profile_manager = None
_config_mode = _ConfigMode.UNCOMMITTED
sessions = None
control_connection = None
scheduler = None
executor = None
is_shutdown = False
_is_setup = False
_prepared_statements = None
_prepared_statement_lock = None
_idle_heartbeat = None
_protocol_version_explicit = False
_discount_down_events = True
_user_types = None
"""
A map of {keyspace: {type_name: UserType}}
"""
_listeners = None
_listener_lock = None
def __init__(self,
contact_points=_NOT_SET,
port=9042,
compression=True,
auth_provider=None,
load_balancing_policy=None,
reconnection_policy=None,
default_retry_policy=None,
conviction_policy_factory=None,
metrics_enabled=False,
connection_class=None,
ssl_options=None,
sockopts=None,
cql_version=None,
protocol_version=_NOT_SET,
executor_threads=2,
max_schema_agreement_wait=10,
control_connection_timeout=2.0,
idle_heartbeat_interval=30,
schema_event_refresh_window=2,
topology_event_refresh_window=10,
connect_timeout=5,
schema_metadata_enabled=True,
token_metadata_enabled=True,
address_translator=None,
status_event_refresh_window=2,
prepare_on_all_hosts=True,
reprepare_on_up=True,
execution_profiles=None,
allow_beta_protocol_version=False,
timestamp_generator=None,
idle_heartbeat_timeout=30,
no_compact=False,
ssl_context=None,
endpoint_factory=None,
+ application_name=None,
+ application_version=None,
+ monitor_reporting_enabled=True,
+ monitor_reporting_interval=30,
+ client_id=None,
cloud=None):
"""
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
extablishing connection pools or refreshing metadata.
Any of the mutable Cluster attributes may be set as keyword arguments to the constructor.
"""
+ if connection_class is not None:
+ self.connection_class = connection_class
if cloud is not None:
+ self.cloud = cloud
if contact_points is not _NOT_SET or endpoint_factory or ssl_context or ssl_options:
raise ValueError("contact_points, endpoint_factory, ssl_context, and ssl_options "
"cannot be specified with a cloud configuration")
- cloud_config = dscloud.get_cloud_config(cloud)
+ uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection)
+ uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection)
+ cloud_config = dscloud.get_cloud_config(cloud, create_pyopenssl_context=uses_twisted or uses_eventlet)
ssl_context = cloud_config.ssl_context
ssl_options = {'check_hostname': True}
if (auth_provider is None and cloud_config.username
and cloud_config.password):
auth_provider = PlainTextAuthProvider(cloud_config.username, cloud_config.password)
endpoint_factory = SniEndPointFactory(cloud_config.sni_host, cloud_config.sni_port)
contact_points = [
endpoint_factory.create_from_sni(host_id)
for host_id in cloud_config.host_ids
]
if contact_points is not None:
if contact_points is _NOT_SET:
self._contact_points_explicit = False
contact_points = ['127.0.0.1']
else:
self._contact_points_explicit = True
if isinstance(contact_points, six.string_types):
raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings")
if None in contact_points:
raise ValueError("contact_points should not contain None (it can resolve to localhost)")
self.contact_points = contact_points
self.port = port
self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port)
self.endpoint_factory.configure(self)
- raw_contact_points = [cp for cp in self.contact_points if not isinstance(cp, EndPoint)]
+ raw_contact_points = []
+ for cp in [cp for cp in self.contact_points if not isinstance(cp, EndPoint)]:
+ raw_contact_points.append(cp if isinstance(cp, tuple) else (cp, port))
+
self.endpoints_resolved = [cp for cp in self.contact_points if isinstance(cp, EndPoint)]
+ self._endpoint_map_for_insights = {repr(ep): '{ip}:{port}'.format(ip=ep.address, port=ep.port)
+ for ep in self.endpoints_resolved}
+
+ strs_resolved_map = _resolve_contact_points_to_string_map(raw_contact_points)
+ self.endpoints_resolved.extend(list(chain(
+ *[
+ [DefaultEndPoint(ip, port) for ip, port in xs if ip is not None]
+ for xs in strs_resolved_map.values() if xs is not None
+ ]
+ )))
- try:
- self.endpoints_resolved += [DefaultEndPoint(address, self.port)
- for address in _resolve_contact_points(raw_contact_points, self.port)]
- except UnresolvableContactPoints:
- # rethrow if no EndPoint was provided
- if not self.endpoints_resolved:
- raise
+ self._endpoint_map_for_insights.update(
+ {key: ['{ip}:{port}'.format(ip=ip, port=port) for ip, port in value]
+ for key, value in strs_resolved_map.items() if value is not None}
+ )
+
+ if contact_points and (not self.endpoints_resolved):
+ # only want to raise here if the user specified CPs but resolution failed
+ raise UnresolvableContactPoints(self._endpoint_map_for_insights)
self.compression = compression
if protocol_version is not _NOT_SET:
self.protocol_version = protocol_version
self._protocol_version_explicit = True
self.allow_beta_protocol_version = allow_beta_protocol_version
self.no_compact = no_compact
self.auth_provider = auth_provider
if load_balancing_policy is not None:
if isinstance(load_balancing_policy, type):
raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class")
self.load_balancing_policy = load_balancing_policy
else:
self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode
if reconnection_policy is not None:
if isinstance(reconnection_policy, type):
raise TypeError("reconnection_policy should not be a class, it should be an instance of that class")
self.reconnection_policy = reconnection_policy
if default_retry_policy is not None:
if isinstance(default_retry_policy, type):
raise TypeError("default_retry_policy should not be a class, it should be an instance of that class")
self.default_retry_policy = default_retry_policy
if conviction_policy_factory is not None:
if not callable(conviction_policy_factory):
raise ValueError("conviction_policy_factory must be callable")
self.conviction_policy_factory = conviction_policy_factory
if address_translator is not None:
if isinstance(address_translator, type):
raise TypeError("address_translator should not be a class, it should be an instance of that class")
self.address_translator = address_translator
- if connection_class is not None:
- self.connection_class = connection_class
-
if timestamp_generator is not None:
if not callable(timestamp_generator):
raise ValueError("timestamp_generator must be callable")
self.timestamp_generator = timestamp_generator
else:
self.timestamp_generator = MonotonicTimestampGenerator()
self.profile_manager = ProfileManager()
self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(
self.load_balancing_policy,
self.default_retry_policy,
request_timeout=Session._default_timeout,
row_factory=Session._row_factory
)
+
# legacy mode if either of these is not default
if load_balancing_policy or default_retry_policy:
if execution_profiles:
raise ValueError("Clusters constructed with execution_profiles should not specify legacy parameters "
"load_balancing_policy or default_retry_policy. Configure this in a profile instead.")
self._config_mode = _ConfigMode.LEGACY
warn("Legacy execution parameters will be removed in 4.0. Consider using "
"execution profiles.", DeprecationWarning)
else:
+ profiles = self.profile_manager.profiles
if execution_profiles:
- self.profile_manager.profiles.update(execution_profiles)
+ profiles.update(execution_profiles)
self._config_mode = _ConfigMode.PROFILES
- if self._contact_points_explicit:
+ lbp = DefaultLoadBalancingPolicy(self.profile_manager.default.load_balancing_policy)
+ profiles.setdefault(EXEC_PROFILE_GRAPH_DEFAULT, GraphExecutionProfile(load_balancing_policy=lbp))
+ profiles.setdefault(EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT,
+ GraphExecutionProfile(load_balancing_policy=lbp, request_timeout=60. * 3.))
+ profiles.setdefault(EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT,
+ GraphAnalyticsExecutionProfile(load_balancing_policy=lbp))
+
+ if self._contact_points_explicit and not self.cloud: # avoid this warning for cloud users.
if self._config_mode is _ConfigMode.PROFILES:
default_lbp_profiles = self.profile_manager._profiles_without_explicit_lbps()
if default_lbp_profiles:
log.warning(
'Cluster.__init__ called with contact_points '
'specified, but load-balancing policies are not '
'specified in some ExecutionProfiles. In the next '
'major version, this will raise an error; please '
'specify a load-balancing policy. '
'(contact_points = {cp}, '
'EPs without explicit LBPs = {eps})'
''.format(cp=contact_points, eps=default_lbp_profiles))
else:
if load_balancing_policy is None:
log.warning(
'Cluster.__init__ called with contact_points '
'specified, but no load_balancing_policy. In the next '
'major version, this will raise an error; please '
'specify a load-balancing policy. '
'(contact_points = {cp}, lbp = {lbp})'
''.format(cp=contact_points, lbp=load_balancing_policy))
self.metrics_enabled = metrics_enabled
if ssl_options and not ssl_context:
warn('Using ssl_options without ssl_context is '
'deprecated and will result in an error in '
'the next major release. Please use ssl_context '
'to prepare for that release.',
DeprecationWarning)
self.ssl_options = ssl_options
self.ssl_context = ssl_context
self.sockopts = sockopts
self.cql_version = cql_version
self.max_schema_agreement_wait = max_schema_agreement_wait
self.control_connection_timeout = control_connection_timeout
self.idle_heartbeat_interval = idle_heartbeat_interval
self.idle_heartbeat_timeout = idle_heartbeat_timeout
self.schema_event_refresh_window = schema_event_refresh_window
self.topology_event_refresh_window = topology_event_refresh_window
self.status_event_refresh_window = status_event_refresh_window
self.connect_timeout = connect_timeout
self.prepare_on_all_hosts = prepare_on_all_hosts
self.reprepare_on_up = reprepare_on_up
+ self.monitor_reporting_enabled = monitor_reporting_enabled
+ self.monitor_reporting_interval = monitor_reporting_interval
self._listeners = set()
self._listener_lock = Lock()
# let Session objects be GC'ed (and shutdown) when the user no longer
# holds a reference.
self.sessions = WeakSet()
self.metadata = Metadata()
self.control_connection = None
self._prepared_statements = WeakValueDictionary()
self._prepared_statement_lock = Lock()
self._user_types = defaultdict(dict)
self._min_requests_per_connection = {
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
}
self._max_requests_per_connection = {
HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
}
self._core_connections_per_host = {
HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
}
self._max_connections_per_host = {
HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
}
self.executor = self._create_thread_pool_executor(max_workers=executor_threads)
self.scheduler = _Scheduler(self.executor)
self._lock = RLock()
if self.metrics_enabled:
from cassandra.metrics import Metrics
self.metrics = Metrics(weakref.proxy(self))
self.control_connection = ControlConnection(
self, self.control_connection_timeout,
self.schema_event_refresh_window, self.topology_event_refresh_window,
self.status_event_refresh_window,
schema_metadata_enabled, token_metadata_enabled)
+ if client_id is None:
+ self.client_id = uuid.uuid4()
+ if application_name is not None:
+ self.application_name = application_name
+ if application_version is not None:
+ self.application_version = application_version
+
def _create_thread_pool_executor(self, **kwargs):
"""
Create a ThreadPoolExecutor for the cluster. In most cases, the built-in
`concurrent.futures.ThreadPoolExecutor` is used.
- Python 3.7 and Eventlet cause the `concurrent.futures.ThreadPoolExecutor`
+ Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor`
to hang indefinitely. In that case, the user needs to have the `futurist`
package so we can use the `futurist.GreenThreadPoolExecutor` class instead.
:param kwargs: All keyword args are passed to the ThreadPoolExecutor constructor.
:return: A ThreadPoolExecutor instance.
"""
tpe_class = ThreadPoolExecutor
if sys.version_info[0] >= 3 and sys.version_info[1] >= 7:
try:
from cassandra.io.eventletreactor import EventletConnection
is_eventlet = issubclass(self.connection_class, EventletConnection)
except:
# Eventlet is not available or can't be detected
return tpe_class(**kwargs)
if is_eventlet:
try:
from futurist import GreenThreadPoolExecutor
tpe_class = GreenThreadPoolExecutor
except ImportError:
# futurist is not available
raise ImportError(
- ("Python 3.7 and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` "
+ ("Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` "
"to hang indefinitely. If you want to use the Eventlet reactor, you "
"need to install the `futurist` package to allow the driver to use "
"the GreenThreadPoolExecutor. See https://github.com/eventlet/eventlet/issues/508 "
"for more details."))
return tpe_class(**kwargs)
def register_user_type(self, keyspace, user_type, klass):
"""
Registers a class to use to represent a particular user-defined type.
Query parameters for this user-defined type will be assumed to be
instances of `klass`. Result sets for this user-defined type will
be instances of `klass`. If no class is registered for a user-defined
type, a namedtuple will be used for result sets, and non-prepared
statements may not encode parameters for this type correctly.
`keyspace` is the name of the keyspace that the UDT is defined in.
`user_type` is the string name of the UDT to register the mapping
for.
`klass` should be a class with attributes whose names match the
fields of the user-defined type. The constructor must accepts kwargs
for each of the fields in the UDT.
This method should only be called after the type has been created
within Cassandra.
Example::
cluster = Cluster(protocol_version=3)
session = cluster.connect()
session.set_keyspace('mykeyspace')
session.execute("CREATE TYPE address (street text, zipcode int)")
session.execute("CREATE TABLE users (id int PRIMARY KEY, location address)")
# create a class to map to the "address" UDT
class Address(object):
def __init__(self, street, zipcode):
self.street = street
self.zipcode = zipcode
cluster.register_user_type('mykeyspace', 'address', Address)
# insert a row using an instance of Address
session.execute("INSERT INTO users (id, location) VALUES (%s, %s)",
(0, Address("123 Main St.", 78723)))
# results will include Address instances
results = session.execute("SELECT * FROM users")
row = results[0]
print row.id, row.location.street, row.location.zipcode
"""
if self.protocol_version < 3:
log.warning("User Type serialization is only supported in native protocol version 3+ (%d in use). "
"CQL encoding for simple statements will still work, but named tuples will "
"be returned when reading type %s.%s.", self.protocol_version, keyspace, user_type)
self._user_types[keyspace][user_type] = klass
for session in tuple(self.sessions):
session.user_type_registered(keyspace, user_type, klass)
UserType.evict_udt_class(keyspace, user_type)
def add_execution_profile(self, name, profile, pool_wait_timeout=5):
"""
Adds an :class:`.ExecutionProfile` to the cluster. This makes it available for use by ``name`` in :meth:`.Session.execute`
and :meth:`.Session.execute_async`. This method will raise if the profile already exists.
Normally profiles will be injected at cluster initialization via ``Cluster(execution_profiles)``. This method
provides a way of adding them dynamically.
Adding a new profile updates the connection pools according to the specified ``load_balancing_policy``. By default,
this method will wait up to five seconds for the pool creation to complete, so the profile can be used immediately
upon return. This behavior can be controlled using ``pool_wait_timeout`` (see
`concurrent.futures.wait `_
for timeout semantics).
"""
if not isinstance(profile, ExecutionProfile):
raise TypeError("profile must be an instance of ExecutionProfile")
if self._config_mode == _ConfigMode.LEGACY:
raise ValueError("Cannot add execution profiles when legacy parameters are set explicitly.")
if name in self.profile_manager.profiles:
raise ValueError("Profile {} already exists".format(name))
contact_points_but_no_lbp = (
self._contact_points_explicit and not
profile._load_balancing_policy_explicit)
if contact_points_but_no_lbp:
log.warning(
'Tried to add an ExecutionProfile with name {name}. '
'{self} was explicitly configured with contact_points, but '
'{ep} was not explicitly configured with a '
'load_balancing_policy. In the next major version, trying to '
'add an ExecutionProfile without an explicitly configured LBP '
'to a cluster with explicitly configured contact_points will '
'raise an exception; please specify a load-balancing policy '
'in the ExecutionProfile.'
''.format(name=_execution_profile_to_string(name), self=self, ep=profile))
self.profile_manager.profiles[name] = profile
profile.load_balancing_policy.populate(self, self.metadata.all_hosts())
# on_up after populate allows things like DCA LBP to choose default local dc
for host in filter(lambda h: h.is_up, self.metadata.all_hosts()):
profile.load_balancing_policy.on_up(host)
futures = set()
for session in tuple(self.sessions):
self._set_default_dbaas_consistency(session)
futures.update(session.update_created_pools())
_, not_done = wait_futures(futures, pool_wait_timeout)
if not_done:
raise OperationTimedOut("Failed to create all new connection pools in the %ss timeout.")
def get_min_requests_per_connection(self, host_distance):
return self._min_requests_per_connection[host_distance]
def set_min_requests_per_connection(self, host_distance, min_requests):
"""
Sets a threshold for concurrent requests per connection, below which
connections will be considered for disposal (down to core connections;
see :meth:`~Cluster.set_core_connections_per_host`).
Pertains to connection pool management in protocol versions {1,2}.
"""
if self.protocol_version >= 3:
raise UnsupportedOperation(
"Cluster.set_min_requests_per_connection() only has an effect "
"when using protocol_version 1 or 2.")
if min_requests < 0 or min_requests > 126 or \
min_requests >= self._max_requests_per_connection[host_distance]:
raise ValueError("min_requests must be 0-126 and less than the max_requests for this host_distance (%d)" %
(self._min_requests_per_connection[host_distance],))
self._min_requests_per_connection[host_distance] = min_requests
def get_max_requests_per_connection(self, host_distance):
return self._max_requests_per_connection[host_distance]
def set_max_requests_per_connection(self, host_distance, max_requests):
"""
Sets a threshold for concurrent requests per connection, above which new
connections will be created to a host (up to max connections;
see :meth:`~Cluster.set_max_connections_per_host`).
Pertains to connection pool management in protocol versions {1,2}.
"""
if self.protocol_version >= 3:
raise UnsupportedOperation(
"Cluster.set_max_requests_per_connection() only has an effect "
"when using protocol_version 1 or 2.")
if max_requests < 1 or max_requests > 127 or \
max_requests <= self._min_requests_per_connection[host_distance]:
raise ValueError("max_requests must be 1-127 and greater than the min_requests for this host_distance (%d)" %
(self._min_requests_per_connection[host_distance],))
self._max_requests_per_connection[host_distance] = max_requests
def get_core_connections_per_host(self, host_distance):
"""
Gets the minimum number of connections per Session that will be opened
for each host with :class:`~.HostDistance` equal to `host_distance`.
The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for
:attr:`~HostDistance.REMOTE`.
This property is ignored if :attr:`~.Cluster.protocol_version` is
3 or higher.
"""
return self._core_connections_per_host[host_distance]
def set_core_connections_per_host(self, host_distance, core_connections):
"""
Sets the minimum number of connections per Session that will be opened
for each host with :class:`~.HostDistance` equal to `host_distance`.
The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for
:attr:`~HostDistance.REMOTE`.
Protocol version 1 and 2 are limited in the number of concurrent
requests they can send per connection. The driver implements connection
pooling to support higher levels of concurrency.
If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this
is not supported (there is always one connection per host, unless
the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`)
- and using this will result in an :exc:`~.UnsupporteOperation`.
+ and using this will result in an :exc:`~.UnsupportedOperation`.
"""
if self.protocol_version >= 3:
raise UnsupportedOperation(
"Cluster.set_core_connections_per_host() only has an effect "
"when using protocol_version 1 or 2.")
old = self._core_connections_per_host[host_distance]
self._core_connections_per_host[host_distance] = core_connections
if old < core_connections:
self._ensure_core_connections()
def get_max_connections_per_host(self, host_distance):
"""
Gets the maximum number of connections per Session that will be opened
for each host with :class:`~.HostDistance` equal to `host_distance`.
The default is 8 for :attr:`~HostDistance.LOCAL` and 2 for
:attr:`~HostDistance.REMOTE`.
This property is ignored if :attr:`~.Cluster.protocol_version` is
3 or higher.
"""
return self._max_connections_per_host[host_distance]
def set_max_connections_per_host(self, host_distance, max_connections):
"""
Sets the maximum number of connections per Session that will be opened
for each host with :class:`~.HostDistance` equal to `host_distance`.
The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for
:attr:`~HostDistance.REMOTE`.
If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this
is not supported (there is always one connection per host, unless
the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`)
- and using this will result in an :exc:`~.UnsupporteOperation`.
+ and using this will result in an :exc:`~.UnsupportedOperation`.
"""
if self.protocol_version >= 3:
raise UnsupportedOperation(
"Cluster.set_max_connections_per_host() only has an effect "
"when using protocol_version 1 or 2.")
self._max_connections_per_host[host_distance] = max_connections
def connection_factory(self, endpoint, *args, **kwargs):
"""
Called to create a new connection with proper configuration.
Intended for internal use only.
"""
kwargs = self._make_connection_kwargs(endpoint, kwargs)
return self.connection_class.factory(endpoint, self.connect_timeout, *args, **kwargs)
def _make_connection_factory(self, host, *args, **kwargs):
kwargs = self._make_connection_kwargs(host.endpoint, kwargs)
return partial(self.connection_class.factory, host.endpoint, self.connect_timeout, *args, **kwargs)
def _make_connection_kwargs(self, endpoint, kwargs_dict):
if self._auth_provider_callable:
kwargs_dict.setdefault('authenticator', self._auth_provider_callable(endpoint.address))
kwargs_dict.setdefault('port', self.port)
kwargs_dict.setdefault('compression', self.compression)
kwargs_dict.setdefault('sockopts', self.sockopts)
kwargs_dict.setdefault('ssl_options', self.ssl_options)
kwargs_dict.setdefault('ssl_context', self.ssl_context)
kwargs_dict.setdefault('cql_version', self.cql_version)
kwargs_dict.setdefault('protocol_version', self.protocol_version)
kwargs_dict.setdefault('user_type_map', self._user_types)
kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version)
kwargs_dict.setdefault('no_compact', self.no_compact)
return kwargs_dict
def protocol_downgrade(self, host_endpoint, previous_version):
if self._protocol_version_explicit:
raise DriverException("ProtocolError returned from server while using explicitly set client protocol_version %d" % (previous_version,))
-
new_version = ProtocolVersion.get_lower_supported(previous_version)
if new_version < ProtocolVersion.MIN_SUPPORTED:
raise DriverException(
"Cannot downgrade protocol version below minimum supported version: %d" % (ProtocolVersion.MIN_SUPPORTED,))
log.warning("Downgrading core protocol version from %d to %d for %s. "
"To avoid this, it is best practice to explicitly set Cluster(protocol_version) to the version supported by your cluster. "
"http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint)
self.protocol_version = new_version
def connect(self, keyspace=None, wait_for_all_pools=False):
"""
Creates and returns a new :class:`~.Session` object.
If `keyspace` is specified, that keyspace will be the default keyspace for
operations on the ``Session``.
`wait_for_all_pools` specifies whether this call should wait for all connection pools to be
established or attempted. Default is `False`, which means it will return when the first
successful connection is established. Remaining pools are added asynchronously.
"""
with self._lock:
if self.is_shutdown:
raise DriverException("Cluster is already shut down")
if not self._is_setup:
log.debug("Connecting to cluster, contact points: %s; protocol version: %s",
self.contact_points, self.protocol_version)
self.connection_class.initialize_reactor()
_register_cluster_shutdown(self)
for endpoint in self.endpoints_resolved:
host, new = self.add_host(endpoint, signal=False)
if new:
host.set_up()
for listener in self.listeners:
listener.on_add(host)
self.profile_manager.populate(
weakref.proxy(self), self.metadata.all_hosts())
self.load_balancing_policy.populate(
weakref.proxy(self), self.metadata.all_hosts()
)
try:
self.control_connection.connect()
# we set all contact points up for connecting, but we won't infer state after this
for endpoint in self.endpoints_resolved:
h = self.metadata.get_host(endpoint)
if h and self.profile_manager.distance(h) == HostDistance.IGNORED:
h.is_up = None
log.debug("Control connection created")
except Exception:
log.exception("Control connection failed to connect, "
"shutting down Cluster:")
self.shutdown()
raise
self.profile_manager.check_supported() # todo: rename this method
if self.idle_heartbeat_interval:
self._idle_heartbeat = ConnectionHeartbeat(
self.idle_heartbeat_interval,
self.get_connection_holders,
timeout=self.idle_heartbeat_timeout
)
self._is_setup = True
session = self._new_session(keyspace)
if wait_for_all_pools:
wait_futures(session._initial_connect_futures)
self._set_default_dbaas_consistency(session)
return session
def _set_default_dbaas_consistency(self, session):
if session.cluster.metadata.dbaas:
for profile in self.profile_manager.profiles.values():
if not profile._consistency_level_explicit:
profile.consistency_level = ConsistencyLevel.LOCAL_QUORUM
session._default_consistency_level = ConsistencyLevel.LOCAL_QUORUM
def get_connection_holders(self):
holders = []
for s in tuple(self.sessions):
holders.extend(s.get_pools())
holders.append(self.control_connection)
return holders
def shutdown(self):
"""
Closes all sessions and connection associated with this Cluster.
To ensure all connections are properly closed, **you should always
call shutdown() on a Cluster instance when you are done with it**.
Once shutdown, a Cluster should not be used for any purpose.
"""
with self._lock:
if self.is_shutdown:
return
else:
self.is_shutdown = True
if self._idle_heartbeat:
self._idle_heartbeat.stop()
self.scheduler.shutdown()
self.control_connection.shutdown()
for session in tuple(self.sessions):
session.shutdown()
self.executor.shutdown()
_discard_cluster_shutdown(self)
def __enter__(self):
return self
def __exit__(self, *args):
self.shutdown()
def _new_session(self, keyspace):
session = Session(self, self.metadata.all_hosts(), keyspace)
self._session_register_user_types(session)
self.sessions.add(session)
return session
def _session_register_user_types(self, session):
for keyspace, type_map in six.iteritems(self._user_types):
for udt_name, klass in six.iteritems(type_map):
session.user_type_registered(keyspace, udt_name, klass)
def _cleanup_failed_on_up_handling(self, host):
self.profile_manager.on_down(host)
self.control_connection.on_down(host)
for session in tuple(self.sessions):
session.remove_pool(host)
self._start_reconnector(host, is_host_addition=False)
def _on_up_future_completed(self, host, futures, results, lock, finished_future):
with lock:
futures.discard(finished_future)
try:
results.append(finished_future.result())
except Exception as exc:
results.append(exc)
if futures:
return
try:
# all futures have completed at this point
for exc in [f for f in results if isinstance(f, Exception)]:
log.error("Unexpected failure while marking node %s up:", host, exc_info=exc)
self._cleanup_failed_on_up_handling(host)
return
if not all(results):
log.debug("Connection pool could not be created, not marking node %s up", host)
self._cleanup_failed_on_up_handling(host)
return
log.info("Connection pools established for node %s", host)
# mark the host as up and notify all listeners
host.set_up()
for listener in self.listeners:
listener.on_up(host)
finally:
with host.lock:
host._currently_handling_node_up = False
# see if there are any pools to add or remove now that the host is marked up
for session in tuple(self.sessions):
session.update_created_pools()
def on_up(self, host):
"""
Intended for internal use only.
"""
if self.is_shutdown:
return
log.debug("Waiting to acquire lock for handling up status of node %s", host)
with host.lock:
if host._currently_handling_node_up:
log.debug("Another thread is already handling up status of node %s", host)
return
if host.is_up:
log.debug("Host %s was already marked up", host)
return
host._currently_handling_node_up = True
log.debug("Starting to handle up status of node %s", host)
have_future = False
futures = set()
try:
log.info("Host %s may be up; will prepare queries and open connection pool", host)
reconnector = host.get_and_set_reconnection_handler(None)
if reconnector:
log.debug("Now that host %s is up, cancelling the reconnection handler", host)
reconnector.cancel()
if self.profile_manager.distance(host) != HostDistance.IGNORED:
self._prepare_all_queries(host)
log.debug("Done preparing all queries for host %s, ", host)
for session in tuple(self.sessions):
session.remove_pool(host)
log.debug("Signalling to load balancing policies that host %s is up", host)
self.profile_manager.on_up(host)
log.debug("Signalling to control connection that host %s is up", host)
self.control_connection.on_up(host)
log.debug("Attempting to open new connection pools for host %s", host)
futures_lock = Lock()
futures_results = []
callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock)
for session in tuple(self.sessions):
future = session.add_or_renew_pool(host, is_host_addition=False)
if future is not None:
have_future = True
future.add_done_callback(callback)
futures.add(future)
except Exception:
log.exception("Unexpected failure handling node %s being marked up:", host)
for future in futures:
future.cancel()
self._cleanup_failed_on_up_handling(host)
with host.lock:
host._currently_handling_node_up = False
raise
else:
if not have_future:
with host.lock:
host.set_up()
host._currently_handling_node_up = False
# for testing purposes
return futures
def _start_reconnector(self, host, is_host_addition):
if self.profile_manager.distance(host) == HostDistance.IGNORED:
return
schedule = self.reconnection_policy.new_schedule()
# in order to not hold references to this Cluster open and prevent
# proper shutdown when the program ends, we'll just make a closure
# of the current Cluster attributes to create new Connections with
conn_factory = self._make_connection_factory(host)
reconnector = _HostReconnectionHandler(
host, conn_factory, is_host_addition, self.on_add, self.on_up,
self.scheduler, schedule, host.get_and_set_reconnection_handler,
new_handler=None)
old_reconnector = host.get_and_set_reconnection_handler(reconnector)
if old_reconnector:
log.debug("Old host reconnector found for %s, cancelling", host)
old_reconnector.cancel()
log.debug("Starting reconnector for host %s", host)
reconnector.start()
@run_in_executor
def on_down(self, host, is_host_addition, expect_host_to_be_down=False):
"""
Intended for internal use only.
"""
if self.is_shutdown:
return
with host.lock:
was_up = host.is_up
# ignore down signals if we have open pools to the host
# this is to avoid closing pools when a control connection host became isolated
if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED:
connected = False
for session in tuple(self.sessions):
pool_states = session.get_pool_state()
pool_state = pool_states.get(host)
if pool_state:
connected |= pool_state['open_count'] > 0
if connected:
return
host.set_down()
if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting():
return
log.warning("Host %s has been marked down", host)
self.profile_manager.on_down(host)
self.control_connection.on_down(host)
for session in tuple(self.sessions):
session.on_down(host)
for listener in self.listeners:
listener.on_down(host)
self._start_reconnector(host, is_host_addition)
def on_add(self, host, refresh_nodes=True):
if self.is_shutdown:
return
log.debug("Handling new host %r and notifying listeners", host)
distance = self.profile_manager.distance(host)
if distance != HostDistance.IGNORED:
self._prepare_all_queries(host)
log.debug("Done preparing queries for new host %r", host)
self.profile_manager.on_add(host)
self.control_connection.on_add(host, refresh_nodes)
if distance == HostDistance.IGNORED:
log.debug("Not adding connection pool for new host %r because the "
"load balancing policy has marked it as IGNORED", host)
self._finalize_add(host, set_up=False)
return
futures_lock = Lock()
futures_results = []
futures = set()
def future_completed(future):
with futures_lock:
futures.discard(future)
try:
futures_results.append(future.result())
except Exception as exc:
futures_results.append(exc)
if futures:
return
log.debug('All futures have completed for added host %s', host)
for exc in [f for f in futures_results if isinstance(f, Exception)]:
log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc)
return
if not all(futures_results):
log.warning("Connection pool could not be created, not marking node %s up", host)
return
self._finalize_add(host)
have_future = False
for session in tuple(self.sessions):
future = session.add_or_renew_pool(host, is_host_addition=True)
if future is not None:
have_future = True
futures.add(future)
future.add_done_callback(future_completed)
if not have_future:
self._finalize_add(host)
def _finalize_add(self, host, set_up=True):
if set_up:
host.set_up()
for listener in self.listeners:
listener.on_add(host)
# see if there are any pools to add or remove now that the host is marked up
for session in tuple(self.sessions):
session.update_created_pools()
def on_remove(self, host):
if self.is_shutdown:
return
log.debug("Removing host %s", host)
host.set_down()
self.profile_manager.on_remove(host)
for session in tuple(self.sessions):
session.on_remove(host)
for listener in self.listeners:
listener.on_remove(host)
self.control_connection.on_remove(host)
+ reconnection_handler = host.get_and_set_reconnection_handler(None)
+ if reconnection_handler:
+ reconnection_handler.cancel()
+
def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False):
is_down = host.signal_connection_failure(connection_exc)
if is_down:
self.on_down(host, is_host_addition, expect_host_to_be_down)
return is_down
def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True):
"""
Called when adding initial contact points and when the control
connection subsequently discovers a new node.
Returns a Host instance, and a flag indicating whether it was new in
the metadata.
Intended for internal use only.
"""
host, new = self.metadata.add_or_return_host(Host(endpoint, self.conviction_policy_factory, datacenter, rack))
if new and signal:
log.info("New Cassandra host %r discovered", host)
self.on_add(host, refresh_nodes)
return host, new
def remove_host(self, host):
"""
Called when the control connection observes that a node has left the
ring. Intended for internal use only.
"""
if host and self.metadata.remove_host(host):
log.info("Cassandra host %s removed", host)
self.on_remove(host)
def register_listener(self, listener):
"""
Adds a :class:`cassandra.policies.HostStateListener` subclass instance to
the list of listeners to be notified when a host is added, removed,
marked up, or marked down.
"""
with self._listener_lock:
self._listeners.add(listener)
def unregister_listener(self, listener):
""" Removes a registered listener. """
with self._listener_lock:
self._listeners.remove(listener)
@property
def listeners(self):
with self._listener_lock:
return self._listeners.copy()
def _ensure_core_connections(self):
"""
If any host has fewer than the configured number of core connections
open, attempt to open connections until that number is met.
"""
for session in tuple(self.sessions):
for pool in tuple(session._pools.values()):
pool.ensure_core_connections()
@staticmethod
def _validate_refresh_schema(keyspace, table, usertype, function, aggregate):
if any((table, usertype, function, aggregate)):
if not keyspace:
raise ValueError("keyspace is required to refresh specific sub-entity {table, usertype, function, aggregate}")
if sum(1 for e in (table, usertype, function) if e) > 1:
raise ValueError("{table, usertype, function, aggregate} are mutually exclusive")
@staticmethod
def _target_type_from_refresh_args(keyspace, table, usertype, function, aggregate):
if aggregate:
return SchemaTargetType.AGGREGATE
elif function:
return SchemaTargetType.FUNCTION
elif usertype:
return SchemaTargetType.TYPE
elif table:
return SchemaTargetType.TABLE
elif keyspace:
return SchemaTargetType.KEYSPACE
return None
def get_control_connection_host(self):
"""
Returns the control connection host metadata.
"""
connection = self.control_connection._connection
endpoint = connection.endpoint if connection else None
return self.metadata.get_host(endpoint) if endpoint else None
def refresh_schema_metadata(self, max_schema_agreement_wait=None):
"""
Synchronously refresh all schema metadata.
By default, the timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait`
and :attr:`~.Cluster.control_connection_timeout`.
Passing max_schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`.
Setting max_schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately.
An Exception is raised if schema refresh fails for any reason.
"""
if not self.control_connection.refresh_schema(schema_agreement_wait=max_schema_agreement_wait, force=True):
raise DriverException("Schema metadata was not refreshed. See log for details.")
def refresh_keyspace_metadata(self, keyspace, max_schema_agreement_wait=None):
"""
Synchronously refresh keyspace metadata. This applies to keyspace-level information such as replication
and durability settings. It does not refresh tables, types, etc. contained in the keyspace.
See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior
"""
if not self.control_connection.refresh_schema(target_type=SchemaTargetType.KEYSPACE, keyspace=keyspace,
schema_agreement_wait=max_schema_agreement_wait, force=True):
raise DriverException("Keyspace metadata was not refreshed. See log for details.")
def refresh_table_metadata(self, keyspace, table, max_schema_agreement_wait=None):
"""
Synchronously refresh table metadata. This applies to a table, and any triggers or indexes attached
to the table.
See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior
"""
if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=table,
schema_agreement_wait=max_schema_agreement_wait, force=True):
raise DriverException("Table metadata was not refreshed. See log for details.")
def refresh_materialized_view_metadata(self, keyspace, view, max_schema_agreement_wait=None):
"""
Synchronously refresh materialized view metadata.
See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior
"""
if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=view,
schema_agreement_wait=max_schema_agreement_wait, force=True):
raise DriverException("View metadata was not refreshed. See log for details.")
def refresh_user_type_metadata(self, keyspace, user_type, max_schema_agreement_wait=None):
"""
Synchronously refresh user defined type metadata.
See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior
"""
if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TYPE, keyspace=keyspace, type=user_type,
schema_agreement_wait=max_schema_agreement_wait, force=True):
raise DriverException("User Type metadata was not refreshed. See log for details.")
def refresh_user_function_metadata(self, keyspace, function, max_schema_agreement_wait=None):
"""
Synchronously refresh user defined function metadata.
``function`` is a :class:`cassandra.UserFunctionDescriptor`.
See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior
"""
if not self.control_connection.refresh_schema(target_type=SchemaTargetType.FUNCTION, keyspace=keyspace, function=function,
schema_agreement_wait=max_schema_agreement_wait, force=True):
raise DriverException("User Function metadata was not refreshed. See log for details.")
def refresh_user_aggregate_metadata(self, keyspace, aggregate, max_schema_agreement_wait=None):
"""
Synchronously refresh user defined aggregate metadata.
``aggregate`` is a :class:`cassandra.UserAggregateDescriptor`.
See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior
"""
if not self.control_connection.refresh_schema(target_type=SchemaTargetType.AGGREGATE, keyspace=keyspace, aggregate=aggregate,
schema_agreement_wait=max_schema_agreement_wait, force=True):
raise DriverException("User Aggregate metadata was not refreshed. See log for details.")
def refresh_nodes(self, force_token_rebuild=False):
"""
Synchronously refresh the node list and token metadata
`force_token_rebuild` can be used to rebuild the token map metadata, even if no new nodes are discovered.
An Exception is raised if node refresh fails for any reason.
"""
if not self.control_connection.refresh_node_list_and_token_map(force_token_rebuild):
raise DriverException("Node list was not refreshed. See log for details.")
def set_meta_refresh_enabled(self, enabled):
"""
*Deprecated:* set :attr:`~.Cluster.schema_metadata_enabled` :attr:`~.Cluster.token_metadata_enabled` instead
Sets a flag to enable (True) or disable (False) all metadata refresh queries.
This applies to both schema and node topology.
Disabling this is useful to minimize refreshes during multiple changes.
Meta refresh must be enabled for the driver to become aware of any cluster
topology changes or schema updates.
"""
warn("Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0. Set "
"Cluster.schema_metadata_enabled and Cluster.token_metadata_enabled instead.", DeprecationWarning)
self.schema_metadata_enabled = enabled
self.token_metadata_enabled = enabled
@classmethod
def _send_chunks(cls, connection, host, chunks, set_keyspace=False):
for ks_chunk in chunks:
messages = [PrepareMessage(query=s.query_string,
keyspace=s.keyspace if set_keyspace else None)
for s in ks_chunk]
# TODO: make this timeout configurable somehow?
responses = connection.wait_for_responses(*messages, timeout=5.0, fail_on_error=False)
for success, response in responses:
if not success:
log.debug("Got unexpected response when preparing "
"statement on host %s: %r", host, response)
def _prepare_all_queries(self, host):
if not self._prepared_statements or not self.reprepare_on_up:
return
log.debug("Preparing all known prepared statements against host %s", host)
connection = None
try:
connection = self.connection_factory(host.endpoint)
statements = list(self._prepared_statements.values())
if ProtocolVersion.uses_keyspace_flag(self.protocol_version):
# V5 protocol and higher, no need to set the keyspace
chunks = []
for i in range(0, len(statements), 10):
chunks.append(statements[i:i + 10])
self._send_chunks(connection, host, chunks, True)
else:
for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace):
if keyspace is not None:
connection.set_keyspace_blocking(keyspace)
# prepare 10 statements at a time
ks_statements = list(ks_statements)
chunks = []
for i in range(0, len(ks_statements), 10):
chunks.append(ks_statements[i:i + 10])
self._send_chunks(connection, host, chunks)
log.debug("Done preparing all known prepared statements against host %s", host)
except OperationTimedOut as timeout:
log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout)
except (ConnectionException, socket.error) as exc:
log.warning("Error trying to prepare all statements on host %s: %r", host, exc)
except Exception:
log.exception("Error trying to prepare all statements on host %s", host)
finally:
if connection:
connection.close()
def add_prepared(self, query_id, prepared_statement):
with self._prepared_statement_lock:
self._prepared_statements[query_id] = prepared_statement
class Session(object):
"""
A collection of connection pools for each host in the cluster.
Instances of this class should not be created directly, only
using :meth:`.Cluster.connect()`.
Queries and statements can be executed through ``Session`` instances
using the :meth:`~.Session.execute()` and :meth:`~.Session.execute_async()`
methods.
Example usage::
>>> session = cluster.connect()
>>> session.set_keyspace("mykeyspace")
>>> session.execute("SELECT * FROM mycf")
"""
cluster = None
hosts = None
keyspace = None
is_shutdown = False
+ session_id = None
+ _monitor_reporter = None
_row_factory = staticmethod(named_tuple_factory)
@property
def row_factory(self):
"""
The format to return row results in. By default, each
returned row will be a named tuple. You can alternatively
use any of the following:
- :func:`cassandra.query.tuple_factory` - return a result row as a tuple
- :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple
- :func:`cassandra.query.dict_factory` - return a result row as a dict
- :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict
"""
return self._row_factory
@row_factory.setter
def row_factory(self, rf):
self._validate_set_legacy_config('row_factory', rf)
_default_timeout = 10.0
@property
def default_timeout(self):
"""
A default timeout, measured in seconds, for queries executed through
:meth:`.execute()` or :meth:`.execute_async()`. This default may be
overridden with the `timeout` parameter for either of those methods.
Setting this to :const:`None` will cause no timeouts to be set by default.
Please see :meth:`.ResponseFuture.result` for details on the scope and
effect of this timeout.
.. versionadded:: 2.0.0
"""
return self._default_timeout
@default_timeout.setter
def default_timeout(self, timeout):
self._validate_set_legacy_config('default_timeout', timeout)
_default_consistency_level = ConsistencyLevel.LOCAL_ONE
@property
def default_consistency_level(self):
"""
*Deprecated:* use execution profiles instead
The default :class:`~ConsistencyLevel` for operations executed through
this session. This default may be overridden by setting the
:attr:`~.Statement.consistency_level` on individual statements.
.. versionadded:: 1.2.0
.. versionchanged:: 3.0.0
default changed from ONE to LOCAL_ONE
"""
return self._default_consistency_level
@default_consistency_level.setter
def default_consistency_level(self, cl):
"""
*Deprecated:* use execution profiles instead
"""
warn("Setting the consistency level at the session level will be removed in 4.0. Consider using "
"execution profiles and setting the desired consitency level to the EXEC_PROFILE_DEFAULT profile."
, DeprecationWarning)
self._validate_set_legacy_config('default_consistency_level', cl)
_default_serial_consistency_level = None
@property
def default_serial_consistency_level(self):
"""
The default :class:`~ConsistencyLevel` for serial phase of conditional updates executed through
this session. This default may be overridden by setting the
:attr:`~.Statement.serial_consistency_level` on individual statements.
Only valid for ``protocol_version >= 2``.
"""
return self._default_serial_consistency_level
@default_serial_consistency_level.setter
def default_serial_consistency_level(self, cl):
if (cl is not None and
not ConsistencyLevel.is_serial(cl)):
raise ValueError("default_serial_consistency_level must be either "
"ConsistencyLevel.SERIAL "
"or ConsistencyLevel.LOCAL_SERIAL.")
self._validate_set_legacy_config('default_serial_consistency_level', cl)
max_trace_wait = 2.0
"""
The maximum amount of time (in seconds) the driver will wait for trace
details to be populated server-side for a query before giving up.
If the `trace` parameter for :meth:`~.execute()` or :meth:`~.execute_async()`
is :const:`True`, the driver will repeatedly attempt to fetch trace
details for the query (using exponential backoff) until this limit is
hit. If the limit is passed, an error will be logged and the
:attr:`.Statement.trace` will be left as :const:`None`. """
default_fetch_size = 5000
"""
By default, this many rows will be fetched at a time. Setting
this to :const:`None` will disable automatic paging for large query
results. The fetch size can be also specified per-query through
:attr:`.Statement.fetch_size`.
This only takes effect when protocol version 2 or higher is used.
See :attr:`.Cluster.protocol_version` for details.
.. versionadded:: 2.0.0
"""
use_client_timestamp = True
"""
When using protocol version 3 or higher, write timestamps may be supplied
client-side at the protocol level. (Normally they are generated
server-side by the coordinator node.) Note that timestamps specified
within a CQL query will override this timestamp.
.. versionadded:: 2.1.0
"""
timestamp_generator = None
"""
When :attr:`use_client_timestamp` is set, sessions call this object and use
the result as the timestamp. (Note that timestamps specified within a CQL
query will override this timestamp.) By default, a new
:class:`~.MonotonicTimestampGenerator` is created for
each :class:`Cluster` instance.
Applications can set this value for custom timestamp behavior. For
example, an application could share a timestamp generator across
:class:`Cluster` objects to guarantee that the application will use unique,
increasing timestamps across clusters, or set it to to ``lambda:
int(time.time() * 1e6)`` if losing records over clock inconsistencies is
acceptable for the application. Custom :attr:`timestamp_generator` s should
be callable, and calling them should return an integer representing microseconds
since some point in time, typically UNIX epoch.
.. versionadded:: 3.8.0
"""
encoder = None
"""
A :class:`~cassandra.encoder.Encoder` instance that will be used when
formatting query parameters for non-prepared statements. This is not used
for prepared statements (because prepared statements give the driver more
information about what CQL types are expected, allowing it to accept a
wider range of python types).
The encoder uses a mapping from python types to encoder methods (for
specific CQL types). This mapping can be be modified by users as they see
fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping
values if possible, because they take precautions to avoid injections and
properly sanitize data.
Example::
cluster = Cluster()
session = cluster.connect("mykeyspace")
session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple
session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple)")
session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')])
.. versionadded:: 2.1.0
"""
client_protocol_handler = ProtocolHandler
"""
Specifies a protocol handler that will be used for client-initiated requests (i.e. no
internal driver requests). This can be used to override or extend features such as
message or type ser/des.
The default pure python implementation is :class:`cassandra.protocol.ProtocolHandler`.
When compiled with Cython, there are also built-in faster alternatives. See :ref:`faster_deser`
"""
+ session_id = None
+ """
+ A UUID that uniquely identifies this Session to Insights. This will be
+ generated automatically.
+ """
+
_lock = None
_pools = None
_profile_manager = None
_metrics = None
_request_init_callbacks = None
+ _graph_paging_available = False
def __init__(self, cluster, hosts, keyspace=None):
self.cluster = cluster
self.hosts = hosts
self.keyspace = keyspace
self._lock = RLock()
self._pools = {}
self._profile_manager = cluster.profile_manager
self._metrics = cluster.metrics
self._request_init_callbacks = []
self._protocol_version = self.cluster.protocol_version
self.encoder = Encoder()
# create connection pools in parallel
self._initial_connect_futures = set()
for host in hosts:
future = self.add_or_renew_pool(host, is_host_addition=False)
if future:
self._initial_connect_futures.add(future)
futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED)
while futures.not_done and not any(f.result() for f in futures.done):
futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED)
if not any(f.result() for f in self._initial_connect_futures):
msg = "Unable to connect to any servers"
if self.keyspace:
msg += " using keyspace '%s'" % self.keyspace
raise NoHostAvailable(msg, [h.address for h in hosts])
+ self.session_id = uuid.uuid4()
+ self._graph_paging_available = self._check_graph_paging_available()
+
+ if self.cluster.monitor_reporting_enabled:
+ cc_host = self.cluster.get_control_connection_host()
+ valid_insights_version = (cc_host and version_supports_insights(cc_host.dse_version))
+ if valid_insights_version:
+ self._monitor_reporter = MonitorReporter(
+ interval_sec=self.cluster.monitor_reporting_interval,
+ session=self,
+ )
+ else:
+ if cc_host:
+ log.debug('Not starting MonitorReporter thread for Insights; '
+ 'not supported by server version {v} on '
+ 'ControlConnection host {c}'.format(v=cc_host.release_version, c=cc_host))
+
+ log.debug('Started Session with client_id {} and session_id {}'.format(self.cluster.client_id,
+ self.session_id))
+
def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False,
custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT,
- paging_state=None, host=None):
+ paging_state=None, host=None, execute_as=None):
"""
Execute the given query and synchronously wait for the response.
If an error is encountered while executing the query, an Exception
will be raised.
`query` may be a query string or an instance of :class:`cassandra.query.Statement`.
`parameters` may be a sequence or dict of parameters to bind. If a
sequence is used, ``%s`` should be used the placeholder for each
argument. If a dict is used, ``%(name)s`` style placeholders must
be used.
`timeout` should specify a floating-point timeout (in seconds) after
which an :exc:`.OperationTimedOut` exception will be raised if the query
- has not completed. If not set, the timeout defaults to
- :attr:`~.Session.default_timeout`. If set to :const:`None`, there is
- no timeout. Please see :meth:`.ResponseFuture.result` for details on
+ has not completed. If not set, the timeout defaults to the request_timeout of the selected ``execution_profile``.
+ If set to :const:`None`, there is no timeout. Please see :meth:`.ResponseFuture.result` for details on
the scope and effect of this timeout.
If `trace` is set to :const:`True`, the query will be sent with tracing enabled.
The trace details can be obtained using the returned :class:`.ResultSet` object.
`custom_payload` is a :ref:`custom_payload` dict to be passed to the server.
If `query` is a Statement with its own custom_payload. The message payload
will be a union of the two, with the values specified here taking precedence.
`execution_profile` is the execution profile to use for this request. It can be a key to a profile configured
via :meth:`Cluster.add_execution_profile` or an instance (from :meth:`Session.execution_profile_clone_update`,
for example
`paging_state` is an optional paging state, reused from a previous :class:`ResultSet`.
`host` is the :class:`cassandra.pool.Host` that should handle the query. If the host specified is down or
not yet connected, the query will fail with :class:`NoHostAvailable`. Using this is
discouraged except in a few cases, e.g., querying node-local tables and applying schema changes.
+
+ `execute_as` the user that will be used on the server to execute the request. This is only available
+ on a DSE cluster.
"""
- return self.execute_async(query, parameters, trace, custom_payload,
- timeout, execution_profile, paging_state, host).result()
+
+ return self.execute_async(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state, host, execute_as).result()
def execute_async(self, query, parameters=None, trace=False, custom_payload=None,
timeout=_NOT_SET, execution_profile=EXEC_PROFILE_DEFAULT,
- paging_state=None, host=None):
+ paging_state=None, host=None, execute_as=None):
"""
Execute the given query and return a :class:`~.ResponseFuture` object
which callbacks may be attached to for asynchronous response
delivery. You may also call :meth:`~.ResponseFuture.result()`
on the :class:`.ResponseFuture` to synchronously block for results at
any time.
See :meth:`Session.execute` for parameter definitions.
Example usage::
>>> session = cluster.connect()
>>> future = session.execute_async("SELECT * FROM mycf")
>>> def log_results(results):
... for row in results:
... log.info("Results: %s", row)
>>> def log_error(exc):
>>> log.error("Operation failed: %s", exc)
>>> future.add_callbacks(log_results, log_error)
Async execution with blocking wait for results::
>>> future = session.execute_async("SELECT * FROM mycf")
>>> # do other stuff...
>>> try:
... results = future.result()
... except Exception:
... log.exception("Operation failed:")
"""
+ custom_payload = custom_payload if custom_payload else {}
+ if execute_as:
+ custom_payload[_proxy_execute_key] = six.b(execute_as)
+
future = self._create_response_future(
query, parameters, trace, custom_payload, timeout,
execution_profile, paging_state, host)
future._protocol_handler = self.client_protocol_handler
self._on_request(future)
future.send_request()
return future
+ def execute_graph(self, query, parameters=None, trace=False, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, execute_as=None):
+ """
+ Executes a Gremlin query string or GraphStatement synchronously,
+ and returns a ResultSet from this execution.
+
+ `parameters` is dict of named parameters to bind. The values must be
+ JSON-serializable.
+
+ `execution_profile`: Selects an execution profile for the request.
+
+ `execute_as` the user that will be used on the server to execute the request.
+ """
+ return self.execute_graph_async(query, parameters, trace, execution_profile, execute_as).result()
+
+ def execute_graph_async(self, query, parameters=None, trace=False, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, execute_as=None):
+ """
+ Execute the graph query and return a :class:`ResponseFuture`
+ object which callbacks may be attached to for asynchronous response delivery. You may also call ``ResponseFuture.result()`` to synchronously block for
+ results at any time.
+ """
+ if self.cluster._config_mode is _ConfigMode.LEGACY:
+ raise ValueError(("Cannot execute graph queries using Cluster legacy parameters. "
+ "Consider using Execution profiles: "
+ "https://docs.datastax.com/en/developer/python-driver/latest/execution_profiles/#execution-profiles"))
+
+ if not isinstance(query, GraphStatement):
+ query = SimpleGraphStatement(query)
+
+ # Clone and look up instance here so we can resolve and apply the extended attributes
+ execution_profile = self.execution_profile_clone_update(execution_profile)
+
+ if not hasattr(execution_profile, 'graph_options'):
+ raise ValueError(
+ "Execution profile for graph queries must derive from GraphExecutionProfile, and provide graph_options")
+
+ self._resolve_execution_profile_options(execution_profile)
+
+ # make sure the graphson context row factory is binded to this cluster
+ try:
+ if issubclass(execution_profile.row_factory, _GraphSONContextRowFactory):
+ execution_profile.row_factory = execution_profile.row_factory(self.cluster)
+ except TypeError:
+ # issubclass might fail if arg1 is an instance
+ pass
+
+ # set graph paging if needed
+ self._maybe_set_graph_paging(execution_profile)
+
+ graph_parameters = None
+ if parameters:
+ graph_parameters = self._transform_params(parameters, graph_options=execution_profile.graph_options)
+
+ custom_payload = execution_profile.graph_options.get_options_map()
+ if execute_as:
+ custom_payload[_proxy_execute_key] = six.b(execute_as)
+ custom_payload[_request_timeout_key] = int64_pack(long(execution_profile.request_timeout * 1000))
+
+ future = self._create_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload,
+ timeout=_NOT_SET, execution_profile=execution_profile)
+
+ future.message.query_params = graph_parameters
+ future._protocol_handler = self.client_protocol_handler
+
+ if execution_profile.graph_options.is_analytics_source and \
+ isinstance(execution_profile.load_balancing_policy, DefaultLoadBalancingPolicy):
+ self._target_analytics_master(future)
+ else:
+ future.send_request()
+ return future
+
+ def _maybe_set_graph_paging(self, execution_profile):
+ graph_paging = execution_profile.continuous_paging_options
+ if execution_profile.continuous_paging_options is _NOT_SET:
+ graph_paging = ContinuousPagingOptions() if self._graph_paging_available else None
+
+ execution_profile.continuous_paging_options = graph_paging
+
+ def _check_graph_paging_available(self):
+ """Verify if we can enable graph paging. This executed only once when the session is created."""
+
+ if not ProtocolVersion.has_continuous_paging_next_pages(self._protocol_version):
+ return False
+
+ for host in self.cluster.metadata.all_hosts():
+ if host.dse_version is None:
+ return False
+
+ version = Version(host.dse_version)
+ if version < _GRAPH_PAGING_MIN_DSE_VERSION:
+ return False
+
+ return True
+
+ def _resolve_execution_profile_options(self, execution_profile):
+ """
+ Determine the GraphSON protocol and row factory for a graph query. This is useful
+ to configure automatically the execution profile when executing a query on a
+ core graph.
+
+ If `graph_protocol` is not explicitly specified, the following rules apply:
+ - Default to GraphProtocol.GRAPHSON_1_0, or GRAPHSON_2_0 if the `graph_language` is not gremlin-groovy.
+ - If `graph_options.graph_name` is specified and is a Core graph, set GraphSON_3_0.
+ If `row_factory` is not explicitly specified, the following rules apply:
+ - Default to graph_object_row_factory.
+ - If `graph_options.graph_name` is specified and is a Core graph, set graph_graphson3_row_factory.
+ """
+ if execution_profile.graph_options.graph_protocol is not None and \
+ execution_profile.row_factory is not None:
+ return
+
+ graph_options = execution_profile.graph_options
+
+ is_core_graph = False
+ if graph_options.graph_name:
+ # graph_options.graph_name is bytes ...
+ name = graph_options.graph_name.decode('utf-8')
+ if name in self.cluster.metadata.keyspaces:
+ ks_metadata = self.cluster.metadata.keyspaces[name]
+ if ks_metadata.graph_engine == 'Core':
+ is_core_graph = True
+
+ if is_core_graph:
+ graph_protocol = GraphProtocol.GRAPHSON_3_0
+ row_factory = graph_graphson3_row_factory
+ else:
+ if graph_options.graph_language == GraphOptions.DEFAULT_GRAPH_LANGUAGE:
+ graph_protocol = GraphOptions.DEFAULT_GRAPH_PROTOCOL
+ row_factory = graph_object_row_factory
+ else:
+ # if not gremlin-groovy, GraphSON_2_0
+ graph_protocol = GraphProtocol.GRAPHSON_2_0
+ row_factory = graph_graphson2_row_factory
+
+ # Only apply if not set explicitly
+ if graph_options.graph_protocol is None:
+ graph_options.graph_protocol = graph_protocol
+ if execution_profile.row_factory is None:
+ execution_profile.row_factory = row_factory
+
+ def _transform_params(self, parameters, graph_options):
+ if not isinstance(parameters, dict):
+ raise ValueError('The parameters must be a dictionary. Unnamed parameters are not allowed.')
+
+ # Serialize python types to graphson
+ serializer = GraphSON1Serializer
+ if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0:
+ serializer = GraphSON2Serializer()
+ elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0:
+ # only required for core graphs
+ context = {
+ 'cluster': self.cluster,
+ 'graph_name': graph_options.graph_name.decode('utf-8') if graph_options.graph_name else None
+ }
+ serializer = GraphSON3Serializer(context)
+
+ serialized_parameters = serializer.serialize(parameters)
+ return [json.dumps(serialized_parameters).encode('utf-8')]
+
+ def _target_analytics_master(self, future):
+ future._start_timer()
+ master_query_future = self._create_response_future("CALL DseClientTool.getAnalyticsGraphServer()",
+ parameters=None, trace=False,
+ custom_payload=None, timeout=future.timeout)
+ master_query_future.row_factory = tuple_factory
+ master_query_future.send_request()
+
+ cb = self._on_analytics_master_result
+ args = (master_query_future, future)
+ master_query_future.add_callbacks(callback=cb, callback_args=args, errback=cb, errback_args=args)
+
+ def _on_analytics_master_result(self, response, master_future, query_future):
+ try:
+ row = master_future.result()[0]
+ addr = row[0]['location']
+ delimiter_index = addr.rfind(':') # assumes : - not robust, but that's what is being provided
+ if delimiter_index > 0:
+ addr = addr[:delimiter_index]
+ targeted_query = HostTargetingStatement(query_future.query, addr)
+ query_future.query_plan = query_future._load_balancer.make_query_plan(self.keyspace, targeted_query)
+ except Exception:
+ log.debug("Failed querying analytics master (request might not be routed optimally). "
+ "Make sure the session is connecting to a graph analytics datacenter.", exc_info=True)
+
+ self.submit(query_future.send_request)
+
def _create_response_future(self, query, parameters, trace, custom_payload,
timeout, execution_profile=EXEC_PROFILE_DEFAULT,
paging_state=None, host=None):
""" Returns the ResponseFuture before calling send_request() on it """
prepared_statement = None
if isinstance(query, six.string_types):
query = SimpleStatement(query)
elif isinstance(query, PreparedStatement):
query = query.bind(parameters)
if self.cluster._config_mode == _ConfigMode.LEGACY:
if execution_profile is not EXEC_PROFILE_DEFAULT:
raise ValueError("Cannot specify execution_profile while using legacy parameters.")
if timeout is _NOT_SET:
timeout = self.default_timeout
cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level
serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else self.default_serial_consistency_level
retry_policy = query.retry_policy or self.cluster.default_retry_policy
row_factory = self.row_factory
load_balancing_policy = self.cluster.load_balancing_policy
spec_exec_policy = None
+ continuous_paging_options = None
else:
execution_profile = self._maybe_get_execution_profile(execution_profile)
if timeout is _NOT_SET:
timeout = execution_profile.request_timeout
cl = query.consistency_level if query.consistency_level is not None else execution_profile.consistency_level
serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else execution_profile.serial_consistency_level
+ continuous_paging_options = execution_profile.continuous_paging_options
retry_policy = query.retry_policy or execution_profile.retry_policy
row_factory = execution_profile.row_factory
load_balancing_policy = execution_profile.load_balancing_policy
spec_exec_policy = execution_profile.speculative_execution_policy
fetch_size = query.fetch_size
if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2:
fetch_size = self.default_fetch_size
elif self._protocol_version == 1:
fetch_size = None
start_time = time.time()
if self._protocol_version >= 3 and self.use_client_timestamp:
timestamp = self.cluster.timestamp_generator()
else:
timestamp = None
+ supports_continuous_paging_state = (
+ ProtocolVersion.has_continuous_paging_next_pages(self._protocol_version)
+ )
+ if continuous_paging_options and supports_continuous_paging_state:
+ continuous_paging_state = ContinuousPagingState(continuous_paging_options.max_queue_size)
+ else:
+ continuous_paging_state = None
+
if isinstance(query, SimpleStatement):
query_string = query.query_string
statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None
if parameters:
query_string = bind_params(query_string, parameters, self.encoder)
message = QueryMessage(
query_string, cl, serial_cl,
- fetch_size, timestamp=timestamp,
- keyspace=statement_keyspace)
+ fetch_size, paging_state, timestamp,
+ continuous_paging_options, statement_keyspace)
elif isinstance(query, BoundStatement):
prepared_statement = query.prepared_statement
message = ExecuteMessage(
prepared_statement.query_id, query.values, cl,
- serial_cl, fetch_size,
- timestamp=timestamp, skip_meta=bool(prepared_statement.result_metadata),
+ serial_cl, fetch_size, paging_state, timestamp,
+ skip_meta=bool(prepared_statement.result_metadata),
+ continuous_paging_options=continuous_paging_options,
result_metadata_id=prepared_statement.result_metadata_id)
elif isinstance(query, BatchStatement):
if self._protocol_version < 2:
raise UnsupportedOperation(
"BatchStatement execution is only supported with protocol version "
"2 or higher (supported in Cassandra 2.0 and higher). Consider "
"setting Cluster.protocol_version to 2 to support this operation.")
statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None
message = BatchMessage(
query.batch_type, query._statements_and_parameters, cl,
serial_cl, timestamp, statement_keyspace)
+ elif isinstance(query, GraphStatement):
+ # the statement_keyspace is not aplicable to GraphStatement
+ message = QueryMessage(query.query, cl, serial_cl, fetch_size,
+ paging_state, timestamp,
+ continuous_paging_options)
message.tracing = trace
-
message.update_custom_payload(query.custom_payload)
message.update_custom_payload(custom_payload)
message.allow_beta_protocol_version = self.cluster.allow_beta_protocol_version
- message.paging_state = paging_state
spec_exec_plan = spec_exec_policy.new_plan(query.keyspace or self.keyspace, query) if query.is_idempotent and spec_exec_policy else None
return ResponseFuture(
self, message, query, timeout, metrics=self._metrics,
prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory,
load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan,
- host=host)
+ continuous_paging_state=continuous_paging_state, host=host)
def get_execution_profile(self, name):
"""
Returns the execution profile associated with the provided ``name``.
:param name: The name (or key) of the execution profile.
"""
profiles = self.cluster.profile_manager.profiles
try:
return profiles[name]
except KeyError:
eps = [_execution_profile_to_string(ep) for ep in profiles.keys()]
raise ValueError("Invalid execution_profile: %s; valid profiles are: %s." % (
_execution_profile_to_string(name), ', '.join(eps)))
def _maybe_get_execution_profile(self, ep):
return ep if isinstance(ep, ExecutionProfile) else self.get_execution_profile(ep)
def execution_profile_clone_update(self, ep, **kwargs):
"""
Returns a clone of the ``ep`` profile. ``kwargs`` can be specified to update attributes
of the returned profile.
This is a shallow clone, so any objects referenced by the profile are shared. This means Load Balancing Policy
is maintained by inclusion in the active profiles. It also means updating any other rich objects will be seen
by the active profile. In cases where this is not desirable, be sure to replace the instance instead of manipulating
the shared object.
"""
clone = copy(self._maybe_get_execution_profile(ep))
for attr, value in kwargs.items():
setattr(clone, attr, value)
return clone
def add_request_init_listener(self, fn, *args, **kwargs):
"""
Adds a callback with arguments to be called when any request is created.
It will be invoked as `fn(response_future, *args, **kwargs)` after each client request is created,
and before the request is sent. This can be used to create extensions by adding result callbacks to the
response future.
`response_future` is the :class:`.ResponseFuture` for the request.
Note that the init callback is done on the client thread creating the request, so you may need to consider
synchronization if you have multiple threads. Any callbacks added to the response future will be executed
on the event loop thread, so the normal advice about minimizing cycles and avoiding blocking apply (see Note in
:meth:`.ResponseFuture.add_callbacks`.
See `this example `_ in the
source tree for an example.
"""
self._request_init_callbacks.append((fn, args, kwargs))
def remove_request_init_listener(self, fn, *args, **kwargs):
"""
Removes a callback and arguments from the list.
See :meth:`.Session.add_request_init_listener`.
"""
self._request_init_callbacks.remove((fn, args, kwargs))
def _on_request(self, response_future):
for fn, args, kwargs in self._request_init_callbacks:
fn(response_future, *args, **kwargs)
def prepare(self, query, custom_payload=None, keyspace=None):
"""
Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement`
instance which can be used as follows::
>>> session = cluster.connect("mykeyspace")
>>> query = "INSERT INTO users (id, name, age) VALUES (?, ?, ?)"
>>> prepared = session.prepare(query)
>>> session.execute(prepared, (user.id, user.name, user.age))
Or you may bind values to the prepared statement ahead of time::
>>> prepared = session.prepare(query)
>>> bound_stmt = prepared.bind((user.id, user.name, user.age))
>>> session.execute(bound_stmt)
Of course, prepared statements may (and should) be reused::
>>> prepared = session.prepare(query)
>>> for user in users:
... bound = prepared.bind((user.id, user.name, user.age))
... session.execute(bound)
Alternatively, if :attr:`~.Cluster.protocol_version` is 5 or higher
(requires Cassandra 4.0+), the keyspace can be specified as a
parameter. This will allow you to avoid specifying the keyspace in the
query without specifying a keyspace in :meth:`~.Cluster.connect`. It
even will let you prepare and use statements against a keyspace other
than the one originally specified on connection:
>>> analyticskeyspace_prepared = session.prepare(
... "INSERT INTO user_activity id, last_activity VALUES (?, ?)",
... keyspace="analyticskeyspace") # note the different keyspace
**Important**: PreparedStatements should be prepared only once.
Preparing the same query more than once will likely affect performance.
`custom_payload` is a key value map to be passed along with the prepare
message. See :ref:`custom_payload`.
"""
message = PrepareMessage(query=query, keyspace=keyspace)
future = ResponseFuture(self, message, query=None, timeout=self.default_timeout)
try:
future.send_request()
- query_id, bind_metadata, pk_indexes, result_metadata, result_metadata_id = future.result()
+ response = future.result().one()
except Exception:
log.exception("Error preparing query:")
raise
prepared_keyspace = keyspace if keyspace else None
prepared_statement = PreparedStatement.from_message(
- query_id, bind_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace,
- self._protocol_version, result_metadata, result_metadata_id)
+ response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
+ self._protocol_version, response.column_metadata, response.result_metadata_id)
prepared_statement.custom_payload = future.custom_payload
- self.cluster.add_prepared(query_id, prepared_statement)
+ self.cluster.add_prepared(response.query_id, prepared_statement)
if self.cluster.prepare_on_all_hosts:
host = future._current_host
try:
self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace)
except Exception:
log.exception("Error preparing query on all hosts:")
return prepared_statement
def prepare_on_all_hosts(self, query, excluded_host, keyspace=None):
"""
Prepare the given query on all hosts, excluding ``excluded_host``.
Intended for internal use only.
"""
futures = []
for host in tuple(self._pools.keys()):
if host != excluded_host and host.is_up:
future = ResponseFuture(self, PrepareMessage(query=query, keyspace=keyspace),
None, self.default_timeout)
# we don't care about errors preparing against specific hosts,
# since we can always prepare them as needed when the prepared
# statement is used. Just log errors and continue on.
try:
request_id = future._query(host)
except Exception:
log.exception("Error preparing query for host %s:", host)
continue
if request_id is None:
# the error has already been logged by ResponsFuture
log.debug("Failed to prepare query for host %s: %r",
host, future._errors.get(host))
continue
futures.append((host, future))
for host, future in futures:
try:
future.result()
except Exception:
log.exception("Error preparing query for host %s:", host)
def shutdown(self):
"""
Close all connections. ``Session`` instances should not be used
for any purpose after being shutdown.
"""
with self._lock:
if self.is_shutdown:
return
else:
self.is_shutdown = True
# PYTHON-673. If shutdown was called shortly after session init, avoid
# a race by cancelling any initial connection attempts haven't started,
# then blocking on any that have.
for future in self._initial_connect_futures:
future.cancel()
wait_futures(self._initial_connect_futures)
+ if self._monitor_reporter:
+ self._monitor_reporter.stop()
+
for pool in tuple(self._pools.values()):
pool.shutdown()
def __enter__(self):
return self
def __exit__(self, *args):
self.shutdown()
def __del__(self):
try:
# Ensure all connections are closed, in case the Session object is deleted by the GC
self.shutdown()
except:
# Ignore all errors. Shutdown errors can be caught by the user
# when cluster.shutdown() is called explicitly.
pass
def add_or_renew_pool(self, host, is_host_addition):
"""
For internal use only.
"""
distance = self._profile_manager.distance(host)
if distance == HostDistance.IGNORED:
return None
def run_add_or_renew_pool():
try:
if self._protocol_version >= 3:
new_pool = HostConnection(host, distance, self)
else:
# TODO remove host pool again ???
new_pool = HostConnectionPool(host, distance, self)
except AuthenticationFailed as auth_exc:
conn_exc = ConnectionException(str(auth_exc), endpoint=host)
self.cluster.signal_connection_failure(host, conn_exc, is_host_addition)
return False
except Exception as conn_exc:
log.warning("Failed to create connection pool for new host %s:",
host, exc_info=conn_exc)
# the host itself will still be marked down, so we need to pass
# a special flag to make sure the reconnector is created
self.cluster.signal_connection_failure(
host, conn_exc, is_host_addition, expect_host_to_be_down=True)
return False
previous = self._pools.get(host)
with self._lock:
while new_pool._keyspace != self.keyspace:
self._lock.release()
set_keyspace_event = Event()
errors_returned = []
def callback(pool, errors):
errors_returned.extend(errors)
set_keyspace_event.set()
new_pool._set_keyspace_for_all_conns(self.keyspace, callback)
set_keyspace_event.wait(self.cluster.connect_timeout)
if not set_keyspace_event.is_set() or errors_returned:
log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned)
self.cluster.on_down(host, is_host_addition)
new_pool.shutdown()
self._lock.acquire()
return False
self._lock.acquire()
self._pools[host] = new_pool
log.debug("Added pool for host %s to session", host)
if previous:
previous.shutdown()
return True
return self.submit(run_add_or_renew_pool)
def remove_pool(self, host):
pool = self._pools.pop(host, None)
if pool:
log.debug("Removed connection pool for %r", host)
return self.submit(pool.shutdown)
else:
return None
def update_created_pools(self):
"""
When the set of live nodes change, the loadbalancer will change its
mind on host distances. It might change it on the node that came/left
but also on other nodes (for instance, if a node dies, another
previously ignored node may be now considered).
This method ensures that all hosts for which a pool should exist
have one, and hosts that shouldn't don't.
For internal use only.
"""
futures = set()
for host in self.cluster.metadata.all_hosts():
distance = self._profile_manager.distance(host)
pool = self._pools.get(host)
future = None
if not pool or pool.is_shutdown:
# we don't eagerly set is_up on previously ignored hosts. None is included here
# to allow us to attempt connections to hosts that have gone from ignored to something
# else.
if distance != HostDistance.IGNORED and host.is_up in (True, None):
future = self.add_or_renew_pool(host, False)
elif distance != pool.host_distance:
# the distance has changed
if distance == HostDistance.IGNORED:
future = self.remove_pool(host)
else:
pool.host_distance = distance
if future:
futures.add(future)
return futures
def on_down(self, host):
"""
Called by the parent Cluster instance when a node is marked down.
Only intended for internal use.
"""
future = self.remove_pool(host)
if future:
future.add_done_callback(lambda f: self.update_created_pools())
def on_remove(self, host):
""" Internal """
self.on_down(host)
def set_keyspace(self, keyspace):
"""
Set the default keyspace for all queries made through this Session.
This operation blocks until complete.
"""
self.execute('USE %s' % (protect_name(keyspace),))
def _set_keyspace_for_all_pools(self, keyspace, callback):
"""
Asynchronously sets the keyspace on all pools. When all
pools have set all of their connections, `callback` will be
called with a dictionary of all errors that occurred, keyed
by the `Host` that they occurred against.
"""
with self._lock:
self.keyspace = keyspace
remaining_callbacks = set(self._pools.values())
errors = {}
if not remaining_callbacks:
callback(errors)
return
def pool_finished_setting_keyspace(pool, host_errors):
remaining_callbacks.remove(pool)
if host_errors:
errors[pool.host] = host_errors
if not remaining_callbacks:
callback(host_errors)
for pool in tuple(self._pools.values()):
pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace)
def user_type_registered(self, keyspace, user_type, klass):
"""
Called by the parent Cluster instance when the user registers a new
mapping from a user-defined type to a class. Intended for internal
use only.
"""
try:
ks_meta = self.cluster.metadata.keyspaces[keyspace]
except KeyError:
raise UserTypeDoesNotExist(
'Keyspace %s does not exist or has not been discovered by the driver' % (keyspace,))
try:
type_meta = ks_meta.user_types[user_type]
except KeyError:
raise UserTypeDoesNotExist(
'User type %s does not exist in keyspace %s' % (user_type, keyspace))
field_names = type_meta.field_names
if six.PY2:
# go from unicode to string to avoid decode errors from implicit
# decode when formatting non-ascii values
field_names = [fn.encode('utf-8') for fn in field_names]
def encode(val):
return '{ %s }' % ' , '.join('%s : %s' % (
field_name,
self.encoder.cql_encode_all_types(getattr(val, field_name, None))
) for field_name in field_names)
self.encoder.mapping[klass] = encode
def submit(self, fn, *args, **kwargs):
""" Internal """
if not self.is_shutdown:
return self.cluster.executor.submit(fn, *args, **kwargs)
def get_pool_state(self):
return dict((host, pool.get_state()) for host, pool in tuple(self._pools.items()))
def get_pools(self):
return self._pools.values()
def _validate_set_legacy_config(self, attr_name, value):
if self.cluster._config_mode == _ConfigMode.PROFILES:
raise ValueError("Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." % (attr_name,))
setattr(self, '_' + attr_name, value)
self.cluster._config_mode = _ConfigMode.LEGACY
class UserTypeDoesNotExist(Exception):
"""
An attempt was made to use a user-defined type that does not exist.
.. versionadded:: 2.1.0
"""
pass
class _ControlReconnectionHandler(_ReconnectionHandler):
"""
Internal
"""
def __init__(self, control_connection, *args, **kwargs):
_ReconnectionHandler.__init__(self, *args, **kwargs)
self.control_connection = weakref.proxy(control_connection)
def try_reconnect(self):
return self.control_connection._reconnect_internal()
def on_reconnection(self, connection):
self.control_connection._set_new_connection(connection)
def on_exception(self, exc, next_delay):
# TODO only overridden to add logging, so add logging
if isinstance(exc, AuthenticationFailed):
return False
else:
log.debug("Error trying to reconnect control connection: %r", exc)
return True
def _watch_callback(obj_weakref, method_name, *args, **kwargs):
"""
A callback handler for the ControlConnection that tolerates
weak references.
"""
obj = obj_weakref()
if obj is None:
return
getattr(obj, method_name)(*args, **kwargs)
def _clear_watcher(conn, expiring_weakref):
"""
Called when the ControlConnection object is about to be finalized.
This clears watchers on the underlying Connection object.
"""
try:
conn.control_conn_disposed()
except ReferenceError:
pass
class ControlConnection(object):
"""
Internal
"""
_SELECT_PEERS = "SELECT * FROM system.peers"
- _SELECT_PEERS_NO_TOKENS = "SELECT host_id, peer, data_center, rack, rpc_address, release_version, schema_version FROM system.peers"
+ _SELECT_PEERS_NO_TOKENS_TEMPLATE = "SELECT host_id, peer, data_center, rack, rpc_address, {nt_col_name}, release_version, schema_version FROM system.peers"
_SELECT_LOCAL = "SELECT * FROM system.local WHERE key='local'"
_SELECT_LOCAL_NO_TOKENS = "SELECT host_id, cluster_name, data_center, rack, partitioner, release_version, schema_version FROM system.local WHERE key='local'"
# Used only when token_metadata_enabled is set to False
_SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS = "SELECT rpc_address FROM system.local WHERE key='local'"
- _SELECT_SCHEMA_PEERS = "SELECT peer, host_id, rpc_address, schema_version FROM system.peers"
+ _SELECT_SCHEMA_PEERS_TEMPLATE = "SELECT peer, host_id, {nt_col_name}, schema_version FROM system.peers"
_SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'"
+ _SELECT_PEERS_V2 = "SELECT * FROM system.peers_v2"
+ _SELECT_PEERS_NO_TOKENS_V2 = "SELECT host_id, peer, peer_port, data_center, rack, native_address, native_port, release_version, schema_version FROM system.peers_v2"
+ _SELECT_SCHEMA_PEERS_V2 = "SELECT host_id, peer, peer_port, native_address, native_port, schema_version FROM system.peers_v2"
+
+ _MINIMUM_NATIVE_ADDRESS_DSE_VERSION = Version("6.0.0")
+
+ class PeersQueryType(object):
+ """internal Enum for _peers_query"""
+ PEERS = 0
+ PEERS_SCHEMA = 1
+
_is_shutdown = False
_timeout = None
_protocol_version = None
_schema_event_refresh_window = None
_topology_event_refresh_window = None
_status_event_refresh_window = None
_schema_meta_enabled = True
_token_meta_enabled = True
+ _uses_peers_v2 = True
+
# for testing purposes
_time = time
def __init__(self, cluster, timeout,
schema_event_refresh_window,
topology_event_refresh_window,
status_event_refresh_window,
schema_meta_enabled=True,
token_meta_enabled=True):
# use a weak reference to allow the Cluster instance to be GC'ed (and
# shutdown) since implementing __del__ disables the cycle detector
self._cluster = weakref.proxy(cluster)
self._connection = None
self._timeout = timeout
self._schema_event_refresh_window = schema_event_refresh_window
self._topology_event_refresh_window = topology_event_refresh_window
self._status_event_refresh_window = status_event_refresh_window
self._schema_meta_enabled = schema_meta_enabled
self._token_meta_enabled = token_meta_enabled
self._lock = RLock()
self._schema_agreement_lock = Lock()
self._reconnection_handler = None
self._reconnection_lock = RLock()
self._event_schedule_times = {}
def connect(self):
if self._is_shutdown:
return
self._protocol_version = self._cluster.protocol_version
self._set_new_connection(self._reconnect_internal())
- self._cluster.metadata.dbaas = self._connection._product_type == dscloud.PRODUCT_APOLLO
+ self._cluster.metadata.dbaas = self._connection._product_type == dscloud.DATASTAX_CLOUD_PRODUCT_TYPE
def _set_new_connection(self, conn):
"""
Replace existing connection (if there is one) and close it.
"""
with self._lock:
old = self._connection
self._connection = conn
if old:
log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn)
old.close()
def _reconnect_internal(self):
"""
Tries to connect to each host in the query plan until one succeeds
or every attempt fails. If successful, a new Connection will be
returned. Otherwise, :exc:`NoHostAvailable` will be raised
with an "errors" arg that is a dict mapping host addresses
to the exception that was raised when an attempt was made to open
a connection to that host.
"""
errors = {}
lbp = (
self._cluster.load_balancing_policy
if self._cluster._config_mode == _ConfigMode.LEGACY else
self._cluster._default_load_balancing_policy
)
for host in lbp.make_query_plan():
try:
return self._try_connect(host)
except ConnectionException as exc:
errors[str(host.endpoint)] = exc
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
self._cluster.signal_connection_failure(host, exc, is_host_addition=False)
except Exception as exc:
errors[str(host.endpoint)] = exc
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
if self._is_shutdown:
raise DriverException("[control connection] Reconnection in progress during shutdown")
raise NoHostAvailable("Unable to connect to any servers", errors)
def _try_connect(self, host):
"""
Creates a new Connection, registers for pushed events, and refreshes
node/token and schema metadata.
"""
log.debug("[control connection] Opening new connection to %s", host)
while True:
try:
connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True)
if self._is_shutdown:
connection.close()
raise DriverException("Reconnecting during shutdown")
break
except ProtocolVersionUnsupported as e:
self._cluster.protocol_downgrade(host.endpoint, e.startup_version)
+ except ProtocolException as e:
+ # protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver
+ # protocol version. If the protocol version was not explicitly specified,
+ # and that the server raises a beta protocol error, we should downgrade.
+ if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error:
+ self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version)
+ else:
+ raise
log.debug("[control connection] Established new connection %r, "
"registering watchers and refreshing schema and topology",
connection)
# use weak references in both directions
# _clear_watcher will be called when this ControlConnection is about to be finalized
# _watch_callback will get the actual callback from the Connection and relay it to
# this object (after a dereferencing a weakref)
self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection)))
try:
connection.register_watchers({
"TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'),
"STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'),
"SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change')
}, register_timeout=self._timeout)
- sel_peers = self._SELECT_PEERS if self._token_meta_enabled else self._SELECT_PEERS_NO_TOKENS
+ sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS
peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE)
local_query = QueryMessage(query=sel_local, consistency_level=ConsistencyLevel.ONE)
- shared_results = connection.wait_for_responses(
- peers_query, local_query, timeout=self._timeout)
+ (peers_success, peers_result), (local_success, local_result) = connection.wait_for_responses(
+ peers_query, local_query, timeout=self._timeout, fail_on_error=False)
+
+ if not local_success:
+ raise local_result
+
+ if not peers_success:
+ # error with the peers v2 query, fallback to peers v1
+ self._uses_peers_v2 = False
+ sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
+ peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE)
+ peers_result = connection.wait_for_response(
+ peers_query, timeout=self._timeout)
+ shared_results = (peers_result, local_result)
self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results)
self._refresh_schema(connection, preloaded_results=shared_results, schema_agreement_wait=-1)
except Exception:
connection.close()
raise
return connection
def reconnect(self):
if self._is_shutdown:
return
self._submit(self._reconnect)
def _reconnect(self):
log.debug("[control connection] Attempting to reconnect")
try:
self._set_new_connection(self._reconnect_internal())
except NoHostAvailable:
# make a retry schedule (which includes backoff)
schedule = self._cluster.reconnection_policy.new_schedule()
with self._reconnection_lock:
# cancel existing reconnection attempts
if self._reconnection_handler:
self._reconnection_handler.cancel()
# when a connection is successfully made, _set_new_connection
# will be called with the new connection and then our
# _reconnection_handler will be cleared out
self._reconnection_handler = _ControlReconnectionHandler(
self, self._cluster.scheduler, schedule,
self._get_and_set_reconnection_handler,
new_handler=None)
self._reconnection_handler.start()
except Exception:
log.debug("[control connection] error reconnecting", exc_info=True)
raise
def _get_and_set_reconnection_handler(self, new_handler):
"""
Called by the _ControlReconnectionHandler when a new connection
is successfully created. Clears out the _reconnection_handler on
this ControlConnection.
"""
with self._reconnection_lock:
old = self._reconnection_handler
self._reconnection_handler = new_handler
return old
def _submit(self, *args, **kwargs):
try:
if not self._cluster.is_shutdown:
return self._cluster.executor.submit(*args, **kwargs)
except ReferenceError:
pass
return None
def shutdown(self):
# stop trying to reconnect (if we are)
with self._reconnection_lock:
if self._reconnection_handler:
self._reconnection_handler.cancel()
with self._lock:
if self._is_shutdown:
return
else:
self._is_shutdown = True
log.debug("Shutting down control connection")
if self._connection:
self._connection.close()
self._connection = None
def refresh_schema(self, force=False, **kwargs):
try:
if self._connection:
return self._refresh_schema(self._connection, force=force, **kwargs)
except ReferenceError:
pass # our weak reference to the Cluster is no good
except Exception:
log.debug("[control connection] Error refreshing schema", exc_info=True)
self._signal_error()
return False
def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_wait=None, force=False, **kwargs):
if self._cluster.is_shutdown:
return False
agreed = self.wait_for_schema_agreement(connection,
preloaded_results=preloaded_results,
wait_time=schema_agreement_wait)
if not self._schema_meta_enabled and not force:
log.debug("[control connection] Skipping schema refresh because schema metadata is disabled")
return False
if not agreed:
log.debug("Skipping schema refresh due to lack of schema agreement")
return False
self._cluster.metadata.refresh(connection, self._timeout, **kwargs)
return True
def refresh_node_list_and_token_map(self, force_token_rebuild=False):
try:
if self._connection:
self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild)
return True
except ReferenceError:
pass # our weak reference to the Cluster is no good
except Exception:
log.debug("[control connection] Error refreshing node list and token map", exc_info=True)
self._signal_error()
return False
def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
force_token_rebuild=False):
-
if preloaded_results:
log.debug("[control connection] Refreshing node list and token map using preloaded results")
peers_result = preloaded_results[0]
local_result = preloaded_results[1]
else:
cl = ConsistencyLevel.ONE
+ sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
if not self._token_meta_enabled:
log.debug("[control connection] Refreshing node list without token map")
- sel_peers = self._SELECT_PEERS_NO_TOKENS
sel_local = self._SELECT_LOCAL_NO_TOKENS
else:
log.debug("[control connection] Refreshing node list and token map")
- sel_peers = self._SELECT_PEERS
sel_local = self._SELECT_LOCAL
peers_query = QueryMessage(query=sel_peers, consistency_level=cl)
local_query = QueryMessage(query=sel_local, consistency_level=cl)
peers_result, local_result = connection.wait_for_responses(
peers_query, local_query, timeout=self._timeout)
- peers_result = dict_factory(*peers_result.results)
+ peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows)
partitioner = None
token_map = {}
found_hosts = set()
- if local_result.results:
+ if local_result.parsed_rows:
found_hosts.add(connection.endpoint)
- local_rows = dict_factory(*(local_result.results))
+ local_rows = dict_factory(local_result.column_names, local_result.parsed_rows)
local_row = local_rows[0]
cluster_name = local_row["cluster_name"]
self._cluster.metadata.cluster_name = cluster_name
partitioner = local_row.get("partitioner")
tokens = local_row.get("tokens")
host = self._cluster.metadata.get_host(connection.endpoint)
if host:
datacenter = local_row.get("data_center")
rack = local_row.get("rack")
self._update_location_info(host, datacenter, rack)
host.host_id = local_row.get("host_id")
host.listen_address = local_row.get("listen_address")
- host.broadcast_address = local_row.get("broadcast_address")
+ host.listen_port = local_row.get("listen_port")
+ host.broadcast_address = _NodeInfo.get_broadcast_address(local_row)
+ host.broadcast_port = _NodeInfo.get_broadcast_port(local_row)
- host.broadcast_rpc_address = self._address_from_row(local_row)
+ host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(local_row)
+ host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(local_row)
if host.broadcast_rpc_address is None:
if self._token_meta_enabled:
# local rpc_address is not available, use the connection endpoint
host.broadcast_rpc_address = connection.endpoint.address
+ host.broadcast_rpc_port = connection.endpoint.port
else:
# local rpc_address has not been queried yet, try to fetch it
# separately, which might fail because C* < 2.1.6 doesn't have rpc_address
# in system.local. See CASSANDRA-9436.
local_rpc_address_query = QueryMessage(query=self._SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS,
consistency_level=ConsistencyLevel.ONE)
success, local_rpc_address_result = connection.wait_for_response(
local_rpc_address_query, timeout=self._timeout, fail_on_error=False)
if success:
- row = dict_factory(*local_rpc_address_result.results)
- host.broadcast_rpc_address = row[0]['rpc_address']
+ row = dict_factory(
+ local_rpc_address_result.column_names,
+ local_rpc_address_result.parsed_rows)
+ host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row[0])
+ host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row[0])
else:
host.broadcast_rpc_address = connection.endpoint.address
+ host.broadcast_rpc_port = connection.endpoint.port
host.release_version = local_row.get("release_version")
host.dse_version = local_row.get("dse_version")
host.dse_workload = local_row.get("workload")
+ host.dse_workloads = local_row.get("workloads")
if partitioner and tokens:
token_map[host] = tokens
# Check metadata.partitioner to see if we haven't built anything yet. If
# every node in the cluster was in the contact points, we won't discover
# any new nodes, so we need this additional check. (See PYTHON-90)
should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None
for row in peers_result:
+ if not self._is_valid_peer(row):
+ log.warning(
+ "Found an invalid row for peer (%s). Ignoring host." %
+ _NodeInfo.get_broadcast_rpc_address(row))
+ continue
+
endpoint = self._cluster.endpoint_factory.create(row)
- tokens = row.get("tokens", None)
- if 'tokens' in row and not tokens: # it was selected, but empty
- log.warning("Excluding host (%s) with no tokens in system.peers table of %s." % (endpoint, connection.endpoint))
- continue
if endpoint in found_hosts:
log.warning("Found multiple hosts with the same endpoint (%s). Excluding peer %s", endpoint, row.get("peer"))
continue
found_hosts.add(endpoint)
host = self._cluster.metadata.get_host(endpoint)
datacenter = row.get("data_center")
rack = row.get("rack")
if host is None:
log.debug("[control connection] Found new host to connect to: %s", endpoint)
host, _ = self._cluster.add_host(endpoint, datacenter, rack, signal=True, refresh_nodes=False)
should_rebuild_token_map = True
else:
should_rebuild_token_map |= self._update_location_info(host, datacenter, rack)
host.host_id = row.get("host_id")
- host.broadcast_address = row.get("peer")
- host.broadcast_rpc_address = self._address_from_row(row)
+ host.broadcast_address = _NodeInfo.get_broadcast_address(row)
+ host.broadcast_port = _NodeInfo.get_broadcast_port(row)
+ host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row)
+ host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row)
host.release_version = row.get("release_version")
host.dse_version = row.get("dse_version")
host.dse_workload = row.get("workload")
+ host.dse_workloads = row.get("workloads")
- if partitioner and tokens:
+ tokens = row.get("tokens", None)
+ if partitioner and tokens and self._token_meta_enabled:
token_map[host] = tokens
for old_host in self._cluster.metadata.all_hosts():
if old_host.endpoint.address != connection.endpoint and old_host.endpoint not in found_hosts:
should_rebuild_token_map = True
log.debug("[control connection] Removing host not found in peers metadata: %r", old_host)
self._cluster.remove_host(old_host)
log.debug("[control connection] Finished fetching ring info")
if partitioner and should_rebuild_token_map:
log.debug("[control connection] Rebuilding token map due to topology changes")
self._cluster.metadata.rebuild_token_map(partitioner, token_map)
+ @staticmethod
+ def _is_valid_peer(row):
+ return bool(_NodeInfo.get_broadcast_rpc_address(row) and row.get("host_id") and
+ row.get("data_center") and row.get("rack") and
+ ('tokens' not in row or row.get('tokens')))
+
def _update_location_info(self, host, datacenter, rack):
if host.datacenter == datacenter and host.rack == rack:
return False
# If the dc/rack information changes, we need to update the load balancing policy.
# For that, we remove and re-add the node against the policy. Not the most elegant, and assumes
# that the policy will update correctly, but in practice this should work.
self._cluster.profile_manager.on_down(host)
host.set_location_info(datacenter, rack)
self._cluster.profile_manager.on_up(host)
return True
def _delay_for_event_type(self, event_type, delay_window):
# this serves to order processing correlated events (received within the window)
# the window and randomization still have the desired effect of skew across client instances
next_time = self._event_schedule_times.get(event_type, 0)
now = self._time.time()
if now <= next_time:
this_time = next_time + 0.01
delay = this_time - now
else:
delay = random() * delay_window
this_time = now + delay
self._event_schedule_times[event_type] = this_time
return delay
def _refresh_nodes_if_not_up(self, host):
"""
Used to mitigate refreshes for nodes that are already known.
Some versions of the server send superfluous NEW_NODE messages in addition to UP events.
"""
if not host or not host.is_up:
self.refresh_node_list_and_token_map()
def _handle_topology_change(self, event):
change_type = event["change_type"]
- host = self._cluster.metadata.get_host(event["address"][0])
+ addr, port = event["address"]
+ host = self._cluster.metadata.get_host(addr, port)
if change_type == "NEW_NODE" or change_type == "MOVED_NODE":
if self._topology_event_refresh_window >= 0:
delay = self._delay_for_event_type('topology_change', self._topology_event_refresh_window)
self._cluster.scheduler.schedule_unique(delay, self._refresh_nodes_if_not_up, host)
elif change_type == "REMOVED_NODE":
self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host)
def _handle_status_change(self, event):
change_type = event["change_type"]
- host = self._cluster.metadata.get_host(event["address"][0])
+ addr, port = event["address"]
+ host = self._cluster.metadata.get_host(addr, port)
if change_type == "UP":
delay = self._delay_for_event_type('status_change', self._status_event_refresh_window)
if host is None:
# this is the first time we've seen the node
self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map)
else:
self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host)
elif change_type == "DOWN":
# Note that there is a slight risk we can receive the event late and thus
# mark the host down even though we already had reconnected successfully.
# But it is unlikely, and don't have too much consequence since we'll try reconnecting
# right away, so we favor the detection to make the Host.is_up more accurate.
if host is not None:
# this will be run by the scheduler
self._cluster.on_down(host, is_host_addition=False)
def _handle_schema_change(self, event):
if self._schema_event_refresh_window < 0:
return
delay = self._delay_for_event_type('schema_change', self._schema_event_refresh_window)
self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event)
def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None):
total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait
if total_timeout <= 0:
return True
# Each schema change typically generates two schema refreshes, one
# from the response type and one from the pushed notification. Holding
# a lock is just a simple way to cut down on the number of schema queries
# we'll make.
with self._schema_agreement_lock:
if self._is_shutdown:
return
if not connection:
connection = self._connection
if preloaded_results:
log.debug("[control connection] Attempting to use preloaded results for schema agreement")
peers_result = preloaded_results[0]
local_result = preloaded_results[1]
schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint)
if schema_mismatches is None:
return True
log.debug("[control connection] Waiting for schema agreement")
start = self._time.time()
elapsed = 0
cl = ConsistencyLevel.ONE
schema_mismatches = None
+ select_peers_query = self._get_peers_query(self.PeersQueryType.PEERS_SCHEMA, connection)
+
while elapsed < total_timeout:
- peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl)
+ peers_query = QueryMessage(query=select_peers_query, consistency_level=cl)
local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl)
try:
timeout = min(self._timeout, total_timeout - elapsed)
peers_result, local_result = connection.wait_for_responses(
peers_query, local_query, timeout=timeout)
except OperationTimedOut as timeout:
log.debug("[control connection] Timed out waiting for "
"response during schema agreement check: %s", timeout)
elapsed = self._time.time() - start
continue
except ConnectionShutdown:
if self._is_shutdown:
log.debug("[control connection] Aborting wait for schema match due to shutdown")
return None
else:
raise
schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint)
if schema_mismatches is None:
return True
log.debug("[control connection] Schemas mismatched, trying again")
self._time.sleep(0.2)
elapsed = self._time.time() - start
log.warning("Node %s is reporting a schema disagreement: %s",
connection.endpoint, schema_mismatches)
return False
def _get_schema_mismatches(self, peers_result, local_result, local_address):
- peers_result = dict_factory(*peers_result.results)
+ peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows)
versions = defaultdict(set)
- if local_result.results:
- local_row = dict_factory(*local_result.results)[0]
+ if local_result.parsed_rows:
+ local_row = dict_factory(local_result.column_names, local_result.parsed_rows)[0]
if local_row.get("schema_version"):
versions[local_row.get("schema_version")].add(local_address)
for row in peers_result:
schema_ver = row.get('schema_version')
if not schema_ver:
continue
endpoint = self._cluster.endpoint_factory.create(row)
peer = self._cluster.metadata.get_host(endpoint)
if peer and peer.is_up is not False:
versions[schema_ver].add(endpoint)
if len(versions) == 1:
log.debug("[control connection] Schemas match")
return None
return dict((version, list(nodes)) for version, nodes in six.iteritems(versions))
- def _address_from_row(self, row):
+ def _get_peers_query(self, peers_query_type, connection=None):
"""
- Parse the broadcast rpc address from a row and return it untranslated.
+ Determine the peers query to use.
+
+ :param peers_query_type: Should be one of PeersQueryType enum.
+
+ If _uses_peers_v2 is True, return the proper peers_v2 query (no templating).
+ Else, apply the logic below to choose the peers v1 address column name:
+
+ Given a connection:
+
+ - find the server product version running on the connection's host,
+ - use that to choose the column name for the transport address (see APOLLO-1130), and
+ - use that column name in the provided peers query template.
"""
- addr = None
- if "rpc_address" in row:
- addr = row.get("rpc_address") # peers and local
- if "native_transport_address" in row:
- addr = row.get("native_transport_address")
- if not addr or addr in ["0.0.0.0", "::"]:
- addr = row.get("peer")
+ if peers_query_type not in (self.PeersQueryType.PEERS, self.PeersQueryType.PEERS_SCHEMA):
+ raise ValueError("Invalid peers query type: %s" % peers_query_type)
+
+ if self._uses_peers_v2:
+ if peers_query_type == self.PeersQueryType.PEERS:
+ query = self._SELECT_PEERS_V2 if self._token_meta_enabled else self._SELECT_PEERS_NO_TOKENS_V2
+ else:
+ query = self._SELECT_SCHEMA_PEERS_V2
+ else:
+ if peers_query_type == self.PeersQueryType.PEERS and self._token_meta_enabled:
+ query = self._SELECT_PEERS
+ else:
+ query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE
+ if peers_query_type == self.PeersQueryType.PEERS_SCHEMA
+ else self._SELECT_PEERS_NO_TOKENS_TEMPLATE)
+
+ host_release_version = self._cluster.metadata.get_host(connection.endpoint).release_version
+ host_dse_version = self._cluster.metadata.get_host(connection.endpoint).dse_version
+ uses_native_address_query = (
+ host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION)
+
+ if uses_native_address_query:
+ query = query_template.format(nt_col_name="native_transport_address")
+ elif host_release_version:
+ query = query_template.format(nt_col_name="rpc_address")
+ else:
+ query = self._SELECT_PEERS
+
+ return query
- return addr
-
def _signal_error(self):
with self._lock:
if self._is_shutdown:
return
# try just signaling the cluster, as this will trigger a reconnect
# as part of marking the host down
if self._connection and self._connection.is_defunct:
host = self._cluster.metadata.get_host(self._connection.endpoint)
# host may be None if it's already been removed, but that indicates
# that errors have already been reported, so we're fine
if host:
self._cluster.signal_connection_failure(
host, self._connection.last_error, is_host_addition=False)
return
# if the connection is not defunct or the host already left, reconnect
# manually
self.reconnect()
def on_up(self, host):
pass
def on_down(self, host):
conn = self._connection
if conn and conn.endpoint == host.endpoint and \
self._reconnection_handler is None:
log.debug("[control connection] Control connection host (%s) is "
"considered down, starting reconnection", host)
# this will result in a task being submitted to the executor to reconnect
self.reconnect()
def on_add(self, host, refresh_nodes=True):
if refresh_nodes:
self.refresh_node_list_and_token_map(force_token_rebuild=True)
def on_remove(self, host):
c = self._connection
if c and c.endpoint == host.endpoint:
log.debug("[control connection] Control connection host (%s) is being removed. Reconnecting", host)
# refresh will be done on reconnect
self.reconnect()
else:
self.refresh_node_list_and_token_map(force_token_rebuild=True)
def get_connections(self):
c = getattr(self, '_connection', None)
return [c] if c else []
def return_connection(self, connection):
if connection is self._connection and (connection.is_defunct or connection.is_closed):
self.reconnect()
def _stop_scheduler(scheduler, thread):
try:
if not scheduler.is_shutdown:
scheduler.shutdown()
except ReferenceError:
pass
thread.join()
class _Scheduler(Thread):
_queue = None
_scheduled_tasks = None
_executor = None
is_shutdown = False
def __init__(self, executor):
self._queue = Queue.PriorityQueue()
self._scheduled_tasks = set()
self._count = count()
self._executor = executor
Thread.__init__(self, name="Task Scheduler")
self.daemon = True
self.start()
def shutdown(self):
try:
log.debug("Shutting down Cluster Scheduler")
except AttributeError:
# this can happen on interpreter shutdown
pass
self.is_shutdown = True
self._queue.put_nowait((0, 0, None))
self.join()
def schedule(self, delay, fn, *args, **kwargs):
self._insert_task(delay, (fn, args, tuple(kwargs.items())))
def schedule_unique(self, delay, fn, *args, **kwargs):
task = (fn, args, tuple(kwargs.items()))
if task not in self._scheduled_tasks:
self._insert_task(delay, task)
else:
log.debug("Ignoring schedule_unique for already-scheduled task: %r", task)
def _insert_task(self, delay, task):
if not self.is_shutdown:
run_at = time.time() + delay
self._scheduled_tasks.add(task)
self._queue.put_nowait((run_at, next(self._count), task))
else:
log.debug("Ignoring scheduled task after shutdown: %r", task)
def run(self):
while True:
if self.is_shutdown:
return
try:
while True:
run_at, i, task = self._queue.get(block=True, timeout=None)
if self.is_shutdown:
if task:
log.debug("Not executing scheduled task due to Scheduler shutdown")
return
if run_at <= time.time():
self._scheduled_tasks.discard(task)
fn, args, kwargs = task
kwargs = dict(kwargs)
future = self._executor.submit(fn, *args, **kwargs)
future.add_done_callback(self._log_if_failed)
else:
self._queue.put_nowait((run_at, i, task))
break
except Queue.Empty:
pass
time.sleep(0.1)
def _log_if_failed(self, future):
exc = future.exception()
if exc:
log.warning(
"An internally scheduled tasked failed with an unhandled exception:",
exc_info=exc)
def refresh_schema_and_set_result(control_conn, response_future, connection, **kwargs):
try:
log.debug("Refreshing schema in response to schema change. "
"%s", kwargs)
response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs)
except Exception:
log.exception("Exception refreshing schema in response to schema change:")
response_future.session.submit(control_conn.refresh_schema, **kwargs)
finally:
response_future._set_final_result(None)
class ResponseFuture(object):
"""
An asynchronous response delivery mechanism that is returned from calls
to :meth:`.Session.execute_async()`.
There are two ways for results to be delivered:
- Synchronously, by calling :meth:`.result()`
- Asynchronously, by attaching callback and errback functions via
:meth:`.add_callback()`, :meth:`.add_errback()`, and
:meth:`.add_callbacks()`.
"""
query = None
"""
The :class:`~.Statement` instance that is being executed through this
:class:`.ResponseFuture`.
"""
is_schema_agreed = True
"""
For DDL requests, this may be set ``False`` if the schema agreement poll after the response fails.
Always ``True`` for non-DDL requests.
"""
request_encoded_size = None
"""
Size of the request message sent
"""
coordinator_host = None
"""
- The host from which we recieved a response
+ The host from which we received a response
"""
attempted_hosts = None
"""
A list of hosts tried, including all speculative executions, retries, and pages
"""
session = None
row_factory = None
message = None
default_timeout = None
_retry_policy = None
_profile_manager = None
_req_id = None
_final_result = _NOT_SET
_col_names = None
_col_types = None
_final_exception = None
_query_traces = None
_callbacks = None
_errbacks = None
_current_host = None
_connection = None
_query_retries = 0
_start_time = None
_metrics = None
_paging_state = None
_custom_payload = None
_warnings = None
_timer = None
_protocol_handler = ProtocolHandler
_spec_execution_plan = NoSpeculativeExecutionPlan()
+ _continuous_paging_options = None
+ _continuous_paging_session = None
_host = None
_warned_timeout = False
def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None,
retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None,
- speculative_execution_plan=None, host=None):
+ speculative_execution_plan=None, continuous_paging_state=None, host=None):
self.session = session
# TODO: normalize handling of retry policy and row factory
self.row_factory = row_factory or session.row_factory
self._load_balancer = load_balancer or session.cluster._default_load_balancing_policy
self.message = message
self.query = query
self.timeout = timeout
self._retry_policy = retry_policy
self._metrics = metrics
self.prepared_statement = prepared_statement
self._callback_lock = Lock()
self._start_time = start_time or time.time()
self._host = host
self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan
self._make_query_plan()
self._event = Event()
self._errors = {}
self._callbacks = []
self._errbacks = []
self.attempted_hosts = []
self._start_timer()
+ self._continuous_paging_state = continuous_paging_state
@property
def _time_remaining(self):
if self.timeout is None:
return None
return (self._start_time + self.timeout) - time.time()
def _start_timer(self):
if self._timer is None:
spec_delay = self._spec_execution_plan.next_execution(self._current_host)
if spec_delay >= 0:
if self._time_remaining is None or self._time_remaining > spec_delay:
self._timer = self.session.cluster.connection_class.create_timer(spec_delay, self._on_speculative_execute)
return
if self._time_remaining is not None:
self._timer = self.session.cluster.connection_class.create_timer(self._time_remaining, self._on_timeout)
def _cancel_timer(self):
if self._timer:
self._timer.cancel()
def _on_timeout(self, _attempts=0):
"""
Called when the request associated with this ResponseFuture times out.
This function may reschedule itself. The ``_attempts`` parameter tracks
the number of times this has happened. This parameter should only be
set in those cases, where ``_on_timeout`` reschedules itself.
"""
# PYTHON-853: for short timeouts, we sometimes race with our __init__
if self._connection is None and _attempts < 3:
self._timer = self.session.cluster.connection_class.create_timer(
0.01,
partial(self._on_timeout, _attempts=_attempts + 1)
)
return
if self._connection is not None:
try:
self._connection._requests.pop(self._req_id)
# PYTHON-1044
# This request might have been removed from the connection after the latter was defunct by heartbeat.
# We should still raise OperationTimedOut to reject the future so that the main event thread will not
# wait for it endlessly
except KeyError:
key = "Connection defunct by heartbeat"
errors = {key: "Client request timeout. See Session.execute[_async](timeout)"}
self._set_final_exception(OperationTimedOut(errors, self._current_host))
return
pool = self.session._pools.get(self._current_host)
if pool and not pool.is_shutdown:
with self._connection.lock:
self._connection.request_ids.append(self._req_id)
pool.return_connection(self._connection)
errors = self._errors
if not errors:
if self.is_schema_agreed:
key = str(self._current_host.endpoint) if self._current_host else 'no host queried before timeout'
errors = {key: "Client request timeout. See Session.execute[_async](timeout)"}
else:
connection = self.session.cluster.control_connection._connection
host = str(connection.endpoint) if connection else 'unknown'
errors = {host: "Request timed out while waiting for schema agreement. See Session.execute[_async](timeout) and Cluster.max_schema_agreement_wait."}
self._set_final_exception(OperationTimedOut(errors, self._current_host))
def _on_speculative_execute(self):
self._timer = None
if not self._event.is_set():
# PYTHON-836, the speculative queries must be after
# the query is sent from the main thread, otherwise the
# query from the main thread may raise NoHostAvailable
# if the _query_plan has been exhausted by the specualtive queries.
# This also prevents a race condition accessing the iterator.
# We reschedule this call until the main thread has succeeded
# making a query
if not self.attempted_hosts:
self._timer = self.session.cluster.connection_class.create_timer(0.01, self._on_speculative_execute)
return
if self._time_remaining is not None:
if self._time_remaining <= 0:
self._on_timeout()
return
self.send_request(error_no_hosts=False)
self._start_timer()
def _make_query_plan(self):
# set the query_plan according to the load balancing policy,
# or to the explicit host target if set
if self._host:
# returning a single value effectively disables retries
self.query_plan = [self._host]
else:
# convert the list/generator/etc to an iterator so that subsequent
# calls to send_request (which retries may do) will resume where
# they last left off
self.query_plan = iter(self._load_balancer.make_query_plan(self.session.keyspace, self.query))
def send_request(self, error_no_hosts=True):
""" Internal """
# query_plan is an iterator, so this will resume where we last left
# off if send_request() is called multiple times
for host in self.query_plan:
req_id = self._query(host)
if req_id is not None:
self._req_id = req_id
return True
if self.timeout is not None and time.time() - self._start_time > self.timeout:
self._on_timeout()
return True
-
if error_no_hosts:
self._set_final_exception(NoHostAvailable(
"Unable to complete the operation against any hosts", self._errors))
return False
def _query(self, host, message=None, cb=None):
if message is None:
message = self.message
pool = self.session._pools.get(host)
if not pool:
self._errors[host] = ConnectionException("Host has been marked down or removed")
return None
elif pool.is_shutdown:
self._errors[host] = ConnectionException("Pool is shutdown")
return None
self._current_host = host
connection = None
try:
# TODO get connectTimeout from cluster settings
connection, request_id = pool.borrow_connection(timeout=2.0)
self._connection = connection
result_meta = self.prepared_statement.result_metadata if self.prepared_statement else []
if cb is None:
cb = partial(self._set_result, host, connection, pool)
self.request_encoded_size = connection.send_msg(message, request_id, cb=cb,
encoder=self._protocol_handler.encode_message,
decoder=self._protocol_handler.decode_message,
result_metadata=result_meta)
self.attempted_hosts.append(host)
return request_id
except NoConnectionsAvailable as exc:
log.debug("All connections for host %s are at capacity, moving to the next host", host)
self._errors[host] = exc
- return None
+ except ConnectionBusy as exc:
+ log.debug("Connection for host %s is busy, moving to the next host", host)
+ self._errors[host] = exc
except Exception as exc:
log.debug("Error querying host %s", host, exc_info=True)
self._errors[host] = exc
if self._metrics is not None:
self._metrics.on_connection_error()
if connection:
pool.return_connection(connection)
- return None
+
+ return None
@property
def has_more_pages(self):
"""
Returns :const:`True` if there are more pages left in the
query results, :const:`False` otherwise. This should only
be checked after the first page has been returned.
.. versionadded:: 2.0.0
"""
return self._paging_state is not None
@property
def warnings(self):
"""
Warnings returned from the server, if any. This will only be
set for protocol_version 4+.
Warnings may be returned for such things as oversized batches,
or too many tombstones in slice queries.
Ensure the future is complete before trying to access this property
(call :meth:`.result()`, or after callback is invoked).
Otherwise it may throw if the response has not been received.
"""
# TODO: When timers are introduced, just make this wait
if not self._event.is_set():
raise DriverException("warnings cannot be retrieved before ResponseFuture is finalized")
return self._warnings
@property
def custom_payload(self):
"""
The custom payload returned from the server, if any. This will only be
set by Cassandra servers implementing a custom QueryHandler, and only
for protocol_version 4+.
Ensure the future is complete before trying to access this property
(call :meth:`.result()`, or after callback is invoked).
Otherwise it may throw if the response has not been received.
:return: :ref:`custom_payload`.
"""
# TODO: When timers are introduced, just make this wait
if not self._event.is_set():
raise DriverException("custom_payload cannot be retrieved before ResponseFuture is finalized")
return self._custom_payload
def start_fetching_next_page(self):
"""
If there are more pages left in the query result, this asynchronously
starts fetching the next page. If there are no pages left, :exc:`.QueryExhausted`
is raised. Also see :attr:`.has_more_pages`.
This should only be called after the first page has been returned.
.. versionadded:: 2.0.0
"""
if not self._paging_state:
raise QueryExhausted()
self._make_query_plan()
self.message.paging_state = self._paging_state
self._event.clear()
self._final_result = _NOT_SET
self._final_exception = None
self._start_timer()
self.send_request()
def _reprepare(self, prepare_message, host, connection, pool):
cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool)
request_id = self._query(host, prepare_message, cb=cb)
if request_id is None:
# try to submit the original prepared statement on some other host
self.send_request()
def _set_result(self, host, connection, pool, response):
try:
self.coordinator_host = host
if pool:
pool.return_connection(connection)
trace_id = getattr(response, 'trace_id', None)
if trace_id:
if not self._query_traces:
self._query_traces = []
self._query_traces.append(QueryTrace(trace_id, self.session))
self._warnings = getattr(response, 'warnings', None)
self._custom_payload = getattr(response, 'custom_payload', None)
if isinstance(response, ResultMessage):
if response.kind == RESULT_KIND_SET_KEYSPACE:
session = getattr(self, 'session', None)
# since we're running on the event loop thread, we need to
# use a non-blocking method for setting the keyspace on
# all connections in this session, otherwise the event
# loop thread will deadlock waiting for keyspaces to be
# set. This uses a callback chain which ends with
# self._set_keyspace_completed() being called in the
# event loop thread.
if session:
session._set_keyspace_for_all_pools(
- response.results, self._set_keyspace_completed)
+ response.new_keyspace, self._set_keyspace_completed)
elif response.kind == RESULT_KIND_SCHEMA_CHANGE:
# refresh the schema before responding, but do it in another
# thread instead of the event loop thread
self.is_schema_agreed = False
self.session.submit(
refresh_schema_and_set_result,
self.session.cluster.control_connection,
- self, connection, **response.results)
+ self, connection, **response.schema_change_event)
+ elif response.kind == RESULT_KIND_ROWS:
+ self._paging_state = response.paging_state
+ self._col_names = response.column_names
+ self._col_types = response.column_types
+ if getattr(self.message, 'continuous_paging_options', None):
+ self._handle_continuous_paging_first_response(connection, response)
+ else:
+ self._set_final_result(self.row_factory(response.column_names, response.parsed_rows))
+ elif response.kind == RESULT_KIND_VOID:
+ self._set_final_result(None)
else:
- results = getattr(response, 'results', None)
- if results is not None and response.kind == RESULT_KIND_ROWS:
- self._paging_state = response.paging_state
- self._col_types = response.col_types
- self._col_names = results[0]
- results = self.row_factory(*results)
- self._set_final_result(results)
+ self._set_final_result(response)
elif isinstance(response, ErrorMessage):
retry_policy = self._retry_policy
if isinstance(response, ReadTimeoutErrorMessage):
if self._metrics is not None:
self._metrics.on_read_timeout()
retry = retry_policy.on_read_timeout(
self.query, retry_num=self._query_retries, **response.info)
elif isinstance(response, WriteTimeoutErrorMessage):
if self._metrics is not None:
self._metrics.on_write_timeout()
retry = retry_policy.on_write_timeout(
self.query, retry_num=self._query_retries, **response.info)
elif isinstance(response, UnavailableErrorMessage):
if self._metrics is not None:
self._metrics.on_unavailable()
retry = retry_policy.on_unavailable(
self.query, retry_num=self._query_retries, **response.info)
elif isinstance(response, (OverloadedErrorMessage,
IsBootstrappingErrorMessage,
TruncateError, ServerError)):
log.warning("Host %s error: %s.", host, response.summary)
if self._metrics is not None:
self._metrics.on_other_error()
+ cl = getattr(self.message, 'consistency_level', None)
retry = retry_policy.on_request_error(
- self.query, self.message.consistency_level, error=response,
+ self.query, cl, error=response,
retry_num=self._query_retries)
elif isinstance(response, PreparedQueryNotFound):
if self.prepared_statement:
query_id = self.prepared_statement.query_id
assert query_id == response.info, \
"Got different query ID in server response (%s) than we " \
"had before (%s)" % (response.info, query_id)
else:
query_id = response.info
try:
prepared_statement = self.session.cluster._prepared_statements[query_id]
except KeyError:
if not self.prepared_statement:
log.error("Tried to execute unknown prepared statement: id=%s",
query_id.encode('hex'))
self._set_final_exception(response)
return
else:
prepared_statement = self.prepared_statement
self.session.cluster._prepared_statements[query_id] = prepared_statement
current_keyspace = self._connection.keyspace
prepared_keyspace = prepared_statement.keyspace
if not ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) \
and prepared_keyspace and current_keyspace != prepared_keyspace:
self._set_final_exception(
ValueError("The Session's current keyspace (%s) does "
"not match the keyspace the statement was "
"prepared with (%s)" %
(current_keyspace, prepared_keyspace)))
return
log.debug("Re-preparing unrecognized prepared statement against host %s: %s",
host, prepared_statement.query_string)
prepared_keyspace = prepared_statement.keyspace \
if ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) else None
prepare_message = PrepareMessage(query=prepared_statement.query_string,
keyspace=prepared_keyspace)
# since this might block, run on the executor to avoid hanging
# the event loop thread
self.session.submit(self._reprepare, prepare_message, host, connection, pool)
return
else:
if hasattr(response, 'to_exception'):
self._set_final_exception(response.to_exception())
else:
self._set_final_exception(response)
return
self._handle_retry_decision(retry, response, host)
elif isinstance(response, ConnectionException):
if self._metrics is not None:
self._metrics.on_connection_error()
if not isinstance(response, ConnectionShutdown):
self._connection.defunct(response)
+ cl = getattr(self.message, 'consistency_level', None)
retry = self._retry_policy.on_request_error(
- self.query, self.message.consistency_level, error=response,
- retry_num=self._query_retries)
+ self.query, cl, error=response, retry_num=self._query_retries)
self._handle_retry_decision(retry, response, host)
elif isinstance(response, Exception):
if hasattr(response, 'to_exception'):
self._set_final_exception(response.to_exception())
else:
self._set_final_exception(response)
else:
# we got some other kind of response message
msg = "Got unexpected message: %r" % (response,)
exc = ConnectionException(msg, host)
self._cancel_timer()
self._connection.defunct(exc)
self._set_final_exception(exc)
except Exception as exc:
# almost certainly caused by a bug, but we need to set something here
log.exception("Unexpected exception while handling result in ResponseFuture:")
self._set_final_exception(exc)
+ def _handle_continuous_paging_first_response(self, connection, response):
+ self._continuous_paging_session = connection.new_continuous_paging_session(response.stream_id,
+ self._protocol_handler.decode_message,
+ self.row_factory,
+ self._continuous_paging_state)
+ self._continuous_paging_session.on_message(response)
+ self._set_final_result(self._continuous_paging_session.results())
+
def _set_keyspace_completed(self, errors):
if not errors:
self._set_final_result(None)
else:
self._set_final_exception(ConnectionException(
"Failed to set keyspace on all hosts: %s" % (errors,)))
def _execute_after_prepare(self, host, connection, pool, response):
"""
Handle the response to our attempt to prepare a statement.
If it succeeded, run the original query again against the same host.
"""
if pool:
pool.return_connection(connection)
if self._final_exception:
return
if isinstance(response, ResultMessage):
if response.kind == RESULT_KIND_PREPARED:
if self.prepared_statement:
- # result metadata is the only thing that could have
- # changed from an alter
- (_, _, _,
- self.prepared_statement.result_metadata,
- new_metadata_id) = response.results
+ if self.prepared_statement.query_id != response.query_id:
+ self._set_final_exception(DriverException(
+ "ID mismatch while trying to reprepare (expected {expected}, got {got}). "
+ "This prepared statement won't work anymore. "
+ "This usually happens when you run a 'USE...' "
+ "query after the statement was prepared.".format(
+ expected=hexlify(self.prepared_statement.query_id), got=hexlify(response.query_id)
+ )
+ ))
+ self.prepared_statement.result_metadata = response.column_metadata
+ new_metadata_id = response.result_metadata_id
if new_metadata_id is not None:
self.prepared_statement.result_metadata_id = new_metadata_id
# use self._query to re-use the same host and
# at the same time properly borrow the connection
request_id = self._query(host)
if request_id is None:
# this host errored out, move on to the next
self.send_request()
else:
self._set_final_exception(ConnectionException(
"Got unexpected response when preparing statement "
"on host %s: %s" % (host, response)))
elif isinstance(response, ErrorMessage):
if hasattr(response, 'to_exception'):
self._set_final_exception(response.to_exception())
else:
self._set_final_exception(response)
elif isinstance(response, ConnectionException):
log.debug("Connection error when preparing statement on host %s: %s",
host, response)
# try again on a different host, preparing again if necessary
self._errors[host] = response
self.send_request()
else:
self._set_final_exception(ConnectionException(
"Got unexpected response type when preparing "
"statement on host %s: %s" % (host, response)))
def _set_final_result(self, response):
self._cancel_timer()
if self._metrics is not None:
self._metrics.request_timer.addValue(time.time() - self._start_time)
with self._callback_lock:
self._final_result = response
# save off current callbacks inside lock for execution outside it
# -- prevents case where _final_result is set, then a callback is
# added and executed on the spot, then executed again as a
# registered callback
to_call = tuple(
partial(fn, response, *args, **kwargs)
for (fn, args, kwargs) in self._callbacks
)
self._event.set()
# apply each callback
for callback_partial in to_call:
callback_partial()
def _set_final_exception(self, response):
self._cancel_timer()
if self._metrics is not None:
self._metrics.request_timer.addValue(time.time() - self._start_time)
with self._callback_lock:
self._final_exception = response
# save off current errbacks inside lock for execution outside it --
# prevents case where _final_exception is set, then an errback is
# added and executed on the spot, then executed again as a
# registered errback
to_call = tuple(
partial(fn, response, *args, **kwargs)
for (fn, args, kwargs) in self._errbacks
)
self._event.set()
# apply each callback
for callback_partial in to_call:
callback_partial()
def _handle_retry_decision(self, retry_decision, response, host):
def exception_from_response(response):
if hasattr(response, 'to_exception'):
return response.to_exception()
else:
return response
retry_type, consistency = retry_decision
if retry_type in (RetryPolicy.RETRY, RetryPolicy.RETRY_NEXT_HOST):
self._query_retries += 1
reuse = retry_type == RetryPolicy.RETRY
self._retry(reuse, consistency, host)
elif retry_type is RetryPolicy.RETHROW:
self._set_final_exception(exception_from_response(response))
else: # IGNORE
if self._metrics is not None:
self._metrics.on_ignore()
self._set_final_result(None)
self._errors[host] = exception_from_response(response)
def _retry(self, reuse_connection, consistency_level, host):
if self._final_exception:
# the connection probably broke while we were waiting
# to retry the operation
return
if self._metrics is not None:
self._metrics.on_retry()
if consistency_level is not None:
self.message.consistency_level = consistency_level
# don't retry on the event loop thread
self.session.submit(self._retry_task, reuse_connection, host)
def _retry_task(self, reuse_connection, host):
if self._final_exception:
# the connection probably broke while we were waiting
# to retry the operation
return
if reuse_connection and self._query(host) is not None:
return
# otherwise, move onto another host
self.send_request()
def result(self):
"""
Return the final result or raise an Exception if errors were
encountered. If the final result or error has not been set
yet, this method will block until it is set, or the timeout
set for the request expires.
Timeout is specified in the Session request execution functions.
If the timeout is exceeded, an :exc:`cassandra.OperationTimedOut` will be raised.
This is a client-side timeout. For more information
about server-side coordinator timeouts, see :class:`.policies.RetryPolicy`.
Example usage::
>>> future = session.execute_async("SELECT * FROM mycf")
>>> # do other stuff...
>>> try:
... rows = future.result()
... for row in rows:
... ... # process results
... except Exception:
... log.exception("Operation failed:")
"""
self._event.wait()
if self._final_result is not _NOT_SET:
return ResultSet(self, self._final_result)
else:
raise self._final_exception
def get_query_trace_ids(self):
"""
Returns the trace session ids for this future, if tracing was enabled (does not fetch trace data).
"""
return [trace.trace_id for trace in self._query_traces]
def get_query_trace(self, max_wait=None, query_cl=ConsistencyLevel.LOCAL_ONE):
"""
Fetches and returns the query trace of the last response, or `None` if tracing was
not enabled.
Note that this may raise an exception if there are problems retrieving the trace
details from Cassandra. If the trace is not available after `max_wait`,
:exc:`cassandra.query.TraceUnavailable` will be raised.
If the ResponseFuture is not done (async execution) and you try to retrieve the trace,
:exc:`cassandra.query.TraceUnavailable` will be raised.
`query_cl` is the consistency level used to poll the trace tables.
"""
if self._final_result is _NOT_SET and self._final_exception is None:
raise TraceUnavailable(
"Trace information was not available. The ResponseFuture is not done.")
if self._query_traces:
return self._get_query_trace(len(self._query_traces) - 1, max_wait, query_cl)
def get_all_query_traces(self, max_wait_per=None, query_cl=ConsistencyLevel.LOCAL_ONE):
"""
Fetches and returns the query traces for all query pages, if tracing was enabled.
See note in :meth:`~.get_query_trace` regarding possible exceptions.
"""
if self._query_traces:
return [self._get_query_trace(i, max_wait_per, query_cl) for i in range(len(self._query_traces))]
return []
def _get_query_trace(self, i, max_wait, query_cl):
trace = self._query_traces[i]
if not trace.events:
trace.populate(max_wait=max_wait, query_cl=query_cl)
return trace
def add_callback(self, fn, *args, **kwargs):
"""
Attaches a callback function to be called when the final results arrive.
By default, `fn` will be called with the results as the first and only
argument. If `*args` or `**kwargs` are supplied, they will be passed
through as additional positional or keyword arguments to `fn`.
If an error is hit while executing the operation, a callback attached
here will not be called. Use :meth:`.add_errback()` or :meth:`add_callbacks()`
if you wish to handle that case.
If the final result has already been seen when this method is called,
the callback will be called immediately (before this method returns).
Note: in the case that the result is not available when the callback is added,
the callback is executed by IO event thread. This means that the callback
should not block or attempt further synchronous requests, because no further
IO will be processed until the callback returns.
**Important**: if the callback you attach results in an exception being
raised, **the exception will be ignored**, so please ensure your
callback handles all error cases that you care about.
Usage example::
>>> session = cluster.connect("mykeyspace")
>>> def handle_results(rows, start_time, should_log=False):
... if should_log:
... log.info("Total time: %f", time.time() - start_time)
... ...
>>> future = session.execute_async("SELECT * FROM users")
>>> future.add_callback(handle_results, time.time(), should_log=True)
"""
run_now = False
with self._callback_lock:
# Always add fn to self._callbacks, even when we're about to
# execute it, to prevent races with functions like
# start_fetching_next_page that reset _final_result
self._callbacks.append((fn, args, kwargs))
if self._final_result is not _NOT_SET:
run_now = True
if run_now:
fn(self._final_result, *args, **kwargs)
return self
def add_errback(self, fn, *args, **kwargs):
"""
Like :meth:`.add_callback()`, but handles error cases.
An Exception instance will be passed as the first positional argument
to `fn`.
"""
run_now = False
with self._callback_lock:
# Always add fn to self._errbacks, even when we're about to execute
# it, to prevent races with functions like start_fetching_next_page
# that reset _final_exception
self._errbacks.append((fn, args, kwargs))
if self._final_exception:
run_now = True
if run_now:
fn(self._final_exception, *args, **kwargs)
return self
def add_callbacks(self, callback, errback,
callback_args=(), callback_kwargs=None,
errback_args=(), errback_kwargs=None):
"""
A convenient combination of :meth:`.add_callback()` and
:meth:`.add_errback()`.
Example usage::
>>> session = cluster.connect()
>>> query = "SELECT * FROM mycf"
>>> future = session.execute_async(query)
>>> def log_results(results, level='debug'):
... for row in results:
... log.log(level, "Result: %s", row)
>>> def log_error(exc, query):
... log.error("Query '%s' failed: %s", query, exc)
>>> future.add_callbacks(
... callback=log_results, callback_kwargs={'level': 'info'},
... errback=log_error, errback_args=(query,))
"""
self.add_callback(callback, *callback_args, **(callback_kwargs or {}))
self.add_errback(errback, *errback_args, **(errback_kwargs or {}))
def clear_callbacks(self):
with self._callback_lock:
self._callbacks = []
self._errbacks = []
def __str__(self):
result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result
return "" \
% (self.query, self._req_id, result, self._final_exception, self.coordinator_host)
__repr__ = __str__
class QueryExhausted(Exception):
"""
Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and
there are no more pages. You can check :attr:`.ResponseFuture.has_more_pages`
before calling to avoid this.
.. versionadded:: 2.0.0
"""
pass
class ResultSet(object):
"""
An iterator over the rows from a query result. Also supplies basic equality
and indexing methods for backward-compatability. These methods materialize
the entire result set (loading all pages), and should only be used if the
total result size is understood. Warnings are emitted when paged results
are materialized in this fashion.
You can treat this as a normal iterator over rows::
>>> from cassandra.query import SimpleStatement
>>> statement = SimpleStatement("SELECT * FROM users", fetch_size=10)
>>> for user_row in session.execute(statement):
... process_user(user_row)
Whenever there are no more rows in the current page, the next page will
be fetched transparently. However, note that it *is* possible for
an :class:`Exception` to be raised while fetching the next page, just
like you might see on a normal call to ``session.execute()``.
"""
def __init__(self, response_future, initial_response):
self.response_future = response_future
self.column_names = response_future._col_names
self.column_types = response_future._col_types
self._set_current_rows(initial_response)
self._page_iter = None
self._list_mode = False
@property
def has_more_pages(self):
"""
True if the last response indicated more pages; False otherwise
"""
return self.response_future.has_more_pages
@property
def current_rows(self):
"""
The list of current page rows. May be empty if the result was empty,
or this is the last page.
"""
return self._current_rows or []
+ def all(self):
+ """
+ Returns all the remaining rows as a list. This is basically
+ a convenient shortcut to `list(result_set)`.
+
+ This function is not recommended for queries that return a large number of elements.
+ """
+ return list(self)
+
def one(self):
"""
Return a single row of the results or None if empty. This is basically
a shortcut to `result_set.current_rows[0]` and should only be used when
you know a query returns a single row. Consider using an iterator if the
ResultSet contains more than one row.
"""
row = None
if self._current_rows:
try:
row = self._current_rows[0]
except TypeError: # generator object is not subscriptable, PYTHON-1026
row = next(iter(self._current_rows))
return row
def __iter__(self):
if self._list_mode:
return iter(self._current_rows)
self._page_iter = iter(self._current_rows)
return self
def next(self):
try:
return next(self._page_iter)
except StopIteration:
if not self.response_future.has_more_pages:
if not self._list_mode:
self._current_rows = []
raise
- self.fetch_next_page()
- self._page_iter = iter(self._current_rows)
+ if not self.response_future._continuous_paging_session:
+ self.fetch_next_page()
+ self._page_iter = iter(self._current_rows)
return next(self._page_iter)
__next__ = next
def fetch_next_page(self):
"""
Manually, synchronously fetch the next page. Supplied for manually retrieving pages
and inspecting :meth:`~.current_page`. It is not necessary to call this when iterating
through results; paging happens implicitly in iteration.
"""
if self.response_future.has_more_pages:
self.response_future.start_fetching_next_page()
result = self.response_future.result()
self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form
else:
self._current_rows = []
def _set_current_rows(self, result):
if isinstance(result, Mapping):
self._current_rows = [result] if result else []
return
try:
iter(result) # can't check directly for generator types because cython generators are different
self._current_rows = result
except TypeError:
self._current_rows = [result] if result else []
def _fetch_all(self):
self._current_rows = list(self)
self._page_iter = None
def _enter_list_mode(self, operator):
if self._list_mode:
return
if self._page_iter:
raise RuntimeError("Cannot use %s when results have been iterated." % operator)
if self.response_future.has_more_pages:
log.warning("Using %s on paged results causes entire result set to be materialized.", operator)
self._fetch_all() # done regardless of paging status in case the row factory produces a generator
self._list_mode = True
def __eq__(self, other):
self._enter_list_mode("equality operator")
return self._current_rows == other
def __getitem__(self, i):
if i == 0:
warn("ResultSet indexing support will be removed in 4.0. Consider using "
"ResultSet.one() to get a single row.", DeprecationWarning)
self._enter_list_mode("index operator")
return self._current_rows[i]
def __nonzero__(self):
return bool(self._current_rows)
__bool__ = __nonzero__
def get_query_trace(self, max_wait_sec=None):
"""
Gets the last query trace from the associated future.
See :meth:`.ResponseFuture.get_query_trace` for details.
"""
return self.response_future.get_query_trace(max_wait_sec)
def get_all_query_traces(self, max_wait_sec_per=None):
"""
Gets all query traces from the associated future.
See :meth:`.ResponseFuture.get_all_query_traces` for details.
"""
return self.response_future.get_all_query_traces(max_wait_sec_per)
+ def cancel_continuous_paging(self):
+ try:
+ self.response_future._continuous_paging_session.cancel()
+ except AttributeError:
+ raise DriverException("Attempted to cancel paging with no active session. This is only for requests with ContinuousdPagingOptions.")
+
@property
def was_applied(self):
"""
For LWT results, returns whether the transaction was applied.
Result is indeterminate if called on a result that was not an LWT request or on
a :class:`.query.BatchStatement` containing LWT. In the latter case either all the batch
succeeds or fails.
Only valid when one of the of the internal row factories is in use.
"""
if self.response_future.row_factory not in (named_tuple_factory, dict_factory, tuple_factory):
raise RuntimeError("Cannot determine LWT result with row factory %s" % (self.response_future.row_factory,))
is_batch_statement = isinstance(self.response_future.query, BatchStatement)
if is_batch_statement and (not self.column_names or self.column_names[0] != "[applied]"):
raise RuntimeError("No LWT were present in the BatchStatement")
if not is_batch_statement and len(self.current_rows) != 1:
raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows)))
row = self.current_rows[0]
if isinstance(row, tuple):
return row[0]
else:
return row['[applied]']
@property
def paging_state(self):
"""
Server paging state of the query. Can be `None` if the query was not paged.
The driver treats paging state as opaque, but it may contain primary key data, so applications may want to
avoid sending this to untrusted parties.
"""
return self.response_future._paging_state
diff --git a/cassandra/connection.py b/cassandra/connection.py
index ba08ae2..0d8a50e 100644
--- a/cassandra/connection.py
+++ b/cassandra/connection.py
@@ -1,1418 +1,1765 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import # to enable import io from stdlib
from collections import defaultdict, deque
import errno
from functools import wraps, partial, total_ordering
from heapq import heappush, heappop
import io
import logging
import six
from six.moves import range
import socket
import struct
import sys
-from threading import Thread, Event, RLock
+from threading import Thread, Event, RLock, Condition
import time
+import ssl
+import weakref
-try:
- import ssl
-except ImportError:
- ssl = None # NOQA
if 'gevent.monkey' in sys.modules:
from gevent.queue import Queue, Empty
else:
from six.moves.queue import Queue, Empty # noqa
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut, ProtocolVersion
from cassandra.marshal import int32_pack
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
StartupMessage, ErrorMessage, CredentialsMessage,
QueryMessage, ResultMessage, ProtocolHandler,
InvalidRequestException, SupportedMessage,
AuthResponseMessage, AuthChallengeMessage,
AuthSuccessMessage, ProtocolException,
- RegisterMessage)
+ RegisterMessage, ReviseRequestMessage)
+from cassandra.segment import SegmentCodec, CrcException
from cassandra.util import OrderedDict
log = logging.getLogger(__name__)
+segment_codec_no_compression = SegmentCodec()
+segment_codec_lz4 = None
+
# We use an ordered dictionary and specifically add lz4 before
# snappy so that lz4 will be preferred. Changing the order of this
# will change the compression preferences for the driver.
locally_supported_compressions = OrderedDict()
try:
import lz4
except ImportError:
pass
else:
# The compress and decompress functions we need were moved from the lz4 to
# the lz4.block namespace, so we try both here.
try:
from lz4 import block as lz4_block
except ImportError:
lz4_block = lz4
try:
lz4_block.compress
lz4_block.decompress
except AttributeError:
raise ImportError(
'lz4 not imported correctly. Imported object should have '
'.compress and and .decompress attributes but does not. '
'Please file a bug report on JIRA. (Imported object was '
'{lz4_block})'.format(lz4_block=repr(lz4_block))
)
# Cassandra writes the uncompressed message length in big endian order,
# but the lz4 lib requires little endian order, so we wrap these
# functions to handle that
def lz4_compress(byts):
# write length in big-endian instead of little-endian
return int32_pack(len(byts)) + lz4_block.compress(byts)[4:]
def lz4_decompress(byts):
# flip from big-endian to little-endian
return lz4_block.decompress(byts[3::-1] + byts[4:])
locally_supported_compressions['lz4'] = (lz4_compress, lz4_decompress)
+ segment_codec_lz4 = SegmentCodec(lz4_compress, lz4_decompress)
try:
import snappy
except ImportError:
pass
else:
# work around apparently buggy snappy decompress
def decompress(byts):
if byts == '\x00':
return ''
return snappy.decompress(byts)
locally_supported_compressions['snappy'] = (snappy.compress, decompress)
DRIVER_NAME, DRIVER_VERSION = 'DataStax Python Driver', sys.modules['cassandra'].__version__
PROTOCOL_VERSION_MASK = 0x7f
HEADER_DIRECTION_FROM_CLIENT = 0x00
HEADER_DIRECTION_TO_CLIENT = 0x80
HEADER_DIRECTION_MASK = 0x80
frame_header_v1_v2 = struct.Struct('>BbBi')
frame_header_v3 = struct.Struct('>BhBi')
class EndPoint(object):
"""
Represents the information to connect to a cassandra node.
"""
@property
def address(self):
"""
The IP address of the node. This is the RPC address the driver uses when connecting to the node
"""
raise NotImplementedError()
@property
def port(self):
"""
The port of the node.
"""
raise NotImplementedError()
@property
def ssl_options(self):
"""
SSL options specific to this endpoint.
"""
return None
@property
def socket_family(self):
"""
The socket family of the endpoint.
"""
return socket.AF_UNSPEC
def resolve(self):
"""
Resolve the endpoint to an address/port. This is called
only on socket connection.
"""
raise NotImplementedError()
class EndPointFactory(object):
cluster = None
def configure(self, cluster):
"""
This is called by the cluster during its initialization.
"""
self.cluster = cluster
return self
def create(self, row):
"""
Create an EndPoint from a system.peers row.
"""
raise NotImplementedError()
@total_ordering
class DefaultEndPoint(EndPoint):
"""
Default EndPoint implementation, basically just an address and port.
"""
def __init__(self, address, port=9042):
self._address = address
self._port = port
@property
def address(self):
return self._address
@property
def port(self):
return self._port
def resolve(self):
return self._address, self._port
def __eq__(self, other):
return isinstance(other, DefaultEndPoint) and \
self.address == other.address and self.port == other.port
def __hash__(self):
return hash((self.address, self.port))
def __lt__(self, other):
return (self.address, self.port) < (other.address, other.port)
def __str__(self):
return str("%s:%d" % (self.address, self.port))
def __repr__(self):
return "<%s: %s:%d>" % (self.__class__.__name__, self.address, self.port)
class DefaultEndPointFactory(EndPointFactory):
port = None
"""
- If set, force all endpoints to use this port.
+ If no port is discovered in the row, this is the default port
+ used for endpoint creation.
"""
def __init__(self, port=None):
self.port = port
def create(self, row):
- addr = None
- if "rpc_address" in row:
- addr = row.get("rpc_address")
- if "native_transport_address" in row:
- addr = row.get("native_transport_address")
- if not addr or addr in ["0.0.0.0", "::"]:
- addr = row.get("peer")
+ # TODO next major... move this class so we don't need this kind of hack
+ from cassandra.metadata import _NodeInfo
+ addr = _NodeInfo.get_broadcast_rpc_address(row)
+ port = _NodeInfo.get_broadcast_rpc_port(row)
+ if port is None:
+ port = self.port if self.port else 9042
# create the endpoint with the translated address
+ # TODO next major, create a TranslatedEndPoint type
return DefaultEndPoint(
self.cluster.address_translator.translate(addr),
- self.port if self.port is not None else 9042)
+ port)
@total_ordering
class SniEndPoint(EndPoint):
"""SNI Proxy EndPoint implementation."""
def __init__(self, proxy_address, server_name, port=9042):
self._proxy_address = proxy_address
self._index = 0
self._resolved_address = None # resolved address
self._port = port
self._server_name = server_name
self._ssl_options = {'server_hostname': server_name}
@property
def address(self):
return self._proxy_address
@property
def port(self):
return self._port
@property
def ssl_options(self):
return self._ssl_options
def resolve(self):
try:
resolved_addresses = socket.getaddrinfo(self._proxy_address, self._port,
socket.AF_UNSPEC, socket.SOCK_STREAM)
except socket.gaierror:
log.debug('Could not resolve sni proxy hostname "%s" '
'with port %d' % (self._proxy_address, self._port))
raise
# round-robin pick
self._resolved_address = sorted(addr[4][0] for addr in resolved_addresses)[self._index % len(resolved_addresses)]
self._index += 1
return self._resolved_address, self._port
def __eq__(self, other):
return (isinstance(other, SniEndPoint) and
self.address == other.address and self.port == other.port and
self._server_name == other._server_name)
def __hash__(self):
return hash((self.address, self.port, self._server_name))
def __lt__(self, other):
return ((self.address, self.port, self._server_name) <
(other.address, other.port, self._server_name))
def __str__(self):
return str("%s:%d:%s" % (self.address, self.port, self._server_name))
def __repr__(self):
return "<%s: %s:%d:%s>" % (self.__class__.__name__,
self.address, self.port, self._server_name)
class SniEndPointFactory(EndPointFactory):
def __init__(self, proxy_address, port):
self._proxy_address = proxy_address
self._port = port
def create(self, row):
host_id = row.get("host_id")
if host_id is None:
raise ValueError("No host_id to create the SniEndPoint")
return SniEndPoint(self._proxy_address, str(host_id), self._port)
def create_from_sni(self, sni):
return SniEndPoint(self._proxy_address, sni, self._port)
@total_ordering
class UnixSocketEndPoint(EndPoint):
"""
Unix Socket EndPoint implementation.
"""
def __init__(self, unix_socket_path):
self._unix_socket_path = unix_socket_path
@property
def address(self):
return self._unix_socket_path
@property
def port(self):
return None
@property
def socket_family(self):
return socket.AF_UNIX
def resolve(self):
return self.address, None
def __eq__(self, other):
return (isinstance(other, UnixSocketEndPoint) and
self._unix_socket_path == other._unix_socket_path)
def __hash__(self):
return hash(self._unix_socket_path)
def __lt__(self, other):
return self._unix_socket_path < other._unix_socket_path
def __str__(self):
return str("%s" % (self._unix_socket_path,))
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self._unix_socket_path)
class _Frame(object):
def __init__(self, version, flags, stream, opcode, body_offset, end_pos):
self.version = version
self.flags = flags
self.stream = stream
self.opcode = opcode
self.body_offset = body_offset
self.end_pos = end_pos
def __eq__(self, other): # facilitates testing
if isinstance(other, _Frame):
return (self.version == other.version and
self.flags == other.flags and
self.stream == other.stream and
self.opcode == other.opcode and
self.body_offset == other.body_offset and
self.end_pos == other.end_pos)
return NotImplemented
def __str__(self):
return "ver({0}); flags({1:04b}); stream({2}); op({3}); offset({4}); len({5})".format(self.version, self.flags, self.stream, self.opcode, self.body_offset, self.end_pos - self.body_offset)
NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK)
class ConnectionException(Exception):
"""
An unrecoverable error was hit when attempting to use a connection,
or the connection was already closed or defunct.
"""
def __init__(self, message, endpoint=None):
Exception.__init__(self, message)
self.endpoint = endpoint
@property
def host(self):
return self.endpoint.address
class ConnectionShutdown(ConnectionException):
"""
Raised when a connection has been marked as defunct or has been closed.
"""
pass
class ProtocolVersionUnsupported(ConnectionException):
"""
Server rejected startup message due to unsupported protocol version
"""
def __init__(self, endpoint, startup_version):
msg = "Unsupported protocol version on %s: %d" % (endpoint, startup_version)
super(ProtocolVersionUnsupported, self).__init__(msg, endpoint)
self.startup_version = startup_version
class ConnectionBusy(Exception):
"""
An attempt was made to send a message through a :class:`.Connection` that
was already at the max number of in-flight operations.
"""
pass
class ProtocolError(Exception):
"""
Communication did not match the protocol that this driver expects.
"""
pass
+class CrcMismatchException(ConnectionException):
+ pass
+
+
+class ContinuousPagingState(object):
+ """
+ A class for specifying continuous paging state, only supported starting with DSE_V2.
+ """
+
+ num_pages_requested = None
+ """
+ How many pages we have already requested
+ """
+
+ num_pages_received = None
+ """
+ How many pages we have already received
+ """
+
+ max_queue_size = None
+ """
+ The max queue size chosen by the user via the options
+ """
+
+ def __init__(self, max_queue_size):
+ self.num_pages_requested = max_queue_size # the initial query requests max_queue_size
+ self.num_pages_received = 0
+ self.max_queue_size = max_queue_size
+
+
+class ContinuousPagingSession(object):
+ def __init__(self, stream_id, decoder, row_factory, connection, state):
+ self.stream_id = stream_id
+ self.decoder = decoder
+ self.row_factory = row_factory
+ self.connection = connection
+ self._condition = Condition()
+ self._stop = False
+ self._page_queue = deque()
+ self._state = state
+ self.released = False
+
+ def on_message(self, result):
+ if isinstance(result, ResultMessage):
+ self.on_page(result)
+ elif isinstance(result, ErrorMessage):
+ self.on_error(result)
+
+ def on_page(self, result):
+ with self._condition:
+ if self._state:
+ self._state.num_pages_received += 1
+ self._page_queue.appendleft((result.column_names, result.parsed_rows, None))
+ self._stop |= result.continuous_paging_last
+ self._condition.notify()
+
+ if result.continuous_paging_last:
+ self.released = True
+
+ def on_error(self, error):
+ if isinstance(error, ErrorMessage):
+ error = error.to_exception()
+
+ log.debug("Got error %s for session %s", error, self.stream_id)
+
+ with self._condition:
+ self._page_queue.appendleft((None, None, error))
+ self._stop = True
+ self._condition.notify()
+
+ self.released = True
+
+ def results(self):
+ try:
+ self._condition.acquire()
+ while True:
+ while not self._page_queue and not self._stop:
+ self._condition.wait(timeout=5)
+ while self._page_queue:
+ names, rows, err = self._page_queue.pop()
+ if err:
+ raise err
+ self.maybe_request_more()
+ self._condition.release()
+ for row in self.row_factory(names, rows):
+ yield row
+ self._condition.acquire()
+ if self._stop:
+ break
+ finally:
+ try:
+ self._condition.release()
+ except RuntimeError:
+ # This exception happens if the CP results are not entirely consumed
+ # and the session is terminated by the runtime. See PYTHON-1054
+ pass
+
+ def maybe_request_more(self):
+ if not self._state:
+ return
+
+ max_queue_size = self._state.max_queue_size
+ num_in_flight = self._state.num_pages_requested - self._state.num_pages_received
+ space_in_queue = max_queue_size - len(self._page_queue) - num_in_flight
+ log.debug("Session %s from %s, space in CP queue: %s, requested: %s, received: %s, num_in_flight: %s",
+ self.stream_id, self.connection.host, space_in_queue, self._state.num_pages_requested,
+ self._state.num_pages_received, num_in_flight)
+
+ if space_in_queue >= max_queue_size / 2:
+ self.update_next_pages(space_in_queue)
+
+ def update_next_pages(self, num_next_pages):
+ try:
+ self._state.num_pages_requested += num_next_pages
+ log.debug("Updating backpressure for session %s from %s", self.stream_id, self.connection.host)
+ with self.connection.lock:
+ self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE,
+ self.stream_id,
+ next_pages=num_next_pages),
+ self.connection.get_request_id(),
+ self._on_backpressure_response)
+ except ConnectionShutdown as ex:
+ log.debug("Failed to update backpressure for session %s from %s, connection is shutdown",
+ self.stream_id, self.connection.host)
+ self.on_error(ex)
+
+ def _on_backpressure_response(self, response):
+ if isinstance(response, ResultMessage):
+ log.debug("Paging session %s backpressure updated.", self.stream_id)
+ else:
+ log.error("Failed updating backpressure for session %s from %s: %s", self.stream_id, self.connection.host,
+ response.to_exception() if hasattr(response, 'to_exception') else response)
+ self.on_error(response)
+
+ def cancel(self):
+ try:
+ log.debug("Canceling paging session %s from %s", self.stream_id, self.connection.host)
+ with self.connection.lock:
+ self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_CANCEL,
+ self.stream_id),
+ self.connection.get_request_id(),
+ self._on_cancel_response)
+ except ConnectionShutdown:
+ log.debug("Failed to cancel session %s from %s, connection is shutdown",
+ self.stream_id, self.connection.host)
+
+ with self._condition:
+ self._stop = True
+ self._condition.notify()
+
+ def _on_cancel_response(self, response):
+ if isinstance(response, ResultMessage):
+ log.debug("Paging session %s canceled.", self.stream_id)
+ else:
+ log.error("Failed canceling streaming session %s from %s: %s", self.stream_id, self.connection.host,
+ response.to_exception() if hasattr(response, 'to_exception') else response)
+ self.released = True
+
+
def defunct_on_error(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
try:
return f(self, *args, **kwargs)
except Exception as exc:
self.defunct(exc)
return wrapper
DEFAULT_CQL_VERSION = '3.0.0'
if six.PY3:
def int_from_buf_item(i):
return i
else:
int_from_buf_item = ord
+class _ConnectionIOBuffer(object):
+ """
+ Abstraction class to ease the use of the different connection io buffers. With
+ protocol V5 and checksumming, the data is read, validated and copied to another
+ cql frame buffer.
+ """
+ _io_buffer = None
+ _cql_frame_buffer = None
+ _connection = None
+ _segment_consumed = False
+
+ def __init__(self, connection):
+ self._io_buffer = io.BytesIO()
+ self._connection = weakref.proxy(connection)
+
+ @property
+ def io_buffer(self):
+ return self._io_buffer
+
+ @property
+ def cql_frame_buffer(self):
+ return self._cql_frame_buffer if self.is_checksumming_enabled else \
+ self._io_buffer
+
+ def set_checksumming_buffer(self):
+ self.reset_io_buffer()
+ self._cql_frame_buffer = io.BytesIO()
+
+ @property
+ def is_checksumming_enabled(self):
+ return self._connection._is_checksumming_enabled
+
+ @property
+ def has_consumed_segment(self):
+ return self._segment_consumed;
+
+ def readable_io_bytes(self):
+ return self.io_buffer.tell()
+
+ def readable_cql_frame_bytes(self):
+ return self.cql_frame_buffer.tell()
+
+ def reset_io_buffer(self):
+ self._io_buffer = io.BytesIO(self._io_buffer.read())
+ self._io_buffer.seek(0, 2) # 2 == SEEK_END
+
+ def reset_cql_frame_buffer(self):
+ if self.is_checksumming_enabled:
+ self._cql_frame_buffer = io.BytesIO(self._cql_frame_buffer.read())
+ self._cql_frame_buffer.seek(0, 2) # 2 == SEEK_END
+ else:
+ self.reset_io_buffer()
+
+
class Connection(object):
CALLBACK_ERR_THREAD_THRESHOLD = 100
in_buffer_size = 4096
out_buffer_size = 4096
cql_version = None
no_compact = False
protocol_version = ProtocolVersion.MAX_SUPPORTED
keyspace = None
compression = True
+ _compression_type = None
compressor = None
decompressor = None
endpoint = None
ssl_options = None
ssl_context = None
last_error = None
# The current number of operations that are in flight. More precisely,
# the number of request IDs that are currently in use.
in_flight = 0
# Max concurrent requests allowed per connection. This is set optimistically high, allowing
# all request ids to be used in protocol version 3+. Normally concurrency would be controlled
# at a higher level by the application or concurrent.execute_concurrent. This attribute
# is for lower-level integrations that want some upper bound without reimplementing.
max_in_flight = 2 ** 15
# A set of available request IDs. When using the v3 protocol or higher,
# this will not initially include all request IDs in order to save memory,
# but the set will grow if it is exhausted.
request_ids = None
# Tracks the highest used request ID in order to help with growing the
# request_ids set
highest_request_id = 0
is_defunct = False
is_closed = False
lock = None
user_type_map = None
msg_received = False
is_unsupported_proto_version = False
is_control_connection = False
signaled_error = False # used for flagging at the pool level
allow_beta_protocol_version = False
- _iobuf = None
_current_frame = None
_socket = None
_socket_impl = socket
_ssl_impl = ssl
_check_hostname = False
_product_type = None
+ _is_checksumming_enabled = False
+
+ @property
+ def _iobuf(self):
+ # backward compatibility, to avoid any change in the reactors
+ return self._io_buffer.io_buffer
+
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression=True,
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
ssl_context=None):
# TODO next major rename host to endpoint and remove port kwarg.
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)
self.authenticator = authenticator
self.ssl_options = ssl_options.copy() if ssl_options else None
self.ssl_context = ssl_context
self.sockopts = sockopts
self.compression = compression
self.cql_version = cql_version
self.protocol_version = protocol_version
self.is_control_connection = is_control_connection
self.user_type_map = user_type_map
self.connect_timeout = connect_timeout
self.allow_beta_protocol_version = allow_beta_protocol_version
self.no_compact = no_compact
self._push_watchers = defaultdict(set)
self._requests = {}
- self._iobuf = io.BytesIO()
+ self._io_buffer = _ConnectionIOBuffer(self)
+ self._continuous_paging_sessions = {}
+ self._socket_writable = True
if ssl_options:
self._check_hostname = bool(self.ssl_options.pop('check_hostname', False))
if self._check_hostname:
if not getattr(ssl, 'match_hostname', None):
raise RuntimeError("ssl_options specify 'check_hostname', but ssl.match_hostname is not provided. "
"Patch or upgrade Python to use this option.")
self.ssl_options.update(self.endpoint.ssl_options or {})
elif self.endpoint.ssl_options:
self.ssl_options = self.endpoint.ssl_options
if protocol_version >= 3:
self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1)
# Don't fill the deque with 2**15 items right away. Start with some and add
# more if needed.
initial_size = min(300, self.max_in_flight)
self.request_ids = deque(range(initial_size))
self.highest_request_id = initial_size - 1
else:
self.max_request_id = min(self.max_in_flight, (2 ** 7) - 1)
self.request_ids = deque(range(self.max_request_id + 1))
self.highest_request_id = self.max_request_id
self.lock = RLock()
self.connected_event = Event()
@property
def host(self):
return self.endpoint.address
@property
def port(self):
return self.endpoint.port
@classmethod
def initialize_reactor(cls):
"""
Called once by Cluster.connect(). This should be used by implementations
to set up any resources that will be shared across connections.
"""
pass
@classmethod
def handle_fork(cls):
"""
Called after a forking. This should cleanup any remaining reactor state
from the parent process.
"""
pass
@classmethod
def create_timer(cls, timeout, callback):
raise NotImplementedError()
@classmethod
def factory(cls, endpoint, timeout, *args, **kwargs):
"""
A factory function which returns connections which have
succeeded in connecting and are ready for service (or
raises an exception otherwise).
"""
start = time.time()
kwargs['connect_timeout'] = timeout
conn = cls(endpoint, *args, **kwargs)
elapsed = time.time() - start
conn.connected_event.wait(timeout - elapsed)
if conn.last_error:
if conn.is_unsupported_proto_version:
raise ProtocolVersionUnsupported(endpoint, conn.protocol_version)
raise conn.last_error
elif not conn.connected_event.is_set():
conn.close()
raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout)
else:
return conn
+ def _wrap_socket_from_context(self):
+ ssl_options = self.ssl_options or {}
+ # PYTHON-1186: set the server_hostname only if the SSLContext has
+ # check_hostname enabled and it is not already provided by the EndPoint ssl options
+ if (self.ssl_context.check_hostname and
+ 'server_hostname' not in ssl_options):
+ ssl_options = ssl_options.copy()
+ ssl_options['server_hostname'] = self.endpoint.address
+ self._socket = self.ssl_context.wrap_socket(self._socket, **ssl_options)
+
+ def _initiate_connection(self, sockaddr):
+ self._socket.connect(sockaddr)
+
+ def _match_hostname(self):
+ ssl.match_hostname(self._socket.getpeercert(), self.endpoint.address)
+
def _get_socket_addresses(self):
address, port = self.endpoint.resolve()
if hasattr(socket, 'AF_UNIX') and self.endpoint.socket_family == socket.AF_UNIX:
return [(socket.AF_UNIX, socket.SOCK_STREAM, 0, None, address)]
addresses = socket.getaddrinfo(address, port, self.endpoint.socket_family, socket.SOCK_STREAM)
if not addresses:
raise ConnectionException("getaddrinfo returned empty list for %s" % (self.endpoint,))
return addresses
def _connect_socket(self):
sockerr = None
addresses = self._get_socket_addresses()
for (af, socktype, proto, _, sockaddr) in addresses:
try:
self._socket = self._socket_impl.socket(af, socktype, proto)
if self.ssl_context:
- self._socket = self.ssl_context.wrap_socket(self._socket,
- **(self.ssl_options or {}))
+ self._wrap_socket_from_context()
elif self.ssl_options:
if not self._ssl_impl:
raise RuntimeError("This version of Python was not compiled with SSL support")
self._socket = self._ssl_impl.wrap_socket(self._socket, **self.ssl_options)
self._socket.settimeout(self.connect_timeout)
- self._socket.connect(sockaddr)
+ self._initiate_connection(sockaddr)
self._socket.settimeout(None)
if self._check_hostname:
- ssl.match_hostname(self._socket.getpeercert(), self.endpoint.address)
+ self._match_hostname()
sockerr = None
break
except socket.error as err:
if self._socket:
self._socket.close()
self._socket = None
sockerr = err
if sockerr:
raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" %
([a[4] for a in addresses], sockerr.strerror or sockerr))
if self.sockopts:
for args in self.sockopts:
self._socket.setsockopt(*args)
+ def _enable_compression(self):
+ if self._compressor:
+ self.compressor = self._compressor
+
+ def _enable_checksumming(self):
+ self._io_buffer.set_checksumming_buffer()
+ self._is_checksumming_enabled = True
+ self._segment_codec = segment_codec_lz4 if self.compressor else segment_codec_no_compression
+ log.debug("Enabling protocol checksumming on connection (%s).", id(self))
+
def close(self):
raise NotImplementedError()
def defunct(self, exc):
with self.lock:
if self.is_defunct or self.is_closed:
return
self.is_defunct = True
exc_info = sys.exc_info()
# if we are not handling an exception, just use the passed exception, and don't try to format exc_info with the message
if any(exc_info):
log.debug("Defuncting connection (%s) to %s:",
id(self), self.endpoint, exc_info=exc_info)
else:
log.debug("Defuncting connection (%s) to %s: %s",
id(self), self.endpoint, exc)
self.last_error = exc
self.close()
+ self.error_all_cp_sessions(exc)
self.error_all_requests(exc)
self.connected_event.set()
return exc
+ def error_all_cp_sessions(self, exc):
+ stream_ids = list(self._continuous_paging_sessions.keys())
+ for stream_id in stream_ids:
+ self._continuous_paging_sessions[stream_id].on_error(exc)
+
def error_all_requests(self, exc):
with self.lock:
requests = self._requests
self._requests = {}
if not requests:
return
new_exc = ConnectionShutdown(str(exc))
+
def try_callback(cb):
try:
cb(new_exc)
except Exception:
log.warning("Ignoring unhandled exception while erroring requests for a "
"failed connection (%s) to host %s:",
id(self), self.endpoint, exc_info=True)
# run first callback from this thread to ensure pool state before leaving
cb, _, _ = requests.popitem()[1]
try_callback(cb)
if not requests:
return
# additional requests are optionally errored from a separate thread
# The default callback and retry logic is fairly expensive -- we don't
# want to tie up the event thread when there are many requests
def err_all_callbacks():
for cb, _, _ in requests.values():
try_callback(cb)
if len(requests) < Connection.CALLBACK_ERR_THREAD_THRESHOLD:
err_all_callbacks()
else:
# daemon thread here because we want to stay decoupled from the cluster TPE
# TODO: would it make sense to just have a driver-global TPE?
t = Thread(target=err_all_callbacks)
t.daemon = True
t.start()
def get_request_id(self):
"""
This must be called while self.lock is held.
"""
try:
return self.request_ids.popleft()
except IndexError:
new_request_id = self.highest_request_id + 1
# in_flight checks should guarantee this
assert new_request_id <= self.max_request_id
self.highest_request_id = new_request_id
return self.highest_request_id
def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response)
for cb in self._push_watchers.get(response.event_type, []):
try:
cb(response.event_args)
except Exception:
log.exception("Pushed event handler errored, ignoring:")
def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None):
if self.is_defunct:
raise ConnectionShutdown("Connection to %s is defunct" % self.endpoint)
elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.endpoint)
+ elif not self._socket_writable:
+ raise ConnectionBusy("Connection %s is overloaded" % self.endpoint)
# queue the decoder function with the request
# this allows us to inject custom functions per request to encode, decode messages
self._requests[request_id] = (cb, decoder, result_metadata)
- msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, allow_beta_protocol_version=self.allow_beta_protocol_version)
+ msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
+ allow_beta_protocol_version=self.allow_beta_protocol_version)
+
+ if self._is_checksumming_enabled:
+ buffer = io.BytesIO()
+ self._segment_codec.encode(buffer, msg)
+ msg = buffer.getvalue()
+
self.push(msg)
return len(msg)
def wait_for_response(self, msg, timeout=None, **kwargs):
return self.wait_for_responses(msg, timeout=timeout, **kwargs)[0]
def wait_for_responses(self, *msgs, **kwargs):
"""
Returns a list of (success, response) tuples. If success
is False, response will be an Exception. Otherwise, response
will be the normal query response.
If fail_on_error was left as True and one of the requests
failed, the corresponding Exception will be raised.
"""
if self.is_closed or self.is_defunct:
raise ConnectionShutdown("Connection %s is already closed" % (self, ))
timeout = kwargs.get('timeout')
fail_on_error = kwargs.get('fail_on_error', True)
waiter = ResponseWaiter(self, len(msgs), fail_on_error)
# busy wait for sufficient space on the connection
messages_sent = 0
while True:
needed = len(msgs) - messages_sent
with self.lock:
available = min(needed, self.max_request_id - self.in_flight + 1)
request_ids = [self.get_request_id() for _ in range(available)]
self.in_flight += available
for i, request_id in enumerate(request_ids):
self.send_msg(msgs[messages_sent + i],
request_id,
partial(waiter.got_response, index=messages_sent + i))
messages_sent += available
if messages_sent == len(msgs):
break
else:
if timeout is not None:
timeout -= 0.01
if timeout <= 0.0:
raise OperationTimedOut()
time.sleep(0.01)
try:
return waiter.deliver(timeout)
except OperationTimedOut:
raise
except Exception as exc:
self.defunct(exc)
raise
def register_watcher(self, event_type, callback, register_timeout=None):
"""
Register a callback for a given event type.
"""
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=[event_type]),
timeout=register_timeout)
def register_watchers(self, type_callback_dict, register_timeout=None):
"""
Register multiple callback/event type pairs, expressed as a dict.
"""
for event_type, callback in type_callback_dict.items():
self._push_watchers[event_type].add(callback)
self.wait_for_response(
RegisterMessage(event_list=type_callback_dict.keys()),
timeout=register_timeout)
def control_conn_disposed(self):
self.is_control_connection = False
self._push_watchers = {}
@defunct_on_error
def _read_frame_header(self):
- buf = self._iobuf.getvalue()
+ buf = self._io_buffer.cql_frame_buffer.getvalue()
pos = len(buf)
if pos:
version = int_from_buf_item(buf[0]) & PROTOCOL_VERSION_MASK
- if version > ProtocolVersion.MAX_SUPPORTED:
+ if version not in ProtocolVersion.SUPPORTED_VERSIONS:
raise ProtocolError("This version of the driver does not support protocol version %d" % version)
frame_header = frame_header_v3 if version >= 3 else frame_header_v1_v2
# this frame header struct is everything after the version byte
header_size = frame_header.size + 1
if pos >= header_size:
flags, stream, op, body_len = frame_header.unpack_from(buf, 1)
if body_len < 0:
raise ProtocolError("Received negative body length: %r" % body_len)
self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size)
return pos
- def _reset_frame(self):
- self._iobuf = io.BytesIO(self._iobuf.read())
- self._iobuf.seek(0, 2) # io.SEEK_END == 2 (constant not present in 2.6)
- self._current_frame = None
+ @defunct_on_error
+ def _process_segment_buffer(self):
+ readable_bytes = self._io_buffer.readable_io_bytes()
+ if readable_bytes >= self._segment_codec.header_length_with_crc:
+ try:
+ self._io_buffer.io_buffer.seek(0)
+ segment_header = self._segment_codec.decode_header(self._io_buffer.io_buffer)
+
+ if readable_bytes >= segment_header.segment_length:
+ segment = self._segment_codec.decode(self._iobuf, segment_header)
+ self._io_buffer._segment_consumed = True
+ self._io_buffer.cql_frame_buffer.write(segment.payload)
+ else:
+ # not enough data to read the segment. reset the buffer pointer at the
+ # beginning to not lose what we previously read (header).
+ self._io_buffer._segment_consumed = False
+ self._io_buffer.io_buffer.seek(0)
+ except CrcException as exc:
+ # re-raise an exception that inherits from ConnectionException
+ raise CrcMismatchException(str(exc), self.endpoint)
+ else:
+ self._io_buffer._segment_consumed = False
def process_io_buffer(self):
while True:
+ if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes():
+ self._process_segment_buffer()
+ self._io_buffer.reset_io_buffer()
+
+ if self._is_checksumming_enabled and not self._io_buffer.has_consumed_segment:
+ # We couldn't read an entire segment from the io buffer, so return
+ # control to allow more bytes to be read off the wire
+ return
+
if not self._current_frame:
pos = self._read_frame_header()
else:
- pos = self._iobuf.tell()
+ pos = self._io_buffer.readable_cql_frame_bytes()
if not self._current_frame or pos < self._current_frame.end_pos:
+ if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes():
+ # We have a multi-segments message and we need to read more
+ # data to complete the current cql frame
+ continue
+
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
return
else:
frame = self._current_frame
- self._iobuf.seek(frame.body_offset)
- msg = self._iobuf.read(frame.end_pos - frame.body_offset)
+ self._io_buffer.cql_frame_buffer.seek(frame.body_offset)
+ msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset)
self.process_msg(frame, msg)
- self._reset_frame()
+ self._io_buffer.reset_cql_frame_buffer()
+ self._current_frame = None
@defunct_on_error
def process_msg(self, header, body):
self.msg_received = True
stream_id = header.stream
if stream_id < 0:
callback = None
decoder = ProtocolHandler.decode_message
result_metadata = None
else:
- try:
- callback, decoder, result_metadata = self._requests.pop(stream_id)
- # This can only happen if the stream_id was
- # removed due to an OperationTimedOut
- except KeyError:
- return
-
- with self.lock:
- self.request_ids.append(stream_id)
+ if stream_id in self._continuous_paging_sessions:
+ paging_session = self._continuous_paging_sessions[stream_id]
+ callback = paging_session.on_message
+ decoder = paging_session.decoder
+ result_metadata = None
+ else:
+ try:
+ callback, decoder, result_metadata = self._requests.pop(stream_id)
+ # This can only happen if the stream_id was
+ # removed due to an OperationTimedOut
+ except KeyError:
+ return
try:
response = decoder(header.version, self.user_type_map, stream_id,
header.flags, header.opcode, body, self.decompressor, result_metadata)
except Exception as exc:
log.exception("Error decoding response from Cassandra. "
"%s; buffer: %r", header, self._iobuf.getvalue())
if callback is not None:
callback(exc)
self.defunct(exc)
return
try:
if stream_id >= 0:
if isinstance(response, ProtocolException):
if 'unsupported protocol version' in response.message:
self.is_unsupported_proto_version = True
else:
log.error("Closing connection %s due to protocol error: %s", self, response.summary_msg())
self.defunct(response)
if callback is not None:
callback(response)
else:
self.handle_pushed(response)
except Exception:
log.exception("Callback handler errored, ignoring:")
+ # done after callback because the callback might signal this as a paging session
+ if stream_id >= 0:
+ if stream_id in self._continuous_paging_sessions:
+ if self._continuous_paging_sessions[stream_id].released:
+ self.remove_continuous_paging_session(stream_id)
+ else:
+ with self.lock:
+ self.request_ids.append(stream_id)
+
+ def new_continuous_paging_session(self, stream_id, decoder, row_factory, state):
+ session = ContinuousPagingSession(stream_id, decoder, row_factory, self, state)
+ self._continuous_paging_sessions[stream_id] = session
+ return session
+
+ def remove_continuous_paging_session(self, stream_id):
+ try:
+ self._continuous_paging_sessions.pop(stream_id)
+ with self.lock:
+ log.debug("Returning cp session stream id %s", stream_id)
+ self.request_ids.append(stream_id)
+ except KeyError:
+ pass
+
@defunct_on_error
def _send_options_message(self):
log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.endpoint)
self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response)
@defunct_on_error
def _handle_options_response(self, options_response):
if self.is_defunct:
return
if not isinstance(options_response, SupportedMessage):
if isinstance(options_response, ConnectionException):
raise options_response
else:
log.error("Did not get expected SupportedMessage response; "
"instead, got: %s", options_response)
raise ConnectionException("Did not get expected SupportedMessage "
"response; instead, got: %s"
% (options_response,))
log.debug("Received options response on new connection (%s) from %s",
id(self), self.endpoint)
supported_cql_versions = options_response.cql_versions
remote_supported_compressions = options_response.options['COMPRESSION']
self._product_type = options_response.options.get('PRODUCT_TYPE', [None])[0]
if self.cql_version:
if self.cql_version not in supported_cql_versions:
raise ProtocolError(
"cql_version %r is not supported by remote (w/ native "
"protocol). Supported versions: %r"
% (self.cql_version, supported_cql_versions))
else:
self.cql_version = supported_cql_versions[0]
self._compressor = None
compression_type = None
if self.compression:
overlap = (set(locally_supported_compressions.keys()) &
set(remote_supported_compressions))
if len(overlap) == 0:
log.debug("No available compression types supported on both ends."
" locally supported: %r. remotely supported: %r",
locally_supported_compressions.keys(),
remote_supported_compressions)
else:
compression_type = None
if isinstance(self.compression, six.string_types):
# the user picked a specific compression type ('snappy' or 'lz4')
if self.compression not in remote_supported_compressions:
raise ProtocolError(
"The requested compression type (%s) is not supported by the Cassandra server at %s"
% (self.compression, self.endpoint))
compression_type = self.compression
else:
# our locally supported compressions are ordered to prefer
# lz4, if available
for k in locally_supported_compressions.keys():
if k in overlap:
compression_type = k
break
- # set the decompressor here, but set the compressor only after
- # a successful Ready message
- self._compressor, self.decompressor = \
- locally_supported_compressions[compression_type]
+ # If snappy compression is selected with v5+checksumming, the connection
+ # will fail with OTO. Only lz4 is supported
+ if (compression_type == 'snappy' and
+ ProtocolVersion.has_checksumming_support(self.protocol_version)):
+ log.debug("Snappy compression is not supported with protocol version %s and "
+ "checksumming. Consider installing lz4. Disabling compression.", self.protocol_version)
+ compression_type = None
+ else:
+ # set the decompressor here, but set the compressor only after
+ # a successful Ready message
+ self._compression_type = compression_type
+ self._compressor, self.decompressor = \
+ locally_supported_compressions[compression_type]
self._send_startup_message(compression_type, no_compact=self.no_compact)
@defunct_on_error
def _send_startup_message(self, compression=None, no_compact=False):
log.debug("Sending StartupMessage on %s", self)
opts = {'DRIVER_NAME': DRIVER_NAME,
'DRIVER_VERSION': DRIVER_VERSION}
if compression:
opts['COMPRESSION'] = compression
if no_compact:
opts['NO_COMPACT'] = 'true'
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response)
log.debug("Sent StartupMessage on %s", self)
@defunct_on_error
def _handle_startup_response(self, startup_response, did_authenticate=False):
if self.is_defunct:
return
+
if isinstance(startup_response, ReadyMessage):
if self.authenticator:
log.warning("An authentication challenge was not sent, "
"this is suspicious because the driver expects "
"authentication (configured authenticator = %s)",
self.authenticator.__class__.__name__)
log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.endpoint)
- if self._compressor:
- self.compressor = self._compressor
+ self._enable_compression()
+
+ if ProtocolVersion.has_checksumming_support(self.protocol_version):
+ self._enable_checksumming()
+
self.connected_event.set()
elif isinstance(startup_response, AuthenticateMessage):
log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s",
id(self), self.endpoint, startup_response.authenticator)
if self.authenticator is None:
- raise AuthenticationFailed('Remote end requires authentication.')
+ log.error("Failed to authenticate to %s. If you are trying to connect to a DSE cluster, "
+ "consider using TransitionalModePlainTextAuthProvider "
+ "if DSE authentication is configured with transitional mode" % (self.host,))
+ raise AuthenticationFailed('Remote end requires authentication')
+
+ self._enable_compression()
+ if ProtocolVersion.has_checksumming_support(self.protocol_version):
+ self._enable_checksumming()
if isinstance(self.authenticator, dict):
log.debug("Sending credentials-based auth response on %s", self)
cm = CredentialsMessage(creds=self.authenticator)
callback = partial(self._handle_startup_response, did_authenticate=True)
self.send_msg(cm, self.get_request_id(), cb=callback)
else:
log.debug("Sending SASL-based auth response on %s", self)
self.authenticator.server_authenticator_class = startup_response.authenticator
initial_response = self.authenticator.initial_response()
initial_response = "" if initial_response is None else initial_response
- self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), self._handle_auth_response)
+ self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(),
+ self._handle_auth_response)
elif isinstance(startup_response, ErrorMessage):
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
id(self), self.endpoint, startup_response.summary_msg())
if did_authenticate:
raise AuthenticationFailed(
"Failed to authenticate to %s: %s" %
(self.endpoint, startup_response.summary_msg()))
else:
raise ConnectionException(
"Failed to initialize new connection to %s: %s"
% (self.endpoint, startup_response.summary_msg()))
elif isinstance(startup_response, ConnectionShutdown):
log.debug("Connection to %s was closed during the startup handshake", (self.endpoint))
raise startup_response
else:
msg = "Unexpected response during Connection setup: %r"
log.error(msg, startup_response)
raise ProtocolError(msg % (startup_response,))
@defunct_on_error
def _handle_auth_response(self, auth_response):
if self.is_defunct:
return
if isinstance(auth_response, AuthSuccessMessage):
log.debug("Connection %s successfully authenticated", self)
self.authenticator.on_authentication_success(auth_response.token)
if self._compressor:
self.compressor = self._compressor
self.connected_event.set()
elif isinstance(auth_response, AuthChallengeMessage):
response = self.authenticator.evaluate_challenge(auth_response.challenge)
msg = AuthResponseMessage("" if response is None else response)
log.debug("Responding to auth challenge on %s", self)
self.send_msg(msg, self.get_request_id(), self._handle_auth_response)
elif isinstance(auth_response, ErrorMessage):
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
id(self), self.endpoint, auth_response.summary_msg())
raise AuthenticationFailed(
"Failed to authenticate to %s: %s" %
(self.endpoint, auth_response.summary_msg()))
elif isinstance(auth_response, ConnectionShutdown):
log.debug("Connection to %s was closed during the authentication process", self.endpoint)
raise auth_response
else:
msg = "Unexpected response during Connection authentication to %s: %r"
log.error(msg, self.endpoint, auth_response)
raise ProtocolError(msg % (self.endpoint, auth_response))
def set_keyspace_blocking(self, keyspace):
if not keyspace or keyspace == self.keyspace:
return
query = QueryMessage(query='USE "%s"' % (keyspace,),
consistency_level=ConsistencyLevel.ONE)
try:
result = self.wait_for_response(query)
except InvalidRequestException as ire:
# the keyspace probably doesn't exist
raise ire.to_exception()
except Exception as exc:
conn_exc = ConnectionException(
"Problem while setting keyspace: %r" % (exc,), self.endpoint)
self.defunct(conn_exc)
raise conn_exc
if isinstance(result, ResultMessage):
self.keyspace = keyspace
else:
conn_exc = ConnectionException(
"Problem while setting keyspace: %r" % (result,), self.endpoint)
self.defunct(conn_exc)
raise conn_exc
def set_keyspace_async(self, keyspace, callback):
"""
Use this in order to avoid deadlocking the event loop thread.
When the operation completes, `callback` will be called with
two arguments: this connection and an Exception if an error
occurred, otherwise :const:`None`.
This method will always increment :attr:`.in_flight` attribute, even if
it doesn't need to make a request, just to maintain an
":attr:`.in_flight` is incremented" invariant.
"""
# Here we increment in_flight unconditionally, whether we need to issue
# a request or not. This is bad, but allows callers -- specifically
# _set_keyspace_for_all_conns -- to assume that we increment
# self.in_flight during this call. This allows the passed callback to
# safely call HostConnection{Pool,}.return_connection on this
# Connection.
#
# We use a busy wait on the lock here because:
# - we'll only spin if the connection is at max capacity, which is very
# unlikely for a set_keyspace call
# - it allows us to avoid signaling a condition every time a request completes
while True:
with self.lock:
if self.in_flight < self.max_request_id:
self.in_flight += 1
break
time.sleep(0.001)
if not keyspace or keyspace == self.keyspace:
callback(self, None)
return
query = QueryMessage(query='USE "%s"' % (keyspace,),
consistency_level=ConsistencyLevel.ONE)
def process_result(result):
if isinstance(result, ResultMessage):
self.keyspace = keyspace
callback(self, None)
elif isinstance(result, InvalidRequestException):
callback(self, result.to_exception())
else:
callback(self, self.defunct(ConnectionException(
"Problem while setting keyspace: %r" % (result,), self.endpoint)))
# We've incremented self.in_flight above, so we "have permission" to
# acquire a new request id
request_id = self.get_request_id()
self.send_msg(query, request_id, process_result)
@property
def is_idle(self):
return not self.msg_received
def reset_idle(self):
self.msg_received = False
def __str__(self):
status = ""
if self.is_defunct:
status = " (defunct)"
elif self.is_closed:
status = " (closed)"
return "<%s(%r) %s%s>" % (self.__class__.__name__, id(self), self.endpoint, status)
__repr__ = __str__
class ResponseWaiter(object):
def __init__(self, connection, num_responses, fail_on_error):
self.connection = connection
self.pending = num_responses
self.fail_on_error = fail_on_error
self.error = None
self.responses = [None] * num_responses
self.event = Event()
def got_response(self, response, index):
with self.connection.lock:
self.connection.in_flight -= 1
if isinstance(response, Exception):
if hasattr(response, 'to_exception'):
response = response.to_exception()
if self.fail_on_error:
self.error = response
self.event.set()
else:
self.responses[index] = (False, response)
else:
if not self.fail_on_error:
self.responses[index] = (True, response)
else:
self.responses[index] = response
self.pending -= 1
if not self.pending:
self.event.set()
def deliver(self, timeout=None):
"""
If fail_on_error was set to False, a list of (success, response)
tuples will be returned. If success is False, response will be
an Exception. Otherwise, response will be the normal query response.
If fail_on_error was left as True and one of the requests
failed, the corresponding Exception will be raised. Otherwise,
the normal response will be returned.
"""
self.event.wait(timeout)
if self.error:
raise self.error
elif not self.event.is_set():
raise OperationTimedOut()
else:
return self.responses
class HeartbeatFuture(object):
def __init__(self, connection, owner):
self._exception = None
self._event = Event()
self.connection = connection
self.owner = owner
log.debug("Sending options message heartbeat on idle connection (%s) %s",
id(connection), connection.endpoint)
with connection.lock:
- if connection.in_flight <= connection.max_request_id:
+ if connection.in_flight < connection.max_request_id:
connection.in_flight += 1
connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
else:
self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold")
self._event.set()
def wait(self, timeout):
self._event.wait(timeout)
if self._event.is_set():
if self._exception:
raise self._exception
else:
raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,), self.connection.endpoint)
def _options_callback(self, response):
if isinstance(response, SupportedMessage):
log.debug("Received options response on connection (%s) from %s",
id(self.connection), self.connection.endpoint)
else:
if isinstance(response, ConnectionException):
self._exception = response
else:
self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s"
% (response,))
self._event.set()
class ConnectionHeartbeat(Thread):
def __init__(self, interval_sec, get_connection_holders, timeout):
Thread.__init__(self, name="Connection heartbeat")
self._interval = interval_sec
self._timeout = timeout
self._get_connection_holders = get_connection_holders
self._shutdown_event = Event()
self.daemon = True
self.start()
class ShutdownException(Exception):
pass
def run(self):
self._shutdown_event.wait(self._interval)
while not self._shutdown_event.is_set():
start_time = time.time()
futures = []
failed_connections = []
try:
for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]:
for connection in connections:
self._raise_if_stopped()
if not (connection.is_defunct or connection.is_closed):
if connection.is_idle:
try:
futures.append(HeartbeatFuture(connection, owner))
except Exception as e:
log.warning("Failed sending heartbeat message on connection (%s) to %s",
id(connection), connection.endpoint)
failed_connections.append((connection, owner, e))
else:
connection.reset_idle()
else:
log.debug("Cannot send heartbeat message on connection (%s) to %s",
id(connection), connection.endpoint)
# make sure the owner sees this defunt/closed connection
owner.return_connection(connection)
self._raise_if_stopped()
# Wait max `self._timeout` seconds for all HeartbeatFutures to complete
timeout = self._timeout
start_time = time.time()
for f in futures:
self._raise_if_stopped()
connection = f.connection
try:
f.wait(timeout)
# TODO: move this, along with connection locks in pool, down into Connection
with connection.lock:
connection.in_flight -= 1
connection.reset_idle()
except Exception as e:
log.warning("Heartbeat failed for connection (%s) to %s",
id(connection), connection.endpoint)
failed_connections.append((f.connection, f.owner, e))
timeout = self._timeout - (time.time() - start_time)
for connection, owner, exc in failed_connections:
self._raise_if_stopped()
if not connection.is_control_connection:
# Only HostConnection supports shutdown_on_error
owner.shutdown_on_error = True
connection.defunct(exc)
owner.return_connection(connection)
except self.ShutdownException:
pass
except Exception:
log.error("Failed connection heartbeat", exc_info=True)
elapsed = time.time() - start_time
self._shutdown_event.wait(max(self._interval - elapsed, 0.01))
def stop(self):
self._shutdown_event.set()
self.join()
def _raise_if_stopped(self):
if self._shutdown_event.is_set():
raise self.ShutdownException()
class Timer(object):
canceled = False
def __init__(self, timeout, callback):
self.end = time.time() + timeout
self.callback = callback
def __lt__(self, other):
return self.end < other.end
def cancel(self):
self.canceled = True
def finish(self, time_now):
if self.canceled:
return True
if time_now >= self.end:
self.callback()
return True
return False
class TimerManager(object):
def __init__(self):
self._queue = []
self._new_timers = []
def add_timer(self, timer):
"""
called from client thread with a Timer object
"""
self._new_timers.append((timer.end, timer))
def service_timeouts(self):
"""
run callbacks on all expired timers
Called from the event thread
:return: next end time, or None
"""
queue = self._queue
if self._new_timers:
new_timers = self._new_timers
while new_timers:
heappush(queue, new_timers.pop())
if queue:
now = time.time()
while queue:
try:
timer = queue[0][1]
if timer.finish(now):
heappop(queue)
else:
return timer.end
except Exception:
log.exception("Exception while servicing timeout callback: ")
@property
def next_timeout(self):
try:
return self._queue[0][0]
except IndexError:
pass
diff --git a/cassandra/cqlengine/connection.py b/cassandra/cqlengine/connection.py
index 884e04e..90e6d90 100644
--- a/cassandra/cqlengine/connection.py
+++ b/cassandra/cqlengine/connection.py
@@ -1,384 +1,392 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import logging
import six
import threading
from cassandra.cluster import Cluster, _ConfigMode, _NOT_SET, NoHostAvailable, UserTypeDoesNotExist, ConsistencyLevel
from cassandra.query import SimpleStatement, dict_factory
from cassandra.cqlengine import CQLEngineException
from cassandra.cqlengine.statements import BaseCQLStatement
log = logging.getLogger(__name__)
NOT_SET = _NOT_SET # required for passing timeout to Session.execute
cluster = None
session = None
# connections registry
DEFAULT_CONNECTION = object()
_connections = {}
# Because type models may be registered before a connection is present,
# and because sessions may be replaced, we must register UDTs here, in order
# to have them registered when a new session is established.
udt_by_keyspace = defaultdict(dict)
def format_log_context(msg, connection=None, keyspace=None):
"""Format log message to add keyspace and connection context"""
connection_info = connection or 'DEFAULT_CONNECTION'
if keyspace:
msg = '[Connection: {0}, Keyspace: {1}] {2}'.format(connection_info, keyspace, msg)
else:
msg = '[Connection: {0}] {1}'.format(connection_info, msg)
return msg
class UndefinedKeyspaceException(CQLEngineException):
pass
class Connection(object):
"""CQLEngine Connection"""
name = None
hosts = None
consistency = None
retry_connect = False
lazy_connect = False
lazy_connect_lock = None
cluster_options = None
cluster = None
session = None
def __init__(self, name, hosts, consistency=None,
lazy_connect=False, retry_connect=False, cluster_options=None):
self.hosts = hosts
self.name = name
self.consistency = consistency
self.lazy_connect = lazy_connect
self.retry_connect = retry_connect
self.cluster_options = cluster_options if cluster_options else {}
self.lazy_connect_lock = threading.RLock()
@classmethod
def from_session(cls, name, session):
instance = cls(name=name, hosts=session.hosts)
instance.cluster, instance.session = session.cluster, session
instance.setup_session()
return instance
def setup(self):
"""Setup the connection"""
global cluster, session
if 'username' in self.cluster_options or 'password' in self.cluster_options:
raise CQLEngineException("Username & Password are now handled by using the native driver's auth_provider")
if self.lazy_connect:
return
- self.cluster = Cluster(self.hosts, **self.cluster_options)
+ if 'cloud' in self.cluster_options:
+ if self.hosts:
+ log.warning("Ignoring hosts %s because a cloud config was provided.", self.hosts)
+ self.cluster = Cluster(**self.cluster_options)
+ else:
+ self.cluster = Cluster(self.hosts, **self.cluster_options)
+
try:
self.session = self.cluster.connect()
log.debug(format_log_context("connection initialized with internally created session", connection=self.name))
except NoHostAvailable:
if self.retry_connect:
log.warning(format_log_context("connect failed, setting up for re-attempt on first use", connection=self.name))
self.lazy_connect = True
raise
if DEFAULT_CONNECTION in _connections and _connections[DEFAULT_CONNECTION] == self:
cluster = _connections[DEFAULT_CONNECTION].cluster
session = _connections[DEFAULT_CONNECTION].session
self.setup_session()
def setup_session(self):
if self.cluster._config_mode == _ConfigMode.PROFILES:
self.cluster.profile_manager.default.row_factory = dict_factory
if self.consistency is not None:
self.cluster.profile_manager.default.consistency_level = self.consistency
else:
self.session.row_factory = dict_factory
if self.consistency is not None:
self.session.default_consistency_level = self.consistency
enc = self.session.encoder
enc.mapping[tuple] = enc.cql_encode_tuple
_register_known_types(self.session.cluster)
def handle_lazy_connect(self):
# if lazy_connect is False, it means the cluster is setup and ready
# No need to acquire the lock
if not self.lazy_connect:
return
with self.lazy_connect_lock:
# lazy_connect might have been set to False by another thread while waiting the lock
# In this case, do nothing.
if self.lazy_connect:
log.debug(format_log_context("Lazy connect enabled", connection=self.name))
self.lazy_connect = False
self.setup()
def register_connection(name, hosts=None, consistency=None, lazy_connect=False,
retry_connect=False, cluster_options=None, default=False,
session=None):
"""
Add a connection to the connection registry. ``hosts`` and ``session`` are
mutually exclusive, and ``consistency``, ``lazy_connect``,
``retry_connect``, and ``cluster_options`` only work with ``hosts``. Using
``hosts`` will create a new :class:`cassandra.cluster.Cluster` and
:class:`cassandra.cluster.Session`.
:param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`).
:param int consistency: The default :class:`~.ConsistencyLevel` for the
registered connection's new session. Default is the same as
:attr:`.Session.default_consistency_level`. For use with ``hosts`` only;
will fail when used with ``session``.
:param bool lazy_connect: True if should not connect until first use. For
use with ``hosts`` only; will fail when used with ``session``.
:param bool retry_connect: True if we should retry to connect even if there
was a connection failure initially. For use with ``hosts`` only; will
fail when used with ``session``.
:param dict cluster_options: A dict of options to be used as keyword
arguments to :class:`cassandra.cluster.Cluster`. For use with ``hosts``
only; will fail when used with ``session``.
:param bool default: If True, set the new connection as the cqlengine
default
:param Session session: A :class:`cassandra.cluster.Session` to be used in
the created connection.
"""
if name in _connections:
log.warning("Registering connection '{0}' when it already exists.".format(name))
if session is not None:
invalid_config_args = (hosts is not None or
consistency is not None or
lazy_connect is not False or
retry_connect is not False or
cluster_options is not None)
if invalid_config_args:
raise CQLEngineException(
"Session configuration arguments and 'session' argument are mutually exclusive"
)
conn = Connection.from_session(name, session=session)
else: # use hosts argument
conn = Connection(
name, hosts=hosts,
consistency=consistency, lazy_connect=lazy_connect,
retry_connect=retry_connect, cluster_options=cluster_options
)
conn.setup()
_connections[name] = conn
if default:
set_default_connection(name)
return conn
def unregister_connection(name):
global cluster, session
if name not in _connections:
return
if DEFAULT_CONNECTION in _connections and _connections[name] == _connections[DEFAULT_CONNECTION]:
del _connections[DEFAULT_CONNECTION]
cluster = None
session = None
conn = _connections[name]
if conn.cluster:
conn.cluster.shutdown()
del _connections[name]
log.debug("Connection '{0}' has been removed from the registry.".format(name))
def set_default_connection(name):
global cluster, session
if name not in _connections:
raise CQLEngineException("Connection '{0}' doesn't exist.".format(name))
log.debug("Connection '{0}' has been set as default.".format(name))
_connections[DEFAULT_CONNECTION] = _connections[name]
cluster = _connections[name].cluster
session = _connections[name].session
def get_connection(name=None):
if not name:
name = DEFAULT_CONNECTION
if name not in _connections:
raise CQLEngineException("Connection name '{0}' doesn't exist in the registry.".format(name))
conn = _connections[name]
conn.handle_lazy_connect()
return conn
def default():
"""
Configures the default connection to localhost, using the driver defaults
(except for row_factory)
"""
try:
conn = get_connection()
if conn.session:
log.warning("configuring new default connection for cqlengine when one was already set")
except:
pass
register_connection('default', hosts=None, default=True)
log.debug("cqlengine connection initialized with default session to localhost")
def set_session(s):
"""
Configures the default connection with a preexisting :class:`cassandra.cluster.Session`
Note: the mapper presently requires a Session :attr:`~.row_factory` set to ``dict_factory``.
This may be relaxed in the future
"""
try:
conn = get_connection()
except CQLEngineException:
# no default connection set; initalize one
register_connection('default', session=s, default=True)
conn = get_connection()
else:
if conn.session:
log.warning("configuring new default session for cqlengine when one was already set")
if not any([
s.cluster.profile_manager.default.row_factory is dict_factory and s.cluster._config_mode in [_ConfigMode.PROFILES, _ConfigMode.UNCOMMITTED],
s.row_factory is dict_factory and s.cluster._config_mode in [_ConfigMode.LEGACY, _ConfigMode.UNCOMMITTED],
]):
raise CQLEngineException("Failed to initialize: row_factory must be 'dict_factory'")
conn.session = s
conn.cluster = s.cluster
# Set default keyspace from given session's keyspace
if conn.session.keyspace:
from cassandra.cqlengine import models
models.DEFAULT_KEYSPACE = conn.session.keyspace
conn.setup_session()
log.debug("cqlengine default connection initialized with %s", s)
+# TODO next major: if a cloud config is specified in kwargs, hosts will be ignored.
+# This function should be refactored to reflect this change. PYTHON-1265
def setup(
hosts,
default_keyspace,
consistency=None,
lazy_connect=False,
retry_connect=False,
**kwargs):
"""
Setup a the driver connection used by the mapper
:param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`)
:param str default_keyspace: The default keyspace to use
:param int consistency: The global default :class:`~.ConsistencyLevel` - default is the same as :attr:`.Session.default_consistency_level`
:param bool lazy_connect: True if should not connect until first use
:param bool retry_connect: True if we should retry to connect even if there was a connection failure initially
:param \*\*kwargs: Pass-through keyword arguments for :class:`cassandra.cluster.Cluster`
"""
from cassandra.cqlengine import models
models.DEFAULT_KEYSPACE = default_keyspace
register_connection('default', hosts=hosts, consistency=consistency, lazy_connect=lazy_connect,
retry_connect=retry_connect, cluster_options=kwargs, default=True)
def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connection=None):
conn = get_connection(connection)
if not conn.session:
raise CQLEngineException("It is required to setup() cqlengine before executing queries")
if isinstance(query, SimpleStatement):
pass #
elif isinstance(query, BaseCQLStatement):
params = query.get_context()
query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
elif isinstance(query, six.string_types):
query = SimpleStatement(query, consistency_level=consistency_level)
log.debug(format_log_context('Query: {}, Params: {}'.format(query.query_string, params), connection=connection))
result = conn.session.execute(query, params, timeout=timeout)
return result
def get_session(connection=None):
conn = get_connection(connection)
return conn.session
def get_cluster(connection=None):
conn = get_connection(connection)
if not conn.cluster:
raise CQLEngineException("%s.cluster is not configured. Call one of the setup or default functions first." % __name__)
return conn.cluster
def register_udt(keyspace, type_name, klass, connection=None):
udt_by_keyspace[keyspace][type_name] = klass
try:
cluster = get_cluster(connection)
except CQLEngineException:
cluster = None
if cluster:
try:
cluster.register_user_type(keyspace, type_name, klass)
except UserTypeDoesNotExist:
pass # new types are covered in management sync functions
def _register_known_types(cluster):
from cassandra.cqlengine import models
for ks_name, name_type_map in udt_by_keyspace.items():
for type_name, klass in name_type_map.items():
try:
cluster.register_user_type(ks_name or models.DEFAULT_KEYSPACE, type_name, klass)
except UserTypeDoesNotExist:
pass # new types are covered in management sync functions
diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py
index 42ded8a..536bde6 100644
--- a/cassandra/cqlengine/management.py
+++ b/cassandra/cqlengine/management.py
@@ -1,548 +1,548 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import json
import logging
import os
import six
import warnings
from itertools import product
from cassandra import metadata
from cassandra.cqlengine import CQLEngineException
from cassandra.cqlengine import columns, query
from cassandra.cqlengine.connection import execute, get_cluster, format_log_context
from cassandra.cqlengine.models import Model
from cassandra.cqlengine.named import NamedTable
from cassandra.cqlengine.usertype import UserType
CQLENG_ALLOW_SCHEMA_MANAGEMENT = 'CQLENG_ALLOW_SCHEMA_MANAGEMENT'
Field = namedtuple('Field', ['name', 'type'])
log = logging.getLogger(__name__)
# system keyspaces
schema_columnfamilies = NamedTable('system', 'schema_columnfamilies')
def _get_context(keyspaces, connections):
"""Return all the execution contexts"""
if keyspaces:
if not isinstance(keyspaces, (list, tuple)):
raise ValueError('keyspaces must be a list or a tuple.')
if connections:
if not isinstance(connections, (list, tuple)):
raise ValueError('connections must be a list or a tuple.')
keyspaces = keyspaces if keyspaces else [None]
connections = connections if connections else [None]
return product(connections, keyspaces)
def create_keyspace_simple(name, replication_factor, durable_writes=True, connections=None):
"""
Creates a keyspace with SimpleStrategy for replica placement
If the keyspace already exists, it will not be modified.
**This function should be used with caution, especially in production environments.
Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).**
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
:param str name: name of keyspace to create
:param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy`
:param bool durable_writes: Write log is bypassed if set to False
:param list connections: List of connection names
"""
_create_keyspace(name, durable_writes, 'SimpleStrategy',
{'replication_factor': replication_factor}, connections=connections)
def create_keyspace_network_topology(name, dc_replication_map, durable_writes=True, connections=None):
"""
Creates a keyspace with NetworkTopologyStrategy for replica placement
If the keyspace already exists, it will not be modified.
**This function should be used with caution, especially in production environments.
Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).**
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
:param str name: name of keyspace to create
:param dict dc_replication_map: map of dc_names: replication_factor
:param bool durable_writes: Write log is bypassed if set to False
:param list connections: List of connection names
"""
_create_keyspace(name, durable_writes, 'NetworkTopologyStrategy', dc_replication_map, connections=connections)
def _create_keyspace(name, durable_writes, strategy_class, strategy_options, connections=None):
if not _allow_schema_modification():
return
if connections:
if not isinstance(connections, (list, tuple)):
raise ValueError('Connections must be a list or a tuple.')
def __create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=None):
cluster = get_cluster(connection)
if name not in cluster.metadata.keyspaces:
log.info(format_log_context("Creating keyspace %s", connection=connection), name)
ks_meta = metadata.KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options)
execute(ks_meta.as_cql_query(), connection=connection)
else:
log.info(format_log_context("Not creating keyspace %s because it already exists", connection=connection), name)
if connections:
for connection in connections:
__create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=connection)
else:
__create_keyspace(name, durable_writes, strategy_class, strategy_options)
def drop_keyspace(name, connections=None):
"""
Drops a keyspace, if it exists.
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
**This function should be used with caution, especially in production environments.
Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).**
:param str name: name of keyspace to drop
:param list connections: List of connection names
"""
if not _allow_schema_modification():
return
if connections:
if not isinstance(connections, (list, tuple)):
raise ValueError('Connections must be a list or a tuple.')
def _drop_keyspace(name, connection=None):
cluster = get_cluster(connection)
if name in cluster.metadata.keyspaces:
execute("DROP KEYSPACE {0}".format(metadata.protect_name(name)), connection=connection)
if connections:
for connection in connections:
_drop_keyspace(name, connection)
else:
_drop_keyspace(name)
def _get_index_name_by_column(table, column_name):
"""
Find the index name for a given table and column.
"""
protected_name = metadata.protect_name(column_name)
possible_index_values = [protected_name, "values(%s)" % protected_name]
for index_metadata in table.indexes.values():
options = dict(index_metadata.index_options)
if options.get('target') in possible_index_values:
return index_metadata.name
def sync_table(model, keyspaces=None, connections=None):
"""
Inspects the model and creates / updates the corresponding table and columns.
If `keyspaces` is specified, the table will be synched for all specified keyspaces.
Note that the `Model.__keyspace__` is ignored in that case.
If `connections` is specified, the table will be synched for all specified connections. Note that the `Model.__connection__` is ignored in that case.
If not specified, it will try to get the connection from the Model.
Any User Defined Types used in the table are implicitly synchronized.
This function can only add fields that are not part of the primary key.
Note that the attributes removed from the model are not deleted on the database.
They become effectively ignored by (will not show up on) the model.
**This function should be used with caution, especially in production environments.
Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).**
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
"""
context = _get_context(keyspaces, connections)
for connection, keyspace in context:
with query.ContextQuery(model, keyspace=keyspace) as m:
_sync_table(m, connection=connection)
def _sync_table(model, connection=None):
if not _allow_schema_modification():
return
if not issubclass(model, Model):
raise CQLEngineException("Models must be derived from base Model.")
if model.__abstract__:
raise CQLEngineException("cannot create table from abstract model")
cf_name = model.column_family_name()
raw_cf_name = model._raw_column_family_name()
ks_name = model._get_keyspace()
connection = connection or model._get_connection()
cluster = get_cluster(connection)
try:
keyspace = cluster.metadata.keyspaces[ks_name]
except KeyError:
msg = format_log_context("Keyspace '{0}' for model {1} does not exist.", connection=connection)
raise CQLEngineException(msg.format(ks_name, model))
tables = keyspace.tables
syncd_types = set()
for col in model._columns.values():
udts = []
columns.resolve_udts(col, udts)
for udt in [u for u in udts if u not in syncd_types]:
_sync_type(ks_name, udt, syncd_types, connection=connection)
if raw_cf_name not in tables:
log.debug(format_log_context("sync_table creating new table %s", keyspace=ks_name, connection=connection), cf_name)
qs = _get_create_table(model)
try:
execute(qs, connection=connection)
except CQLEngineException as ex:
# 1.2 doesn't return cf names, so we have to examine the exception
# and ignore if it says the column family already exists
- if "Cannot add already existing column family" not in unicode(ex):
+ if "Cannot add already existing column family" not in six.text_type(ex):
raise
else:
log.debug(format_log_context("sync_table checking existing table %s", keyspace=ks_name, connection=connection), cf_name)
table_meta = tables[raw_cf_name]
_validate_pk(model, table_meta)
table_columns = table_meta.columns
model_fields = set()
for model_name, col in model._columns.items():
db_name = col.db_field_name
model_fields.add(db_name)
if db_name in table_columns:
col_meta = table_columns[db_name]
if col_meta.cql_type != col.db_type:
msg = format_log_context('Existing table {0} has column "{1}" with a type ({2}) differing from the model type ({3}).'
' Model should be updated.', keyspace=ks_name, connection=connection)
msg = msg.format(cf_name, db_name, col_meta.cql_type, col.db_type)
warnings.warn(msg)
log.warning(msg)
continue
if col.primary_key or col.primary_key:
msg = format_log_context("Cannot add primary key '{0}' (with db_field '{1}') to existing table {2}", keyspace=ks_name, connection=connection)
raise CQLEngineException(msg.format(model_name, db_name, cf_name))
query = "ALTER TABLE {0} add {1}".format(cf_name, col.get_column_def())
execute(query, connection=connection)
db_fields_not_in_model = model_fields.symmetric_difference(table_columns)
if db_fields_not_in_model:
msg = format_log_context("Table {0} has fields not referenced by model: {1}", keyspace=ks_name, connection=connection)
log.info(msg.format(cf_name, db_fields_not_in_model))
_update_options(model, connection=connection)
table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name]
indexes = [c for n, c in model._columns.items() if c.index]
# TODO: support multiple indexes in C* 3.0+
for column in indexes:
index_name = _get_index_name_by_column(table, column.db_field_name)
if index_name:
continue
qs = ['CREATE INDEX']
qs += ['ON {0}'.format(cf_name)]
qs += ['("{0}")'.format(column.db_field_name)]
qs = ' '.join(qs)
execute(qs, connection=connection)
def _validate_pk(model, table_meta):
model_partition = [c.db_field_name for c in model._partition_keys.values()]
meta_partition = [c.name for c in table_meta.partition_key]
model_clustering = [c.db_field_name for c in model._clustering_keys.values()]
meta_clustering = [c.name for c in table_meta.clustering_key]
if model_partition != meta_partition or model_clustering != meta_clustering:
def _pk_string(partition, clustering):
return "PRIMARY KEY (({0}){1})".format(', '.join(partition), ', ' + ', '.join(clustering) if clustering else '')
raise CQLEngineException("Model {0} PRIMARY KEY composition does not match existing table {1}. "
"Model: {2}; Table: {3}. "
"Update model or drop the table.".format(model, model.column_family_name(),
_pk_string(model_partition, model_clustering),
_pk_string(meta_partition, meta_clustering)))
def sync_type(ks_name, type_model, connection=None):
"""
Inspects the type_model and creates / updates the corresponding type.
Note that the attributes removed from the type_model are not deleted on the database (this operation is not supported).
They become effectively ignored by (will not show up on) the type_model.
**This function should be used with caution, especially in production environments.
Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).**
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
"""
if not _allow_schema_modification():
return
if not issubclass(type_model, UserType):
raise CQLEngineException("Types must be derived from base UserType.")
_sync_type(ks_name, type_model, connection=connection)
def _sync_type(ks_name, type_model, omit_subtypes=None, connection=None):
syncd_sub_types = omit_subtypes or set()
for field in type_model._fields.values():
udts = []
columns.resolve_udts(field, udts)
for udt in [u for u in udts if u not in syncd_sub_types]:
_sync_type(ks_name, udt, syncd_sub_types, connection=connection)
syncd_sub_types.add(udt)
type_name = type_model.type_name()
type_name_qualified = "%s.%s" % (ks_name, type_name)
cluster = get_cluster(connection)
keyspace = cluster.metadata.keyspaces[ks_name]
defined_types = keyspace.user_types
if type_name not in defined_types:
log.debug(format_log_context("sync_type creating new type %s", keyspace=ks_name, connection=connection), type_name_qualified)
cql = get_create_type(type_model, ks_name)
execute(cql, connection=connection)
cluster.refresh_user_type_metadata(ks_name, type_name)
type_model.register_for_keyspace(ks_name, connection=connection)
else:
type_meta = defined_types[type_name]
defined_fields = type_meta.field_names
model_fields = set()
for field in type_model._fields.values():
model_fields.add(field.db_field_name)
if field.db_field_name not in defined_fields:
execute("ALTER TYPE {0} ADD {1}".format(type_name_qualified, field.get_column_def()), connection=connection)
else:
field_type = type_meta.field_types[defined_fields.index(field.db_field_name)]
if field_type != field.db_type:
msg = format_log_context('Existing user type {0} has field "{1}" with a type ({2}) differing from the model user type ({3}).'
' UserType should be updated.', keyspace=ks_name, connection=connection)
msg = msg.format(type_name_qualified, field.db_field_name, field_type, field.db_type)
warnings.warn(msg)
log.warning(msg)
type_model.register_for_keyspace(ks_name, connection=connection)
if len(defined_fields) == len(model_fields):
log.info(format_log_context("Type %s did not require synchronization", keyspace=ks_name, connection=connection), type_name_qualified)
return
db_fields_not_in_model = model_fields.symmetric_difference(defined_fields)
if db_fields_not_in_model:
msg = format_log_context("Type %s has fields not referenced by model: %s", keyspace=ks_name, connection=connection)
log.info(msg, type_name_qualified, db_fields_not_in_model)
def get_create_type(type_model, keyspace):
type_meta = metadata.UserType(keyspace,
type_model.type_name(),
(f.db_field_name for f in type_model._fields.values()),
(v.db_type for v in type_model._fields.values()))
return type_meta.as_cql_query()
def _get_create_table(model):
ks_table_name = model.column_family_name()
query_strings = ['CREATE TABLE {0}'.format(ks_table_name)]
# add column types
pkeys = [] # primary keys
ckeys = [] # clustering keys
qtypes = [] # field types
def add_column(col):
s = col.get_column_def()
if col.primary_key:
keys = (pkeys if col.partition_key else ckeys)
keys.append('"{0}"'.format(col.db_field_name))
qtypes.append(s)
for name, col in model._columns.items():
add_column(col)
qtypes.append('PRIMARY KEY (({0}){1})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or ''))
query_strings += ['({0})'.format(', '.join(qtypes))]
property_strings = []
_order = ['"{0}" {1}'.format(c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()]
if _order:
property_strings.append('CLUSTERING ORDER BY ({0})'.format(', '.join(_order)))
# options strings use the V3 format, which matches CQL more closely and does not require mapping
property_strings += metadata.TableMetadataV3._make_option_strings(model.__options__ or {})
if property_strings:
query_strings += ['WITH {0}'.format(' AND '.join(property_strings))]
return ' '.join(query_strings)
def _get_table_metadata(model, connection=None):
# returns the table as provided by the native driver for a given model
cluster = get_cluster(connection)
ks = model._get_keyspace()
table = model._raw_column_family_name()
table = cluster.metadata.keyspaces[ks].tables[table]
return table
def _options_map_from_strings(option_strings):
# converts options strings to a mapping to strings or dict
options = {}
for option in option_strings:
name, value = option.split('=')
i = value.find('{')
if i >= 0:
value = value[i:value.rfind('}') + 1].replace("'", '"') # from cql single quotes to json double; not aware of any values that would be escaped right now
value = json.loads(value)
else:
value = value.strip()
options[name.strip()] = value
return options
def _update_options(model, connection=None):
"""Updates the table options for the given model if necessary.
:param model: The model to update.
:param connection: Name of the connection to use
:return: `True`, if the options were modified in Cassandra,
`False` otherwise.
:rtype: bool
"""
ks_name = model._get_keyspace()
msg = format_log_context("Checking %s for option differences", keyspace=ks_name, connection=connection)
log.debug(msg, model)
model_options = model.__options__ or {}
table_meta = _get_table_metadata(model, connection=connection)
# go to CQL string first to normalize meta from different versions
existing_option_strings = set(table_meta._make_option_strings(table_meta.options))
existing_options = _options_map_from_strings(existing_option_strings)
model_option_strings = metadata.TableMetadataV3._make_option_strings(model_options)
model_options = _options_map_from_strings(model_option_strings)
update_options = {}
for name, value in model_options.items():
try:
existing_value = existing_options[name]
except KeyError:
msg = format_log_context("Invalid table option: '%s'; known options: %s", keyspace=ks_name, connection=connection)
raise KeyError(msg % (name, existing_options.keys()))
if isinstance(existing_value, six.string_types):
if value != existing_value:
update_options[name] = value
else:
try:
for k, v in value.items():
if existing_value[k] != v:
update_options[name] = value
break
except KeyError:
update_options[name] = value
if update_options:
options = ' AND '.join(metadata.TableMetadataV3._make_option_strings(update_options))
query = "ALTER TABLE {0} WITH {1}".format(model.column_family_name(), options)
execute(query, connection=connection)
return True
return False
def drop_table(model, keyspaces=None, connections=None):
"""
Drops the table indicated by the model, if it exists.
If `keyspaces` is specified, the table will be dropped for all specified keyspaces. Note that the `Model.__keyspace__` is ignored in that case.
If `connections` is specified, the table will be synched for all specified connections. Note that the `Model.__connection__` is ignored in that case.
If not specified, it will try to get the connection from the Model.
**This function should be used with caution, especially in production environments.
Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).**
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
"""
context = _get_context(keyspaces, connections)
for connection, keyspace in context:
with query.ContextQuery(model, keyspace=keyspace) as m:
_drop_table(m, connection=connection)
def _drop_table(model, connection=None):
if not _allow_schema_modification():
return
connection = connection or model._get_connection()
# don't try to delete non existant tables
meta = get_cluster(connection).metadata
ks_name = model._get_keyspace()
raw_cf_name = model._raw_column_family_name()
try:
meta.keyspaces[ks_name].tables[raw_cf_name]
execute('DROP TABLE {0};'.format(model.column_family_name()), connection=connection)
except KeyError:
pass
def _allow_schema_modification():
if not os.getenv(CQLENG_ALLOW_SCHEMA_MANAGEMENT):
msg = CQLENG_ALLOW_SCHEMA_MANAGEMENT + " environment variable is not set. Future versions of this package will require this variable to enable management functions."
warnings.warn(msg)
log.warning(msg)
return True
diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py
index 9fe5d3e..b3c7c9e 100644
--- a/cassandra/cqlengine/models.py
+++ b/cassandra/cqlengine/models.py
@@ -1,1088 +1,1088 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import six
from warnings import warn
from cassandra.cqlengine import CQLEngineException, ValidationError
from cassandra.cqlengine import columns
from cassandra.cqlengine import connection
from cassandra.cqlengine import query
from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
from cassandra.metadata import protect_name
from cassandra.util import OrderedDict
log = logging.getLogger(__name__)
def _clone_model_class(model, attrs):
new_type = type(model.__name__, (model,), attrs)
try:
new_type.__abstract__ = model.__abstract__
new_type.__discriminator_value__ = model.__discriminator_value__
new_type.__default_ttl__ = model.__default_ttl__
except AttributeError:
pass
return new_type
class ModelException(CQLEngineException):
pass
class ModelDefinitionException(ModelException):
pass
class PolymorphicModelException(ModelException):
pass
class UndefinedKeyspaceWarning(Warning):
pass
DEFAULT_KEYSPACE = None
class hybrid_classmethod(object):
"""
Allows a method to behave as both a class method and
normal instance method depending on how it's called
"""
def __init__(self, clsmethod, instmethod):
self.clsmethod = clsmethod
self.instmethod = instmethod
def __get__(self, instance, owner):
if instance is None:
return self.clsmethod.__get__(owner, owner)
else:
return self.instmethod.__get__(instance, owner)
def __call__(self, *args, **kwargs):
"""
Just a hint to IDEs that it's ok to call this
"""
raise NotImplementedError
class QuerySetDescriptor(object):
"""
returns a fresh queryset for the given model
it's declared on everytime it's accessed
"""
def __get__(self, obj, model):
""" :rtype: ModelQuerySet """
if model.__abstract__:
raise CQLEngineException('cannot execute queries against abstract models')
queryset = model.__queryset__(model)
# if this is a concrete polymorphic model, and the discriminator
# key is an indexed column, add a filter clause to only return
# logical rows of the proper type
if model._is_polymorphic and not model._is_polymorphic_base:
name, column = model._discriminator_column_name, model._discriminator_column
if column.partition_key or column.index:
# look for existing poly types
return queryset.filter(**{name: model.__discriminator_value__})
return queryset
def __call__(self, *args, **kwargs):
"""
Just a hint to IDEs that it's ok to call this
:rtype: ModelQuerySet
"""
raise NotImplementedError
class ConditionalDescriptor(object):
"""
returns a query set descriptor
"""
def __get__(self, instance, model):
if instance:
def conditional_setter(*prepared_conditional, **unprepared_conditionals):
if len(prepared_conditional) > 0:
conditionals = prepared_conditional[0]
else:
conditionals = instance.objects.iff(**unprepared_conditionals)._conditional
instance._conditional = conditionals
return instance
return conditional_setter
qs = model.__queryset__(model)
def conditional_setter(**unprepared_conditionals):
conditionals = model.objects.iff(**unprepared_conditionals)._conditional
qs._conditional = conditionals
return qs
return conditional_setter
def __call__(self, *args, **kwargs):
raise NotImplementedError
class TTLDescriptor(object):
"""
returns a query set descriptor
"""
def __get__(self, instance, model):
if instance:
# instance = copy.deepcopy(instance)
# instance method
def ttl_setter(ts):
instance._ttl = ts
return instance
return ttl_setter
qs = model.__queryset__(model)
def ttl_setter(ts):
qs._ttl = ts
return qs
return ttl_setter
def __call__(self, *args, **kwargs):
raise NotImplementedError
class TimestampDescriptor(object):
"""
returns a query set descriptor with a timestamp specified
"""
def __get__(self, instance, model):
if instance:
# instance method
def timestamp_setter(ts):
instance._timestamp = ts
return instance
return timestamp_setter
return model.objects.timestamp
def __call__(self, *args, **kwargs):
raise NotImplementedError
class IfNotExistsDescriptor(object):
"""
return a query set descriptor with a if_not_exists flag specified
"""
def __get__(self, instance, model):
if instance:
# instance method
def ifnotexists_setter(ife=True):
instance._if_not_exists = ife
return instance
return ifnotexists_setter
return model.objects.if_not_exists
def __call__(self, *args, **kwargs):
raise NotImplementedError
class IfExistsDescriptor(object):
"""
return a query set descriptor with a if_exists flag specified
"""
def __get__(self, instance, model):
if instance:
# instance method
def ifexists_setter(ife=True):
instance._if_exists = ife
return instance
return ifexists_setter
return model.objects.if_exists
def __call__(self, *args, **kwargs):
raise NotImplementedError
class ConsistencyDescriptor(object):
"""
returns a query set descriptor if called on Class, instance if it was an instance call
"""
def __get__(self, instance, model):
if instance:
# instance = copy.deepcopy(instance)
def consistency_setter(consistency):
instance.__consistency__ = consistency
return instance
return consistency_setter
qs = model.__queryset__(model)
def consistency_setter(consistency):
qs._consistency = consistency
return qs
return consistency_setter
def __call__(self, *args, **kwargs):
raise NotImplementedError
class UsingDescriptor(object):
"""
return a query set descriptor with a connection context specified
"""
def __get__(self, instance, model):
if instance:
# instance method
def using_setter(connection=None):
if connection:
instance._connection = connection
return instance
return using_setter
return model.objects.using
def __call__(self, *args, **kwargs):
raise NotImplementedError
class ColumnQueryEvaluator(query.AbstractQueryableColumn):
"""
Wraps a column and allows it to be used in comparator
expressions, returning query operators
ie:
Model.column == 5
"""
def __init__(self, column):
self.column = column
def __unicode__(self):
return self.column.db_field_name
def _get_column(self):
return self.column
class ColumnDescriptor(object):
"""
Handles the reading and writing of column values to and from
a model instance's value manager, as well as creating
comparator queries
"""
def __init__(self, column):
"""
:param column:
:type column: columns.Column
:return:
"""
self.column = column
self.query_evaluator = ColumnQueryEvaluator(self.column)
def __get__(self, instance, owner):
"""
Returns either the value or column, depending
on if an instance is provided or not
:param instance: the model instance
:type instance: Model
"""
try:
return instance._values[self.column.column_name].getval()
except AttributeError:
return self.query_evaluator
def __set__(self, instance, value):
"""
Sets the value on an instance, raises an exception with classes
TODO: use None instance to create update statements
"""
if instance:
return instance._values[self.column.column_name].setval(value)
else:
raise AttributeError('cannot reassign column values')
def __delete__(self, instance):
"""
Sets the column value to None, if possible
"""
if instance:
if self.column.can_delete:
instance._values[self.column.column_name].delval()
else:
raise AttributeError('cannot delete {0} columns'.format(self.column.column_name))
class BaseModel(object):
"""
The base model class, don't inherit from this, inherit from Model, defined below
"""
class DoesNotExist(_DoesNotExist):
pass
class MultipleObjectsReturned(_MultipleObjectsReturned):
pass
objects = QuerySetDescriptor()
ttl = TTLDescriptor()
consistency = ConsistencyDescriptor()
iff = ConditionalDescriptor()
# custom timestamps, see USING TIMESTAMP X
timestamp = TimestampDescriptor()
if_not_exists = IfNotExistsDescriptor()
if_exists = IfExistsDescriptor()
using = UsingDescriptor()
# _len is lazily created by __len__
__table_name__ = None
__table_name_case_sensitive__ = False
__keyspace__ = None
__connection__ = None
__discriminator_value__ = None
__options__ = None
__compute_routing_key__ = True
# the queryset class used for this class
__queryset__ = query.ModelQuerySet
__dmlquery__ = query.DMLQuery
__consistency__ = None # can be set per query
_timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP)
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
_if_exists = False # optional if_exists flag to check existence before update
_table_name = None # used internally to cache a derived table name
_connection = None
def __init__(self, **values):
self._ttl = None
self._timestamp = None
self._conditional = None
self._batch = None
self._timeout = connection.NOT_SET
self._is_persisted = False
self._connection = None
self._values = {}
for name, column in self._columns.items():
# Set default values on instantiation. Thanks to this, we don't have
# to wait anylonger for a call to validate() to have CQLengine set
# default columns values.
column_default = column.get_default() if column.has_default else None
value = values.get(name, column_default)
if value is not None or isinstance(column, columns.BaseContainerColumn):
value = column.to_python(value)
value_mngr = column.value_manager(self, column, value)
value_mngr.explicit = name in values
self._values[name] = value_mngr
def __repr__(self):
return '{0}({1})'.format(self.__class__.__name__,
', '.join('{0}={1!r}'.format(k, getattr(self, k))
for k in self._defined_columns.keys()
if k != self._discriminator_column_name))
def __str__(self):
"""
Pretty printing of models by their primary key
"""
return '{0} <{1}>'.format(self.__class__.__name__,
', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys()))
@classmethod
def _routing_key_from_values(cls, pk_values, protocol_version):
return cls._key_serializer(pk_values, protocol_version)
@classmethod
def _discover_polymorphic_submodels(cls):
if not cls._is_polymorphic_base:
raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes')
def _discover(klass):
if not klass._is_polymorphic_base and klass.__discriminator_value__ is not None:
cls._discriminator_map[klass.__discriminator_value__] = klass
for subklass in klass.__subclasses__():
_discover(subklass)
_discover(cls)
@classmethod
def _get_model_by_discriminator_value(cls, key):
if not cls._is_polymorphic_base:
raise ModelException('_get_model_by_discriminator_value can only be called on polymorphic base classes')
return cls._discriminator_map.get(key)
@classmethod
def _construct_instance(cls, values):
"""
method used to construct instances from query results
this is where polymorphic deserialization occurs
"""
# we're going to take the values, which is from the DB as a dict
# and translate that into our local fields
# the db_map is a db_field -> model field map
if cls._db_map:
values = dict((cls._db_map.get(k, k), v) for k, v in values.items())
if cls._is_polymorphic:
disc_key = values.get(cls._discriminator_column_name)
if disc_key is None:
raise PolymorphicModelException('discriminator value was not found in values')
poly_base = cls if cls._is_polymorphic_base else cls._polymorphic_base
klass = poly_base._get_model_by_discriminator_value(disc_key)
if klass is None:
poly_base._discover_polymorphic_submodels()
klass = poly_base._get_model_by_discriminator_value(disc_key)
if klass is None:
raise PolymorphicModelException(
'unrecognized discriminator column {0} for class {1}'.format(disc_key, poly_base.__name__)
)
if not issubclass(klass, cls):
raise PolymorphicModelException(
'{0} is not a subclass of {1}'.format(klass.__name__, cls.__name__)
)
values = dict((k, v) for k, v in values.items() if k in klass._columns.keys())
else:
klass = cls
instance = klass(**values)
instance._set_persisted(force=True)
return instance
def _set_persisted(self, force=False):
# ensure we don't modify to any values not affected by the last save/update
for v in [v for v in self._values.values() if v.changed or force]:
v.reset_previous_value()
v.explicit = False
self._is_persisted = True
def _can_update(self):
"""
Called by the save function to check if this should be
persisted with update or insert
:return:
"""
if not self._is_persisted:
return False
- return all(not self._values[k].changed for k in self._primary_keys)
+ return all([not self._values[k].changed for k in self._primary_keys])
@classmethod
def _get_keyspace(cls):
"""
Returns the manual keyspace, if set, otherwise the default keyspace
"""
return cls.__keyspace__ or DEFAULT_KEYSPACE
@classmethod
def _get_column(cls, name):
"""
Returns the column matching the given name, raising a key error if
it doesn't exist
:param name: the name of the column to return
:rtype: Column
"""
return cls._columns[name]
@classmethod
def _get_column_by_db_name(cls, name):
"""
Returns the column, mapped by db_field name
"""
return cls._columns.get(cls._db_map.get(name, name))
def __eq__(self, other):
if self.__class__ != other.__class__:
return False
# check attribute keys
keys = set(self._columns.keys())
other_keys = set(other._columns.keys())
if keys != other_keys:
return False
return all(getattr(self, key, None) == getattr(other, key, None) for key in other_keys)
def __ne__(self, other):
return not self.__eq__(other)
@classmethod
def column_family_name(cls, include_keyspace=True):
"""
Returns the column family name if it's been defined
otherwise, it creates it from the module and class name
"""
cf_name = protect_name(cls._raw_column_family_name())
if include_keyspace:
keyspace = cls._get_keyspace()
if not keyspace:
raise CQLEngineException("Model keyspace is not set and no default is available. Set model keyspace or setup connection before attempting to generate a query.")
return '{0}.{1}'.format(protect_name(keyspace), cf_name)
return cf_name
@classmethod
def _raw_column_family_name(cls):
if not cls._table_name:
if cls.__table_name__:
if cls.__table_name_case_sensitive__:
warn("Model __table_name_case_sensitive__ will be removed in 4.0.", PendingDeprecationWarning)
cls._table_name = cls.__table_name__
else:
table_name = cls.__table_name__.lower()
if cls.__table_name__ != table_name:
warn(("Model __table_name__ will be case sensitive by default in 4.0. "
"You should fix the __table_name__ value of the '{0}' model.").format(cls.__name__))
cls._table_name = table_name
else:
if cls._is_polymorphic and not cls._is_polymorphic_base:
cls._table_name = cls._polymorphic_base._raw_column_family_name()
else:
camelcase = re.compile(r'([a-z])([A-Z])')
ccase = lambda s: camelcase.sub(lambda v: '{0}_{1}'.format(v.group(1), v.group(2).lower()), s)
cf_name = ccase(cls.__name__)
# trim to less than 48 characters or cassandra will complain
cf_name = cf_name[-48:]
cf_name = cf_name.lower()
cf_name = re.sub(r'^_+', '', cf_name)
cls._table_name = cf_name
return cls._table_name
def _set_column_value(self, name, value):
"""Function to change a column value without changing the value manager states"""
self._values[name].value = value # internal assignement, skip the main setter
def validate(self):
"""
Cleans and validates the field values
"""
for name, col in self._columns.items():
v = getattr(self, name)
if v is None and not self._values[name].explicit and col.has_default:
v = col.get_default()
val = col.validate(v)
self._set_column_value(name, val)
# Let an instance be used like a dict of its columns keys/values
def __iter__(self):
""" Iterate over column ids. """
for column_id in self._columns.keys():
yield column_id
def __getitem__(self, key):
""" Returns column's value. """
if not isinstance(key, six.string_types):
raise TypeError
if key not in self._columns.keys():
raise KeyError
return getattr(self, key)
def __setitem__(self, key, val):
""" Sets a column's value. """
if not isinstance(key, six.string_types):
raise TypeError
if key not in self._columns.keys():
raise KeyError
return setattr(self, key, val)
def __len__(self):
"""
Returns the number of columns defined on that model.
"""
try:
return self._len
except:
self._len = len(self._columns.keys())
return self._len
def keys(self):
""" Returns a list of column IDs. """
return [k for k in self]
def values(self):
""" Returns list of column values. """
return [self[k] for k in self]
def items(self):
""" Returns a list of column ID/value tuples. """
return [(k, self[k]) for k in self]
def _as_dict(self):
""" Returns a map of column names to cleaned values """
values = self._dynamic_columns or {}
for name, col in self._columns.items():
values[name] = col.to_database(getattr(self, name, None))
return values
@classmethod
def create(cls, **kwargs):
"""
Create an instance of this model in the database.
Takes the model column values as keyword arguments. Setting a value to
`None` is equivalent to running a CQL `DELETE` on that column.
Returns the instance.
"""
extra_columns = set(kwargs.keys()) - set(cls._columns.keys())
if extra_columns:
raise ValidationError("Incorrect columns passed: {0}".format(extra_columns))
return cls.objects.create(**kwargs)
@classmethod
def all(cls):
"""
Returns a queryset representing all stored objects
This is a pass-through to the model objects().all()
"""
return cls.objects.all()
@classmethod
def filter(cls, *args, **kwargs):
"""
Returns a queryset based on filter parameters.
This is a pass-through to the model objects().:method:`~cqlengine.queries.filter`.
"""
return cls.objects.filter(*args, **kwargs)
@classmethod
def get(cls, *args, **kwargs):
"""
Returns a single object based on the passed filter constraints.
This is a pass-through to the model objects().:method:`~cqlengine.queries.get`.
"""
return cls.objects.get(*args, **kwargs)
def timeout(self, timeout):
"""
Sets a timeout for use in :meth:`~.save`, :meth:`~.update`, and :meth:`~.delete`
operations
"""
assert self._batch is None, 'Setting both timeout and batch is not supported'
self._timeout = timeout
return self
def save(self):
"""
Saves an object to the database.
.. code-block:: python
#create a person instance
person = Person(first_name='Kimberly', last_name='Eggleston')
#saves it to Cassandra
person.save()
"""
# handle polymorphic models
if self._is_polymorphic:
if self._is_polymorphic_base:
raise PolymorphicModelException('cannot save polymorphic base model')
else:
setattr(self, self._discriminator_column_name, self.__discriminator_value__)
self.validate()
self.__dmlquery__(self.__class__, self,
batch=self._batch,
ttl=self._ttl,
timestamp=self._timestamp,
consistency=self.__consistency__,
if_not_exists=self._if_not_exists,
conditional=self._conditional,
timeout=self._timeout,
if_exists=self._if_exists).save()
self._set_persisted()
self._timestamp = None
return self
def update(self, **values):
"""
Performs an update on the model instance. You can pass in values to set on the model
for updating, or you can call without values to execute an update against any modified
fields. If no fields on the model have been modified since loading, no query will be
performed. Model validation is performed normally. Setting a value to `None` is
equivalent to running a CQL `DELETE` on that column.
It is possible to do a blind update, that is, to update a field without having first selected the object out of the database.
See :ref:`Blind Updates `
"""
for column_id, v in values.items():
col = self._columns.get(column_id)
# check for nonexistant columns
if col is None:
raise ValidationError(
"{0}.{1} has no column named: {2}".format(
self.__module__, self.__class__.__name__, column_id))
# check for primary key update attempts
if col.is_primary_key:
current_value = getattr(self, column_id)
if v != current_value:
raise ValidationError(
"Cannot apply update to primary key '{0}' for {1}.{2}".format(
column_id, self.__module__, self.__class__.__name__))
setattr(self, column_id, v)
# handle polymorphic models
if self._is_polymorphic:
if self._is_polymorphic_base:
raise PolymorphicModelException('cannot update polymorphic base model')
else:
setattr(self, self._discriminator_column_name, self.__discriminator_value__)
self.validate()
self.__dmlquery__(self.__class__, self,
batch=self._batch,
ttl=self._ttl,
timestamp=self._timestamp,
consistency=self.__consistency__,
conditional=self._conditional,
timeout=self._timeout,
if_exists=self._if_exists).update()
self._set_persisted()
self._timestamp = None
return self
def delete(self):
"""
Deletes the object from the database
"""
self.__dmlquery__(self.__class__, self,
batch=self._batch,
timestamp=self._timestamp,
consistency=self.__consistency__,
timeout=self._timeout,
conditional=self._conditional,
if_exists=self._if_exists).delete()
def get_changed_columns(self):
"""
Returns a list of the columns that have been updated since instantiation or save
"""
return [k for k, v in self._values.items() if v.changed]
@classmethod
def _class_batch(cls, batch):
return cls.objects.batch(batch)
def _inst_batch(self, batch):
assert self._timeout is connection.NOT_SET, 'Setting both timeout and batch is not supported'
if self._connection:
raise CQLEngineException("Cannot specify a connection on model in batch mode.")
self._batch = batch
return self
batch = hybrid_classmethod(_class_batch, _inst_batch)
@classmethod
def _class_get_connection(cls):
return cls.__connection__
def _inst_get_connection(self):
return self._connection or self.__connection__
_get_connection = hybrid_classmethod(_class_get_connection, _inst_get_connection)
class ModelMetaClass(type):
def __new__(cls, name, bases, attrs):
# move column definitions into columns dict
# and set default column names
column_dict = OrderedDict()
primary_keys = OrderedDict()
pk_name = None
# get inherited properties
inherited_columns = OrderedDict()
for base in bases:
for k, v in getattr(base, '_defined_columns', {}).items():
inherited_columns.setdefault(k, v)
# short circuit __abstract__ inheritance
is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False)
# short circuit __discriminator_value__ inheritance
attrs['__discriminator_value__'] = attrs.get('__discriminator_value__')
# TODO __default__ttl__ should be removed in the next major release
options = attrs.get('__options__') or {}
attrs['__default_ttl__'] = options.get('default_time_to_live')
column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)]
column_definitions = sorted(column_definitions, key=lambda x: x[1].position)
is_polymorphic_base = any([c[1].discriminator_column for c in column_definitions])
column_definitions = [x for x in inherited_columns.items()] + column_definitions
discriminator_columns = [c for c in column_definitions if c[1].discriminator_column]
is_polymorphic = len(discriminator_columns) > 0
if len(discriminator_columns) > 1:
raise ModelDefinitionException('only one discriminator_column can be defined in a model, {0} found'.format(len(discriminator_columns)))
if attrs['__discriminator_value__'] and not is_polymorphic:
raise ModelDefinitionException('__discriminator_value__ specified, but no base columns defined with discriminator_column=True')
discriminator_column_name, discriminator_column = discriminator_columns[0] if discriminator_columns else (None, None)
if isinstance(discriminator_column, (columns.BaseContainerColumn, columns.Counter)):
raise ModelDefinitionException('counter and container columns cannot be used as discriminator columns')
# find polymorphic base class
polymorphic_base = None
if is_polymorphic and not is_polymorphic_base:
def _get_polymorphic_base(bases):
for base in bases:
if getattr(base, '_is_polymorphic_base', False):
return base
klass = _get_polymorphic_base(base.__bases__)
if klass:
return klass
polymorphic_base = _get_polymorphic_base(bases)
defined_columns = OrderedDict(column_definitions)
# check for primary key
if not is_abstract and not any([v.primary_key for k, v in column_definitions]):
raise ModelDefinitionException("At least 1 primary key is required.")
counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)]
data_columns = [c for c in defined_columns.values() if not c.primary_key and not isinstance(c, columns.Counter)]
if counter_columns and data_columns:
raise ModelDefinitionException('counter models may not have data columns')
has_partition_keys = any(v.partition_key for (k, v) in column_definitions)
def _transform_column(col_name, col_obj):
column_dict[col_name] = col_obj
if col_obj.primary_key:
primary_keys[col_name] = col_obj
col_obj.set_column_name(col_name)
# set properties
attrs[col_name] = ColumnDescriptor(col_obj)
partition_key_index = 0
# transform column definitions
for k, v in column_definitions:
# don't allow a column with the same name as a built-in attribute or method
if k in BaseModel.__dict__:
raise ModelDefinitionException("column '{0}' conflicts with built-in attribute/method".format(k))
# counter column primary keys are not allowed
if (v.primary_key or v.partition_key) and isinstance(v, columns.Counter):
raise ModelDefinitionException('counter columns cannot be used as primary keys')
# this will mark the first primary key column as a partition
# key, if one hasn't been set already
if not has_partition_keys and v.primary_key:
v.partition_key = True
has_partition_keys = True
if v.partition_key:
v._partition_key_index = partition_key_index
partition_key_index += 1
overriding = column_dict.get(k)
if overriding:
v.position = overriding.position
v.partition_key = overriding.partition_key
v._partition_key_index = overriding._partition_key_index
_transform_column(k, v)
partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key)
clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key)
if attrs.get('__compute_routing_key__', True):
key_cols = [c for c in partition_keys.values()]
partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols)
key_cql_types = [c.cql_type for c in key_cols]
key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)])
else:
partition_key_index = {}
key_serializer = staticmethod(lambda parts, proto_version: None)
# setup partition key shortcut
if len(partition_keys) == 0:
if not is_abstract:
raise ModelException("at least one partition key must be defined")
if len(partition_keys) == 1:
pk_name = [x for x in partition_keys.keys()][0]
attrs['pk'] = attrs[pk_name]
else:
# composite partition key case, get/set a tuple of values
_get = lambda self: tuple(self._values[c].getval() for c in partition_keys.keys())
_set = lambda self, val: tuple(self._values[c].setval(v) for (c, v) in zip(partition_keys.keys(), val))
attrs['pk'] = property(_get, _set)
# some validation
col_names = set()
for v in column_dict.values():
# check for duplicate column names
if v.db_field_name in col_names:
raise ModelException("{0} defines the column '{1}' more than once".format(name, v.db_field_name))
if v.clustering_order and not (v.primary_key and not v.partition_key):
raise ModelException("clustering_order may be specified only for clustering primary keys")
if v.clustering_order and v.clustering_order.lower() not in ('asc', 'desc'):
raise ModelException("invalid clustering order '{0}' for column '{1}'".format(repr(v.clustering_order), v.db_field_name))
col_names.add(v.db_field_name)
# create db_name -> model name map for loading
db_map = {}
for col_name, field in column_dict.items():
db_field = field.db_field_name
if db_field != col_name:
db_map[db_field] = col_name
# add management members to the class
attrs['_columns'] = column_dict
attrs['_primary_keys'] = primary_keys
attrs['_defined_columns'] = defined_columns
# maps the database field to the models key
attrs['_db_map'] = db_map
attrs['_pk_name'] = pk_name
attrs['_dynamic_columns'] = {}
attrs['_partition_keys'] = partition_keys
attrs['_partition_key_index'] = partition_key_index
attrs['_key_serializer'] = key_serializer
attrs['_clustering_keys'] = clustering_keys
attrs['_has_counter'] = len(counter_columns) > 0
# add polymorphic management attributes
attrs['_is_polymorphic_base'] = is_polymorphic_base
attrs['_is_polymorphic'] = is_polymorphic
attrs['_polymorphic_base'] = polymorphic_base
attrs['_discriminator_column'] = discriminator_column
attrs['_discriminator_column_name'] = discriminator_column_name
attrs['_discriminator_map'] = {} if is_polymorphic_base else None
# setup class exceptions
DoesNotExistBase = None
for base in bases:
DoesNotExistBase = getattr(base, 'DoesNotExist', None)
if DoesNotExistBase is not None:
break
DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist)
attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {})
MultipleObjectsReturnedBase = None
for base in bases:
MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None)
if MultipleObjectsReturnedBase is not None:
break
MultipleObjectsReturnedBase = MultipleObjectsReturnedBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned)
attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {})
# create the class and add a QuerySet to it
klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs)
udts = []
for col in column_dict.values():
columns.resolve_udts(col, udts)
for user_type in set(udts):
user_type.register_for_keyspace(klass._get_keyspace())
return klass
@six.add_metaclass(ModelMetaClass)
class Model(BaseModel):
__abstract__ = True
"""
*Optional.* Indicates that this model is only intended to be used as a base class for other models.
You can't create tables for abstract models, but checks around schema validity are skipped during class construction.
"""
__table_name__ = None
"""
*Optional.* Sets the name of the CQL table for this model. If left blank, the table name will be the name of the model, with it's module name as it's prefix. Manually defined table names are not inherited.
"""
__table_name_case_sensitive__ = False
"""
*Optional.* By default, __table_name__ is case insensitive. Set this to True if you want to preserve the case sensitivity.
"""
__keyspace__ = None
"""
Sets the name of the keyspace used by this model.
"""
__connection__ = None
"""
Sets the name of the default connection used by this model.
"""
__options__ = None
"""
*Optional* Table options applied with this model
(e.g. compaction, default ttl, cache settings, tec.)
"""
__discriminator_value__ = None
"""
*Optional* Specifies a value for the discriminator column when using model inheritance.
"""
__compute_routing_key__ = True
"""
*Optional* Setting False disables computing the routing key for TokenAwareRouting
"""
diff --git a/cassandra/cqlengine/usertype.py b/cassandra/cqlengine/usertype.py
index adf3f5e..155068d 100644
--- a/cassandra/cqlengine/usertype.py
+++ b/cassandra/cqlengine/usertype.py
@@ -1,215 +1,229 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import re
import six
from cassandra.util import OrderedDict
from cassandra.cqlengine import CQLEngineException
from cassandra.cqlengine import columns
from cassandra.cqlengine import connection as conn
from cassandra.cqlengine import models
class UserTypeException(CQLEngineException):
pass
class UserTypeDefinitionException(UserTypeException):
pass
class BaseUserType(object):
"""
The base type class; don't inherit from this, inherit from UserType, defined below
"""
__type_name__ = None
_fields = None
_db_map = None
def __init__(self, **values):
self._values = {}
if self._db_map:
values = dict((self._db_map.get(k, k), v) for k, v in values.items())
for name, field in self._fields.items():
field_default = field.get_default() if field.has_default else None
value = values.get(name, field_default)
if value is not None or isinstance(field, columns.BaseContainerColumn):
value = field.to_python(value)
value_mngr = field.value_manager(self, field, value)
value_mngr.explicit = name in values
self._values[name] = value_mngr
def __eq__(self, other):
if self.__class__ != other.__class__:
return False
keys = set(self._fields.keys())
other_keys = set(other._fields.keys())
if keys != other_keys:
return False
for key in other_keys:
if getattr(self, key, None) != getattr(other, key, None):
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in six.iteritems(self._values)))
def has_changed_fields(self):
return any(v.changed for v in self._values.values())
def reset_changed_fields(self):
for v in self._values.values():
v.reset_previous_value()
def __iter__(self):
for field in self._fields.keys():
yield field
def __getattr__(self, attr):
# provides the mapping from db_field to fields
try:
return getattr(self, self._db_map[attr])
except KeyError:
raise AttributeError(attr)
def __getitem__(self, key):
if not isinstance(key, six.string_types):
raise TypeError
if key not in self._fields.keys():
raise KeyError
return getattr(self, key)
def __setitem__(self, key, val):
if not isinstance(key, six.string_types):
raise TypeError
if key not in self._fields.keys():
raise KeyError
return setattr(self, key, val)
def __len__(self):
try:
return self._len
except:
self._len = len(self._fields.keys())
return self._len
def keys(self):
""" Returns a list of column IDs. """
return [k for k in self]
def values(self):
""" Returns list of column values. """
return [self[k] for k in self]
def items(self):
""" Returns a list of column ID/value tuples. """
return [(k, self[k]) for k in self]
@classmethod
def register_for_keyspace(cls, keyspace, connection=None):
conn.register_udt(keyspace, cls.type_name(), cls, connection=connection)
@classmethod
def type_name(cls):
"""
Returns the type name if it's been defined
otherwise, it creates it from the class name
"""
if cls.__type_name__:
type_name = cls.__type_name__.lower()
else:
camelcase = re.compile(r'([a-z])([A-Z])')
ccase = lambda s: camelcase.sub(lambda v: '{0}_{1}'.format(v.group(1), v.group(2)), s)
type_name = ccase(cls.__name__)
# trim to less than 48 characters or cassandra will complain
type_name = type_name[-48:]
type_name = type_name.lower()
type_name = re.sub(r'^_+', '', type_name)
cls.__type_name__ = type_name
return type_name
def validate(self):
"""
Cleans and validates the field values
"""
for name, field in self._fields.items():
v = getattr(self, name)
if v is None and not self._values[name].explicit and field.has_default:
v = field.get_default()
val = field.validate(v)
setattr(self, name, val)
class UserTypeMetaClass(type):
def __new__(cls, name, bases, attrs):
field_dict = OrderedDict()
field_defs = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)]
field_defs = sorted(field_defs, key=lambda x: x[1].position)
def _transform_column(field_name, field_obj):
field_dict[field_name] = field_obj
field_obj.set_column_name(field_name)
attrs[field_name] = models.ColumnDescriptor(field_obj)
# transform field definitions
for k, v in field_defs:
# don't allow a field with the same name as a built-in attribute or method
if k in BaseUserType.__dict__:
raise UserTypeDefinitionException("field '{0}' conflicts with built-in attribute/method".format(k))
_transform_column(k, v)
attrs['_fields'] = field_dict
db_map = {}
for field_name, field in field_dict.items():
db_field = field.db_field_name
if db_field != field_name:
if db_field in field_dict:
raise UserTypeDefinitionException("db_field '{0}' for field '{1}' conflicts with another attribute name".format(db_field, field_name))
db_map[db_field] = field_name
attrs['_db_map'] = db_map
klass = super(UserTypeMetaClass, cls).__new__(cls, name, bases, attrs)
return klass
@six.add_metaclass(UserTypeMetaClass)
class UserType(BaseUserType):
"""
This class is used to model User Defined Types. To define a type, declare a class inheriting from this,
and assign field types as class attributes:
.. code-block:: python
# connect with default keyspace ...
from cassandra.cqlengine.columns import Text, Integer
from cassandra.cqlengine.usertype import UserType
class address(UserType):
street = Text()
zipcode = Integer()
from cassandra.cqlengine import management
management.sync_type(address)
Please see :ref:`user_types` for a complete example and discussion.
"""
__type_name__ = None
"""
*Optional.* Sets the name of the CQL type for this type.
If not specified, the type name will be the name of the class, with it's module name as it's prefix.
"""
diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py
index 55bb022..7946a63 100644
--- a/cassandra/cqltypes.py
+++ b/cassandra/cqltypes.py
@@ -1,1147 +1,1423 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Representation of Cassandra data types. These classes should make it simple for
the library (and caller software) to deal with Cassandra-style Java class type
names and CQL type specifiers, and convert between them cleanly. Parameterized
types are fully supported in both flavors. Once you have the right Type object
for the type you want, you can use it to serialize, deserialize, or retrieve
the corresponding CQL or Cassandra type strings.
"""
# NOTE:
# If/when the need arises for interpret types from CQL string literals in
# different ways (for https://issues.apache.org/jira/browse/CASSANDRA-3799,
# for example), these classes would be a good place to tack on
# .from_cql_literal() and .as_cql_literal() classmethods (or whatever).
from __future__ import absolute_import # to enable import io from stdlib
import ast
from binascii import unhexlify
import calendar
from collections import namedtuple
from decimal import Decimal
import io
+from itertools import chain
import logging
import re
import socket
import time
import six
from six.moves import range
+import struct
import sys
from uuid import UUID
-import warnings
-
-if six.PY3:
- import ipaddress
from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack,
uint16_pack, uint16_unpack, uint32_pack, uint32_unpack,
int32_pack, int32_unpack, int64_pack, int64_unpack,
float_pack, float_unpack, double_pack, double_unpack,
- varint_pack, varint_unpack, vints_pack, vints_unpack)
+ varint_pack, varint_unpack, point_be, point_le,
+ vints_pack, vints_unpack)
from cassandra import util
+_little_endian_flag = 1 # we always serialize LE
+if six.PY3:
+ import ipaddress
+
+_ord = ord if six.PY2 else lambda x: x
+
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
cassandra_empty_type = 'org.apache.cassandra.db.marshal.EmptyType'
cql_empty_type = 'empty'
log = logging.getLogger(__name__)
if six.PY3:
_number_types = frozenset((int, float))
long = int
def _name_from_hex_string(encoded_name):
bin_str = unhexlify(encoded_name)
return bin_str.decode('ascii')
else:
_number_types = frozenset((int, long, float))
_name_from_hex_string = unhexlify
def trim_if_startswith(s, prefix):
if s.startswith(prefix):
return s[len(prefix):]
return s
_casstypes = {}
_cqltypes = {}
cql_type_scanner = re.Scanner((
('frozen', None),
(r'[a-zA-Z0-9_]+', lambda s, t: t),
(r'[\s,<>]', None),
))
def cql_types_from_string(cql_type):
return cql_type_scanner.scan(cql_type)[0]
class CassandraTypeType(type):
"""
The CassandraType objects in this module will normally be used directly,
rather than through instances of those types. They can be instantiated,
of course, but the type information is what this driver mainly needs.
This metaclass registers CassandraType classes in the global
by-cassandra-typename and by-cql-typename registries, unless their class
name starts with an underscore.
"""
def __new__(metacls, name, bases, dct):
dct.setdefault('cassname', name)
cls = type.__new__(metacls, name, bases, dct)
if not name.startswith('_'):
_casstypes[name] = cls
if not cls.typename.startswith(apache_cassandra_type_prefix):
_cqltypes[cls.typename] = cls
return cls
casstype_scanner = re.Scanner((
(r'[()]', lambda s, t: t),
(r'[a-zA-Z0-9_.:=>]+', lambda s, t: t),
(r'[\s,]', None),
))
def cqltype_to_python(cql_string):
"""
Given a cql type string, creates a list that can be manipulated in python
Example:
int -> ['int']
frozen> -> ['frozen', ['tuple', ['text', 'int']]]
"""
scanner = re.Scanner((
(r'[a-zA-Z0-9_]+', lambda s, t: "'{}'".format(t)),
(r'<', lambda s, t: ', ['),
(r'>', lambda s, t: ']'),
(r'[, ]', lambda s, t: t),
(r'".*?"', lambda s, t: "'{}'".format(t)),
))
scanned_tokens = scanner.scan(cql_string)[0]
hierarchy = ast.literal_eval(''.join(scanned_tokens))
return [hierarchy] if isinstance(hierarchy, str) else list(hierarchy)
def python_to_cqltype(types):
"""
Opposite of the `cql_to_python` function. Given a python list, creates a cql type string from the representation
Example:
['int'] -> int
['frozen', ['tuple', ['text', 'int']]] -> frozen>
"""
scanner = re.Scanner((
(r"'[a-zA-Z0-9_]+'", lambda s, t: t[1:-1]),
(r'^\[', lambda s, t: None),
(r'\]$', lambda s, t: None),
(r',\s*\[', lambda s, t: '<'),
(r'\]', lambda s, t: '>'),
(r'[, ]', lambda s, t: t),
(r'\'".*?"\'', lambda s, t: t[1:-1]),
))
scanned_tokens = scanner.scan(repr(types))[0]
cql = ''.join(scanned_tokens).replace('\\\\', '\\')
return cql
def _strip_frozen_from_python(types):
"""
Given a python list representing a cql type, removes 'frozen'
Example:
['frozen', ['tuple', ['text', 'int']]] -> ['tuple', ['text', 'int']]
"""
while 'frozen' in types:
index = types.index('frozen')
types = types[:index] + types[index + 1] + types[index + 2:]
new_types = [_strip_frozen_from_python(item) if isinstance(item, list) else item for item in types]
return new_types
def strip_frozen(cql):
"""
Given a cql type string, and removes frozen
Example:
frozen> -> tuple
"""
types = cqltype_to_python(cql)
types_without_frozen = _strip_frozen_from_python(types)
cql = python_to_cqltype(types_without_frozen)
return cql
def lookup_casstype_simple(casstype):
"""
Given a Cassandra type name (either fully distinguished or not), hand
back the CassandraType class responsible for it. If a name is not
recognized, a custom _UnrecognizedType subclass will be created for it.
This function does not handle complex types (so no type parameters--
nothing with parentheses). Use lookup_casstype() instead if you might need
that.
"""
shortname = trim_if_startswith(casstype, apache_cassandra_type_prefix)
try:
typeclass = _casstypes[shortname]
except KeyError:
typeclass = mkUnrecognizedType(casstype)
return typeclass
def parse_casstype_args(typestring):
tokens, remainder = casstype_scanner.scan(typestring)
if remainder:
raise ValueError("weird characters %r at end" % remainder)
# use a stack of (types, names) lists
args = [([], [])]
for tok in tokens:
if tok == '(':
args.append(([], []))
elif tok == ')':
types, names = args.pop()
prev_types, prev_names = args[-1]
prev_types[-1] = prev_types[-1].apply_parameters(types, names)
else:
types, names = args[-1]
parts = re.split(':|=>', tok)
tok = parts.pop()
if parts:
names.append(parts[0])
else:
names.append(None)
ctype = lookup_casstype_simple(tok)
types.append(ctype)
# return the first (outer) type, which will have all parameters applied
return args[0][0][0]
def lookup_casstype(casstype):
"""
Given a Cassandra type as a string (possibly including parameters), hand
back the CassandraType class responsible for it. If a name is not
recognized, a custom _UnrecognizedType subclass will be created for it.
Example:
>>> lookup_casstype('org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type)')
"""
if isinstance(casstype, (CassandraType, CassandraTypeType)):
return casstype
try:
return parse_casstype_args(casstype)
except (ValueError, AssertionError, IndexError) as e:
raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e))
def is_reversed_casstype(data_type):
return issubclass(data_type, ReversedType)
class EmptyValue(object):
""" See _CassandraType.support_empty_values """
def __str__(self):
return "EMPTY"
__repr__ = __str__
EMPTY = EmptyValue()
@six.add_metaclass(CassandraTypeType)
class _CassandraType(object):
subtypes = ()
num_subtypes = 0
empty_binary_ok = False
support_empty_values = False
"""
Back in the Thrift days, empty strings were used for "null" values of
all types, including non-string types. For most users, an empty
string value in an int column is the same as being null/not present,
so the driver normally returns None in this case. (For string-like
types, it *will* return an empty string by default instead of None.)
To avoid this behavior, set this to :const:`True`. Instead of returning
None for empty string values, the EMPTY singleton (an instance
of EmptyValue) will be returned.
"""
def __repr__(self):
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
@classmethod
def from_binary(cls, byts, protocol_version):
"""
Deserialize a bytestring into a value. See the deserialize() method
for more information. This method differs in that if None or the empty
string is passed in, None may be returned.
"""
if byts is None:
return None
elif len(byts) == 0 and not cls.empty_binary_ok:
return EMPTY if cls.support_empty_values else None
return cls.deserialize(byts, protocol_version)
@classmethod
def to_binary(cls, val, protocol_version):
"""
Serialize a value into a bytestring. See the serialize() method for
more information. This method differs in that if None is passed in,
the result is the empty string.
"""
return b'' if val is None else cls.serialize(val, protocol_version)
@staticmethod
def deserialize(byts, protocol_version):
"""
Given a bytestring, deserialize into a value according to the protocol
for this type. Note that this does not create a new instance of this
class; it merely gives back a value that would be appropriate to go
inside an instance of this class.
"""
return byts
@staticmethod
def serialize(val, protocol_version):
"""
Given a value appropriate for this class, serialize it according to the
protocol for this type and return the corresponding bytestring.
"""
return val
@classmethod
def cass_parameterized_type_with(cls, subtypes, full=False):
"""
Return the name of this type as it would be expressed by Cassandra,
optionally fully qualified. If subtypes is not None, it is expected
to be a list of other CassandraType subclasses, and the output
string includes the Cassandra names for those subclasses as well,
as parameters to this one.
Example:
>>> LongType.cass_parameterized_type_with(())
'LongType'
>>> LongType.cass_parameterized_type_with((), full=True)
'org.apache.cassandra.db.marshal.LongType'
>>> SetType.cass_parameterized_type_with([DecimalType], full=True)
'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)'
"""
cname = cls.cassname
if full and '.' not in cname:
cname = apache_cassandra_type_prefix + cname
if not subtypes:
return cname
sublist = ', '.join(styp.cass_parameterized_type(full=full) for styp in subtypes)
return '%s(%s)' % (cname, sublist)
@classmethod
def apply_parameters(cls, subtypes, names=None):
"""
Given a set of other CassandraTypes, create a new subtype of this type
using them as parameters. This is how composite types are constructed.
>>> MapType.apply_parameters([DateType, BooleanType])
`subtypes` will be a sequence of CassandraTypes. If provided, `names`
will be an equally long sequence of column names or Nones.
"""
if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes:
raise ValueError("%s types require %d subtypes (%d given)"
% (cls.typename, cls.num_subtypes, len(subtypes)))
newname = cls.cass_parameterized_type_with(subtypes)
if six.PY2 and isinstance(newname, unicode):
newname = newname.encode('utf-8')
return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names})
@classmethod
def cql_parameterized_type(cls):
"""
Return a CQL type specifier for this type. If this type has parameters,
they are included in standard CQL <> notation.
"""
if not cls.subtypes:
return cls.typename
return '%s<%s>' % (cls.typename, ', '.join(styp.cql_parameterized_type() for styp in cls.subtypes))
@classmethod
def cass_parameterized_type(cls, full=False):
"""
Return a Cassandra type specifier for this type. If this type has
parameters, they are included in the standard () notation.
"""
return cls.cass_parameterized_type_with(cls.subtypes, full=full)
# it's initially named with a _ to avoid registering it as a real type, but
# client programs may want to use the name still for isinstance(), etc
CassandraType = _CassandraType
class _UnrecognizedType(_CassandraType):
num_subtypes = 'UNKNOWN'
if six.PY3:
def mkUnrecognizedType(casstypename):
return CassandraTypeType(casstypename,
(_UnrecognizedType,),
{'typename': "'%s'" % casstypename})
else:
def mkUnrecognizedType(casstypename): # noqa
return CassandraTypeType(casstypename.encode('utf8'),
(_UnrecognizedType,),
{'typename': "'%s'" % casstypename})
class BytesType(_CassandraType):
typename = 'blob'
empty_binary_ok = True
@staticmethod
def serialize(val, protocol_version):
return six.binary_type(val)
class DecimalType(_CassandraType):
typename = 'decimal'
@staticmethod
def deserialize(byts, protocol_version):
scale = int32_unpack(byts[:4])
unscaled = varint_unpack(byts[4:])
return Decimal('%de%d' % (unscaled, -scale))
@staticmethod
def serialize(dec, protocol_version):
try:
sign, digits, exponent = dec.as_tuple()
except AttributeError:
try:
sign, digits, exponent = Decimal(dec).as_tuple()
except Exception:
raise TypeError("Invalid type for Decimal value: %r", dec)
unscaled = int(''.join([str(digit) for digit in digits]))
if sign:
unscaled *= -1
scale = int32_pack(-exponent)
unscaled = varint_pack(unscaled)
return scale + unscaled
class UUIDType(_CassandraType):
typename = 'uuid'
@staticmethod
def deserialize(byts, protocol_version):
return UUID(bytes=byts)
@staticmethod
def serialize(uuid, protocol_version):
try:
return uuid.bytes
except AttributeError:
raise TypeError("Got a non-UUID object for a UUID value")
class BooleanType(_CassandraType):
typename = 'boolean'
@staticmethod
def deserialize(byts, protocol_version):
return bool(int8_unpack(byts))
@staticmethod
def serialize(truth, protocol_version):
return int8_pack(truth)
class ByteType(_CassandraType):
typename = 'tinyint'
@staticmethod
def deserialize(byts, protocol_version):
return int8_unpack(byts)
@staticmethod
def serialize(byts, protocol_version):
return int8_pack(byts)
if six.PY2:
class AsciiType(_CassandraType):
typename = 'ascii'
empty_binary_ok = True
else:
class AsciiType(_CassandraType):
typename = 'ascii'
empty_binary_ok = True
@staticmethod
def deserialize(byts, protocol_version):
return byts.decode('ascii')
@staticmethod
def serialize(var, protocol_version):
try:
return var.encode('ascii')
except UnicodeDecodeError:
return var
class FloatType(_CassandraType):
typename = 'float'
@staticmethod
def deserialize(byts, protocol_version):
return float_unpack(byts)
@staticmethod
def serialize(byts, protocol_version):
return float_pack(byts)
class DoubleType(_CassandraType):
typename = 'double'
@staticmethod
def deserialize(byts, protocol_version):
return double_unpack(byts)
@staticmethod
def serialize(byts, protocol_version):
return double_pack(byts)
class LongType(_CassandraType):
typename = 'bigint'
@staticmethod
def deserialize(byts, protocol_version):
return int64_unpack(byts)
@staticmethod
def serialize(byts, protocol_version):
return int64_pack(byts)
class Int32Type(_CassandraType):
typename = 'int'
@staticmethod
def deserialize(byts, protocol_version):
return int32_unpack(byts)
@staticmethod
def serialize(byts, protocol_version):
return int32_pack(byts)
class IntegerType(_CassandraType):
typename = 'varint'
@staticmethod
def deserialize(byts, protocol_version):
return varint_unpack(byts)
@staticmethod
def serialize(byts, protocol_version):
return varint_pack(byts)
class InetAddressType(_CassandraType):
typename = 'inet'
@staticmethod
def deserialize(byts, protocol_version):
if len(byts) == 16:
return util.inet_ntop(socket.AF_INET6, byts)
else:
# util.inet_pton could also handle, but this is faster
# since we've already determined the AF
return socket.inet_ntoa(byts)
@staticmethod
def serialize(addr, protocol_version):
try:
if ':' in addr:
return util.inet_pton(socket.AF_INET6, addr)
else:
# util.inet_pton could also handle, but this is faster
# since we've already determined the AF
return socket.inet_aton(addr)
except:
if six.PY3 and isinstance(addr, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
return addr.packed
raise ValueError("can't interpret %r as an inet address" % (addr,))
class CounterColumnType(LongType):
typename = 'counter'
cql_timestamp_formats = (
'%Y-%m-%d %H:%M',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%dT%H:%M',
'%Y-%m-%dT%H:%M:%S',
'%Y-%m-%d'
)
_have_warned_about_timestamps = False
class DateType(_CassandraType):
typename = 'timestamp'
@staticmethod
def interpret_datestring(val):
if val[-5] in ('+', '-'):
offset = (int(val[-4:-2]) * 3600 + int(val[-2:]) * 60) * int(val[-5] + '1')
val = val[:-5]
else:
offset = -time.timezone
for tformat in cql_timestamp_formats:
try:
tval = time.strptime(val, tformat)
except ValueError:
continue
# scale seconds to millis for the raw value
return (calendar.timegm(tval) + offset) * 1e3
else:
raise ValueError("can't interpret %r as a date" % (val,))
@staticmethod
def deserialize(byts, protocol_version):
timestamp = int64_unpack(byts) / 1000.0
return util.datetime_from_timestamp(timestamp)
@staticmethod
def serialize(v, protocol_version):
try:
# v is datetime
timestamp_seconds = calendar.timegm(v.utctimetuple())
timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3
except AttributeError:
try:
timestamp = calendar.timegm(v.timetuple()) * 1e3
except AttributeError:
# Ints and floats are valid timestamps too
if type(v) not in _number_types:
raise TypeError('DateType arguments must be a datetime, date, or timestamp')
timestamp = v
return int64_pack(long(timestamp))
class TimestampType(DateType):
pass
class TimeUUIDType(DateType):
typename = 'timeuuid'
def my_timestamp(self):
return util.unix_time_from_uuid1(self.val)
@staticmethod
def deserialize(byts, protocol_version):
return UUID(bytes=byts)
@staticmethod
def serialize(timeuuid, protocol_version):
try:
return timeuuid.bytes
except AttributeError:
raise TypeError("Got a non-UUID object for a UUID value")
class SimpleDateType(_CassandraType):
typename = 'date'
date_format = "%Y-%m-%d"
# Values of the 'date'` type are encoded as 32-bit unsigned integers
# representing a number of days with epoch (January 1st, 1970) at the center of the
# range (2^31).
EPOCH_OFFSET_DAYS = 2 ** 31
@staticmethod
def deserialize(byts, protocol_version):
days = uint32_unpack(byts) - SimpleDateType.EPOCH_OFFSET_DAYS
return util.Date(days)
@staticmethod
def serialize(val, protocol_version):
try:
days = val.days_from_epoch
except AttributeError:
if isinstance(val, six.integer_types):
# the DB wants offset int values, but util.Date init takes days from epoch
# here we assume int values are offset, as they would appear in CQL
# short circuit to avoid subtracting just to add offset
return uint32_pack(val)
days = util.Date(val).days_from_epoch
return uint32_pack(days + SimpleDateType.EPOCH_OFFSET_DAYS)
class ShortType(_CassandraType):
typename = 'smallint'
@staticmethod
def deserialize(byts, protocol_version):
return int16_unpack(byts)
@staticmethod
def serialize(byts, protocol_version):
return int16_pack(byts)
class TimeType(_CassandraType):
typename = 'time'
@staticmethod
def deserialize(byts, protocol_version):
return util.Time(int64_unpack(byts))
@staticmethod
def serialize(val, protocol_version):
try:
nano = val.nanosecond_time
except AttributeError:
nano = util.Time(val).nanosecond_time
return int64_pack(nano)
class DurationType(_CassandraType):
typename = 'duration'
@staticmethod
def deserialize(byts, protocol_version):
months, days, nanoseconds = vints_unpack(byts)
return util.Duration(months, days, nanoseconds)
@staticmethod
def serialize(duration, protocol_version):
try:
m, d, n = duration.months, duration.days, duration.nanoseconds
except AttributeError:
raise TypeError('DurationType arguments must be a Duration.')
return vints_pack([m, d, n])
class UTF8Type(_CassandraType):
typename = 'text'
empty_binary_ok = True
@staticmethod
def deserialize(byts, protocol_version):
return byts.decode('utf8')
@staticmethod
def serialize(ustr, protocol_version):
try:
return ustr.encode('utf-8')
except UnicodeDecodeError:
# already utf-8
return ustr
class VarcharType(UTF8Type):
typename = 'varchar'
class _ParameterizedType(_CassandraType):
num_subtypes = 'UNKNOWN'
@classmethod
def deserialize(cls, byts, protocol_version):
if not cls.subtypes:
raise NotImplementedError("can't deserialize unparameterized %s"
% cls.typename)
return cls.deserialize_safe(byts, protocol_version)
@classmethod
def serialize(cls, val, protocol_version):
if not cls.subtypes:
raise NotImplementedError("can't serialize unparameterized %s"
% cls.typename)
return cls.serialize_safe(val, protocol_version)
class _SimpleParameterizedType(_ParameterizedType):
@classmethod
def deserialize_safe(cls, byts, protocol_version):
subtype, = cls.subtypes
if protocol_version >= 3:
unpack = int32_unpack
length = 4
else:
unpack = uint16_unpack
length = 2
numelements = unpack(byts[:length])
p = length
result = []
inner_proto = max(3, protocol_version)
for _ in range(numelements):
itemlen = unpack(byts[p:p + length])
p += length
- item = byts[p:p + itemlen]
- p += itemlen
- result.append(subtype.from_binary(item, inner_proto))
+ if itemlen < 0:
+ result.append(None)
+ else:
+ item = byts[p:p + itemlen]
+ p += itemlen
+ result.append(subtype.from_binary(item, inner_proto))
return cls.adapter(result)
@classmethod
def serialize_safe(cls, items, protocol_version):
if isinstance(items, six.string_types):
raise TypeError("Received a string for a type that expects a sequence")
subtype, = cls.subtypes
pack = int32_pack if protocol_version >= 3 else uint16_pack
buf = io.BytesIO()
buf.write(pack(len(items)))
inner_proto = max(3, protocol_version)
for item in items:
itembytes = subtype.to_binary(item, inner_proto)
buf.write(pack(len(itembytes)))
buf.write(itembytes)
return buf.getvalue()
class ListType(_SimpleParameterizedType):
typename = 'list'
num_subtypes = 1
adapter = list
class SetType(_SimpleParameterizedType):
typename = 'set'
num_subtypes = 1
adapter = util.sortedset
class MapType(_ParameterizedType):
typename = 'map'
num_subtypes = 2
@classmethod
def deserialize_safe(cls, byts, protocol_version):
key_type, value_type = cls.subtypes
if protocol_version >= 3:
unpack = int32_unpack
length = 4
else:
unpack = uint16_unpack
length = 2
numelements = unpack(byts[:length])
p = length
themap = util.OrderedMapSerializedKey(key_type, protocol_version)
inner_proto = max(3, protocol_version)
for _ in range(numelements):
key_len = unpack(byts[p:p + length])
p += length
- keybytes = byts[p:p + key_len]
- p += key_len
+ if key_len < 0:
+ keybytes = None
+ key = None
+ else:
+ keybytes = byts[p:p + key_len]
+ p += key_len
+ key = key_type.from_binary(keybytes, inner_proto)
+
val_len = unpack(byts[p:p + length])
p += length
- valbytes = byts[p:p + val_len]
- p += val_len
- key = key_type.from_binary(keybytes, inner_proto)
- val = value_type.from_binary(valbytes, inner_proto)
+ if val_len < 0:
+ val = None
+ else:
+ valbytes = byts[p:p + val_len]
+ p += val_len
+ val = value_type.from_binary(valbytes, inner_proto)
+
themap._insert_unchecked(key, keybytes, val)
return themap
@classmethod
def serialize_safe(cls, themap, protocol_version):
key_type, value_type = cls.subtypes
pack = int32_pack if protocol_version >= 3 else uint16_pack
buf = io.BytesIO()
buf.write(pack(len(themap)))
try:
items = six.iteritems(themap)
except AttributeError:
raise TypeError("Got a non-map object for a map value")
inner_proto = max(3, protocol_version)
for key, val in items:
keybytes = key_type.to_binary(key, inner_proto)
valbytes = value_type.to_binary(val, inner_proto)
buf.write(pack(len(keybytes)))
buf.write(keybytes)
buf.write(pack(len(valbytes)))
buf.write(valbytes)
return buf.getvalue()
class TupleType(_ParameterizedType):
typename = 'tuple'
@classmethod
def deserialize_safe(cls, byts, protocol_version):
proto_version = max(3, protocol_version)
p = 0
values = []
for col_type in cls.subtypes:
if p == len(byts):
break
itemlen = int32_unpack(byts[p:p + 4])
p += 4
if itemlen >= 0:
item = byts[p:p + itemlen]
p += itemlen
else:
item = None
# collections inside UDTs are always encoded with at least the
# version 3 format
values.append(col_type.from_binary(item, proto_version))
if len(values) < len(cls.subtypes):
nones = [None] * (len(cls.subtypes) - len(values))
values = values + nones
return tuple(values)
@classmethod
def serialize_safe(cls, val, protocol_version):
if len(val) > len(cls.subtypes):
raise ValueError("Expected %d items in a tuple, but got %d: %s" %
(len(cls.subtypes), len(val), val))
proto_version = max(3, protocol_version)
buf = io.BytesIO()
for item, subtype in zip(val, cls.subtypes):
if item is not None:
packed_item = subtype.to_binary(item, proto_version)
buf.write(int32_pack(len(packed_item)))
buf.write(packed_item)
else:
buf.write(int32_pack(-1))
return buf.getvalue()
@classmethod
def cql_parameterized_type(cls):
subtypes_string = ', '.join(sub.cql_parameterized_type() for sub in cls.subtypes)
return 'frozen>' % (subtypes_string,)
class UserType(TupleType):
typename = "org.apache.cassandra.db.marshal.UserType"
_cache = {}
_module = sys.modules[__name__]
@classmethod
def make_udt_class(cls, keyspace, udt_name, field_names, field_types):
assert len(field_names) == len(field_types)
if six.PY2 and isinstance(udt_name, unicode):
udt_name = udt_name.encode('utf-8')
instance = cls._cache.get((keyspace, udt_name))
if not instance or instance.fieldnames != field_names or instance.subtypes != field_types:
instance = type(udt_name, (cls,), {'subtypes': field_types,
'cassname': cls.cassname,
'typename': udt_name,
'fieldnames': field_names,
'keyspace': keyspace,
'mapped_class': None,
'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)})
cls._cache[(keyspace, udt_name)] = instance
return instance
@classmethod
def evict_udt_class(cls, keyspace, udt_name):
if six.PY2 and isinstance(udt_name, unicode):
udt_name = udt_name.encode('utf-8')
try:
del cls._cache[(keyspace, udt_name)]
except KeyError:
pass
@classmethod
def apply_parameters(cls, subtypes, names):
keyspace = subtypes[0].cass_parameterized_type() # when parsed from cassandra type, the keyspace is created as an unrecognized cass type; This gets the name back
udt_name = _name_from_hex_string(subtypes[1].cassname)
field_names = tuple(_name_from_hex_string(encoded_name) for encoded_name in names[2:]) # using tuple here to match what comes into make_udt_class from other sources (for caching equality test)
return cls.make_udt_class(keyspace, udt_name, field_names, tuple(subtypes[2:]))
@classmethod
def cql_parameterized_type(cls):
return "frozen<%s>" % (cls.typename,)
@classmethod
def deserialize_safe(cls, byts, protocol_version):
values = super(UserType, cls).deserialize_safe(byts, protocol_version)
if cls.mapped_class:
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
elif cls.tuple_type:
return cls.tuple_type(*values)
else:
return tuple(values)
@classmethod
def serialize_safe(cls, val, protocol_version):
proto_version = max(3, protocol_version)
buf = io.BytesIO()
for i, (fieldname, subtype) in enumerate(zip(cls.fieldnames, cls.subtypes)):
# first treat as a tuple, else by custom type
try:
item = val[i]
except TypeError:
item = getattr(val, fieldname)
if item is not None:
packed_item = subtype.to_binary(item, proto_version)
buf.write(int32_pack(len(packed_item)))
buf.write(packed_item)
else:
buf.write(int32_pack(-1))
return buf.getvalue()
@classmethod
def _make_registered_udt_namedtuple(cls, keyspace, name, field_names):
# this is required to make the type resolvable via this module...
# required when unregistered udts are pickled for use as keys in
# util.OrderedMap
t = cls._make_udt_tuple_type(name, field_names)
if t:
qualified_name = "%s_%s" % (keyspace, name)
setattr(cls._module, qualified_name, t)
return t
@classmethod
def _make_udt_tuple_type(cls, name, field_names):
# fallback to positional named, then unnamed tuples
# for CQL identifiers that aren't valid in Python,
try:
t = namedtuple(name, field_names)
except ValueError:
try:
t = namedtuple(name, util._positional_rename_invalid_identifiers(field_names))
log.warning("could not create a namedtuple for '%s' because one or more "
"field names are not valid Python identifiers (%s); "
"returning positionally-named fields" % (name, field_names))
except ValueError:
t = None
log.warning("could not create a namedtuple for '%s' because the name is "
"not a valid Python identifier; will return tuples in "
"its place" % (name,))
return t
class CompositeType(_ParameterizedType):
typename = "org.apache.cassandra.db.marshal.CompositeType"
@classmethod
def cql_parameterized_type(cls):
"""
There is no CQL notation for Composites, so we override this.
"""
typestring = cls.cass_parameterized_type(full=True)
return "'%s'" % (typestring,)
@classmethod
def deserialize_safe(cls, byts, protocol_version):
result = []
for subtype in cls.subtypes:
if not byts:
# CompositeType can have missing elements at the end
break
element_length = uint16_unpack(byts[:2])
element = byts[2:2 + element_length]
# skip element length, element, and the EOC (one byte)
byts = byts[2 + element_length + 1:]
result.append(subtype.from_binary(element, protocol_version))
return tuple(result)
class DynamicCompositeType(_ParameterizedType):
typename = "org.apache.cassandra.db.marshal.DynamicCompositeType"
@classmethod
def cql_parameterized_type(cls):
sublist = ', '.join('%s=>%s' % (alias, typ.cass_parameterized_type(full=True)) for alias, typ in zip(cls.fieldnames, cls.subtypes))
return "'%s(%s)'" % (cls.typename, sublist)
class ColumnToCollectionType(_ParameterizedType):
"""
This class only really exists so that we can cleanly evaluate types when
Cassandra includes this. We don't actually need or want the extra
information.
"""
typename = "org.apache.cassandra.db.marshal.ColumnToCollectionType"
class ReversedType(_ParameterizedType):
typename = "org.apache.cassandra.db.marshal.ReversedType"
num_subtypes = 1
@classmethod
def deserialize_safe(cls, byts, protocol_version):
subtype, = cls.subtypes
return subtype.from_binary(byts, protocol_version)
@classmethod
def serialize_safe(cls, val, protocol_version):
subtype, = cls.subtypes
return subtype.to_binary(val, protocol_version)
class FrozenType(_ParameterizedType):
typename = "frozen"
num_subtypes = 1
@classmethod
def deserialize_safe(cls, byts, protocol_version):
subtype, = cls.subtypes
return subtype.from_binary(byts, protocol_version)
@classmethod
def serialize_safe(cls, val, protocol_version):
subtype, = cls.subtypes
return subtype.to_binary(val, protocol_version)
def is_counter_type(t):
if isinstance(t, six.string_types):
t = lookup_casstype(t)
return issubclass(t, CounterColumnType)
def cql_typename(casstypename):
"""
Translate a Cassandra-style type specifier (optionally-fully-distinguished
Java class names for data types, along with optional parameters) into a
CQL-style type specifier.
>>> cql_typename('DateType')
'timestamp'
>>> cql_typename('org.apache.cassandra.db.marshal.ListType(IntegerType)')
'list'
"""
return lookup_casstype(casstypename).cql_parameterized_type()
+
+
+class WKBGeometryType(object):
+ POINT = 1
+ LINESTRING = 2
+ POLYGON = 3
+
+
+class PointType(CassandraType):
+ typename = 'PointType'
+
+ _type = struct.pack('[[]]
+ type_ = int8_unpack(byts[0:1])
+
+ if type_ in (BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE),
+ BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN)):
+ time0 = precision0 = None
+ else:
+ time0 = int64_unpack(byts[1:9])
+ precision0 = int8_unpack(byts[9:10])
+
+ if type_ == BoundKind.to_int(BoundKind.CLOSED_RANGE):
+ time1 = int64_unpack(byts[10:18])
+ precision1 = int8_unpack(byts[18:19])
+ else:
+ time1 = precision1 = None
+
+ if time0 is not None:
+ date_range_bound0 = util.DateRangeBound(
+ time0,
+ cls._decode_precision(precision0)
+ )
+ if time1 is not None:
+ date_range_bound1 = util.DateRangeBound(
+ time1,
+ cls._decode_precision(precision1)
+ )
+
+ if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE):
+ return util.DateRange(value=date_range_bound0)
+ if type_ == BoundKind.to_int(BoundKind.CLOSED_RANGE):
+ return util.DateRange(lower_bound=date_range_bound0,
+ upper_bound=date_range_bound1)
+ if type_ == BoundKind.to_int(BoundKind.OPEN_RANGE_HIGH):
+ return util.DateRange(lower_bound=date_range_bound0,
+ upper_bound=util.OPEN_BOUND)
+ if type_ == BoundKind.to_int(BoundKind.OPEN_RANGE_LOW):
+ return util.DateRange(lower_bound=util.OPEN_BOUND,
+ upper_bound=date_range_bound0)
+ if type_ == BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE):
+ return util.DateRange(lower_bound=util.OPEN_BOUND,
+ upper_bound=util.OPEN_BOUND)
+ if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN):
+ return util.DateRange(value=util.OPEN_BOUND)
+ raise ValueError('Could not deserialize %r' % (byts,))
+
+ @classmethod
+ def serialize(cls, v, protocol_version):
+ buf = io.BytesIO()
+ bound_kind, bounds = None, ()
+
+ try:
+ value = v.value
+ except AttributeError:
+ raise ValueError(
+ '%s.serialize expects an object with a value attribute; got'
+ '%r' % (cls.__name__, v)
+ )
+
+ if value is None:
+ try:
+ lower_bound, upper_bound = v.lower_bound, v.upper_bound
+ except AttributeError:
+ raise ValueError(
+ '%s.serialize expects an object with lower_bound and '
+ 'upper_bound attributes; got %r' % (cls.__name__, v)
+ )
+ if lower_bound == util.OPEN_BOUND and upper_bound == util.OPEN_BOUND:
+ bound_kind = BoundKind.BOTH_OPEN_RANGE
+ elif lower_bound == util.OPEN_BOUND:
+ bound_kind = BoundKind.OPEN_RANGE_LOW
+ bounds = (upper_bound,)
+ elif upper_bound == util.OPEN_BOUND:
+ bound_kind = BoundKind.OPEN_RANGE_HIGH
+ bounds = (lower_bound,)
+ else:
+ bound_kind = BoundKind.CLOSED_RANGE
+ bounds = lower_bound, upper_bound
+ else: # value is not None
+ if value == util.OPEN_BOUND:
+ bound_kind = BoundKind.SINGLE_DATE_OPEN
+ else:
+ bound_kind = BoundKind.SINGLE_DATE
+ bounds = (value,)
+
+ if bound_kind is None:
+ raise ValueError(
+ 'Cannot serialize %r; could not find bound kind' % (v,)
+ )
+
+ buf.write(int8_pack(BoundKind.to_int(bound_kind)))
+ for bound in bounds:
+ buf.write(int64_pack(bound.milliseconds))
+ buf.write(int8_pack(cls._encode_precision(bound.precision)))
+
+ return buf.getvalue()
diff --git a/cassandra/datastax/cloud/__init__.py b/cassandra/datastax/cloud/__init__.py
index ed9435e..ecb4a73 100644
--- a/cassandra/datastax/cloud/__init__.py
+++ b/cassandra/datastax/cloud/__init__.py
@@ -1,167 +1,196 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import json
+import sys
import tempfile
import shutil
+import six
from six.moves.urllib.request import urlopen
_HAS_SSL = True
try:
- from ssl import SSLContext, PROTOCOL_TLSv1, CERT_REQUIRED
+ from ssl import SSLContext, PROTOCOL_TLS, CERT_REQUIRED
except:
_HAS_SSL = False
from zipfile import ZipFile
# 2.7 vs 3.x
try:
from zipfile import BadZipFile
except:
from zipfile import BadZipfile as BadZipFile
from cassandra import DriverException
log = logging.getLogger(__name__)
__all__ = ['get_cloud_config']
-PRODUCT_APOLLO = "DATASTAX_APOLLO"
+DATASTAX_CLOUD_PRODUCT_TYPE = "DATASTAX_APOLLO"
class CloudConfig(object):
username = None
password = None
host = None
port = None
keyspace = None
local_dc = None
ssl_context = None
sni_host = None
sni_port = None
host_ids = None
@classmethod
def from_dict(cls, d):
c = cls()
c.port = d.get('port', None)
try:
c.port = int(d['port'])
except:
pass
c.username = d.get('username', None)
c.password = d.get('password', None)
c.host = d.get('host', None)
c.keyspace = d.get('keyspace', None)
c.local_dc = d.get('localDC', None)
return c
-def get_cloud_config(cloud_config):
+def get_cloud_config(cloud_config, create_pyopenssl_context=False):
if not _HAS_SSL:
raise DriverException("A Python installation with SSL is required to connect to a cloud cluster.")
if 'secure_connect_bundle' not in cloud_config:
raise ValueError("The cloud config doesn't have a secure_connect_bundle specified.")
try:
- config = read_cloud_config_from_zip(cloud_config)
+ config = read_cloud_config_from_zip(cloud_config, create_pyopenssl_context)
except BadZipFile:
raise ValueError("Unable to open the zip file for the cloud config. Check your secure connect bundle.")
- return read_metadata_info(config, cloud_config)
+ config = read_metadata_info(config, cloud_config)
+ if create_pyopenssl_context:
+ config.ssl_context = config.pyopenssl_context
+ return config
-def read_cloud_config_from_zip(cloud_config):
+def read_cloud_config_from_zip(cloud_config, create_pyopenssl_context):
secure_bundle = cloud_config['secure_connect_bundle']
+ use_default_tempdir = cloud_config.get('use_default_tempdir', None)
with ZipFile(secure_bundle) as zipfile:
- base_dir = os.path.dirname(secure_bundle)
+ base_dir = tempfile.gettempdir() if use_default_tempdir else os.path.dirname(secure_bundle)
tmp_dir = tempfile.mkdtemp(dir=base_dir)
try:
zipfile.extractall(path=tmp_dir)
- return parse_cloud_config(os.path.join(tmp_dir, 'config.json'), cloud_config)
+ return parse_cloud_config(os.path.join(tmp_dir, 'config.json'), cloud_config, create_pyopenssl_context)
finally:
shutil.rmtree(tmp_dir)
-def parse_cloud_config(path, cloud_config):
+def parse_cloud_config(path, cloud_config, create_pyopenssl_context):
with open(path, 'r') as stream:
data = json.load(stream)
config = CloudConfig.from_dict(data)
config_dir = os.path.dirname(path)
if 'ssl_context' in cloud_config:
config.ssl_context = cloud_config['ssl_context']
else:
# Load the ssl_context before we delete the temporary directory
ca_cert_location = os.path.join(config_dir, 'ca.crt')
cert_location = os.path.join(config_dir, 'cert')
key_location = os.path.join(config_dir, 'key')
+ # Regardless of if we create a pyopenssl context, we still need the builtin one
+ # to connect to the metadata service
config.ssl_context = _ssl_context_from_cert(ca_cert_location, cert_location, key_location)
+ if create_pyopenssl_context:
+ config.pyopenssl_context = _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location)
return config
def read_metadata_info(config, cloud_config):
url = "https://{}:{}/metadata".format(config.host, config.port)
timeout = cloud_config['connect_timeout'] if 'connect_timeout' in cloud_config else 5
try:
response = urlopen(url, context=config.ssl_context, timeout=timeout)
except Exception as e:
log.exception(e)
- raise DriverException("Unable to connect to the metadata service at %s" % url)
+ raise DriverException("Unable to connect to the metadata service at %s. "
+ "Check the cluster status in the cloud console. " % url)
if response.code != 200:
raise DriverException(("Error while fetching the metadata at: %s. "
"The service returned error code %d." % (url, response.code)))
return parse_metadata_info(config, response.read().decode('utf-8'))
def parse_metadata_info(config, http_data):
try:
data = json.loads(http_data)
except:
msg = "Failed to load cluster metadata"
raise DriverException(msg)
contact_info = data['contact_info']
config.local_dc = contact_info['local_dc']
proxy_info = contact_info['sni_proxy_address'].split(':')
config.sni_host = proxy_info[0]
try:
config.sni_port = int(proxy_info[1])
except:
config.sni_port = 9042
config.host_ids = [host_id for host_id in contact_info['contact_points']]
return config
def _ssl_context_from_cert(ca_cert_location, cert_location, key_location):
- ssl_context = SSLContext(PROTOCOL_TLSv1)
+ ssl_context = SSLContext(PROTOCOL_TLS)
ssl_context.load_verify_locations(ca_cert_location)
ssl_context.verify_mode = CERT_REQUIRED
ssl_context.load_cert_chain(certfile=cert_location, keyfile=key_location)
return ssl_context
+
+
+def _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location):
+ try:
+ from OpenSSL import SSL
+ except ImportError as e:
+ six.reraise(
+ ImportError,
+ ImportError("PyOpenSSL must be installed to connect to Astra with the Eventlet or Twisted event loops"),
+ sys.exc_info()[2]
+ )
+ ssl_context = SSL.Context(SSL.TLSv1_METHOD)
+ ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: ok)
+ ssl_context.use_certificate_file(cert_location)
+ ssl_context.use_privatekey_file(key_location)
+ ssl_context.load_verify_locations(ca_cert_location)
+
+ return ssl_context
\ No newline at end of file
diff --git a/tests/integration/advanced/__init__.py b/cassandra/datastax/graph/__init__.py
similarity index 54%
copy from tests/integration/advanced/__init__.py
copy to cassandra/datastax/graph/__init__.py
index 662f5b8..11785c8 100644
--- a/tests/integration/advanced/__init__.py
+++ b/cassandra/datastax/graph/__init__.py
@@ -1,23 +1,23 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License
+# limitations under the License.
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest # noqa
-try:
- from ccmlib import common
-except ImportError as e:
- raise unittest.SkipTest('ccm is a dependency for integration tests:', e)
+from cassandra.datastax.graph.types import Element, Vertex, VertexProperty, Edge, Path, T
+from cassandra.datastax.graph.query import (
+ GraphOptions, GraphProtocol, GraphStatement, SimpleGraphStatement, Result,
+ graph_object_row_factory, single_object_row_factory,
+ graph_result_row_factory, graph_graphson2_row_factory,
+ graph_graphson3_row_factory
+)
+from cassandra.datastax.graph.graphson import *
diff --git a/cassandra/datastax/graph/fluent/__init__.py b/cassandra/datastax/graph/fluent/__init__.py
new file mode 100644
index 0000000..44a0d13
--- /dev/null
+++ b/cassandra/datastax/graph/fluent/__init__.py
@@ -0,0 +1,303 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import copy
+
+from concurrent.futures import Future
+
+HAVE_GREMLIN = False
+try:
+ import gremlin_python
+ HAVE_GREMLIN = True
+except ImportError:
+ # gremlinpython is not installed.
+ pass
+
+if HAVE_GREMLIN:
+ from gremlin_python.structure.graph import Graph
+ from gremlin_python.driver.remote_connection import RemoteConnection, RemoteTraversal
+ from gremlin_python.process.traversal import Traverser, TraversalSideEffects
+ from gremlin_python.process.graph_traversal import GraphTraversal
+
+ from cassandra.cluster import Session, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT
+ from cassandra.datastax.graph import GraphOptions, GraphProtocol
+ from cassandra.datastax.graph.query import _GraphSONContextRowFactory
+
+ from cassandra.datastax.graph.fluent.serializers import (
+ GremlinGraphSONReaderV2,
+ GremlinGraphSONReaderV3,
+ dse_graphson2_deserializers,
+ gremlin_graphson2_deserializers,
+ dse_graphson3_deserializers,
+ gremlin_graphson3_deserializers
+ )
+ from cassandra.datastax.graph.fluent.query import _DefaultTraversalBatch, _query_from_traversal
+
+ log = logging.getLogger(__name__)
+
+ __all__ = ['BaseGraphRowFactory', 'graph_traversal_row_factory',
+ 'graph_traversal_dse_object_row_factory', 'DSESessionRemoteGraphConnection', 'DseGraph']
+
+ # Traversal result keys
+ _bulk_key = 'bulk'
+ _result_key = 'result'
+
+
+ class BaseGraphRowFactory(_GraphSONContextRowFactory):
+ """
+ Base row factory for graph traversal. This class basically wraps a
+ graphson reader function to handle additional features of Gremlin/DSE
+ and is callable as a normal row factory.
+
+ Currently supported:
+ - bulk results
+ """
+
+ def __call__(self, column_names, rows):
+ for row in rows:
+ parsed_row = self.graphson_reader.readObject(row[0])
+ yield parsed_row[_result_key]
+ bulk = parsed_row.get(_bulk_key, 1)
+ for _ in range(bulk - 1):
+ yield copy.deepcopy(parsed_row[_result_key])
+
+
+ class _GremlinGraphSON2RowFactory(BaseGraphRowFactory):
+ """Row Factory that returns the decoded graphson2."""
+ graphson_reader_class = GremlinGraphSONReaderV2
+ graphson_reader_kwargs = {'deserializer_map': gremlin_graphson2_deserializers}
+
+
+ class _DseGraphSON2RowFactory(BaseGraphRowFactory):
+ """Row Factory that returns the decoded graphson2 as DSE types."""
+ graphson_reader_class = GremlinGraphSONReaderV2
+ graphson_reader_kwargs = {'deserializer_map': dse_graphson2_deserializers}
+
+ gremlin_graphson2_traversal_row_factory = _GremlinGraphSON2RowFactory
+ # TODO remove in next major
+ graph_traversal_row_factory = gremlin_graphson2_traversal_row_factory
+
+ dse_graphson2_traversal_row_factory = _DseGraphSON2RowFactory
+ # TODO remove in next major
+ graph_traversal_dse_object_row_factory = dse_graphson2_traversal_row_factory
+
+
+ class _GremlinGraphSON3RowFactory(BaseGraphRowFactory):
+ """Row Factory that returns the decoded graphson2."""
+ graphson_reader_class = GremlinGraphSONReaderV3
+ graphson_reader_kwargs = {'deserializer_map': gremlin_graphson3_deserializers}
+
+
+ class _DseGraphSON3RowFactory(BaseGraphRowFactory):
+ """Row Factory that returns the decoded graphson3 as DSE types."""
+ graphson_reader_class = GremlinGraphSONReaderV3
+ graphson_reader_kwargs = {'deserializer_map': dse_graphson3_deserializers}
+
+
+ gremlin_graphson3_traversal_row_factory = _GremlinGraphSON3RowFactory
+ dse_graphson3_traversal_row_factory = _DseGraphSON3RowFactory
+
+
+ class DSESessionRemoteGraphConnection(RemoteConnection):
+ """
+ A Tinkerpop RemoteConnection to execute traversal queries on DSE.
+
+ :param session: A DSE session
+ :param graph_name: (Optional) DSE Graph name.
+ :param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`.
+ """
+
+ session = None
+ graph_name = None
+ execution_profile = None
+
+ def __init__(self, session, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT):
+ super(DSESessionRemoteGraphConnection, self).__init__(None, None)
+
+ if not isinstance(session, Session):
+ raise ValueError('A DSE Session must be provided to execute graph traversal queries.')
+
+ self.session = session
+ self.graph_name = graph_name
+ self.execution_profile = execution_profile
+
+ @staticmethod
+ def _traversers_generator(traversers):
+ for t in traversers:
+ yield Traverser(t)
+
+ def _prepare_query(self, bytecode):
+ ep = self.session.execution_profile_clone_update(self.execution_profile)
+ graph_options = ep.graph_options
+ graph_options.graph_name = self.graph_name or graph_options.graph_name
+ graph_options.graph_language = DseGraph.DSE_GRAPH_QUERY_LANGUAGE
+ # We resolve the execution profile options here , to know how what gremlin factory to set
+ self.session._resolve_execution_profile_options(ep)
+
+ context = None
+ if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0:
+ row_factory = gremlin_graphson2_traversal_row_factory
+ elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0:
+ row_factory = gremlin_graphson3_traversal_row_factory
+ context = {
+ 'cluster': self.session.cluster,
+ 'graph_name': graph_options.graph_name.decode('utf-8')
+ }
+ else:
+ raise ValueError('Unknown graph protocol: {}'.format(graph_options.graph_protocol))
+
+ ep.row_factory = row_factory
+ query = DseGraph.query_from_traversal(bytecode, graph_options.graph_protocol, context)
+
+ return query, ep
+
+ @staticmethod
+ def _handle_query_results(result_set, gremlin_future):
+ try:
+ gremlin_future.set_result(
+ RemoteTraversal(DSESessionRemoteGraphConnection._traversers_generator(result_set), TraversalSideEffects())
+ )
+ except Exception as e:
+ gremlin_future.set_exception(e)
+
+ @staticmethod
+ def _handle_query_error(response, gremlin_future):
+ gremlin_future.set_exception(response)
+
+ def submit(self, bytecode):
+ # the only reason I don't use submitAsync here
+ # is to avoid an unuseful future wrap
+ query, ep = self._prepare_query(bytecode)
+
+ traversers = self.session.execute_graph(query, execution_profile=ep)
+ return RemoteTraversal(self._traversers_generator(traversers), TraversalSideEffects())
+
+ def submitAsync(self, bytecode):
+ query, ep = self._prepare_query(bytecode)
+
+ # to be compatible with gremlinpython, we need to return a concurrent.futures.Future
+ gremlin_future = Future()
+ response_future = self.session.execute_graph_async(query, execution_profile=ep)
+ response_future.add_callback(self._handle_query_results, gremlin_future)
+ response_future.add_errback(self._handle_query_error, gremlin_future)
+
+ return gremlin_future
+
+ def __str__(self):
+ return "".format(self.graph_name)
+
+ __repr__ = __str__
+
+
+ class DseGraph(object):
+ """
+ Dse Graph utility class for GraphTraversal construction and execution.
+ """
+
+ DSE_GRAPH_QUERY_LANGUAGE = 'bytecode-json'
+ """
+ Graph query language, Default is 'bytecode-json' (GraphSON).
+ """
+
+ DSE_GRAPH_QUERY_PROTOCOL = GraphProtocol.GRAPHSON_2_0
+ """
+ Graph query language, Default is GraphProtocol.GRAPHSON_2_0.
+ """
+
+ @staticmethod
+ def query_from_traversal(traversal, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, context=None):
+ """
+ From a GraphTraversal, return a query string based on the language specified in `DseGraph.DSE_GRAPH_QUERY_LANGUAGE`.
+
+ :param traversal: The GraphTraversal object
+ :param graph_protocol: The graph protocol. Default is `DseGraph.DSE_GRAPH_QUERY_PROTOCOL`.
+ :param context: The dict of the serialization context, needed for GraphSON3 (tuple, udt).
+ e.g: {'cluster': cluster, 'graph_name': name}
+ """
+
+ if isinstance(traversal, GraphTraversal):
+ for strategy in traversal.traversal_strategies.traversal_strategies:
+ rc = strategy.remote_connection
+ if (isinstance(rc, DSESessionRemoteGraphConnection) and
+ rc.session or rc.graph_name or rc.execution_profile):
+ log.warning("GraphTraversal session, graph_name and execution_profile are "
+ "only taken into account when executed with TinkerPop.")
+
+ return _query_from_traversal(traversal, graph_protocol, context)
+
+ @staticmethod
+ def traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT,
+ traversal_class=None):
+ """
+ Returns a TinkerPop GraphTraversalSource binded to the session and graph_name if provided.
+
+ :param session: (Optional) A DSE session
+ :param graph_name: (Optional) DSE Graph name
+ :param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`.
+ :param traversal_class: (Optional) The GraphTraversalSource class to use (DSL).
+
+ .. code-block:: python
+
+ from cassandra.cluster import Cluster
+ from cassandra.datastax.graph.fluent import DseGraph
+
+ c = Cluster()
+ session = c.connect()
+
+ g = DseGraph.traversal_source(session, 'my_graph')
+ print g.V().valueMap().toList()
+
+ """
+
+ graph = Graph()
+ traversal_source = graph.traversal(traversal_class)
+
+ if session:
+ traversal_source = traversal_source.withRemote(
+ DSESessionRemoteGraphConnection(session, graph_name, execution_profile))
+
+ return traversal_source
+
+ @staticmethod
+ def create_execution_profile(graph_name, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, **kwargs):
+ """
+ Creates an ExecutionProfile for GraphTraversal execution. You need to register that execution profile to the
+ cluster by using `cluster.add_execution_profile`.
+
+ :param graph_name: The graph name
+ :param graph_protocol: (Optional) The graph protocol, default is `DSE_GRAPH_QUERY_PROTOCOL`.
+ """
+
+ if graph_protocol == GraphProtocol.GRAPHSON_2_0:
+ row_factory = dse_graphson2_traversal_row_factory
+ elif graph_protocol == GraphProtocol.GRAPHSON_3_0:
+ row_factory = dse_graphson3_traversal_row_factory
+ else:
+ raise ValueError('Unknown graph protocol: {}'.format(graph_protocol))
+
+ ep = GraphExecutionProfile(row_factory=row_factory,
+ graph_options=GraphOptions(graph_name=graph_name,
+ graph_language=DseGraph.DSE_GRAPH_QUERY_LANGUAGE,
+ graph_protocol=graph_protocol),
+ **kwargs)
+ return ep
+
+ @staticmethod
+ def batch(*args, **kwargs):
+ """
+ Returns the :class:`cassandra.datastax.graph.fluent.query.TraversalBatch` object allowing to
+ execute multiple traversals in the same transaction.
+ """
+ return _DefaultTraversalBatch(*args, **kwargs)
diff --git a/cassandra/datastax/graph/fluent/_predicates.py b/cassandra/datastax/graph/fluent/_predicates.py
new file mode 100644
index 0000000..95bd533
--- /dev/null
+++ b/cassandra/datastax/graph/fluent/_predicates.py
@@ -0,0 +1,202 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+from gremlin_python.process.traversal import P
+
+from cassandra.util import Distance
+
+__all__ = ['GeoP', 'TextDistanceP', 'Search', 'GeoUnit', 'Geo', 'CqlCollection']
+
+
+class GeoP(object):
+
+ def __init__(self, operator, value, other=None):
+ self.operator = operator
+ self.value = value
+ self.other = other
+
+ @staticmethod
+ def inside(*args, **kwargs):
+ return GeoP("inside", *args, **kwargs)
+
+ def __eq__(self, other):
+ return isinstance(other,
+ self.__class__) and self.operator == other.operator and self.value == other.value and self.other == other.other
+
+ def __repr__(self):
+ return self.operator + "(" + str(self.value) + ")" if self.other is None else self.operator + "(" + str(
+ self.value) + "," + str(self.other) + ")"
+
+
+class TextDistanceP(object):
+
+ def __init__(self, operator, value, distance):
+ self.operator = operator
+ self.value = value
+ self.distance = distance
+
+ @staticmethod
+ def fuzzy(*args):
+ return TextDistanceP("fuzzy", *args)
+
+ @staticmethod
+ def token_fuzzy(*args):
+ return TextDistanceP("tokenFuzzy", *args)
+
+ @staticmethod
+ def phrase(*args):
+ return TextDistanceP("phrase", *args)
+
+ def __eq__(self, other):
+ return isinstance(other,
+ self.__class__) and self.operator == other.operator and self.value == other.value and self.distance == other.distance
+
+ def __repr__(self):
+ return self.operator + "(" + str(self.value) + "," + str(self.distance) + ")"
+
+
+class Search(object):
+
+ @staticmethod
+ def token(value):
+ """
+ Search any instance of a certain token within the text property targeted.
+ :param value: the value to look for.
+ """
+ return P('token', value)
+
+ @staticmethod
+ def token_prefix(value):
+ """
+ Search any instance of a certain token prefix withing the text property targeted.
+ :param value: the value to look for.
+ """
+ return P('tokenPrefix', value)
+
+ @staticmethod
+ def token_regex(value):
+ """
+ Search any instance of the provided regular expression for the targeted property.
+ :param value: the value to look for.
+ """
+ return P('tokenRegex', value)
+
+ @staticmethod
+ def prefix(value):
+ """
+ Search for a specific prefix at the beginning of the text property targeted.
+ :param value: the value to look for.
+ """
+ return P('prefix', value)
+
+ @staticmethod
+ def regex(value):
+ """
+ Search for this regular expression inside the text property targeted.
+ :param value: the value to look for.
+ """
+ return P('regex', value)
+
+ @staticmethod
+ def fuzzy(value, distance):
+ """
+ Search for a fuzzy string inside the text property targeted.
+ :param value: the value to look for.
+ :param distance: The distance for the fuzzy search. ie. 1, to allow a one-letter misspellings.
+ """
+ return TextDistanceP.fuzzy(value, distance)
+
+ @staticmethod
+ def token_fuzzy(value, distance):
+ """
+ Search for a token fuzzy inside the text property targeted.
+ :param value: the value to look for.
+ :param distance: The distance for the token fuzzy search. ie. 1, to allow a one-letter misspellings.
+ """
+ return TextDistanceP.token_fuzzy(value, distance)
+
+ @staticmethod
+ def phrase(value, proximity):
+ """
+ Search for a phrase inside the text property targeted.
+ :param value: the value to look for.
+ :param proximity: The proximity for the phrase search. ie. phrase('David Felcey', 2).. to find 'David Felcey' with up to two middle names.
+ """
+ return TextDistanceP.phrase(value, proximity)
+
+
+class CqlCollection(object):
+
+ @staticmethod
+ def contains(value):
+ """
+ Search for a value inside a cql list/set column.
+ :param value: the value to look for.
+ """
+ return P('contains', value)
+
+ @staticmethod
+ def contains_value(value):
+ """
+ Search for a map value.
+ :param value: the value to look for.
+ """
+ return P('containsValue', value)
+
+ @staticmethod
+ def contains_key(value):
+ """
+ Search for a map key.
+ :param value: the value to look for.
+ """
+ return P('containsKey', value)
+
+ @staticmethod
+ def entry_eq(value):
+ """
+ Search for a map entry.
+ :param value: the value to look for.
+ """
+ return P('entryEq', value)
+
+
+class GeoUnit(object):
+ _EARTH_MEAN_RADIUS_KM = 6371.0087714
+ _DEGREES_TO_RADIANS = math.pi / 180
+ _DEG_TO_KM = _DEGREES_TO_RADIANS * _EARTH_MEAN_RADIUS_KM
+ _KM_TO_DEG = 1 / _DEG_TO_KM
+ _MILES_TO_KM = 1.609344001
+
+ MILES = _MILES_TO_KM * _KM_TO_DEG
+ KILOMETERS = _KM_TO_DEG
+ METERS = _KM_TO_DEG / 1000.0
+ DEGREES = 1
+
+
+class Geo(object):
+
+ @staticmethod
+ def inside(value, units=GeoUnit.DEGREES):
+ """
+ Search any instance of geometry inside the Distance targeted.
+ :param value: A Distance to look for.
+ :param units: The units for ``value``. See GeoUnit enum. (Can also
+ provide an integer to use as a multiplier to convert ``value`` to
+ degrees.)
+ """
+ return GeoP.inside(
+ value=Distance(x=value.x, y=value.y, radius=value.radius * units)
+ )
diff --git a/cassandra/datastax/graph/fluent/_query.py b/cassandra/datastax/graph/fluent/_query.py
new file mode 100644
index 0000000..bd89046
--- /dev/null
+++ b/cassandra/datastax/graph/fluent/_query.py
@@ -0,0 +1,229 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import six
+import logging
+
+from cassandra.graph import SimpleGraphStatement, GraphProtocol
+from cassandra.cluster import EXEC_PROFILE_GRAPH_DEFAULT
+
+from gremlin_python.process.graph_traversal import GraphTraversal
+from gremlin_python.structure.io.graphsonV2d0 import GraphSONWriter as GraphSONWriterV2
+from gremlin_python.structure.io.graphsonV3d0 import GraphSONWriter as GraphSONWriterV3
+
+from cassandra.datastax.graph.fluent.serializers import GremlinUserTypeIO, \
+ dse_graphson2_serializers, dse_graphson3_serializers
+
+log = logging.getLogger(__name__)
+
+
+__all__ = ['TraversalBatch', '_query_from_traversal', '_DefaultTraversalBatch']
+
+
+class _GremlinGraphSONWriterAdapter(object):
+
+ def __init__(self, context, **kwargs):
+ super(_GremlinGraphSONWriterAdapter, self).__init__(**kwargs)
+ self.context = context
+ self.user_types = None
+
+ def serialize(self, value, _):
+ return self.toDict(value)
+
+ def get_serializer(self, value):
+ serializer = None
+ try:
+ serializer = self.serializers[type(value)]
+ except KeyError:
+ for key, ser in self.serializers.items():
+ if isinstance(value, key):
+ serializer = ser
+
+ if self.context:
+ # Check if UDT
+ if self.user_types is None:
+ try:
+ user_types = self.context['cluster']._user_types[self.context['graph_name']]
+ self.user_types = dict(map(reversed, six.iteritems(user_types)))
+ except KeyError:
+ self.user_types = {}
+
+ # Custom detection to map a namedtuple to udt
+ if (tuple in self.serializers and serializer is self.serializers[tuple] and hasattr(value, '_fields') or
+ (not serializer and type(value) in self.user_types)):
+ serializer = GremlinUserTypeIO
+
+ if serializer:
+ try:
+ # A serializer can have specialized serializers (e.g for Int32 and Int64, so value dependant)
+ serializer = serializer.get_specialized_serializer(value)
+ except AttributeError:
+ pass
+
+ return serializer
+
+ def toDict(self, obj):
+ serializer = self.get_serializer(obj)
+ return serializer.dictify(obj, self) if serializer else obj
+
+ def definition(self, value):
+ serializer = self.get_serializer(value)
+ return serializer.definition(value, self)
+
+
+class GremlinGraphSON2Writer(_GremlinGraphSONWriterAdapter, GraphSONWriterV2):
+ pass
+
+
+class GremlinGraphSON3Writer(_GremlinGraphSONWriterAdapter, GraphSONWriterV3):
+ pass
+
+
+graphson2_writer = GremlinGraphSON2Writer
+graphson3_writer = GremlinGraphSON3Writer
+
+
+def _query_from_traversal(traversal, graph_protocol, context=None):
+ """
+ From a GraphTraversal, return a query string.
+
+ :param traversal: The GraphTraversal object
+ :param graphson_protocol: The graph protocol to determine the output format.
+ """
+ if graph_protocol == GraphProtocol.GRAPHSON_2_0:
+ graphson_writer = graphson2_writer(context, serializer_map=dse_graphson2_serializers)
+ elif graph_protocol == GraphProtocol.GRAPHSON_3_0:
+ if context is None:
+ raise ValueError('Missing context for GraphSON3 serialization requires.')
+ graphson_writer = graphson3_writer(context, serializer_map=dse_graphson3_serializers)
+ else:
+ raise ValueError('Unknown graph protocol: {}'.format(graph_protocol))
+
+ try:
+ query = graphson_writer.writeObject(traversal)
+ except Exception:
+ log.exception("Error preparing graphson traversal query:")
+ raise
+
+ return query
+
+
+class TraversalBatch(object):
+ """
+ A `TraversalBatch` is used to execute multiple graph traversals in a
+ single transaction. If any traversal in the batch fails, the entire
+ batch will fail to apply.
+
+ If a TraversalBatch is bounded to a DSE session, it can be executed using
+ `traversal_batch.execute()`.
+ """
+
+ _session = None
+ _execution_profile = None
+
+ def __init__(self, session=None, execution_profile=None):
+ """
+ :param session: (Optional) A DSE session
+ :param execution_profile: (Optional) The execution profile to use for the batch execution
+ """
+ self._session = session
+ self._execution_profile = execution_profile
+
+ def add(self, traversal):
+ """
+ Add a traversal to the batch.
+
+ :param traversal: A gremlin GraphTraversal
+ """
+ raise NotImplementedError()
+
+ def add_all(self, traversals):
+ """
+ Adds a sequence of traversals to the batch.
+
+ :param traversals: A sequence of gremlin GraphTraversal
+ """
+ raise NotImplementedError()
+
+ def execute(self):
+ """
+ Execute the traversal batch if bounded to a `DSE Session`.
+ """
+ raise NotImplementedError()
+
+ def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0):
+ """
+ Return the traversal batch as GraphStatement.
+
+ :param graph_protocol: The graph protocol for the GraphSONWriter. Default is GraphProtocol.GRAPHSON_2_0.
+ """
+ raise NotImplementedError()
+
+ def clear(self):
+ """
+ Clear a traversal batch for reuse.
+ """
+ raise NotImplementedError()
+
+ def __len__(self):
+ raise NotImplementedError()
+
+ def __str__(self):
+ return u''.format(len(self))
+ __repr__ = __str__
+
+
+class _DefaultTraversalBatch(TraversalBatch):
+
+ _traversals = None
+
+ def __init__(self, *args, **kwargs):
+ super(_DefaultTraversalBatch, self).__init__(*args, **kwargs)
+ self._traversals = []
+
+ def add(self, traversal):
+ if not isinstance(traversal, GraphTraversal):
+ raise ValueError('traversal should be a gremlin GraphTraversal')
+
+ self._traversals.append(traversal)
+ return self
+
+ def add_all(self, traversals):
+ for traversal in traversals:
+ self.add(traversal)
+
+ def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0, context=None):
+ statements = [_query_from_traversal(t, graph_protocol, context) for t in self._traversals]
+ query = u"[{0}]".format(','.join(statements))
+ return SimpleGraphStatement(query)
+
+ def execute(self):
+ if self._session is None:
+ raise ValueError('A DSE Session must be provided to execute the traversal batch.')
+
+ execution_profile = self._execution_profile if self._execution_profile else EXEC_PROFILE_GRAPH_DEFAULT
+ graph_options = self._session.get_execution_profile(execution_profile).graph_options
+ context = {
+ 'cluster': self._session.cluster,
+ 'graph_name': graph_options.graph_name
+ }
+ statement = self.as_graph_statement(graph_options.graph_protocol, context=context) \
+ if graph_options.graph_protocol else self.as_graph_statement(context=context)
+ return self._session.execute_graph(statement, execution_profile=execution_profile)
+
+ def clear(self):
+ del self._traversals[:]
+
+ def __len__(self):
+ return len(self._traversals)
diff --git a/cassandra/datastax/graph/fluent/_serializers.py b/cassandra/datastax/graph/fluent/_serializers.py
new file mode 100644
index 0000000..db8e715
--- /dev/null
+++ b/cassandra/datastax/graph/fluent/_serializers.py
@@ -0,0 +1,262 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict
+
+import six
+
+from gremlin_python.structure.io.graphsonV2d0 import (
+ GraphSONReader as GraphSONReaderV2,
+ GraphSONUtil as GraphSONUtil, # no difference between v2 and v3
+ VertexDeserializer as VertexDeserializerV2,
+ VertexPropertyDeserializer as VertexPropertyDeserializerV2,
+ PropertyDeserializer as PropertyDeserializerV2,
+ EdgeDeserializer as EdgeDeserializerV2,
+ PathDeserializer as PathDeserializerV2
+)
+
+from gremlin_python.structure.io.graphsonV3d0 import (
+ GraphSONReader as GraphSONReaderV3,
+ VertexDeserializer as VertexDeserializerV3,
+ VertexPropertyDeserializer as VertexPropertyDeserializerV3,
+ PropertyDeserializer as PropertyDeserializerV3,
+ EdgeDeserializer as EdgeDeserializerV3,
+ PathDeserializer as PathDeserializerV3
+)
+
+try:
+ from gremlin_python.structure.io.graphsonV2d0 import (
+ TraversalMetricsDeserializer as TraversalMetricsDeserializerV2,
+ MetricsDeserializer as MetricsDeserializerV2
+ )
+ from gremlin_python.structure.io.graphsonV3d0 import (
+ TraversalMetricsDeserializer as TraversalMetricsDeserializerV3,
+ MetricsDeserializer as MetricsDeserializerV3
+ )
+except ImportError:
+ TraversalMetricsDeserializerV2 = MetricsDeserializerV2 = None
+ TraversalMetricsDeserializerV3 = MetricsDeserializerV3 = None
+
+from cassandra.graph import (
+ GraphSON2Serializer,
+ GraphSON2Deserializer,
+ GraphSON3Serializer,
+ GraphSON3Deserializer
+)
+from cassandra.graph.graphson import UserTypeIO, TypeWrapperTypeIO
+from cassandra.datastax.graph.fluent.predicates import GeoP, TextDistanceP
+from cassandra.util import Distance
+
+
+__all__ = ['GremlinGraphSONReader', 'GeoPSerializer', 'TextDistancePSerializer',
+ 'DistanceIO', 'gremlin_deserializers', 'deserializers', 'serializers',
+ 'GremlinGraphSONReaderV2', 'GremlinGraphSONReaderV3', 'dse_graphson2_serializers',
+ 'dse_graphson2_deserializers', 'dse_graphson3_serializers', 'dse_graphson3_deserializers',
+ 'gremlin_graphson2_deserializers', 'gremlin_graphson3_deserializers', 'GremlinUserTypeIO']
+
+
+class _GremlinGraphSONTypeSerializer(object):
+ TYPE_KEY = "@type"
+ VALUE_KEY = "@value"
+ serializer = None
+
+ def __init__(self, serializer):
+ self.serializer = serializer
+
+ def dictify(self, v, writer):
+ value = self.serializer.serialize(v, writer)
+ if self.serializer is TypeWrapperTypeIO:
+ graphson_base_type = v.type_io.graphson_base_type
+ graphson_type = v.type_io.graphson_type
+ else:
+ graphson_base_type = self.serializer.graphson_base_type
+ graphson_type = self.serializer.graphson_type
+
+ if graphson_base_type is None:
+ out = value
+ else:
+ out = {self.TYPE_KEY: graphson_type}
+ if value is not None:
+ out[self.VALUE_KEY] = value
+
+ return out
+
+ def definition(self, value, writer=None):
+ return self.serializer.definition(value, writer)
+
+ def get_specialized_serializer(self, value):
+ ser = self.serializer.get_specialized_serializer(value)
+ if ser is not self.serializer:
+ return _GremlinGraphSONTypeSerializer(ser)
+ return self
+
+
+class _GremlinGraphSONTypeDeserializer(object):
+
+ deserializer = None
+
+ def __init__(self, deserializer):
+ self.deserializer = deserializer
+
+ def objectify(self, v, reader):
+ return self.deserializer.deserialize(v, reader)
+
+
+def _make_gremlin_graphson2_deserializer(graphson_type):
+ return _GremlinGraphSONTypeDeserializer(
+ GraphSON2Deserializer.get_deserializer(graphson_type.graphson_type)
+ )
+
+
+def _make_gremlin_graphson3_deserializer(graphson_type):
+ return _GremlinGraphSONTypeDeserializer(
+ GraphSON3Deserializer.get_deserializer(graphson_type.graphson_type)
+ )
+
+
+class _GremlinGraphSONReader(object):
+ """Gremlin GraphSONReader Adapter, required to use gremlin types"""
+
+ context = None
+
+ def __init__(self, context, deserializer_map=None):
+ self.context = context
+ super(_GremlinGraphSONReader, self).__init__(deserializer_map)
+
+ def deserialize(self, obj):
+ return self.toObject(obj)
+
+
+class GremlinGraphSONReaderV2(_GremlinGraphSONReader, GraphSONReaderV2):
+ pass
+
+# TODO remove next major
+GremlinGraphSONReader = GremlinGraphSONReaderV2
+
+class GremlinGraphSONReaderV3(_GremlinGraphSONReader, GraphSONReaderV3):
+ pass
+
+
+class GeoPSerializer(object):
+ @classmethod
+ def dictify(cls, p, writer):
+ out = {
+ "predicateType": "Geo",
+ "predicate": p.operator,
+ "value": [writer.toDict(p.value), writer.toDict(p.other)] if p.other is not None else writer.toDict(p.value)
+ }
+ return GraphSONUtil.typedValue("P", out, prefix='dse')
+
+
+class TextDistancePSerializer(object):
+ @classmethod
+ def dictify(cls, p, writer):
+ out = {
+ "predicate": p.operator,
+ "value": {
+ 'query': writer.toDict(p.value),
+ 'distance': writer.toDict(p.distance)
+ }
+ }
+ return GraphSONUtil.typedValue("P", out)
+
+
+class DistanceIO(object):
+ @classmethod
+ def dictify(cls, v, _):
+ return GraphSONUtil.typedValue('Distance', six.text_type(v), prefix='dse')
+
+
+GremlinUserTypeIO = _GremlinGraphSONTypeSerializer(UserTypeIO)
+
+# GraphSON2
+dse_graphson2_serializers = OrderedDict([
+ (t, _GremlinGraphSONTypeSerializer(s))
+ for t, s in six.iteritems(GraphSON2Serializer.get_type_definitions())
+])
+
+dse_graphson2_serializers.update(OrderedDict([
+ (Distance, DistanceIO),
+ (GeoP, GeoPSerializer),
+ (TextDistanceP, TextDistancePSerializer)
+]))
+
+# TODO remove next major, this is just in case someone was using it
+serializers = dse_graphson2_serializers
+
+dse_graphson2_deserializers = {
+ k: _make_gremlin_graphson2_deserializer(v)
+ for k, v in six.iteritems(GraphSON2Deserializer.get_type_definitions())
+}
+
+dse_graphson2_deserializers.update({
+ "dse:Distance": DistanceIO,
+})
+
+# TODO remove next major, this is just in case someone was using it
+deserializers = dse_graphson2_deserializers
+
+gremlin_graphson2_deserializers = dse_graphson2_deserializers.copy()
+gremlin_graphson2_deserializers.update({
+ 'g:Vertex': VertexDeserializerV2,
+ 'g:VertexProperty': VertexPropertyDeserializerV2,
+ 'g:Edge': EdgeDeserializerV2,
+ 'g:Property': PropertyDeserializerV2,
+ 'g:Path': PathDeserializerV2
+})
+
+if TraversalMetricsDeserializerV2:
+ gremlin_graphson2_deserializers.update({
+ 'g:TraversalMetrics': TraversalMetricsDeserializerV2,
+ 'g:lMetrics': MetricsDeserializerV2
+ })
+
+# TODO remove next major, this is just in case someone was using it
+gremlin_deserializers = gremlin_graphson2_deserializers
+
+# GraphSON3
+dse_graphson3_serializers = OrderedDict([
+ (t, _GremlinGraphSONTypeSerializer(s))
+ for t, s in six.iteritems(GraphSON3Serializer.get_type_definitions())
+])
+
+dse_graphson3_serializers.update(OrderedDict([
+ (Distance, DistanceIO),
+ (GeoP, GeoPSerializer),
+ (TextDistanceP, TextDistancePSerializer)
+]))
+
+dse_graphson3_deserializers = {
+ k: _make_gremlin_graphson3_deserializer(v)
+ for k, v in six.iteritems(GraphSON3Deserializer.get_type_definitions())
+}
+
+dse_graphson3_deserializers.update({
+ "dse:Distance": DistanceIO
+})
+
+gremlin_graphson3_deserializers = dse_graphson3_deserializers.copy()
+gremlin_graphson3_deserializers.update({
+ 'g:Vertex': VertexDeserializerV3,
+ 'g:VertexProperty': VertexPropertyDeserializerV3,
+ 'g:Edge': EdgeDeserializerV3,
+ 'g:Property': PropertyDeserializerV3,
+ 'g:Path': PathDeserializerV3
+})
+
+if TraversalMetricsDeserializerV3:
+ gremlin_graphson3_deserializers.update({
+ 'g:TraversalMetrics': TraversalMetricsDeserializerV3,
+ 'g:Metrics': MetricsDeserializerV3
+ })
diff --git a/tests/integration/advanced/__init__.py b/cassandra/datastax/graph/fluent/predicates.py
similarity index 70%
copy from tests/integration/advanced/__init__.py
copy to cassandra/datastax/graph/fluent/predicates.py
index 662f5b8..6bfd6b3 100644
--- a/tests/integration/advanced/__init__.py
+++ b/cassandra/datastax/graph/fluent/predicates.py
@@ -1,23 +1,20 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License
+# limitations under the License.
try:
- import unittest2 as unittest
+ import gremlin_python
+ from cassandra.datastax.graph.fluent._predicates import *
except ImportError:
- import unittest # noqa
-
-try:
- from ccmlib import common
-except ImportError as e:
- raise unittest.SkipTest('ccm is a dependency for integration tests:', e)
+ # gremlinpython is not installed.
+ pass
diff --git a/tests/integration/advanced/__init__.py b/cassandra/datastax/graph/fluent/query.py
similarity index 70%
copy from tests/integration/advanced/__init__.py
copy to cassandra/datastax/graph/fluent/query.py
index 662f5b8..c5026cc 100644
--- a/tests/integration/advanced/__init__.py
+++ b/cassandra/datastax/graph/fluent/query.py
@@ -1,23 +1,20 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License
+# limitations under the License.
try:
- import unittest2 as unittest
+ import gremlin_python
+ from cassandra.datastax.graph.fluent._query import *
except ImportError:
- import unittest # noqa
-
-try:
- from ccmlib import common
-except ImportError as e:
- raise unittest.SkipTest('ccm is a dependency for integration tests:', e)
+ # gremlinpython is not installed.
+ pass
diff --git a/tests/integration/advanced/__init__.py b/cassandra/datastax/graph/fluent/serializers.py
similarity index 70%
copy from tests/integration/advanced/__init__.py
copy to cassandra/datastax/graph/fluent/serializers.py
index 662f5b8..680e613 100644
--- a/tests/integration/advanced/__init__.py
+++ b/cassandra/datastax/graph/fluent/serializers.py
@@ -1,23 +1,20 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License
+# limitations under the License.
try:
- import unittest2 as unittest
+ import gremlin_python
+ from cassandra.datastax.graph.fluent._serializers import *
except ImportError:
- import unittest # noqa
-
-try:
- from ccmlib import common
-except ImportError as e:
- raise unittest.SkipTest('ccm is a dependency for integration tests:', e)
+ # gremlinpython is not installed.
+ pass
diff --git a/cassandra/datastax/graph/graphson.py b/cassandra/datastax/graph/graphson.py
new file mode 100644
index 0000000..4b333eb
--- /dev/null
+++ b/cassandra/datastax/graph/graphson.py
@@ -0,0 +1,1151 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import base64
+import uuid
+import re
+import json
+from decimal import Decimal
+from collections import OrderedDict
+import logging
+import itertools
+from functools import partial
+
+import six
+
+try:
+ import ipaddress
+except:
+ ipaddress = None
+
+
+from cassandra.cqltypes import cql_types_from_string
+from cassandra.metadata import UserType
+from cassandra.util import Polygon, Point, LineString, Duration
+from cassandra.datastax.graph.types import Vertex, VertexProperty, Edge, Path, T
+
+__all__ = ['GraphSON1Serializer', 'GraphSON1Deserializer', 'GraphSON1TypeDeserializer',
+ 'GraphSON2Serializer', 'GraphSON2Deserializer', 'GraphSON2Reader',
+ 'GraphSON3Serializer', 'GraphSON3Deserializer', 'GraphSON3Reader',
+ 'to_bigint', 'to_int', 'to_double', 'to_float', 'to_smallint',
+ 'BooleanTypeIO', 'Int16TypeIO', 'Int32TypeIO', 'DoubleTypeIO',
+ 'FloatTypeIO', 'UUIDTypeIO', 'BigDecimalTypeIO', 'DurationTypeIO', 'InetTypeIO',
+ 'InstantTypeIO', 'LocalDateTypeIO', 'LocalTimeTypeIO', 'Int64TypeIO', 'BigIntegerTypeIO',
+ 'LocalDateTypeIO', 'PolygonTypeIO', 'PointTypeIO', 'LineStringTypeIO', 'BlobTypeIO',
+ 'GraphSON3Serializer', 'GraphSON3Deserializer', 'UserTypeIO', 'TypeWrapperTypeIO']
+
+"""
+Supported types:
+
+DSE Graph GraphSON 2.0 GraphSON 3.0 | Python Driver
+------------ | -------------- | -------------- | ------------
+text | string | string | str
+boolean | | | bool
+bigint | g:Int64 | g:Int64 | long
+int | g:Int32 | g:Int32 | int
+double | g:Double | g:Double | float
+float | g:Float | g:Float | float
+uuid | g:UUID | g:UUID | UUID
+bigdecimal | gx:BigDecimal | gx:BigDecimal | Decimal
+duration | gx:Duration | N/A | timedelta (Classic graph only)
+DSE Duration | N/A | dse:Duration | Duration (Core graph only)
+inet | gx:InetAddress | gx:InetAddress | str (unicode), IPV4Address/IPV6Address (PY3)
+timestamp | gx:Instant | gx:Instant | datetime.datetime
+date | gx:LocalDate | gx:LocalDate | datetime.date
+time | gx:LocalTime | gx:LocalTime | datetime.time
+smallint | gx:Int16 | gx:Int16 | int
+varint | gx:BigInteger | gx:BigInteger | long
+date | gx:LocalDate | gx:LocalDate | Date
+polygon | dse:Polygon | dse:Polygon | Polygon
+point | dse:Point | dse:Point | Point
+linestring | dse:Linestring | dse:LineString | LineString
+blob | dse:Blob | dse:Blob | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3)
+blob | gx:ByteBuffer | gx:ByteBuffer | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3)
+list | N/A | g:List | list (Core graph only)
+map | N/A | g:Map | dict (Core graph only)
+set | N/A | g:Set | set or list (Core graph only)
+ Can return a list due to numerical values returned by Java
+tuple | N/A | dse:Tuple | tuple (Core graph only)
+udt | N/A | dse:UDT | class or namedtuple (Core graph only)
+"""
+
+MAX_INT32 = 2 ** 32 - 1
+MIN_INT32 = -2 ** 31
+
+log = logging.getLogger(__name__)
+
+
+class _GraphSONTypeType(type):
+ """GraphSONType metaclass, required to create a class property."""
+
+ @property
+ def graphson_type(cls):
+ return "{0}:{1}".format(cls.prefix, cls.graphson_base_type)
+
+
+@six.add_metaclass(_GraphSONTypeType)
+class GraphSONTypeIO(object):
+ """Represent a serializable GraphSON type"""
+
+ prefix = 'g'
+ graphson_base_type = None
+ cql_type = None
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ return {'cqlType': cls.cql_type}
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return six.text_type(value)
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return value
+
+ @classmethod
+ def get_specialized_serializer(cls, value):
+ return cls
+
+
+class TextTypeIO(GraphSONTypeIO):
+ cql_type = 'text'
+
+
+class BooleanTypeIO(GraphSONTypeIO):
+ graphson_base_type = None
+ cql_type = 'boolean'
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return bool(value)
+
+
+class IntegerTypeIO(GraphSONTypeIO):
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return value
+
+ @classmethod
+ def get_specialized_serializer(cls, value):
+ if type(value) in six.integer_types and (value > MAX_INT32 or value < MIN_INT32):
+ return Int64TypeIO
+
+ return Int32TypeIO
+
+
+class Int16TypeIO(IntegerTypeIO):
+ prefix = 'gx'
+ graphson_base_type = 'Int16'
+ cql_type = 'smallint'
+
+
+class Int32TypeIO(IntegerTypeIO):
+ graphson_base_type = 'Int32'
+ cql_type = 'int'
+
+
+class Int64TypeIO(IntegerTypeIO):
+ graphson_base_type = 'Int64'
+ cql_type = 'bigint'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ if six.PY3:
+ return value
+ return long(value)
+
+
+class FloatTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'Float'
+ cql_type = 'float'
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return value
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return float(value)
+
+
+class DoubleTypeIO(FloatTypeIO):
+ graphson_base_type = 'Double'
+ cql_type = 'double'
+
+
+class BigIntegerTypeIO(IntegerTypeIO):
+ prefix = 'gx'
+ graphson_base_type = 'BigInteger'
+
+
+class LocalDateTypeIO(GraphSONTypeIO):
+ FORMAT = '%Y-%m-%d'
+
+ prefix = 'gx'
+ graphson_base_type = 'LocalDate'
+ cql_type = 'date'
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return value.isoformat()
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ try:
+ return datetime.datetime.strptime(value, cls.FORMAT).date()
+ except ValueError:
+ # negative date
+ return value
+
+
+class InstantTypeIO(GraphSONTypeIO):
+ prefix = 'gx'
+ graphson_base_type = 'Instant'
+ cql_type = 'timestamp'
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ if isinstance(value, datetime.datetime):
+ value = datetime.datetime(*value.utctimetuple()[:6]).replace(microsecond=value.microsecond)
+ else:
+ value = datetime.datetime.combine(value, datetime.datetime.min.time())
+
+ return "{0}Z".format(value.isoformat())
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ try:
+ d = datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%fZ')
+ except ValueError:
+ d = datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%SZ')
+ return d
+
+
+class LocalTimeTypeIO(GraphSONTypeIO):
+ FORMATS = [
+ '%H:%M',
+ '%H:%M:%S',
+ '%H:%M:%S.%f'
+ ]
+
+ prefix = 'gx'
+ graphson_base_type = 'LocalTime'
+ cql_type = 'time'
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return value.strftime(cls.FORMATS[2])
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ dt = None
+ for f in cls.FORMATS:
+ try:
+ dt = datetime.datetime.strptime(value, f)
+ break
+ except ValueError:
+ continue
+
+ if dt is None:
+ raise ValueError('Unable to decode LocalTime: {0}'.format(value))
+
+ return dt.time()
+
+
+class BlobTypeIO(GraphSONTypeIO):
+ prefix = 'dse'
+ graphson_base_type = 'Blob'
+ cql_type = 'blob'
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ value = base64.b64encode(value)
+ if six.PY3:
+ value = value.decode('utf-8')
+ return value
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return bytearray(base64.b64decode(value))
+
+
+class ByteBufferTypeIO(BlobTypeIO):
+ prefix = 'gx'
+ graphson_base_type = 'ByteBuffer'
+
+
+class UUIDTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'UUID'
+ cql_type = 'uuid'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return uuid.UUID(value)
+
+
+class BigDecimalTypeIO(GraphSONTypeIO):
+ prefix = 'gx'
+ graphson_base_type = 'BigDecimal'
+ cql_type = 'bigdecimal'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return Decimal(value)
+
+
+class DurationTypeIO(GraphSONTypeIO):
+ prefix = 'gx'
+ graphson_base_type = 'Duration'
+ cql_type = 'duration'
+
+ _duration_regex = re.compile(r"""
+ ^P((?P\d+)D)?
+ T((?P\d+)H)?
+ ((?P\d+)M)?
+ ((?P[0-9.]+)S)?$
+ """, re.VERBOSE)
+ _duration_format = "P{days}DT{hours}H{minutes}M{seconds}S"
+
+ _seconds_in_minute = 60
+ _seconds_in_hour = 60 * _seconds_in_minute
+ _seconds_in_day = 24 * _seconds_in_hour
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ total_seconds = int(value.total_seconds())
+ days, total_seconds = divmod(total_seconds, cls._seconds_in_day)
+ hours, total_seconds = divmod(total_seconds, cls._seconds_in_hour)
+ minutes, total_seconds = divmod(total_seconds, cls._seconds_in_minute)
+ total_seconds += value.microseconds / 1e6
+
+ return cls._duration_format.format(
+ days=int(days), hours=int(hours), minutes=int(minutes), seconds=total_seconds
+ )
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ duration = cls._duration_regex.match(value)
+ if duration is None:
+ raise ValueError('Invalid duration: {0}'.format(value))
+
+ duration = {k: float(v) if v is not None else 0
+ for k, v in six.iteritems(duration.groupdict())}
+ return datetime.timedelta(days=duration['days'], hours=duration['hours'],
+ minutes=duration['minutes'], seconds=duration['seconds'])
+
+
+class DseDurationTypeIO(GraphSONTypeIO):
+ prefix = 'dse'
+ graphson_base_type = 'Duration'
+ cql_type = 'duration'
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return {
+ 'months': value.months,
+ 'days': value.days,
+ 'nanos': value.nanoseconds
+ }
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return Duration(
+ reader.deserialize(value['months']),
+ reader.deserialize(value['days']),
+ reader.deserialize(value['nanos'])
+ )
+
+
+class TypeWrapperTypeIO(GraphSONTypeIO):
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ return {'cqlType': value.type_io.cql_type}
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return value.type_io.serialize(value.value)
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return value.type_io.deserialize(value.value)
+
+
+class PointTypeIO(GraphSONTypeIO):
+ prefix = 'dse'
+ graphson_base_type = 'Point'
+ cql_type = "org.apache.cassandra.db.marshal.PointType"
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return Point.from_wkt(value)
+
+
+class LineStringTypeIO(GraphSONTypeIO):
+ prefix = 'dse'
+ graphson_base_type = 'LineString'
+ cql_type = "org.apache.cassandra.db.marshal.LineStringType"
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return LineString.from_wkt(value)
+
+
+class PolygonTypeIO(GraphSONTypeIO):
+ prefix = 'dse'
+ graphson_base_type = 'Polygon'
+ cql_type = "org.apache.cassandra.db.marshal.PolygonType"
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return Polygon.from_wkt(value)
+
+
+class InetTypeIO(GraphSONTypeIO):
+ prefix = 'gx'
+ graphson_base_type = 'InetAddress'
+ cql_type = 'inet'
+
+
+class VertexTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'Vertex'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ vertex = Vertex(id=reader.deserialize(value["id"]),
+ label=value["label"] if "label" in value else "vertex",
+ type='vertex',
+ properties={})
+ # avoid the properties processing in Vertex.__init__
+ vertex.properties = reader.deserialize(value.get('properties', {}))
+ return vertex
+
+
+class VertexPropertyTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'VertexProperty'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return VertexProperty(label=value['label'],
+ value=reader.deserialize(value["value"]),
+ properties=reader.deserialize(value.get('properties', {})))
+
+
+class EdgeTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'Edge'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ in_vertex = Vertex(id=reader.deserialize(value["inV"]),
+ label=value['inVLabel'],
+ type='vertex',
+ properties={})
+ out_vertex = Vertex(id=reader.deserialize(value["outV"]),
+ label=value['outVLabel'],
+ type='vertex',
+ properties={})
+ return Edge(
+ id=reader.deserialize(value["id"]),
+ label=value["label"] if "label" in value else "vertex",
+ type='edge',
+ properties=reader.deserialize(value.get("properties", {})),
+ inV=in_vertex,
+ inVLabel=value['inVLabel'],
+ outV=out_vertex,
+ outVLabel=value['outVLabel']
+ )
+
+
+class PropertyTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'Property'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return {value["key"]: reader.deserialize(value["value"])}
+
+
+class PathTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'Path'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ labels = [set(label) for label in reader.deserialize(value['labels'])]
+ objects = [obj for obj in reader.deserialize(value['objects'])]
+ p = Path(labels, [])
+ p.objects = objects # avoid the object processing in Path.__init__
+ return p
+
+
+class TraversalMetricsTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'TraversalMetrics'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return reader.deserialize(value)
+
+
+class MetricsTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'Metrics'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return reader.deserialize(value)
+
+
+class JsonMapTypeIO(GraphSONTypeIO):
+ """In GraphSON2, dict are simply serialized as json map"""
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ out = {}
+ for k, v in six.iteritems(value):
+ out[k] = writer.serialize(v, writer)
+
+ return out
+
+
+class MapTypeIO(GraphSONTypeIO):
+ """In GraphSON3, dict has its own type"""
+
+ graphson_base_type = 'Map'
+ cql_type = 'map'
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ out = OrderedDict([('cqlType', cls.cql_type)])
+ out['definition'] = []
+ for k, v in six.iteritems(value):
+ # we just need the first pair to write the def
+ out['definition'].append(writer.definition(k))
+ out['definition'].append(writer.definition(v))
+ break
+ return out
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ out = []
+ for k, v in six.iteritems(value):
+ out.append(writer.serialize(k, writer))
+ out.append(writer.serialize(v, writer))
+
+ return out
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ out = {}
+ a, b = itertools.tee(value)
+ for key, val in zip(
+ itertools.islice(a, 0, None, 2),
+ itertools.islice(b, 1, None, 2)
+ ):
+ out[reader.deserialize(key)] = reader.deserialize(val)
+ return out
+
+
+class ListTypeIO(GraphSONTypeIO):
+ """In GraphSON3, list has its own type"""
+
+ graphson_base_type = 'List'
+ cql_type = 'list'
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ out = OrderedDict([('cqlType', cls.cql_type)])
+ out['definition'] = []
+ if value:
+ out['definition'].append(writer.definition(value[0]))
+ return out
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return [writer.serialize(v, writer) for v in value]
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return [reader.deserialize(obj) for obj in value]
+
+
+class SetTypeIO(GraphSONTypeIO):
+ """In GraphSON3, set has its own type"""
+
+ graphson_base_type = 'Set'
+ cql_type = 'set'
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ out = OrderedDict([('cqlType', cls.cql_type)])
+ out['definition'] = []
+ for v in value:
+ # we only take into account the first value for the definition
+ out['definition'].append(writer.definition(v))
+ break
+ return out
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return [writer.serialize(v, writer) for v in value]
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ lst = [reader.deserialize(obj) for obj in value]
+
+ s = set(lst)
+ if len(s) != len(lst):
+ log.warning("Coercing g:Set to list due to numerical values returned by Java. "
+ "See TINKERPOP-1844 for details.")
+ return lst
+
+ return s
+
+
+class BulkSetTypeIO(GraphSONTypeIO):
+ graphson_base_type = "BulkSet"
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ out = []
+
+ a, b = itertools.tee(value)
+ for val, bulk in zip(
+ itertools.islice(a, 0, None, 2),
+ itertools.islice(b, 1, None, 2)
+ ):
+ val = reader.deserialize(val)
+ bulk = reader.deserialize(bulk)
+ for n in range(bulk):
+ out.append(val)
+
+ return out
+
+
+class TupleTypeIO(GraphSONTypeIO):
+ prefix = 'dse'
+ graphson_base_type = 'Tuple'
+ cql_type = 'tuple'
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ out = OrderedDict()
+ out['cqlType'] = cls.cql_type
+ serializers = [writer.get_serializer(s) for s in value]
+ out['definition'] = [s.definition(v, writer) for v, s in zip(value, serializers)]
+ return out
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ out = cls.definition(value, writer)
+ out['value'] = [writer.serialize(v, writer) for v in value]
+ return out
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return tuple(reader.deserialize(obj) for obj in value['value'])
+
+
+class UserTypeIO(GraphSONTypeIO):
+ prefix = 'dse'
+ graphson_base_type = 'UDT'
+ cql_type = 'udt'
+
+ FROZEN_REMOVAL_REGEX = re.compile(r'frozen<"*([^"]+)"*>')
+
+ @classmethod
+ def cql_types_from_string(cls, typ):
+ # sanitizing: remove frozen references and double quotes...
+ return cql_types_from_string(
+ re.sub(cls.FROZEN_REMOVAL_REGEX, r'\1', typ)
+ )
+
+ @classmethod
+ def get_udt_definition(cls, value, writer):
+ user_type_name = writer.user_types[type(value)]
+ keyspace = writer.context['graph_name']
+ return writer.context['cluster'].metadata.keyspaces[keyspace].user_types[user_type_name]
+
+ @classmethod
+ def is_collection(cls, typ):
+ return typ in ['list', 'tuple', 'map', 'set']
+
+ @classmethod
+ def is_udt(cls, typ, writer):
+ keyspace = writer.context['graph_name']
+ if keyspace in writer.context['cluster'].metadata.keyspaces:
+ return typ in writer.context['cluster'].metadata.keyspaces[keyspace].user_types
+ return False
+
+ @classmethod
+ def field_definition(cls, types, writer, name=None):
+ """
+ Build the udt field definition. This is required when we have a complex udt type.
+ """
+ index = -1
+ out = [OrderedDict() if name is None else OrderedDict([('fieldName', name)])]
+
+ while types:
+ index += 1
+ typ = types.pop(0)
+ if index > 0:
+ out.append(OrderedDict())
+
+ if cls.is_udt(typ, writer):
+ keyspace = writer.context['graph_name']
+ udt = writer.context['cluster'].metadata.keyspaces[keyspace].user_types[typ]
+ out[index].update(cls.definition(udt, writer))
+ elif cls.is_collection(typ):
+ out[index]['cqlType'] = typ
+ definition = cls.field_definition(types, writer)
+ out[index]['definition'] = definition if isinstance(definition, list) else [definition]
+ else:
+ out[index]['cqlType'] = typ
+
+ return out if len(out) > 1 else out[0]
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ udt = value if isinstance(value, UserType) else cls.get_udt_definition(value, writer)
+ return OrderedDict([
+ ('cqlType', cls.cql_type),
+ ('keyspace', udt.keyspace),
+ ('name', udt.name),
+ ('definition', [
+ cls.field_definition(cls.cql_types_from_string(typ), writer, name=name)
+ for name, typ in zip(udt.field_names, udt.field_types)])
+ ])
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ udt = cls.get_udt_definition(value, writer)
+ out = cls.definition(value, writer)
+ out['value'] = []
+ for name, typ in zip(udt.field_names, udt.field_types):
+ out['value'].append(writer.serialize(getattr(value, name), writer))
+ return out
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ udt_class = reader.context['cluster']._user_types[value['keyspace']][value['name']]
+ kwargs = zip(
+ list(map(lambda v: v['fieldName'], value['definition'])),
+ [reader.deserialize(v) for v in value['value']]
+ )
+ return udt_class(**dict(kwargs))
+
+
+class TTypeIO(GraphSONTypeIO):
+ prefix = 'g'
+ graphson_base_type = 'T'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return T.name_to_value[value]
+
+
+class _BaseGraphSONSerializer(object):
+
+ _serializers = OrderedDict()
+
+ @classmethod
+ def register(cls, type, serializer):
+ cls._serializers[type] = serializer
+
+ @classmethod
+ def get_type_definitions(cls):
+ return cls._serializers.copy()
+
+ @classmethod
+ def get_serializer(cls, value):
+ """
+ Get the serializer for a python object.
+
+ :param value: The python object.
+ """
+
+ # The serializer matching logic is as follow:
+ # 1. Try to find the python type by direct access.
+ # 2. Try to find the first serializer by class inheritance.
+ # 3. If no serializer found, return the raw value.
+
+ # Note that when trying to find the serializer by class inheritance,
+ # the order that serializers are registered is important. The use of
+ # an OrderedDict is to avoid the difference between executions.
+ serializer = None
+ try:
+ serializer = cls._serializers[type(value)]
+ except KeyError:
+ for key, serializer_ in cls._serializers.items():
+ if isinstance(value, key):
+ serializer = serializer_
+ break
+
+ if serializer:
+ # A serializer can have specialized serializers (e.g for Int32 and Int64, so value dependant)
+ serializer = serializer.get_specialized_serializer(value)
+
+ return serializer
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ """
+ Serialize a python object to GraphSON.
+
+ e.g 'P42DT10H5M37S'
+ e.g. {'key': value}
+
+ :param value: The python object to serialize.
+ :param writer: A graphson serializer for recursive types (Optional)
+ """
+ serializer = cls.get_serializer(value)
+ if serializer:
+ return serializer.serialize(value, writer or cls)
+
+ return value
+
+
+class GraphSON1Serializer(_BaseGraphSONSerializer):
+ """
+ Serialize python objects to graphson types.
+ """
+
+ # When we fall back to a superclass's serializer, we iterate over this map.
+ # We want that iteration order to be consistent, so we use an OrderedDict,
+ # not a dict.
+ _serializers = OrderedDict([
+ (str, TextTypeIO),
+ (bool, BooleanTypeIO),
+ (bytearray, ByteBufferTypeIO),
+ (Decimal, BigDecimalTypeIO),
+ (datetime.date, LocalDateTypeIO),
+ (datetime.time, LocalTimeTypeIO),
+ (datetime.timedelta, DurationTypeIO),
+ (datetime.datetime, InstantTypeIO),
+ (uuid.UUID, UUIDTypeIO),
+ (Polygon, PolygonTypeIO),
+ (Point, PointTypeIO),
+ (LineString, LineStringTypeIO),
+ (dict, JsonMapTypeIO),
+ (float, FloatTypeIO)
+ ])
+
+
+if ipaddress:
+ GraphSON1Serializer.register(ipaddress.IPv4Address, InetTypeIO)
+ GraphSON1Serializer.register(ipaddress.IPv6Address, InetTypeIO)
+
+if six.PY2:
+ GraphSON1Serializer.register(buffer, ByteBufferTypeIO)
+ GraphSON1Serializer.register(unicode, TextTypeIO)
+else:
+ GraphSON1Serializer.register(memoryview, ByteBufferTypeIO)
+ GraphSON1Serializer.register(bytes, ByteBufferTypeIO)
+
+
+class _BaseGraphSONDeserializer(object):
+
+ _deserializers = {}
+
+ @classmethod
+ def get_type_definitions(cls):
+ return cls._deserializers.copy()
+
+ @classmethod
+ def register(cls, graphson_type, serializer):
+ cls._deserializers[graphson_type] = serializer
+
+ @classmethod
+ def get_deserializer(cls, graphson_type):
+ try:
+ return cls._deserializers[graphson_type]
+ except KeyError:
+ raise ValueError('Invalid `graphson_type` specified: {}'.format(graphson_type))
+
+ @classmethod
+ def deserialize(cls, graphson_type, value):
+ """
+ Deserialize a `graphson_type` value to a python object.
+
+ :param graphson_base_type: The graphson graphson_type. e.g. 'gx:Instant'
+ :param value: The graphson value to deserialize.
+ """
+ return cls.get_deserializer(graphson_type).deserialize(value)
+
+
+class GraphSON1Deserializer(_BaseGraphSONDeserializer):
+ """
+ Deserialize graphson1 types to python objects.
+ """
+ _TYPES = [UUIDTypeIO, BigDecimalTypeIO, InstantTypeIO, BlobTypeIO, ByteBufferTypeIO,
+ PointTypeIO, LineStringTypeIO, PolygonTypeIO, LocalDateTypeIO,
+ LocalTimeTypeIO, DurationTypeIO, InetTypeIO]
+
+ _deserializers = {
+ t.graphson_type: t
+ for t in _TYPES
+ }
+
+ @classmethod
+ def deserialize_date(cls, value):
+ return cls._deserializers[LocalDateTypeIO.graphson_type].deserialize(value)
+
+ @classmethod
+ def deserialize_time(cls, value):
+ return cls._deserializers[LocalTimeTypeIO.graphson_type].deserialize(value)
+
+ @classmethod
+ def deserialize_timestamp(cls, value):
+ return cls._deserializers[InstantTypeIO.graphson_type].deserialize(value)
+
+ @classmethod
+ def deserialize_duration(cls, value):
+ return cls._deserializers[DurationTypeIO.graphson_type].deserialize(value)
+
+ @classmethod
+ def deserialize_int(cls, value):
+ return int(value)
+
+ deserialize_smallint = deserialize_int
+
+ deserialize_varint = deserialize_int
+
+ @classmethod
+ def deserialize_bigint(cls, value):
+ if six.PY3:
+ return cls.deserialize_int(value)
+ return long(value)
+
+ @classmethod
+ def deserialize_double(cls, value):
+ return float(value)
+
+ deserialize_float = deserialize_double
+
+ @classmethod
+ def deserialize_uuid(cls, value):
+ return cls._deserializers[UUIDTypeIO.graphson_type].deserialize(value)
+
+ @classmethod
+ def deserialize_decimal(cls, value):
+ return cls._deserializers[BigDecimalTypeIO.graphson_type].deserialize(value)
+
+ @classmethod
+ def deserialize_blob(cls, value):
+ return cls._deserializers[ByteBufferTypeIO.graphson_type].deserialize(value)
+
+ @classmethod
+ def deserialize_point(cls, value):
+ return cls._deserializers[PointTypeIO.graphson_type].deserialize(value)
+
+ @classmethod
+ def deserialize_linestring(cls, value):
+ return cls._deserializers[LineStringTypeIO.graphson_type].deserialize(value)
+
+ @classmethod
+ def deserialize_polygon(cls, value):
+ return cls._deserializers[PolygonTypeIO.graphson_type].deserialize(value)
+
+ @classmethod
+ def deserialize_inet(cls, value):
+ return value
+
+ @classmethod
+ def deserialize_boolean(cls, value):
+ return value
+
+
+# TODO Remove in the next major
+GraphSON1TypeDeserializer = GraphSON1Deserializer
+GraphSON1TypeSerializer = GraphSON1Serializer
+
+
+class GraphSON2Serializer(_BaseGraphSONSerializer):
+ TYPE_KEY = "@type"
+ VALUE_KEY = "@value"
+
+ _serializers = GraphSON1Serializer.get_type_definitions()
+
+ def serialize(self, value, writer=None):
+ """
+ Serialize a type to GraphSON2.
+
+ e.g {'@type': 'gx:Duration', '@value': 'P2DT4H'}
+
+ :param value: The python object to serialize.
+ """
+ serializer = self.get_serializer(value)
+ if not serializer:
+ raise ValueError("Unable to find a serializer for value of type: ".format(type(value)))
+
+ val = serializer.serialize(value, writer or self)
+ if serializer is TypeWrapperTypeIO:
+ graphson_base_type = value.type_io.graphson_base_type
+ graphson_type = value.type_io.graphson_type
+ else:
+ graphson_base_type = serializer.graphson_base_type
+ graphson_type = serializer.graphson_type
+
+ if graphson_base_type is None:
+ out = val
+ else:
+ out = {self.TYPE_KEY: graphson_type}
+ if val is not None:
+ out[self.VALUE_KEY] = val
+
+ return out
+
+
+GraphSON2Serializer.register(int, IntegerTypeIO)
+if six.PY2:
+ GraphSON2Serializer.register(long, IntegerTypeIO)
+
+
+class GraphSON2Deserializer(_BaseGraphSONDeserializer):
+
+ _TYPES = GraphSON1Deserializer._TYPES + [
+ Int16TypeIO, Int32TypeIO, Int64TypeIO, DoubleTypeIO, FloatTypeIO,
+ BigIntegerTypeIO, VertexTypeIO, VertexPropertyTypeIO, EdgeTypeIO,
+ PathTypeIO, PropertyTypeIO, TraversalMetricsTypeIO, MetricsTypeIO]
+
+ _deserializers = {
+ t.graphson_type: t
+ for t in _TYPES
+ }
+
+
+class GraphSON2Reader(object):
+ """
+ GraphSON2 Reader that parse json and deserialize to python objects.
+ """
+
+ def __init__(self, context, extra_deserializer_map=None):
+ """
+ :param extra_deserializer_map: map from GraphSON type tag to deserializer instance implementing `deserialize`
+ """
+ self.context = context
+ self.deserializers = GraphSON2Deserializer.get_type_definitions()
+ if extra_deserializer_map:
+ self.deserializers.update(extra_deserializer_map)
+
+ def read(self, json_data):
+ """
+ Read and deserialize ``json_data``.
+ """
+ return self.deserialize(json.loads(json_data))
+
+ def deserialize(self, obj):
+ """
+ Deserialize GraphSON type-tagged dict values into objects mapped in self.deserializers
+ """
+ if isinstance(obj, dict):
+ try:
+ des = self.deserializers[obj[GraphSON2Serializer.TYPE_KEY]]
+ return des.deserialize(obj[GraphSON2Serializer.VALUE_KEY], self)
+ except KeyError:
+ pass
+ # list and map are treated as normal json objs (could be isolated deserializers)
+ return {self.deserialize(k): self.deserialize(v) for k, v in six.iteritems(obj)}
+ elif isinstance(obj, list):
+ return [self.deserialize(o) for o in obj]
+ else:
+ return obj
+
+
+class TypeIOWrapper(object):
+ """Used to force a graphson type during serialization"""
+
+ type_io = None
+ value = None
+
+ def __init__(self, type_io, value):
+ self.type_io = type_io
+ self.value = value
+
+
+def _wrap_value(type_io, value):
+ return TypeIOWrapper(type_io, value)
+
+
+to_bigint = partial(_wrap_value, Int64TypeIO)
+to_int = partial(_wrap_value, Int32TypeIO)
+to_smallint = partial(_wrap_value, Int16TypeIO)
+to_double = partial(_wrap_value, DoubleTypeIO)
+to_float = partial(_wrap_value, FloatTypeIO)
+
+
+class GraphSON3Serializer(GraphSON2Serializer):
+
+ _serializers = GraphSON2Serializer.get_type_definitions()
+
+ context = None
+ """A dict of the serialization context"""
+
+ def __init__(self, context):
+ self.context = context
+ self.user_types = None
+
+ def definition(self, value):
+ serializer = self.get_serializer(value)
+ return serializer.definition(value, self)
+
+ def get_serializer(self, value):
+ """Custom get_serializer to support UDT/Tuple"""
+
+ serializer = super(GraphSON3Serializer, self).get_serializer(value)
+ is_namedtuple_udt = serializer is TupleTypeIO and hasattr(value, '_fields')
+ if not serializer or is_namedtuple_udt:
+ # Check if UDT
+ if self.user_types is None:
+ try:
+ user_types = self.context['cluster']._user_types[self.context['graph_name']]
+ self.user_types = dict(map(reversed, six.iteritems(user_types)))
+ except KeyError:
+ self.user_types = {}
+
+ serializer = UserTypeIO if (is_namedtuple_udt or (type(value) in self.user_types)) else serializer
+
+ return serializer
+
+
+GraphSON3Serializer.register(dict, MapTypeIO)
+GraphSON3Serializer.register(list, ListTypeIO)
+GraphSON3Serializer.register(set, SetTypeIO)
+GraphSON3Serializer.register(tuple, TupleTypeIO)
+GraphSON3Serializer.register(Duration, DseDurationTypeIO)
+GraphSON3Serializer.register(TypeIOWrapper, TypeWrapperTypeIO)
+
+
+class GraphSON3Deserializer(GraphSON2Deserializer):
+ _TYPES = GraphSON2Deserializer._TYPES + [MapTypeIO, ListTypeIO,
+ SetTypeIO, TupleTypeIO,
+ UserTypeIO, DseDurationTypeIO,
+ TTypeIO, BulkSetTypeIO]
+
+ _deserializers = {t.graphson_type: t for t in _TYPES}
+
+
+class GraphSON3Reader(GraphSON2Reader):
+ """
+ GraphSON3 Reader that parse json and deserialize to python objects.
+ """
+
+ def __init__(self, context, extra_deserializer_map=None):
+ """
+ :param context: A dict of the context, mostly used as context for udt deserialization.
+ :param extra_deserializer_map: map from GraphSON type tag to deserializer instance implementing `deserialize`
+ """
+ self.context = context
+ self.deserializers = GraphSON3Deserializer.get_type_definitions()
+ if extra_deserializer_map:
+ self.deserializers.update(extra_deserializer_map)
diff --git a/cassandra/datastax/graph/query.py b/cassandra/datastax/graph/query.py
new file mode 100644
index 0000000..7c0e265
--- /dev/null
+++ b/cassandra/datastax/graph/query.py
@@ -0,0 +1,332 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from warnings import warn
+
+import six
+
+from cassandra import ConsistencyLevel
+from cassandra.query import Statement, SimpleStatement
+from cassandra.datastax.graph.types import Vertex, Edge, Path, VertexProperty
+from cassandra.datastax.graph.graphson import GraphSON2Reader, GraphSON3Reader
+
+
+__all__ = [
+ 'GraphProtocol', 'GraphOptions', 'GraphStatement', 'SimpleGraphStatement',
+ 'single_object_row_factory', 'graph_result_row_factory', 'graph_object_row_factory',
+ 'graph_graphson2_row_factory', 'Result', 'graph_graphson3_row_factory'
+]
+
+# (attr, description, server option)
+_graph_options = (
+ ('graph_name', 'name of the targeted graph.', 'graph-name'),
+ ('graph_source', 'choose the graph traversal source, configured on the server side.', 'graph-source'),
+ ('graph_language', 'the language used in the queries (default "gremlin-groovy")', 'graph-language'),
+ ('graph_protocol', 'the graph protocol that the server should use for query results (default "graphson-1-0")', 'graph-results'),
+ ('graph_read_consistency_level', '''read `cassandra.ConsistencyLevel `_ for graph queries (if distinct from session default).
+Setting this overrides the native `Statement.consistency_level `_ for read operations from Cassandra persistence''', 'graph-read-consistency'),
+ ('graph_write_consistency_level', '''write `cassandra.ConsistencyLevel `_ for graph queries (if distinct from session default).
+Setting this overrides the native `Statement.consistency_level `_ for write operations to Cassandra persistence.''', 'graph-write-consistency')
+)
+_graph_option_names = tuple(option[0] for option in _graph_options)
+
+# this is defined by the execution profile attribute, not in graph options
+_request_timeout_key = 'request-timeout'
+
+
+class GraphProtocol(object):
+
+ GRAPHSON_1_0 = b'graphson-1.0'
+ """
+ GraphSON1
+ """
+
+ GRAPHSON_2_0 = b'graphson-2.0'
+ """
+ GraphSON2
+ """
+
+ GRAPHSON_3_0 = b'graphson-3.0'
+ """
+ GraphSON3
+ """
+
+
+class GraphOptions(object):
+ """
+ Options for DSE Graph Query handler.
+ """
+ # See _graph_options map above for notes on valid options
+
+ DEFAULT_GRAPH_PROTOCOL = GraphProtocol.GRAPHSON_1_0
+ DEFAULT_GRAPH_LANGUAGE = b'gremlin-groovy'
+
+ def __init__(self, **kwargs):
+ self._graph_options = {}
+ kwargs.setdefault('graph_source', 'g')
+ kwargs.setdefault('graph_language', GraphOptions.DEFAULT_GRAPH_LANGUAGE)
+ for attr, value in six.iteritems(kwargs):
+ if attr not in _graph_option_names:
+ warn("Unknown keyword argument received for GraphOptions: {0}".format(attr))
+ setattr(self, attr, value)
+
+ def copy(self):
+ new_options = GraphOptions()
+ new_options._graph_options = self._graph_options.copy()
+ return new_options
+
+ def update(self, options):
+ self._graph_options.update(options._graph_options)
+
+ def get_options_map(self, other_options=None):
+ """
+ Returns a map for these options updated with other options,
+ and mapped to graph payload types.
+ """
+ options = self._graph_options.copy()
+ if other_options:
+ options.update(other_options._graph_options)
+
+ # cls are special-cased so they can be enums in the API, and names in the protocol
+ for cl in ('graph-write-consistency', 'graph-read-consistency'):
+ cl_enum = options.get(cl)
+ if cl_enum is not None:
+ options[cl] = six.b(ConsistencyLevel.value_to_name[cl_enum])
+ return options
+
+ def set_source_default(self):
+ """
+ Sets ``graph_source`` to the server-defined default traversal source ('default')
+ """
+ self.graph_source = 'default'
+
+ def set_source_analytics(self):
+ """
+ Sets ``graph_source`` to the server-defined analytic traversal source ('a')
+ """
+ self.graph_source = 'a'
+
+ def set_source_graph(self):
+ """
+ Sets ``graph_source`` to the server-defined graph traversal source ('g')
+ """
+ self.graph_source = 'g'
+
+ def set_graph_protocol(self, protocol):
+ """
+ Sets ``graph_protocol`` as server graph results format (See :class:`cassandra.datastax.graph.GraphProtocol`)
+ """
+ self.graph_protocol = protocol
+
+ @property
+ def is_default_source(self):
+ return self.graph_source in (b'default', None)
+
+ @property
+ def is_analytics_source(self):
+ """
+ True if ``graph_source`` is set to the server-defined analytics traversal source ('a')
+ """
+ return self.graph_source == b'a'
+
+ @property
+ def is_graph_source(self):
+ """
+ True if ``graph_source`` is set to the server-defined graph traversal source ('g')
+ """
+ return self.graph_source == b'g'
+
+
+for opt in _graph_options:
+
+ def get(self, key=opt[2]):
+ return self._graph_options.get(key)
+
+ def set(self, value, key=opt[2]):
+ if value is not None:
+ # normalize text here so it doesn't have to be done every time we get options map
+ if isinstance(value, six.text_type) and not isinstance(value, six.binary_type):
+ value = six.b(value)
+ self._graph_options[key] = value
+ else:
+ self._graph_options.pop(key, None)
+
+ def delete(self, key=opt[2]):
+ self._graph_options.pop(key, None)
+
+ setattr(GraphOptions, opt[0], property(get, set, delete, opt[1]))
+
+
+class GraphStatement(Statement):
+ """ An abstract class representing a graph query."""
+
+ @property
+ def query(self):
+ raise NotImplementedError()
+
+ def __str__(self):
+ return u''.format(self.query)
+ __repr__ = __str__
+
+
+class SimpleGraphStatement(GraphStatement, SimpleStatement):
+ """
+ Simple graph statement for :meth:`.Session.execute_graph`.
+ Takes the same parameters as :class:`.SimpleStatement`.
+ """
+ @property
+ def query(self):
+ return self._query_string
+
+
+def single_object_row_factory(column_names, rows):
+ """
+ returns the JSON string value of graph results
+ """
+ return [row[0] for row in rows]
+
+
+def graph_result_row_factory(column_names, rows):
+ """
+ Returns a :class:`Result ` object that can load graph results and produce specific types.
+ The Result JSON is deserialized and unpacked from the top-level 'result' dict.
+ """
+ return [Result(json.loads(row[0])['result']) for row in rows]
+
+
+def graph_object_row_factory(column_names, rows):
+ """
+ Like :func:`~.graph_result_row_factory`, except known element types (:class:`~.Vertex`, :class:`~.Edge`) are
+ converted to their simplified objects. Some low-level metadata is shed in this conversion. Unknown result types are
+ still returned as :class:`Result `.
+ """
+ return _graph_object_sequence(json.loads(row[0])['result'] for row in rows)
+
+
+def _graph_object_sequence(objects):
+ for o in objects:
+ res = Result(o)
+ if isinstance(o, dict):
+ typ = res.value.get('type')
+ if typ == 'vertex':
+ res = res.as_vertex()
+ elif typ == 'edge':
+ res = res.as_edge()
+ yield res
+
+
+class _GraphSONContextRowFactory(object):
+ graphson_reader_class = None
+ graphson_reader_kwargs = None
+
+ def __init__(self, cluster):
+ context = {'cluster': cluster}
+ kwargs = self.graphson_reader_kwargs or {}
+ self.graphson_reader = self.graphson_reader_class(context, **kwargs)
+
+ def __call__(self, column_names, rows):
+ return [self.graphson_reader.read(row[0])['result'] for row in rows]
+
+
+class _GraphSON2RowFactory(_GraphSONContextRowFactory):
+ """Row factory to deserialize GraphSON2 results."""
+ graphson_reader_class = GraphSON2Reader
+
+
+class _GraphSON3RowFactory(_GraphSONContextRowFactory):
+ """Row factory to deserialize GraphSON3 results."""
+ graphson_reader_class = GraphSON3Reader
+
+
+graph_graphson2_row_factory = _GraphSON2RowFactory
+graph_graphson3_row_factory = _GraphSON3RowFactory
+
+
+class Result(object):
+ """
+ Represents deserialized graph results.
+ Property and item getters are provided for convenience.
+ """
+
+ value = None
+ """
+ Deserialized value from the result
+ """
+
+ def __init__(self, value):
+ self.value = value
+
+ def __getattr__(self, attr):
+ if not isinstance(self.value, dict):
+ raise ValueError("Value cannot be accessed as a dict")
+
+ if attr in self.value:
+ return self.value[attr]
+
+ raise AttributeError("Result has no top-level attribute %r" % (attr,))
+
+ def __getitem__(self, item):
+ if isinstance(self.value, dict) and isinstance(item, six.string_types):
+ return self.value[item]
+ elif isinstance(self.value, list) and isinstance(item, int):
+ return self.value[item]
+ else:
+ raise ValueError("Result cannot be indexed by %r" % (item,))
+
+ def __str__(self):
+ return str(self.value)
+
+ def __repr__(self):
+ return "%s(%r)" % (Result.__name__, self.value)
+
+ def __eq__(self, other):
+ return self.value == other.value
+
+ def as_vertex(self):
+ """
+ Return a :class:`Vertex` parsed from this result
+
+ Raises TypeError if parsing fails (i.e. the result structure is not valid).
+ """
+ try:
+ return Vertex(self.id, self.label, self.type, self.value.get('properties', {}))
+ except (AttributeError, ValueError, TypeError):
+ raise TypeError("Could not create Vertex from %r" % (self,))
+
+ def as_edge(self):
+ """
+ Return a :class:`Edge` parsed from this result
+
+ Raises TypeError if parsing fails (i.e. the result structure is not valid).
+ """
+ try:
+ return Edge(self.id, self.label, self.type, self.value.get('properties', {}),
+ self.inV, self.inVLabel, self.outV, self.outVLabel)
+ except (AttributeError, ValueError, TypeError):
+ raise TypeError("Could not create Edge from %r" % (self,))
+
+ def as_path(self):
+ """
+ Return a :class:`Path` parsed from this result
+
+ Raises TypeError if parsing fails (i.e. the result structure is not valid).
+ """
+ try:
+ return Path(self.labels, self.objects)
+ except (AttributeError, ValueError, TypeError):
+ raise TypeError("Could not create Path from %r" % (self,))
+
+ def as_vertex_property(self):
+ return VertexProperty(self.value.get('label'), self.value.get('value'), self.value.get('properties', {}))
diff --git a/cassandra/datastax/graph/types.py b/cassandra/datastax/graph/types.py
new file mode 100644
index 0000000..9817c99
--- /dev/null
+++ b/cassandra/datastax/graph/types.py
@@ -0,0 +1,210 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ['Element', 'Vertex', 'Edge', 'VertexProperty', 'Path', 'T']
+
+
+class Element(object):
+
+ element_type = None
+
+ _attrs = ('id', 'label', 'type', 'properties')
+
+ def __init__(self, id, label, type, properties):
+ if type != self.element_type:
+ raise TypeError("Attempted to create %s from %s element", (type, self.element_type))
+
+ self.id = id
+ self.label = label
+ self.type = type
+ self.properties = self._extract_properties(properties)
+
+ @staticmethod
+ def _extract_properties(properties):
+ return dict(properties)
+
+ def __eq__(self, other):
+ return all(getattr(self, attr) == getattr(other, attr) for attr in self._attrs)
+
+ def __str__(self):
+ return str(dict((k, getattr(self, k)) for k in self._attrs))
+
+
+class Vertex(Element):
+ """
+ Represents a Vertex element from a graph query.
+
+ Vertex ``properties`` are extracted into a ``dict`` of property names to list of :class:`~VertexProperty` (list
+ because they are always encoded that way, and sometimes have multiple cardinality; VertexProperty because sometimes
+ the properties themselves have property maps).
+ """
+
+ element_type = 'vertex'
+
+ @staticmethod
+ def _extract_properties(properties):
+ # vertex properties are always encoded as a list, regardless of Cardinality
+ return dict((k, [VertexProperty(k, p['value'], p.get('properties')) for p in v]) for k, v in properties.items())
+
+ def __repr__(self):
+ properties = dict((name, [{'label': prop.label, 'value': prop.value, 'properties': prop.properties} for prop in prop_list])
+ for name, prop_list in self.properties.items())
+ return "%s(%r, %r, %r, %r)" % (self.__class__.__name__,
+ self.id, self.label,
+ self.type, properties)
+
+
+class VertexProperty(object):
+ """
+ Vertex properties have a top-level value and an optional ``dict`` of properties.
+ """
+
+ label = None
+ """
+ label of the property
+ """
+
+ value = None
+ """
+ Value of the property
+ """
+
+ properties = None
+ """
+ dict of properties attached to the property
+ """
+
+ def __init__(self, label, value, properties=None):
+ self.label = label
+ self.value = value
+ self.properties = properties or {}
+
+ def __eq__(self, other):
+ return isinstance(other, VertexProperty) and self.label == other.label and self.value == other.value and self.properties == other.properties
+
+ def __repr__(self):
+ return "%s(%r, %r, %r)" % (self.__class__.__name__, self.label, self.value, self.properties)
+
+
+class Edge(Element):
+ """
+ Represents an Edge element from a graph query.
+
+ Attributes match initializer parameters.
+ """
+
+ element_type = 'edge'
+
+ _attrs = Element._attrs + ('inV', 'inVLabel', 'outV', 'outVLabel')
+
+ def __init__(self, id, label, type, properties,
+ inV, inVLabel, outV, outVLabel):
+ super(Edge, self).__init__(id, label, type, properties)
+ self.inV = inV
+ self.inVLabel = inVLabel
+ self.outV = outV
+ self.outVLabel = outVLabel
+
+ def __repr__(self):
+ return "%s(%r, %r, %r, %r, %r, %r, %r, %r)" %\
+ (self.__class__.__name__,
+ self.id, self.label,
+ self.type, self.properties,
+ self.inV, self.inVLabel,
+ self.outV, self.outVLabel)
+
+
+class Path(object):
+ """
+ Represents a graph path.
+
+ Labels list is taken verbatim from the results.
+
+ Objects are either :class:`~.Result` or :class:`~.Vertex`/:class:`~.Edge` for recognized types
+ """
+
+ labels = None
+ """
+ List of labels in the path
+ """
+
+ objects = None
+ """
+ List of objects in the path
+ """
+
+ def __init__(self, labels, objects):
+ # TODO fix next major
+ # The Path class should not do any deserialization by itself. To fix in the next major.
+ from cassandra.datastax.graph.query import _graph_object_sequence
+ self.labels = labels
+ self.objects = list(_graph_object_sequence(objects))
+
+ def __eq__(self, other):
+ return self.labels == other.labels and self.objects == other.objects
+
+ def __str__(self):
+ return str({'labels': self.labels, 'objects': self.objects})
+
+ def __repr__(self):
+ return "%s(%r, %r)" % (self.__class__.__name__, self.labels, [o.value for o in self.objects])
+
+
+class T(object):
+ """
+ Represents a collection of tokens for more concise Traversal definitions.
+ """
+
+ name = None
+ val = None
+
+ # class attributes
+ id = None
+ """
+ """
+
+ key = None
+ """
+ """
+ label = None
+ """
+ """
+ value = None
+ """
+ """
+
+ def __init__(self, name, val):
+ self.name = name
+ self.val = val
+
+ def __str__(self):
+ return self.name
+
+ def __repr__(self):
+ return "T.%s" % (self.name, )
+
+
+T.id = T("id", 1)
+T.id_ = T("id_", 2)
+T.key = T("key", 3)
+T.label = T("label", 4)
+T.value = T("value", 5)
+
+T.name_to_value = {
+ 'id': T.id,
+ 'id_': T.id_,
+ 'key': T.key,
+ 'label': T.label,
+ 'value': T.value
+}
diff --git a/tests/integration/advanced/__init__.py b/cassandra/datastax/insights/__init__.py
similarity index 67%
copy from tests/integration/advanced/__init__.py
copy to cassandra/datastax/insights/__init__.py
index 662f5b8..2c9ca17 100644
--- a/tests/integration/advanced/__init__.py
+++ b/cassandra/datastax/insights/__init__.py
@@ -1,23 +1,13 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License
-
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest # noqa
-
-try:
- from ccmlib import common
-except ImportError as e:
- raise unittest.SkipTest('ccm is a dependency for integration tests:', e)
+# limitations under the License.
diff --git a/cassandra/datastax/insights/registry.py b/cassandra/datastax/insights/registry.py
new file mode 100644
index 0000000..3dd1d25
--- /dev/null
+++ b/cassandra/datastax/insights/registry.py
@@ -0,0 +1,123 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import six
+from collections import OrderedDict
+from warnings import warn
+
+from cassandra.datastax.insights.util import namespace
+
+_NOT_SET = object()
+
+
+def _default_serializer_for_object(obj, policy):
+ # the insights server expects an 'options' dict for policy
+ # objects, but not for other objects
+ if policy:
+ return {'type': obj.__class__.__name__,
+ 'namespace': namespace(obj.__class__),
+ 'options': {}}
+ else:
+ return {'type': obj.__class__.__name__,
+ 'namespace': namespace(obj.__class__)}
+
+
+class InsightsSerializerRegistry(object):
+
+ initialized = False
+
+ def __init__(self, mapping_dict=None):
+ mapping_dict = mapping_dict or {}
+ class_order = self._class_topological_sort(mapping_dict)
+ self._mapping_dict = OrderedDict(
+ ((cls, mapping_dict[cls]) for cls in class_order)
+ )
+
+ def serialize(self, obj, policy=False, default=_NOT_SET, cls=None):
+ try:
+ return self._get_serializer(cls if cls is not None else obj.__class__)(obj)
+ except Exception:
+ if default is _NOT_SET:
+ result = _default_serializer_for_object(obj, policy)
+ else:
+ result = default
+
+ return result
+
+ def _get_serializer(self, cls):
+ try:
+ return self._mapping_dict[cls]
+ except KeyError:
+ for registered_cls, serializer in six.iteritems(self._mapping_dict):
+ if issubclass(cls, registered_cls):
+ return self._mapping_dict[registered_cls]
+ raise ValueError
+
+ def register(self, cls, serializer):
+ self._mapping_dict[cls] = serializer
+ self._mapping_dict = OrderedDict(
+ ((cls, self._mapping_dict[cls])
+ for cls in self._class_topological_sort(self._mapping_dict))
+ )
+
+ def register_serializer_for(self, cls):
+ """
+ Parameterized registration helper decorator. Given a class `cls`,
+ produces a function that registers the decorated function as a
+ serializer for it.
+ """
+ def decorator(serializer):
+ self.register(cls, serializer)
+ return serializer
+
+ return decorator
+
+ @staticmethod
+ def _class_topological_sort(classes):
+ """
+ A simple topological sort for classes. Takes an iterable of class objects
+ and returns a list A of those classes, ordered such that A[X] is never a
+ superclass of A[Y] for X < Y.
+
+ This is an inefficient sort, but that's ok because classes are infrequently
+ registered. It's more important that this be maintainable than fast.
+
+ We can't use `.sort()` or `sorted()` with a custom `key` -- those assume
+ a total ordering, which we don't have.
+ """
+ unsorted, sorted_ = list(classes), []
+ while unsorted:
+ head, tail = unsorted[0], unsorted[1:]
+
+ # if head has no subclasses remaining, it can safely go in the list
+ if not any(issubclass(x, head) for x in tail):
+ sorted_.append(head)
+ else:
+ # move to the back -- head has to wait until all its subclasses
+ # are sorted into the list
+ tail.append(head)
+
+ unsorted = tail
+
+ # check that sort is valid
+ for i, head in enumerate(sorted_):
+ for after_head_value in sorted_[(i + 1):]:
+ if issubclass(after_head_value, head):
+ warn('Sorting classes produced an invalid ordering.\n'
+ 'In: {classes}\n'
+ 'Out: {sorted_}'.format(classes=classes, sorted_=sorted_))
+ return sorted_
+
+
+insights_registry = InsightsSerializerRegistry()
diff --git a/cassandra/datastax/insights/reporter.py b/cassandra/datastax/insights/reporter.py
new file mode 100644
index 0000000..b05a88d
--- /dev/null
+++ b/cassandra/datastax/insights/reporter.py
@@ -0,0 +1,222 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import Counter
+import datetime
+import json
+import logging
+import multiprocessing
+import random
+import platform
+import socket
+import ssl
+import sys
+from threading import Event, Thread
+import time
+import six
+
+from cassandra.policies import HostDistance
+from cassandra.util import ms_timestamp_from_datetime
+from cassandra.datastax.insights.registry import insights_registry
+from cassandra.datastax.insights.serializers import initialize_registry
+
+log = logging.getLogger(__name__)
+
+
+class MonitorReporter(Thread):
+
+ def __init__(self, interval_sec, session):
+ """
+ takes an int indicating interval between requests, a function returning
+ the connection to be used, and the timeout per request
+ """
+ # Thread is an old-style class so we can't super()
+ Thread.__init__(self, name='monitor_reporter')
+
+ initialize_registry(insights_registry)
+
+ self._interval, self._session = interval_sec, session
+
+ self._shutdown_event = Event()
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ self._send_via_rpc(self._get_startup_data())
+
+ # introduce some jitter -- send up to 1/10 of _interval early
+ self._shutdown_event.wait(self._interval * random.uniform(.9, 1))
+
+ while not self._shutdown_event.is_set():
+ start_time = time.time()
+
+ self._send_via_rpc(self._get_status_data())
+
+ elapsed = time.time() - start_time
+ self._shutdown_event.wait(max(self._interval - elapsed, 0.01))
+
+ # TODO: redundant with ConnectionHeartbeat.ShutdownException
+ class ShutDownException(Exception):
+ pass
+
+ def _send_via_rpc(self, data):
+ try:
+ self._session.execute(
+ "CALL InsightsRpc.reportInsight(%s)", (json.dumps(data),)
+ )
+ log.debug('Insights RPC data: {}'.format(data))
+ except Exception as e:
+ log.debug('Insights RPC send failed with {}'.format(e))
+ log.debug('Insights RPC data: {}'.format(data))
+
+ def _get_status_data(self):
+ cc = self._session.cluster.control_connection
+
+ connected_nodes = {
+ host.address: {
+ 'connections': state['open_count'],
+ 'inFlightQueries': state['in_flights']
+ }
+ for (host, state) in self._session.get_pool_state().items()
+ }
+
+ return {
+ 'metadata': {
+ # shared across drivers; never change
+ 'name': 'driver.status',
+ # format version
+ 'insightMappingId': 'v1',
+ 'insightType': 'EVENT',
+ # since epoch
+ 'timestamp': ms_timestamp_from_datetime(datetime.datetime.utcnow()),
+ 'tags': {
+ 'language': 'python'
+ }
+ },
+ # // 'clientId', 'sessionId' and 'controlConnection' are mandatory
+ # // the rest of the properties are optional
+ 'data': {
+ # // 'clientId' must be the same as the one provided in the startup message
+ 'clientId': str(self._session.cluster.client_id),
+ # // 'sessionId' must be the same as the one provided in the startup message
+ 'sessionId': str(self._session.session_id),
+ 'controlConnection': cc._connection.host if cc._connection else None,
+ 'connectedNodes': connected_nodes
+ }
+ }
+
+ def _get_startup_data(self):
+ cc = self._session.cluster.control_connection
+ try:
+ local_ipaddr = cc._connection._socket.getsockname()[0]
+ except Exception as e:
+ local_ipaddr = None
+ log.debug('Unable to get local socket addr from {}: {}'.format(cc._connection, e))
+ hostname = socket.getfqdn()
+
+ host_distances_counter = Counter(
+ self._session.cluster.profile_manager.distance(host)
+ for host in self._session.hosts
+ )
+ host_distances_dict = {
+ 'local': host_distances_counter[HostDistance.LOCAL],
+ 'remote': host_distances_counter[HostDistance.REMOTE],
+ 'ignored': host_distances_counter[HostDistance.IGNORED]
+ }
+
+ try:
+ compression_type = cc._connection._compression_type
+ except AttributeError:
+ compression_type = 'NONE'
+
+ cert_validation = None
+ try:
+ if self._session.cluster.ssl_context:
+ if isinstance(self._session.cluster.ssl_context, ssl.SSLContext):
+ cert_validation = self._session.cluster.ssl_context.verify_mode == ssl.CERT_REQUIRED
+ else: # pyopenssl
+ from OpenSSL import SSL
+ cert_validation = self._session.cluster.ssl_context.get_verify_mode() != SSL.VERIFY_NONE
+ elif self._session.cluster.ssl_options:
+ cert_validation = self._session.cluster.ssl_options.get('cert_reqs') == ssl.CERT_REQUIRED
+ except Exception as e:
+ log.debug('Unable to get the cert validation: {}'.format(e))
+
+ uname_info = platform.uname()
+
+ return {
+ 'metadata': {
+ 'name': 'driver.startup',
+ 'insightMappingId': 'v1',
+ 'insightType': 'EVENT',
+ 'timestamp': ms_timestamp_from_datetime(datetime.datetime.utcnow()),
+ 'tags': {
+ 'language': 'python'
+ },
+ },
+ 'data': {
+ 'driverName': 'DataStax Python Driver',
+ 'driverVersion': sys.modules['cassandra'].__version__,
+ 'clientId': str(self._session.cluster.client_id),
+ 'sessionId': str(self._session.session_id),
+ 'applicationName': self._session.cluster.application_name or 'python',
+ 'applicationNameWasGenerated': not self._session.cluster.application_name,
+ 'applicationVersion': self._session.cluster.application_version,
+ 'contactPoints': self._session.cluster._endpoint_map_for_insights,
+ 'dataCenters': list(set(h.datacenter for h in self._session.cluster.metadata.all_hosts()
+ if (h.datacenter and
+ self._session.cluster.profile_manager.distance(h) == HostDistance.LOCAL))),
+ 'initialControlConnection': cc._connection.host if cc._connection else None,
+ 'protocolVersion': self._session.cluster.protocol_version,
+ 'localAddress': local_ipaddr,
+ 'hostName': hostname,
+ 'executionProfiles': insights_registry.serialize(self._session.cluster.profile_manager),
+ 'configuredConnectionLength': host_distances_dict,
+ 'heartbeatInterval': self._session.cluster.idle_heartbeat_interval,
+ 'compression': compression_type.upper() if compression_type else 'NONE',
+ 'reconnectionPolicy': insights_registry.serialize(self._session.cluster.reconnection_policy),
+ 'sslConfigured': {
+ 'enabled': bool(self._session.cluster.ssl_options or self._session.cluster.ssl_context),
+ 'certValidation': cert_validation
+ },
+ 'authProvider': {
+ 'type': (self._session.cluster.auth_provider.__class__.__name__
+ if self._session.cluster.auth_provider else
+ None)
+ },
+ 'otherOptions': {
+ },
+ 'platformInfo': {
+ 'os': {
+ 'name': uname_info.system if six.PY3 else uname_info[0],
+ 'version': uname_info.release if six.PY3 else uname_info[2],
+ 'arch': uname_info.machine if six.PY3 else uname_info[4]
+ },
+ 'cpus': {
+ 'length': multiprocessing.cpu_count(),
+ 'model': platform.processor()
+ },
+ 'runtime': {
+ 'python': sys.version,
+ 'event_loop': self._session.cluster.connection_class.__name__
+ }
+ },
+ 'periodicStatusInterval': self._interval
+ }
+ }
+
+ def stop(self):
+ log.debug("Shutting down Monitor Reporter")
+ self._shutdown_event.set()
+ self.join()
diff --git a/cassandra/datastax/insights/serializers.py b/cassandra/datastax/insights/serializers.py
new file mode 100644
index 0000000..aec4467
--- /dev/null
+++ b/cassandra/datastax/insights/serializers.py
@@ -0,0 +1,221 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import six
+
+
+def initialize_registry(insights_registry):
+ # This will be called from the cluster module, so we put all this behavior
+ # in a function to avoid circular imports
+
+ if insights_registry.initialized:
+ return False
+
+ from cassandra import ConsistencyLevel
+ from cassandra.cluster import (
+ ExecutionProfile, GraphExecutionProfile,
+ ProfileManager, ContinuousPagingOptions,
+ EXEC_PROFILE_DEFAULT, EXEC_PROFILE_GRAPH_DEFAULT,
+ EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT,
+ EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT,
+ _NOT_SET
+ )
+ from cassandra.datastax.graph import GraphOptions
+ from cassandra.datastax.insights.registry import insights_registry
+ from cassandra.datastax.insights.util import namespace
+ from cassandra.policies import (
+ RoundRobinPolicy,
+ DCAwareRoundRobinPolicy,
+ TokenAwarePolicy,
+ WhiteListRoundRobinPolicy,
+ HostFilterPolicy,
+ ConstantReconnectionPolicy,
+ ExponentialReconnectionPolicy,
+ RetryPolicy,
+ SpeculativeExecutionPolicy,
+ ConstantSpeculativeExecutionPolicy,
+ WrapperPolicy
+ )
+
+ import logging
+
+ log = logging.getLogger(__name__)
+
+ @insights_registry.register_serializer_for(RoundRobinPolicy)
+ def round_robin_policy_insights_serializer(policy):
+ return {'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {}}
+
+ @insights_registry.register_serializer_for(DCAwareRoundRobinPolicy)
+ def dc_aware_round_robin_policy_insights_serializer(policy):
+ return {'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {'local_dc': policy.local_dc,
+ 'used_hosts_per_remote_dc': policy.used_hosts_per_remote_dc}
+ }
+
+ @insights_registry.register_serializer_for(TokenAwarePolicy)
+ def token_aware_policy_insights_serializer(policy):
+ return {'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {'child_policy': insights_registry.serialize(policy._child_policy,
+ policy=True),
+ 'shuffle_replicas': policy.shuffle_replicas}
+ }
+
+ @insights_registry.register_serializer_for(WhiteListRoundRobinPolicy)
+ def whitelist_round_robin_policy_insights_serializer(policy):
+ return {'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {'allowed_hosts': policy._allowed_hosts}
+ }
+
+ @insights_registry.register_serializer_for(HostFilterPolicy)
+ def host_filter_policy_insights_serializer(policy):
+ return {
+ 'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {'child_policy': insights_registry.serialize(policy._child_policy,
+ policy=True),
+ 'predicate': policy.predicate.__name__}
+ }
+
+ @insights_registry.register_serializer_for(ConstantReconnectionPolicy)
+ def constant_reconnection_policy_insights_serializer(policy):
+ return {'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {'delay': policy.delay,
+ 'max_attempts': policy.max_attempts}
+ }
+
+ @insights_registry.register_serializer_for(ExponentialReconnectionPolicy)
+ def exponential_reconnection_policy_insights_serializer(policy):
+ return {'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {'base_delay': policy.base_delay,
+ 'max_delay': policy.max_delay,
+ 'max_attempts': policy.max_attempts}
+ }
+
+ @insights_registry.register_serializer_for(RetryPolicy)
+ def retry_policy_insights_serializer(policy):
+ return {'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {}}
+
+ @insights_registry.register_serializer_for(SpeculativeExecutionPolicy)
+ def speculative_execution_policy_insights_serializer(policy):
+ return {'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {}}
+
+ @insights_registry.register_serializer_for(ConstantSpeculativeExecutionPolicy)
+ def constant_speculative_execution_policy_insights_serializer(policy):
+ return {'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {'delay': policy.delay,
+ 'max_attempts': policy.max_attempts}
+ }
+
+ @insights_registry.register_serializer_for(WrapperPolicy)
+ def wrapper_policy_insights_serializer(policy):
+ return {'type': policy.__class__.__name__,
+ 'namespace': namespace(policy.__class__),
+ 'options': {
+ 'child_policy': insights_registry.serialize(policy._child_policy,
+ policy=True)
+ }}
+
+ @insights_registry.register_serializer_for(ExecutionProfile)
+ def execution_profile_insights_serializer(profile):
+ return {
+ 'loadBalancing': insights_registry.serialize(profile.load_balancing_policy,
+ policy=True),
+ 'retry': insights_registry.serialize(profile.retry_policy,
+ policy=True),
+ 'readTimeout': profile.request_timeout,
+ 'consistency': ConsistencyLevel.value_to_name.get(profile.consistency_level, None),
+ 'serialConsistency': ConsistencyLevel.value_to_name.get(profile.serial_consistency_level, None),
+ 'continuousPagingOptions': (insights_registry.serialize(profile.continuous_paging_options)
+ if (profile.continuous_paging_options is not None and
+ profile.continuous_paging_options is not _NOT_SET) else
+ None),
+ 'speculativeExecution': insights_registry.serialize(profile.speculative_execution_policy),
+ 'graphOptions': None
+ }
+
+ @insights_registry.register_serializer_for(GraphExecutionProfile)
+ def graph_execution_profile_insights_serializer(profile):
+ rv = insights_registry.serialize(profile, cls=ExecutionProfile)
+ rv['graphOptions'] = insights_registry.serialize(profile.graph_options)
+ return rv
+
+ _EXEC_PROFILE_DEFAULT_KEYS = (EXEC_PROFILE_DEFAULT,
+ EXEC_PROFILE_GRAPH_DEFAULT,
+ EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT,
+ EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT)
+
+ @insights_registry.register_serializer_for(ProfileManager)
+ def profile_manager_insights_serializer(manager):
+ defaults = {
+ # Insights's expected default
+ 'default': insights_registry.serialize(manager.profiles[EXEC_PROFILE_DEFAULT]),
+ # remaining named defaults for driver's defaults, including duplicated default
+ 'EXEC_PROFILE_DEFAULT': insights_registry.serialize(manager.profiles[EXEC_PROFILE_DEFAULT]),
+ 'EXEC_PROFILE_GRAPH_DEFAULT': insights_registry.serialize(manager.profiles[EXEC_PROFILE_GRAPH_DEFAULT]),
+ 'EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT': insights_registry.serialize(
+ manager.profiles[EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT]
+ ),
+ 'EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT': insights_registry.serialize(
+ manager.profiles[EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT]
+ )
+ }
+ other = {
+ key: insights_registry.serialize(value)
+ for key, value in manager.profiles.items()
+ if key not in _EXEC_PROFILE_DEFAULT_KEYS
+ }
+ overlapping_keys = set(defaults) & set(other)
+ if overlapping_keys:
+ log.debug('The following key names overlap default key sentinel keys '
+ 'and these non-default EPs will not be displayed in Insights '
+ ': {}'.format(list(overlapping_keys)))
+
+ other.update(defaults)
+ return other
+
+ @insights_registry.register_serializer_for(GraphOptions)
+ def graph_options_insights_serializer(options):
+ rv = {
+ 'source': options.graph_source,
+ 'language': options.graph_language,
+ 'graphProtocol': options.graph_protocol
+ }
+ updates = {k: v.decode('utf-8') for k, v in six.iteritems(rv)
+ if isinstance(v, six.binary_type)}
+ rv.update(updates)
+ return rv
+
+ @insights_registry.register_serializer_for(ContinuousPagingOptions)
+ def continuous_paging_options_insights_serializer(paging_options):
+ return {
+ 'page_unit': paging_options.page_unit,
+ 'max_pages': paging_options.max_pages,
+ 'max_pages_per_second': paging_options.max_pages_per_second,
+ 'max_queue_size': paging_options.max_queue_size
+ }
+
+ insights_registry.initialized = True
+ return True
diff --git a/cassandra/datastax/insights/util.py b/cassandra/datastax/insights/util.py
new file mode 100644
index 0000000..a483b3f
--- /dev/null
+++ b/cassandra/datastax/insights/util.py
@@ -0,0 +1,75 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import traceback
+from warnings import warn
+
+from cassandra.util import Version
+
+
+DSE_60 = Version('6.0.0')
+DSE_51_MIN_SUPPORTED = Version('5.1.13')
+DSE_60_MIN_SUPPORTED = Version('6.0.5')
+
+
+log = logging.getLogger(__name__)
+
+
+def namespace(cls):
+ """
+ Best-effort method for getting the namespace in which a class is defined.
+ """
+ try:
+ # __module__ can be None
+ module = cls.__module__ or ''
+ except Exception:
+ warn("Unable to obtain namespace for {cls} for Insights, returning ''. "
+ "Exception: \n{e}".format(e=traceback.format_exc(), cls=cls))
+ module = ''
+
+ module_internal_namespace = _module_internal_namespace_or_emtpy_string(cls)
+ if module_internal_namespace:
+ return '.'.join((module, module_internal_namespace))
+ return module
+
+
+def _module_internal_namespace_or_emtpy_string(cls):
+ """
+ Best-effort method for getting the module-internal namespace in which a
+ class is defined -- i.e. the namespace _inside_ the module.
+ """
+ try:
+ qualname = cls.__qualname__
+ except AttributeError:
+ return ''
+
+ return '.'.join(
+ # the last segment is the name of the class -- use everything else
+ qualname.split('.')[:-1]
+ )
+
+
+def version_supports_insights(dse_version):
+ if dse_version:
+ try:
+ dse_version = Version(dse_version)
+ return (DSE_51_MIN_SUPPORTED <= dse_version < DSE_60
+ or
+ DSE_60_MIN_SUPPORTED <= dse_version)
+ except Exception:
+ warn("Unable to check version {v} for Insights compatibility, returning False. "
+ "Exception: \n{e}".format(e=traceback.format_exc(), v=dse_version))
+
+ return False
diff --git a/cassandra/encoder.py b/cassandra/encoder.py
index 00f7bf1..f2c3f8d 100644
--- a/cassandra/encoder.py
+++ b/cassandra/encoder.py
@@ -1,243 +1,249 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
These functions are used to convert Python objects into CQL strings.
When non-prepared statements are executed, these encoder functions are
called on each query parameter.
"""
import logging
log = logging.getLogger(__name__)
from binascii import hexlify
import calendar
import datetime
import math
import sys
import types
from uuid import UUID
import six
+from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey,
+ sortedset, Time, Date, Point, LineString, Polygon)
+
if six.PY3:
import ipaddress
-from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey,
- sortedset, Time, Date)
-
if six.PY3:
long = int
def cql_quote(term):
# The ordering of this method is important for the result of this method to
# be a native str type (for both Python 2 and 3)
if isinstance(term, str):
return "'%s'" % str(term).replace("'", "''")
# This branch of the if statement will only be used by Python 2 to catch
# unicode strings, text_type is used to prevent type errors with Python 3.
elif isinstance(term, six.text_type):
return "'%s'" % term.encode('utf8').replace("'", "''")
else:
return str(term)
class ValueSequence(list):
pass
class Encoder(object):
"""
A container for mapping python types to CQL string literals when working
with non-prepared statements. The type :attr:`~.Encoder.mapping` can be
directly customized by users.
"""
mapping = None
"""
A map of python types to encoder functions.
"""
def __init__(self):
self.mapping = {
float: self.cql_encode_float,
bytearray: self.cql_encode_bytes,
str: self.cql_encode_str,
int: self.cql_encode_object,
UUID: self.cql_encode_object,
datetime.datetime: self.cql_encode_datetime,
datetime.date: self.cql_encode_date,
datetime.time: self.cql_encode_time,
Date: self.cql_encode_date_ext,
Time: self.cql_encode_time,
dict: self.cql_encode_map_collection,
OrderedDict: self.cql_encode_map_collection,
OrderedMap: self.cql_encode_map_collection,
OrderedMapSerializedKey: self.cql_encode_map_collection,
list: self.cql_encode_list_collection,
tuple: self.cql_encode_list_collection, # TODO: change to tuple in next major
set: self.cql_encode_set_collection,
sortedset: self.cql_encode_set_collection,
frozenset: self.cql_encode_set_collection,
types.GeneratorType: self.cql_encode_list_collection,
- ValueSequence: self.cql_encode_sequence
+ ValueSequence: self.cql_encode_sequence,
+ Point: self.cql_encode_str_quoted,
+ LineString: self.cql_encode_str_quoted,
+ Polygon: self.cql_encode_str_quoted
}
if six.PY2:
self.mapping.update({
unicode: self.cql_encode_unicode,
buffer: self.cql_encode_bytes,
long: self.cql_encode_object,
types.NoneType: self.cql_encode_none,
})
else:
self.mapping.update({
memoryview: self.cql_encode_bytes,
bytes: self.cql_encode_bytes,
type(None): self.cql_encode_none,
ipaddress.IPv4Address: self.cql_encode_ipaddress,
ipaddress.IPv6Address: self.cql_encode_ipaddress
})
def cql_encode_none(self, val):
"""
Converts :const:`None` to the string 'NULL'.
"""
return 'NULL'
def cql_encode_unicode(self, val):
"""
Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping.
"""
return cql_quote(val.encode('utf-8'))
def cql_encode_str(self, val):
"""
Escapes quotes in :class:`str` objects.
"""
return cql_quote(val)
+ def cql_encode_str_quoted(self, val):
+ return "'%s'" % val
+
if six.PY3:
def cql_encode_bytes(self, val):
return (b'0x' + hexlify(val)).decode('utf-8')
elif sys.version_info >= (2, 7):
def cql_encode_bytes(self, val): # noqa
return b'0x' + hexlify(val)
else:
# python 2.6 requires string or read-only buffer for hexlify
def cql_encode_bytes(self, val): # noqa
return b'0x' + hexlify(buffer(val))
def cql_encode_object(self, val):
"""
Default encoder for all objects that do not have a specific encoder function
registered. This function simply calls :meth:`str()` on the object.
"""
return str(val)
def cql_encode_float(self, val):
"""
Encode floats using repr to preserve precision
"""
if math.isinf(val):
return 'Infinity' if val > 0 else '-Infinity'
elif math.isnan(val):
return 'NaN'
else:
return repr(val)
def cql_encode_datetime(self, val):
"""
Converts a :class:`datetime.datetime` object to a (string) integer timestamp
with millisecond precision.
"""
timestamp = calendar.timegm(val.utctimetuple())
return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
def cql_encode_date(self, val):
"""
Converts a :class:`datetime.date` object to a string with format
``YYYY-MM-DD``.
"""
return "'%s'" % val.strftime('%Y-%m-%d')
def cql_encode_time(self, val):
"""
Converts a :class:`cassandra.util.Time` object to a string with format
``HH:MM:SS.mmmuuunnn``.
"""
return "'%s'" % val
def cql_encode_date_ext(self, val):
"""
Encodes a :class:`cassandra.util.Date` object as an integer
"""
# using the int form in case the Date exceeds datetime.[MIN|MAX]YEAR
return str(val.days_from_epoch + 2 ** 31)
def cql_encode_sequence(self, val):
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``IN`` value lists.
"""
return '(%s)' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v)
for v in val)
cql_encode_tuple = cql_encode_sequence
"""
Converts a sequence to a string of the form ``(item1, item2, ...)``. This
is suitable for ``tuple`` type columns.
"""
def cql_encode_map_collection(self, val):
"""
Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``.
This is suitable for ``map`` type columns.
"""
return '{%s}' % ', '.join('%s: %s' % (
self.mapping.get(type(k), self.cql_encode_object)(k),
self.mapping.get(type(v), self.cql_encode_object)(v)
) for k, v in six.iteritems(val))
def cql_encode_list_collection(self, val):
"""
Converts a sequence to a string of the form ``[item1, item2, ...]``. This
is suitable for ``list`` type columns.
"""
return '[%s]' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
def cql_encode_set_collection(self, val):
"""
Converts a sequence to a string of the form ``{item1, item2, ...}``. This
is suitable for ``set`` type columns.
"""
return '{%s}' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val)
def cql_encode_all_types(self, val, as_text_type=False):
"""
Converts any type into a CQL string, defaulting to ``cql_encode_object``
if :attr:`~Encoder.mapping` does not contain an entry for the type.
"""
encoded = self.mapping.get(type(val), self.cql_encode_object)(val)
if as_text_type and not isinstance(encoded, six.text_type):
return encoded.decode('utf-8')
return encoded
if six.PY3:
def cql_encode_ipaddress(self, val):
"""
Converts an ipaddress (IPV4Address, IPV6Address) to a CQL string. This
is suitable for ``inet`` type columns.
"""
return "'%s'" % val.compressed
diff --git a/tests/integration/advanced/__init__.py b/cassandra/graph/__init__.py
similarity index 67%
copy from tests/integration/advanced/__init__.py
copy to cassandra/graph/__init__.py
index 662f5b8..51bd1de 100644
--- a/tests/integration/advanced/__init__.py
+++ b/cassandra/graph/__init__.py
@@ -1,23 +1,16 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License
+# limitations under the License.
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest # noqa
-
-try:
- from ccmlib import common
-except ImportError as e:
- raise unittest.SkipTest('ccm is a dependency for integration tests:', e)
+# This is only for backward compatibility when migrating from dse-driver.
+from cassandra.datastax.graph import *
\ No newline at end of file
diff --git a/tests/integration/advanced/__init__.py b/cassandra/graph/graphson.py
similarity index 67%
copy from tests/integration/advanced/__init__.py
copy to cassandra/graph/graphson.py
index 662f5b8..d37c172 100644
--- a/tests/integration/advanced/__init__.py
+++ b/cassandra/graph/graphson.py
@@ -1,23 +1,16 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License
+# limitations under the License.
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest # noqa
-
-try:
- from ccmlib import common
-except ImportError as e:
- raise unittest.SkipTest('ccm is a dependency for integration tests:', e)
+# This is only for backward compatibility when migrating from dse-driver.
+from cassandra.datastax.graph.graphson import *
diff --git a/tests/integration/advanced/__init__.py b/cassandra/graph/query.py
similarity index 67%
copy from tests/integration/advanced/__init__.py
copy to cassandra/graph/query.py
index 662f5b8..50eef72 100644
--- a/tests/integration/advanced/__init__.py
+++ b/cassandra/graph/query.py
@@ -1,23 +1,16 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License
+# limitations under the License.
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest # noqa
-
-try:
- from ccmlib import common
-except ImportError as e:
- raise unittest.SkipTest('ccm is a dependency for integration tests:', e)
+# This is only for backward compatibility when migrating from dse-driver.
+from cassandra.datastax.graph.query import *
diff --git a/tests/integration/advanced/__init__.py b/cassandra/graph/types.py
similarity index 67%
copy from tests/integration/advanced/__init__.py
copy to cassandra/graph/types.py
index 662f5b8..c8b613f 100644
--- a/tests/integration/advanced/__init__.py
+++ b/cassandra/graph/types.py
@@ -1,23 +1,16 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License
+# limitations under the License.
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest # noqa
-
-try:
- from ccmlib import common
-except ImportError as e:
- raise unittest.SkipTest('ccm is a dependency for integration tests:', e)
+# This is only for backward compatibility when migrating from dse-driver.
+from cassandra.datastax.graph.types import *
diff --git a/cassandra/io/asyncioreactor.py b/cassandra/io/asyncioreactor.py
index b386388..7cb0444 100644
--- a/cassandra/io/asyncioreactor.py
+++ b/cassandra/io/asyncioreactor.py
@@ -1,215 +1,225 @@
from cassandra.connection import Connection, ConnectionShutdown
import asyncio
import logging
import os
import socket
import ssl
from threading import Lock, Thread, get_ident
log = logging.getLogger(__name__)
# This module uses ``yield from`` and ``@asyncio.coroutine`` over ``await`` and
# ``async def`` for pre-Python-3.5 compatibility, so keep in mind that the
# managed coroutines are generator-based, not native coroutines. See PEP 492:
# https://www.python.org/dev/peps/pep-0492/#coroutine-objects
try:
asyncio.run_coroutine_threadsafe
except AttributeError:
raise ImportError(
'Cannot use asyncioreactor without access to '
'asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)'
)
class AsyncioTimer(object):
"""
An ``asyncioreactor``-specific Timer. Similar to :class:`.connection.Timer,
but with a slightly different API due to limitations in the underlying
``call_later`` interface. Not meant to be used with a
:class:`.connection.TimerManager`.
"""
@property
def end(self):
raise NotImplementedError('{} is not compatible with TimerManager and '
'does not implement .end()')
def __init__(self, timeout, callback, loop):
delayed = self._call_delayed_coro(timeout=timeout,
callback=callback,
loop=loop)
self._handle = asyncio.run_coroutine_threadsafe(delayed, loop=loop)
@staticmethod
@asyncio.coroutine
def _call_delayed_coro(timeout, callback, loop):
yield from asyncio.sleep(timeout, loop=loop)
return callback()
def __lt__(self, other):
try:
return self._handle < other._handle
except AttributeError:
raise NotImplemented
def cancel(self):
self._handle.cancel()
def finish(self):
# connection.Timer method not implemented here because we can't inspect
# the Handle returned from call_later
raise NotImplementedError('{} is not compatible with TimerManager and '
'does not implement .finish()')
class AsyncioConnection(Connection):
"""
An experimental implementation of :class:`.Connection` that uses the
``asyncio`` module in the Python standard library for its event loop.
Note that it requires ``asyncio`` features that were only introduced in the
3.4 line in 3.4.6, and in the 3.5 line in 3.5.1.
"""
_loop = None
_pid = os.getpid()
_lock = Lock()
_loop_thread = None
_write_queue = None
+ _write_queue_lock = None
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self._connect_socket()
self._socket.setblocking(0)
self._write_queue = asyncio.Queue(loop=self._loop)
+ self._write_queue_lock = asyncio.Lock(loop=self._loop)
# see initialize_reactor -- loop is running in a separate thread, so we
# have to use a threadsafe call
self._read_watcher = asyncio.run_coroutine_threadsafe(
self.handle_read(), loop=self._loop
)
self._write_watcher = asyncio.run_coroutine_threadsafe(
self.handle_write(), loop=self._loop
)
self._send_options_message()
@classmethod
def initialize_reactor(cls):
with cls._lock:
if cls._pid != os.getpid():
cls._loop = None
if cls._loop is None:
cls._loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls._loop)
if not cls._loop_thread:
# daemonize so the loop will be shut down on interpreter
# shutdown
cls._loop_thread = Thread(target=cls._loop.run_forever,
daemon=True, name="asyncio_thread")
cls._loop_thread.start()
@classmethod
def create_timer(cls, timeout, callback):
return AsyncioTimer(timeout, callback, loop=cls._loop)
def close(self):
with self.lock:
if self.is_closed:
return
self.is_closed = True
# close from the loop thread to avoid races when removing file
# descriptors
asyncio.run_coroutine_threadsafe(
self._close(), loop=self._loop
)
@asyncio.coroutine
def _close(self):
log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint))
if self._write_watcher:
self._write_watcher.cancel()
if self._read_watcher:
self._read_watcher.cancel()
if self._socket:
self._loop.remove_writer(self._socket.fileno())
self._loop.remove_reader(self._socket.fileno())
self._socket.close()
log.debug("Closed socket to %s" % (self.endpoint,))
if not self.is_defunct:
self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.endpoint))
# don't leave in-progress operations hanging
self.connected_event.set()
def push(self, data):
buff_size = self.out_buffer_size
if len(data) > buff_size:
+ chunks = []
for i in range(0, len(data), buff_size):
- self._push_chunk(data[i:i + buff_size])
+ chunks.append(data[i:i + buff_size])
else:
- self._push_chunk(data)
+ chunks = [data]
- def _push_chunk(self, chunk):
if self._loop_thread.ident != get_ident():
asyncio.run_coroutine_threadsafe(
- self._write_queue.put(chunk),
+ self._push_msg(chunks),
loop=self._loop
)
else:
# avoid races/hangs by just scheduling this, not using threadsafe
- self._loop.create_task(self._write_queue.put(chunk))
+ self._loop.create_task(self._push_msg(chunks))
+
+ @asyncio.coroutine
+ def _push_msg(self, chunks):
+ # This lock ensures all chunks of a message are sequential in the Queue
+ with (yield from self._write_queue_lock):
+ for chunk in chunks:
+ self._write_queue.put_nowait(chunk)
+
@asyncio.coroutine
def handle_write(self):
while True:
try:
next_msg = yield from self._write_queue.get()
if next_msg:
yield from self._loop.sock_sendall(self._socket, next_msg)
except socket.error as err:
log.debug("Exception in send for %s: %s", self, err)
self.defunct(err)
return
except asyncio.CancelledError:
return
@asyncio.coroutine
def handle_read(self):
while True:
try:
buf = yield from self._loop.sock_recv(self._socket, self.in_buffer_size)
self._iobuf.write(buf)
# sock_recv expects EWOULDBLOCK if socket provides no data, but
# nonblocking ssl sockets raise these instead, so we handle them
# ourselves by yielding to the event loop, where the socket will
# get the reading/writing it "wants" before retrying
except (ssl.SSLWantWriteError, ssl.SSLWantReadError):
yield
continue
except socket.error as err:
log.debug("Exception during socket recv for %s: %s",
self, err)
self.defunct(err)
return # leave the read loop
except asyncio.CancelledError:
return
if buf and self._iobuf.tell():
self.process_io_buffer()
else:
log.debug("Connection %s closed by server", self)
self.close()
return
diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py
index d3dd0cf..681552e 100644
--- a/cassandra/io/asyncorereactor.py
+++ b/cassandra/io/asyncorereactor.py
@@ -1,464 +1,467 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
from collections import deque
from functools import partial
import logging
import os
import socket
import sys
from threading import Lock, Thread, Event
import time
import weakref
import sys
+import ssl
from six.moves import range
try:
from weakref import WeakSet
except ImportError:
from cassandra.util import WeakSet # noqa
import asyncore
-try:
- import ssl
-except ImportError:
- ssl = None # NOQA
-
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager
+
log = logging.getLogger(__name__)
_dispatcher_map = {}
def _cleanup(loop):
if loop:
loop._cleanup()
class WaitableTimer(Timer):
def __init__(self, timeout, callback):
Timer.__init__(self, timeout, callback)
self.callback = callback
self.event = Event()
self.final_exception = None
def finish(self, time_now):
try:
finished = Timer.finish(self, time_now)
if finished:
self.event.set()
return True
return False
except Exception as e:
self.final_exception = e
self.event.set()
return True
def wait(self, timeout=None):
self.event.wait(timeout)
if self.final_exception:
raise self.final_exception
class _PipeWrapper(object):
def __init__(self, fd):
self.fd = fd
def fileno(self):
return self.fd
def close(self):
os.close(self.fd)
def getsockopt(self, level, optname, buflen=None):
# act like an unerrored socket for the asyncore error handling
if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen:
return 0
raise NotImplementedError()
class _AsyncoreDispatcher(asyncore.dispatcher):
def __init__(self, socket):
asyncore.dispatcher.__init__(self, map=_dispatcher_map)
# inject after to avoid base class validation
self.set_socket(socket)
self._notified = False
def writable(self):
return False
def validate(self):
assert not self._notified
self.notify_loop()
assert self._notified
self.loop(0.1)
assert not self._notified
def loop(self, timeout):
asyncore.loop(timeout=timeout, use_poll=True, map=_dispatcher_map, count=1)
class _AsyncorePipeDispatcher(_AsyncoreDispatcher):
def __init__(self):
self.read_fd, self.write_fd = os.pipe()
_AsyncoreDispatcher.__init__(self, _PipeWrapper(self.read_fd))
def writable(self):
return False
def handle_read(self):
while len(os.read(self.read_fd, 4096)) == 4096:
pass
self._notified = False
def notify_loop(self):
if not self._notified:
self._notified = True
os.write(self.write_fd, b'x')
class _AsyncoreUDPDispatcher(_AsyncoreDispatcher):
"""
Experimental alternate dispatcher for avoiding busy wait in the asyncore loop. It is not used by default because
it relies on local port binding.
Port scanning is not implemented, so multiple clients on one host will collide. This address would need to be set per
instance, or this could be specialized to scan until an address is found.
To use::
from cassandra.io.asyncorereactor import _AsyncoreUDPDispatcher, AsyncoreLoop
AsyncoreLoop._loop_dispatch_class = _AsyncoreUDPDispatcher
"""
bind_address = ('localhost', 10000)
def __init__(self):
self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._socket.bind(self.bind_address)
self._socket.setblocking(0)
_AsyncoreDispatcher.__init__(self, self._socket)
def handle_read(self):
try:
d = self._socket.recvfrom(1)
while d and d[1]:
d = self._socket.recvfrom(1)
except socket.error as e:
pass
self._notified = False
def notify_loop(self):
if not self._notified:
self._notified = True
self._socket.sendto(b'', self.bind_address)
def loop(self, timeout):
asyncore.loop(timeout=timeout, use_poll=False, map=_dispatcher_map, count=1)
class _BusyWaitDispatcher(object):
max_write_latency = 0.001
"""
Timeout pushed down to asyncore select/poll. Dictates the amount of time it will sleep before coming back to check
if anything is writable.
"""
def notify_loop(self):
pass
def loop(self, timeout):
if not _dispatcher_map:
time.sleep(0.005)
count = timeout // self.max_write_latency
asyncore.loop(timeout=self.max_write_latency, use_poll=True, map=_dispatcher_map, count=count)
def validate(self):
pass
def close(self):
pass
class AsyncoreLoop(object):
timer_resolution = 0.1 # used as the max interval to be in the io loop before returning to service timeouts
_loop_dispatch_class = _AsyncorePipeDispatcher if os.name != 'nt' else _BusyWaitDispatcher
def __init__(self):
self._pid = os.getpid()
self._loop_lock = Lock()
self._started = False
self._shutdown = False
self._thread = None
self._timers = TimerManager()
try:
dispatcher = self._loop_dispatch_class()
dispatcher.validate()
log.debug("Validated loop dispatch with %s", self._loop_dispatch_class)
except Exception:
log.exception("Failed validating loop dispatch with %s. Using busy wait execution instead.", self._loop_dispatch_class)
dispatcher.close()
dispatcher = _BusyWaitDispatcher()
self._loop_dispatcher = dispatcher
def maybe_start(self):
should_start = False
did_acquire = False
try:
did_acquire = self._loop_lock.acquire(False)
if did_acquire and not self._started:
self._started = True
should_start = True
finally:
if did_acquire:
self._loop_lock.release()
if should_start:
self._thread = Thread(target=self._run_loop, name="asyncore_cassandra_driver_event_loop")
self._thread.daemon = True
self._thread.start()
def wake_loop(self):
self._loop_dispatcher.notify_loop()
def _run_loop(self):
log.debug("Starting asyncore event loop")
with self._loop_lock:
while not self._shutdown:
try:
self._loop_dispatcher.loop(self.timer_resolution)
self._timers.service_timeouts()
except Exception:
- log.debug("Asyncore event loop stopped unexepectedly", exc_info=True)
+ try:
+ log.debug("Asyncore event loop stopped unexpectedly", exc_info=True)
+ except Exception:
+ # TODO: Remove when Python 2 support is removed
+ # PYTHON-1266. If our logger has disappeared, there's nothing we
+ # can do, so just log nothing.
+ pass
break
self._started = False
log.debug("Asyncore event loop ended")
def add_timer(self, timer):
self._timers.add_timer(timer)
# This function is called from a different thread than the event loop
# thread, so for this call to be thread safe, we must wake up the loop
# in case it's stuck at a select
self.wake_loop()
def _cleanup(self):
global _dispatcher_map
self._shutdown = True
if not self._thread:
return
log.debug("Waiting for event loop thread to join...")
self._thread.join(timeout=1.0)
if self._thread.is_alive():
log.warning(
"Event loop thread could not be joined, so shutdown may not be clean. "
"Please call Cluster.shutdown() to avoid this.")
log.debug("Event loop thread was joined")
# Ensure all connections are closed and in-flight requests cancelled
for conn in tuple(_dispatcher_map.values()):
if conn is not self._loop_dispatcher:
conn.close()
self._timers.service_timeouts()
# Once all the connections are closed, close the dispatcher
self._loop_dispatcher.close()
log.debug("Dispatchers were closed")
_global_loop = None
atexit.register(partial(_cleanup, _global_loop))
class AsyncoreConnection(Connection, asyncore.dispatcher):
"""
An implementation of :class:`.Connection` that uses the ``asyncore``
module in the Python standard library for its event loop.
"""
_writable = False
_readable = False
@classmethod
def initialize_reactor(cls):
global _global_loop
if not _global_loop:
_global_loop = AsyncoreLoop()
else:
current_pid = os.getpid()
if _global_loop._pid != current_pid:
log.debug("Detected fork, clearing and reinitializing reactor state")
cls.handle_fork()
_global_loop = AsyncoreLoop()
@classmethod
def handle_fork(cls):
global _dispatcher_map, _global_loop
_dispatcher_map = {}
if _global_loop:
_global_loop._cleanup()
_global_loop = None
@classmethod
def create_timer(cls, timeout, callback):
timer = Timer(timeout, callback)
_global_loop.add_timer(timer)
return timer
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self.deque = deque()
self.deque_lock = Lock()
self._connect_socket()
# start the event loop if needed
_global_loop.maybe_start()
init_handler = WaitableTimer(
timeout=0,
callback=partial(asyncore.dispatcher.__init__,
self, self._socket, _dispatcher_map)
)
_global_loop.add_timer(init_handler)
init_handler.wait(kwargs["connect_timeout"])
self._writable = True
self._readable = True
self._send_options_message()
def close(self):
with self.lock:
if self.is_closed:
return
self.is_closed = True
log.debug("Closing connection (%s) to %s", id(self), self.endpoint)
self._writable = False
self._readable = False
# We don't have to wait for this to be closed, we can just schedule it
self.create_timer(0, partial(asyncore.dispatcher.close, self))
log.debug("Closed socket to %s", self.endpoint)
if not self.is_defunct:
self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.endpoint))
#This happens when the connection is shutdown while waiting for the ReadyMessage
if not self.connected_event.is_set():
self.last_error = ConnectionShutdown("Connection to %s was closed" % self.endpoint)
# don't leave in-progress operations hanging
self.connected_event.set()
def handle_error(self):
self.defunct(sys.exc_info()[1])
def handle_close(self):
log.debug("Connection %s closed by server", self)
self.close()
def handle_write(self):
while True:
with self.deque_lock:
try:
next_msg = self.deque.popleft()
except IndexError:
self._writable = False
return
try:
sent = self.send(next_msg)
self._readable = True
except socket.error as err:
if (err.args[0] in NONBLOCKING or
err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE)):
with self.deque_lock:
self.deque.appendleft(next_msg)
else:
self.defunct(err)
return
else:
if sent < len(next_msg):
with self.deque_lock:
self.deque.appendleft(next_msg[sent:])
if sent == 0:
return
def handle_read(self):
try:
while True:
buf = self.recv(self.in_buffer_size)
self._iobuf.write(buf)
if len(buf) < self.in_buffer_size:
break
except socket.error as err:
- if ssl and isinstance(err, ssl.SSLError):
+ if isinstance(err, ssl.SSLError):
if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
if not self._iobuf.tell():
return
else:
self.defunct(err)
return
elif err.args[0] in NONBLOCKING:
if not self._iobuf.tell():
return
else:
self.defunct(err)
return
if self._iobuf.tell():
self.process_io_buffer()
if not self._requests and not self.is_control_connection:
self._readable = False
def push(self, data):
sabs = self.out_buffer_size
if len(data) > sabs:
chunks = []
for i in range(0, len(data), sabs):
chunks.append(data[i:i + sabs])
else:
chunks = [data]
with self.deque_lock:
self.deque.extend(chunks)
self._writable = True
_global_loop.wake_loop()
def writable(self):
return self._writable
def readable(self):
- return self._readable or (self.is_control_connection and not (self.is_defunct or self.is_closed))
+ return self._readable or ((self.is_control_connection or self._continuous_paging_sessions) and not (self.is_defunct or self.is_closed))
diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py
index 2b16ef6..162661f 100644
--- a/cassandra/io/eventletreactor.py
+++ b/cassandra/io/eventletreactor.py
@@ -1,155 +1,194 @@
# Copyright 2014 Symantec Corporation
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Originally derived from MagnetoDB source:
# https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py
-
import eventlet
from eventlet.green import socket
from eventlet.queue import Queue
from greenlet import GreenletExit
import logging
from threading import Event
import time
from six.moves import xrange
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
+try:
+ from eventlet.green.OpenSSL import SSL
+ _PYOPENSSL = True
+except ImportError as e:
+ _PYOPENSSL = False
+ no_pyopenssl_error = e
log = logging.getLogger(__name__)
+def _check_pyopenssl():
+ if not _PYOPENSSL:
+ raise ImportError(
+ "{}, pyOpenSSL must be installed to enable "
+ "SSL support with the Eventlet event loop".format(str(no_pyopenssl_error))
+ )
+
+
class EventletConnection(Connection):
"""
An implementation of :class:`.Connection` that utilizes ``eventlet``.
This implementation assumes all eventlet monkey patching is active. It is not tested with partial patching.
"""
_read_watcher = None
_write_watcher = None
_socket_impl = eventlet.green.socket
_ssl_impl = eventlet.green.ssl
_timers = None
_timeout_watcher = None
_new_timer = None
@classmethod
def initialize_reactor(cls):
eventlet.monkey_patch()
if not cls._timers:
cls._timers = TimerManager()
cls._timeout_watcher = eventlet.spawn(cls.service_timeouts)
cls._new_timer = Event()
@classmethod
def create_timer(cls, timeout, callback):
timer = Timer(timeout, callback)
cls._timers.add_timer(timer)
cls._new_timer.set()
return timer
@classmethod
def service_timeouts(cls):
"""
cls._timeout_watcher runs in this loop forever.
It is usually waiting for the next timeout on the cls._new_timer Event.
When new timers are added, that event is set so that the watcher can
wake up and possibly set an earlier timeout.
"""
timer_manager = cls._timers
while True:
next_end = timer_manager.service_timeouts()
sleep_time = max(next_end - time.time(), 0) if next_end else 10000
cls._new_timer.wait(sleep_time)
cls._new_timer.clear()
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
+ self.uses_legacy_ssl_options = self.ssl_options and not self.ssl_context
self._write_queue = Queue()
self._connect_socket()
self._read_watcher = eventlet.spawn(lambda: self.handle_read())
self._write_watcher = eventlet.spawn(lambda: self.handle_write())
self._send_options_message()
+ def _wrap_socket_from_context(self):
+ _check_pyopenssl()
+ self._socket = SSL.Connection(self.ssl_context, self._socket)
+ self._socket.set_connect_state()
+ if self.ssl_options and 'server_hostname' in self.ssl_options:
+ # This is necessary for SNI
+ self._socket.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii'))
+
+ def _initiate_connection(self, sockaddr):
+ if self.uses_legacy_ssl_options:
+ super(EventletConnection, self)._initiate_connection(sockaddr)
+ else:
+ self._socket.connect(sockaddr)
+ if self.ssl_context or self.ssl_options:
+ self._socket.do_handshake()
+
+ def _match_hostname(self):
+ if self.uses_legacy_ssl_options:
+ super(EventletConnection, self)._match_hostname()
+ else:
+ cert_name = self._socket.get_peer_certificate().get_subject().commonName
+ if cert_name != self.endpoint.address:
+ raise Exception("Hostname verification failed! Certificate name '{}' "
+ "doesn't endpoint '{}'".format(cert_name, self.endpoint.address))
+
def close(self):
with self.lock:
if self.is_closed:
return
self.is_closed = True
log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint))
cur_gthread = eventlet.getcurrent()
if self._read_watcher and self._read_watcher != cur_gthread:
self._read_watcher.kill()
if self._write_watcher and self._write_watcher != cur_gthread:
self._write_watcher.kill()
if self._socket:
self._socket.close()
log.debug("Closed socket to %s" % (self.endpoint,))
if not self.is_defunct:
self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.endpoint))
# don't leave in-progress operations hanging
self.connected_event.set()
def handle_close(self):
log.debug("connection closed by server")
self.close()
def handle_write(self):
while True:
try:
next_msg = self._write_queue.get()
self._socket.sendall(next_msg)
except socket.error as err:
log.debug("Exception during socket send for %s: %s", self, err)
self.defunct(err)
return # Leave the write loop
except GreenletExit: # graceful greenthread exit
return
def handle_read(self):
while True:
try:
buf = self._socket.recv(self.in_buffer_size)
self._iobuf.write(buf)
except socket.error as err:
log.debug("Exception during socket recv for %s: %s",
self, err)
self.defunct(err)
return # leave the read loop
except GreenletExit: # graceful greenthread exit
return
if buf and self._iobuf.tell():
self.process_io_buffer()
else:
log.debug("Connection %s closed by server", self)
self.close()
return
def push(self, data):
chunk_size = self.out_buffer_size
for i in xrange(0, len(data), chunk_size):
self._write_queue.put(data[i:i + chunk_size])
diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py
index 7d4bf8e..54e2d0d 100644
--- a/cassandra/io/libevreactor.py
+++ b/cassandra/io/libevreactor.py
@@ -1,377 +1,386 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
from collections import deque
from functools import partial
import logging
import os
import socket
import ssl
from threading import Lock, Thread
import time
from six.moves import range
from cassandra.connection import (Connection, ConnectionShutdown,
NONBLOCKING, Timer, TimerManager)
try:
import cassandra.io.libevwrapper as libev
except ImportError:
raise ImportError(
"The C extension needed to use libev was not found. This "
"probably means that you didn't have the required build dependencies "
"when installing the driver. See "
"http://datastax.github.io/python-driver/installation.html#c-extensions "
"for instructions on installing build dependencies and building "
"the C extension.")
log = logging.getLogger(__name__)
def _cleanup(loop):
if loop:
loop._cleanup()
class LibevLoop(object):
def __init__(self):
self._pid = os.getpid()
self._loop = libev.Loop()
self._notifier = libev.Async(self._loop)
self._notifier.start()
# prevent _notifier from keeping the loop from returning
self._loop.unref()
self._started = False
self._shutdown = False
self._lock = Lock()
self._lock_thread = Lock()
self._thread = None
# set of all connections; only replaced with a new copy
# while holding _conn_set_lock, never modified in place
self._live_conns = set()
# newly created connections that need their write/read watcher started
self._new_conns = set()
# recently closed connections that need their write/read watcher stopped
self._closed_conns = set()
self._conn_set_lock = Lock()
self._preparer = libev.Prepare(self._loop, self._loop_will_run)
# prevent _preparer from keeping the loop from returning
self._loop.unref()
self._preparer.start()
self._timers = TimerManager()
self._loop_timer = libev.Timer(self._loop, self._on_loop_timer)
def maybe_start(self):
should_start = False
with self._lock:
if not self._started:
log.debug("Starting libev event loop")
self._started = True
should_start = True
if should_start:
with self._lock_thread:
if not self._shutdown:
self._thread = Thread(target=self._run_loop, name="event_loop")
self._thread.daemon = True
self._thread.start()
self._notifier.send()
def _run_loop(self):
while True:
self._loop.start()
# there are still active watchers, no deadlock
with self._lock:
if not self._shutdown and self._live_conns:
log.debug("Restarting event loop")
continue
else:
# all Connections have been closed, no active watchers
log.debug("All Connections currently closed, event loop ended")
self._started = False
break
def _cleanup(self):
self._shutdown = True
if not self._thread:
return
for conn in self._live_conns | self._new_conns | self._closed_conns:
conn.close()
for watcher in (conn._write_watcher, conn._read_watcher):
if watcher:
watcher.stop()
self.notify() # wake the timer watcher
# PYTHON-752 Thread might have just been created and not started
with self._lock_thread:
self._thread.join(timeout=1.0)
if self._thread.is_alive():
log.warning(
"Event loop thread could not be joined, so shutdown may not be clean. "
"Please call Cluster.shutdown() to avoid this.")
log.debug("Event loop thread was joined")
def add_timer(self, timer):
self._timers.add_timer(timer)
self._notifier.send() # wake up in case this timer is earlier
def _update_timer(self):
if not self._shutdown:
next_end = self._timers.service_timeouts()
if next_end:
self._loop_timer.start(next_end - time.time()) # timer handles negative values
else:
self._loop_timer.stop()
def _on_loop_timer(self):
self._timers.service_timeouts()
def notify(self):
self._notifier.send()
def connection_created(self, conn):
with self._conn_set_lock:
new_live_conns = self._live_conns.copy()
new_live_conns.add(conn)
self._live_conns = new_live_conns
new_new_conns = self._new_conns.copy()
new_new_conns.add(conn)
self._new_conns = new_new_conns
def connection_destroyed(self, conn):
with self._conn_set_lock:
new_live_conns = self._live_conns.copy()
new_live_conns.discard(conn)
self._live_conns = new_live_conns
new_closed_conns = self._closed_conns.copy()
new_closed_conns.add(conn)
self._closed_conns = new_closed_conns
self._notifier.send()
def _loop_will_run(self, prepare):
changed = False
for conn in self._live_conns:
if not conn.deque and conn._write_watcher_is_active:
if conn._write_watcher:
conn._write_watcher.stop()
conn._write_watcher_is_active = False
changed = True
elif conn.deque and not conn._write_watcher_is_active:
conn._write_watcher.start()
conn._write_watcher_is_active = True
changed = True
if self._new_conns:
with self._conn_set_lock:
to_start = self._new_conns
self._new_conns = set()
for conn in to_start:
conn._read_watcher.start()
changed = True
if self._closed_conns:
with self._conn_set_lock:
to_stop = self._closed_conns
self._closed_conns = set()
for conn in to_stop:
if conn._write_watcher:
conn._write_watcher.stop()
# clear reference cycles from IO callback
del conn._write_watcher
if conn._read_watcher:
conn._read_watcher.stop()
# clear reference cycles from IO callback
del conn._read_watcher
changed = True
# TODO: update to do connection management, timer updates through dedicated async 'notifier' callbacks
self._update_timer()
if changed:
self._notifier.send()
_global_loop = None
atexit.register(partial(_cleanup, _global_loop))
class LibevConnection(Connection):
"""
An implementation of :class:`.Connection` that uses libev for its event loop.
"""
_write_watcher_is_active = False
_read_watcher = None
_write_watcher = None
_socket = None
@classmethod
def initialize_reactor(cls):
global _global_loop
if not _global_loop:
_global_loop = LibevLoop()
else:
if _global_loop._pid != os.getpid():
log.debug("Detected fork, clearing and reinitializing reactor state")
cls.handle_fork()
_global_loop = LibevLoop()
@classmethod
def handle_fork(cls):
global _global_loop
if _global_loop:
_global_loop._cleanup()
_global_loop = None
@classmethod
def create_timer(cls, timeout, callback):
timer = Timer(timeout, callback)
_global_loop.add_timer(timer)
return timer
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self.deque = deque()
self._deque_lock = Lock()
self._connect_socket()
self._socket.setblocking(0)
with _global_loop._lock:
self._read_watcher = libev.IO(self._socket.fileno(), libev.EV_READ, _global_loop._loop, self.handle_read)
self._write_watcher = libev.IO(self._socket.fileno(), libev.EV_WRITE, _global_loop._loop, self.handle_write)
self._send_options_message()
_global_loop.connection_created(self)
# start the global event loop if needed
_global_loop.maybe_start()
def close(self):
with self.lock:
if self.is_closed:
return
self.is_closed = True
log.debug("Closing connection (%s) to %s", id(self), self.endpoint)
_global_loop.connection_destroyed(self)
self._socket.close()
log.debug("Closed socket to %s", self.endpoint)
# don't leave in-progress operations hanging
if not self.is_defunct:
self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.endpoint))
def handle_write(self, watcher, revents, errno=None):
if revents & libev.EV_ERROR:
if errno:
exc = IOError(errno, os.strerror(errno))
else:
exc = Exception("libev reported an error")
self.defunct(exc)
return
while True:
try:
with self._deque_lock:
next_msg = self.deque.popleft()
except IndexError:
+ if not self._socket_writable:
+ self._socket_writable = True
return
try:
sent = self._socket.send(next_msg)
except socket.error as err:
if (err.args[0] in NONBLOCKING or
err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE)):
+ if err.args[0] in NONBLOCKING:
+ self._socket_writable = False
with self._deque_lock:
self.deque.appendleft(next_msg)
else:
self.defunct(err)
return
else:
if sent < len(next_msg):
with self._deque_lock:
self.deque.appendleft(next_msg[sent:])
+ # we've seen some cases that 0 is returned instead of NONBLOCKING. But usually,
+ # we don't expect this to happen. https://bugs.python.org/issue20951
+ if sent == 0:
+ self._socket_writable = False
+ return
def handle_read(self, watcher, revents, errno=None):
if revents & libev.EV_ERROR:
if errno:
exc = IOError(errno, os.strerror(errno))
else:
exc = Exception("libev reported an error")
self.defunct(exc)
return
try:
while True:
buf = self._socket.recv(self.in_buffer_size)
self._iobuf.write(buf)
if len(buf) < self.in_buffer_size:
break
except socket.error as err:
- if ssl and isinstance(err, ssl.SSLError):
+ if isinstance(err, ssl.SSLError):
if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
if not self._iobuf.tell():
return
else:
self.defunct(err)
return
elif err.args[0] in NONBLOCKING:
if not self._iobuf.tell():
return
else:
self.defunct(err)
return
if self._iobuf.tell():
self.process_io_buffer()
else:
log.debug("Connection %s closed by server", self)
self.close()
def push(self, data):
sabs = self.out_buffer_size
if len(data) > sabs:
chunks = []
for i in range(0, len(data), sabs):
chunks.append(data[i:i + sabs])
else:
chunks = [data]
with self._deque_lock:
self.deque.extend(chunks)
_global_loop.notify()
diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py
index 1dbe9d8..9b3ff09 100644
--- a/cassandra/io/twistedreactor.py
+++ b/cassandra/io/twistedreactor.py
@@ -1,317 +1,303 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module that implements an event loop based on twisted
( https://twistedmatrix.com ).
"""
import atexit
-from functools import partial
import logging
-from threading import Thread, Lock
import time
-from twisted.internet import reactor, protocol
+from functools import partial
+from threading import Thread, Lock
import weakref
-from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
+from twisted.internet import reactor, protocol
+from twisted.internet.endpoints import connectProtocol, TCP4ClientEndpoint, SSL4ClientEndpoint
+from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
+from twisted.python.failure import Failure
+from zope.interface import implementer
+from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager, ConnectionException
+try:
+ from OpenSSL import SSL
+ _HAS_SSL = True
+except ImportError as e:
+ _HAS_SSL = False
+ import_exception = e
log = logging.getLogger(__name__)
def _cleanup(cleanup_weakref):
try:
cleanup_weakref()._cleanup()
except ReferenceError:
return
class TwistedConnectionProtocol(protocol.Protocol):
"""
Twisted Protocol class for handling data received and connection
made events.
"""
- def __init__(self):
- self.connection = None
+ def __init__(self, connection):
+ self.connection = connection
def dataReceived(self, data):
"""
Callback function that is called when data has been received
on the connection.
Reaches back to the Connection object and queues the data for
processing.
"""
self.connection._iobuf.write(data)
self.connection.handle_read()
+
def connectionMade(self):
"""
Callback function that is called when a connection has succeeded.
Reaches back to the Connection object and confirms that the connection
is ready.
"""
- try:
- # Non SSL connection
- self.connection = self.transport.connector.factory.conn
- except AttributeError:
- # SSL connection
- self.connection = self.transport.connector.factory.wrappedFactory.conn
-
self.connection.client_connection_made(self.transport)
def connectionLost(self, reason):
# reason is a Failure instance
- self.connection.defunct(reason.value)
-
-
-class TwistedConnectionClientFactory(protocol.ClientFactory):
-
- def __init__(self, connection):
- # ClientFactory does not define __init__() in parent classes
- # and does not inherit from object.
- self.conn = connection
-
- def buildProtocol(self, addr):
- """
- Twisted function that defines which kind of protocol to use
- in the ClientFactory.
- """
- return TwistedConnectionProtocol()
-
- def clientConnectionFailed(self, connector, reason):
- """
- Overridden twisted callback which is called when the
- connection attempt fails.
- """
- log.debug("Connect failed: %s", reason)
- self.conn.defunct(reason.value)
-
- def clientConnectionLost(self, connector, reason):
- """
- Overridden twisted callback which is called when the
- connection goes away (cleanly or otherwise).
-
- It should be safe to call defunct() here instead of just close, because
- we can assume that if the connection was closed cleanly, there are no
- requests to error out. If this assumption turns out to be false, we
- can call close() instead of defunct() when "reason" is an appropriate
- type.
- """
log.debug("Connect lost: %s", reason)
- self.conn.defunct(reason.value)
+ self.connection.defunct(reason.value)
class TwistedLoop(object):
_lock = None
_thread = None
_timeout_task = None
_timeout = None
def __init__(self):
self._lock = Lock()
self._timers = TimerManager()
def maybe_start(self):
with self._lock:
if not reactor.running:
self._thread = Thread(target=reactor.run,
name="cassandra_driver_twisted_event_loop",
kwargs={'installSignalHandlers': False})
self._thread.daemon = True
self._thread.start()
atexit.register(partial(_cleanup, weakref.ref(self)))
def _cleanup(self):
if self._thread:
reactor.callFromThread(reactor.stop)
self._thread.join(timeout=1.0)
if self._thread.is_alive():
log.warning("Event loop thread could not be joined, so "
"shutdown may not be clean. Please call "
"Cluster.shutdown() to avoid this.")
log.debug("Event loop thread was joined")
def add_timer(self, timer):
self._timers.add_timer(timer)
# callFromThread to schedule from the loop thread, where
# the timeout task can safely be modified
reactor.callFromThread(self._schedule_timeout, timer.end)
def _schedule_timeout(self, next_timeout):
if next_timeout:
delay = max(next_timeout - time.time(), 0)
if self._timeout_task and self._timeout_task.active():
if next_timeout < self._timeout:
self._timeout_task.reset(delay)
self._timeout = next_timeout
else:
self._timeout_task = reactor.callLater(delay, self._on_loop_timer)
self._timeout = next_timeout
def _on_loop_timer(self):
self._timers.service_timeouts()
self._schedule_timeout(self._timers.next_timeout)
-try:
- from twisted.internet import ssl
- import OpenSSL.crypto
- from OpenSSL.crypto import load_certificate, FILETYPE_PEM
-
- class _SSLContextFactory(ssl.ClientContextFactory):
- def __init__(self, ssl_options, check_hostname, host):
- self.ssl_options = ssl_options
- self.check_hostname = check_hostname
- self.host = host
-
- def getContext(self):
- # This version has to be OpenSSL.SSL.DESIRED_VERSION
- # instead of ssl.DESIRED_VERSION as in other loops
- self.method = self.ssl_options["ssl_version"]
- context = ssl.ClientContextFactory.getContext(self)
+@implementer(IOpenSSLClientConnectionCreator)
+class _SSLCreator(object):
+ def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout):
+ self.endpoint = endpoint
+ self.ssl_options = ssl_options
+ self.check_hostname = check_hostname
+ self.timeout = timeout
+
+ if ssl_context:
+ self.context = ssl_context
+ else:
+ self.context = SSL.Context(SSL.TLSv1_METHOD)
if "certfile" in self.ssl_options:
- context.use_certificate_file(self.ssl_options["certfile"])
+ self.context.use_certificate_file(self.ssl_options["certfile"])
if "keyfile" in self.ssl_options:
- context.use_privatekey_file(self.ssl_options["keyfile"])
+ self.context.use_privatekey_file(self.ssl_options["keyfile"])
if "ca_certs" in self.ssl_options:
- x509 = load_certificate(FILETYPE_PEM, open(self.ssl_options["ca_certs"]).read())
- store = context.get_cert_store()
- store.add_cert(x509)
+ self.context.load_verify_locations(self.ssl_options["ca_certs"])
if "cert_reqs" in self.ssl_options:
- # This expects OpenSSL.SSL.VERIFY_NONE/OpenSSL.SSL.VERIFY_PEER
- # or OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT
- context.set_verify(self.ssl_options["cert_reqs"],
- callback=self.verify_callback)
- return context
-
- def verify_callback(self, connection, x509, errnum, errdepth, ok):
- if ok:
- if self.check_hostname and self.host != x509.get_subject().commonName:
- return False
- return ok
+ self.context.set_verify(
+ self.ssl_options["cert_reqs"],
+ callback=self.verify_callback
+ )
+ self.context.set_info_callback(self.info_callback)
- _HAS_SSL = True
+ def verify_callback(self, connection, x509, errnum, errdepth, ok):
+ return ok
-except ImportError as e:
- _HAS_SSL = False
+ def info_callback(self, connection, where, ret):
+ if where & SSL.SSL_CB_HANDSHAKE_DONE:
+ if self.check_hostname and self.endpoint.address != connection.get_peer_certificate().get_subject().commonName:
+ transport = connection.get_app_data()
+ transport.failVerification(Failure(ConnectionException("Hostname verification failed", self.endpoint)))
+
+ def clientConnectionForTLS(self, tlsProtocol):
+ connection = SSL.Connection(self.context, None)
+ connection.set_app_data(tlsProtocol)
+ if self.ssl_options and "server_hostname" in self.ssl_options:
+ connection.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii'))
+ return connection
class TwistedConnection(Connection):
"""
An implementation of :class:`.Connection` that utilizes the
Twisted event loop.
"""
_loop = None
@classmethod
def initialize_reactor(cls):
if not cls._loop:
cls._loop = TwistedLoop()
@classmethod
def create_timer(cls, timeout, callback):
timer = Timer(timeout, callback)
cls._loop.add_timer(timer)
return timer
def __init__(self, *args, **kwargs):
"""
Initialization method.
Note that we can't call reactor methods directly here because
it's not thread-safe, so we schedule the reactor/connection
stuff to be run from the event loop thread when it gets the
chance.
"""
Connection.__init__(self, *args, **kwargs)
self.is_closed = True
self.connector = None
self.transport = None
reactor.callFromThread(self.add_connection)
self._loop.maybe_start()
- def add_connection(self):
- """
- Convenience function to connect and store the resulting
- connector.
- """
- if self.ssl_options:
-
+ def _check_pyopenssl(self):
+ if self.ssl_context or self.ssl_options:
if not _HAS_SSL:
raise ImportError(
- str(e) +
+ str(import_exception) +
', pyOpenSSL must be installed to enable SSL support with the Twisted event loop'
)
- self.connector = reactor.connectSSL(
- host=self.endpoint.address, port=self.port,
- factory=TwistedConnectionClientFactory(self),
- contextFactory=_SSLContextFactory(self.ssl_options, self._check_hostname, self.endpoint.address),
- timeout=self.connect_timeout)
+ def add_connection(self):
+ """
+ Convenience function to connect and store the resulting
+ connector.
+ """
+ host, port = self.endpoint.resolve()
+ if self.ssl_context or self.ssl_options:
+ # Can't use optionsForClientTLS here because it *forces* hostname verification.
+ # Cool they enforce strong security, but we have to be able to turn it off
+ self._check_pyopenssl()
+
+ ssl_connection_creator = _SSLCreator(
+ self.endpoint,
+ self.ssl_context if self.ssl_context else None,
+ self.ssl_options,
+ self._check_hostname,
+ self.connect_timeout,
+ )
+
+ endpoint = SSL4ClientEndpoint(
+ reactor,
+ host,
+ port,
+ sslContextFactory=ssl_connection_creator,
+ timeout=self.connect_timeout,
+ )
else:
- self.connector = reactor.connectTCP(
- host=self.endpoint.address, port=self.port,
- factory=TwistedConnectionClientFactory(self),
- timeout=self.connect_timeout)
+ endpoint = TCP4ClientEndpoint(
+ reactor,
+ host,
+ port,
+ timeout=self.connect_timeout
+ )
+ connectProtocol(endpoint, TwistedConnectionProtocol(self))
def client_connection_made(self, transport):
"""
Called by twisted protocol when a connection attempt has
succeeded.
"""
with self.lock:
self.is_closed = False
self.transport = transport
self._send_options_message()
def close(self):
"""
Disconnect and error-out all requests.
"""
with self.lock:
if self.is_closed:
return
self.is_closed = True
log.debug("Closing connection (%s) to %s", id(self), self.endpoint)
- reactor.callFromThread(self.connector.disconnect)
+ reactor.callFromThread(self.transport.connector.disconnect)
log.debug("Closed socket to %s", self.endpoint)
if not self.is_defunct:
self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.endpoint))
# don't leave in-progress operations hanging
self.connected_event.set()
def handle_read(self):
"""
Process the incoming data buffer.
"""
self.process_io_buffer()
def push(self, data):
"""
This function is called when outgoing data should be queued
for sending.
Note that we can't call transport.write() directly because
it is not thread-safe, so we schedule it to run from within
the event loop when it gets the chance.
"""
reactor.callFromThread(self.transport.write, data)
diff --git a/cassandra/marshal.py b/cassandra/marshal.py
index 3b80f34..43cb627 100644
--- a/cassandra/marshal.py
+++ b/cassandra/marshal.py
@@ -1,157 +1,165 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import struct
def _make_packer(format_string):
packer = struct.Struct(format_string)
pack = packer.pack
unpack = lambda s: packer.unpack(s)[0]
return pack, unpack
int64_pack, int64_unpack = _make_packer('>q')
int32_pack, int32_unpack = _make_packer('>i')
int16_pack, int16_unpack = _make_packer('>h')
int8_pack, int8_unpack = _make_packer('>b')
uint64_pack, uint64_unpack = _make_packer('>Q')
uint32_pack, uint32_unpack = _make_packer('>I')
+uint32_le_pack, uint32_le_unpack = _make_packer('H')
uint8_pack, uint8_unpack = _make_packer('>B')
float_pack, float_unpack = _make_packer('>f')
double_pack, double_unpack = _make_packer('>d')
# Special case for cassandra header
header_struct = struct.Struct('>BBbB')
header_pack = header_struct.pack
header_unpack = header_struct.unpack
# in protocol version 3 and higher, the stream ID is two bytes
v3_header_struct = struct.Struct('>BBhB')
v3_header_pack = v3_header_struct.pack
v3_header_unpack = v3_header_struct.unpack
if six.PY3:
def byte2int(b):
return b
def varint_unpack(term):
val = int(''.join("%02x" % i for i in term), 16)
if (term[0] & 128) != 0:
len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code
val -= 1 << (len_term * 8)
return val
else:
def byte2int(b):
return ord(b)
def varint_unpack(term): # noqa
val = int(term.encode('hex'), 16)
if (ord(term[0]) & 128) != 0:
len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code
val = val - (1 << (len_term * 8))
return val
def bit_length(n):
if six.PY3 or isinstance(n, int):
return int.bit_length(n)
else:
return long.bit_length(n)
def varint_pack(big):
pos = True
if big == 0:
return b'\x00'
if big < 0:
bytelength = bit_length(abs(big) - 1) // 8 + 1
big = (1 << bytelength * 8) + big
pos = False
revbytes = bytearray()
while big > 0:
revbytes.append(big & 0xff)
big >>= 8
if pos and revbytes[-1] & 0x80:
revbytes.append(0)
revbytes.reverse()
return six.binary_type(revbytes)
+point_be = struct.Struct('>dd')
+point_le = struct.Struct('
ddd')
+circle_le = struct.Struct('> 63)
def decode_zig_zag(n):
return (n >> 1) ^ -(n & 1)
def vints_unpack(term): # noqa
values = []
n = 0
while n < len(term):
first_byte = byte2int(term[n])
if (first_byte & 128) == 0:
val = first_byte
else:
num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
val = first_byte & (0xff >> num_extra_bytes)
end = n + num_extra_bytes
while n < end:
n += 1
val <<= 8
val |= byte2int(term[n]) & 0xff
n += 1
values.append(decode_zig_zag(val))
return tuple(values)
def vints_pack(values):
revbytes = bytearray()
values = [int(v) for v in values[::-1]]
for value in values:
v = encode_zig_zag(value)
if v < 128:
revbytes.append(v)
else:
num_extra_bytes = 0
num_bits = v.bit_length()
# We need to reserve (num_extra_bytes+1) bits in the first byte
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved
# ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved
reserved_bits = num_extra_bytes + 1
while num_bits > (8-(reserved_bits)):
num_extra_bytes += 1
num_bits -= 8
reserved_bits = min(num_extra_bytes + 1, 8)
revbytes.append(v & 0xff)
v >>= 8
if num_extra_bytes > 8:
raise ValueError('Value %d is too big and cannot be encoded as vint' % value)
# We can now store the last bits in the first byte
n = 8 - num_extra_bytes
v |= (0xff >> n << n)
revbytes.append(abs(v))
revbytes.reverse()
return six.binary_type(revbytes)
diff --git a/cassandra/metadata.py b/cassandra/metadata.py
index 1824b3f..a82fbe4 100644
--- a/cassandra/metadata.py
+++ b/cassandra/metadata.py
@@ -1,2845 +1,3416 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from binascii import unhexlify
from bisect import bisect_left
from collections import defaultdict
from functools import total_ordering
from hashlib import md5
-from itertools import islice, cycle
import json
import logging
import re
import six
from six.moves import zip
import sys
from threading import RLock
import struct
import random
murmur3 = None
try:
from cassandra.murmur3 import murmur3
except ImportError as e:
pass
from cassandra import SignatureDescriptor, ConsistencyLevel, InvalidRequest, Unauthorized
import cassandra.cqltypes as types
from cassandra.encoder import Encoder
from cassandra.marshal import varint_unpack
from cassandra.protocol import QueryMessage
from cassandra.query import dict_factory, bind_params
from cassandra.util import OrderedDict, Version
from cassandra.pool import HostDistance
from cassandra.connection import EndPoint
from cassandra.compat import Mapping
log = logging.getLogger(__name__)
cql_keywords = set((
'add', 'aggregate', 'all', 'allow', 'alter', 'and', 'apply', 'as', 'asc', 'ascii', 'authorize', 'batch', 'begin',
'bigint', 'blob', 'boolean', 'by', 'called', 'clustering', 'columnfamily', 'compact', 'contains', 'count',
- 'counter', 'create', 'custom', 'date', 'decimal', 'delete', 'desc', 'describe', 'distinct', 'double', 'drop',
+ 'counter', 'create', 'custom', 'date', 'decimal', 'default', 'delete', 'desc', 'describe', 'deterministic', 'distinct', 'double', 'drop',
'entries', 'execute', 'exists', 'filtering', 'finalfunc', 'float', 'from', 'frozen', 'full', 'function',
'functions', 'grant', 'if', 'in', 'index', 'inet', 'infinity', 'initcond', 'input', 'insert', 'int', 'into', 'is', 'json',
- 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'modify', 'nan', 'nologin',
- 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission',
+ 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'mbean', 'mbeans', 'modify', 'monotonic',
+ 'nan', 'nologin', 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission',
'permissions', 'primary', 'rename', 'replace', 'returns', 'revoke', 'role', 'roles', 'schema', 'select', 'set',
'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'table', 'text', 'time', 'timestamp', 'timeuuid',
- 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', 'update', 'use', 'user',
- 'users', 'using', 'uuid', 'values', 'varchar', 'varint', 'view', 'where', 'with', 'writetime'
+ 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', 'unset', 'update', 'use', 'user',
+ 'users', 'using', 'uuid', 'values', 'varchar', 'varint', 'view', 'where', 'with', 'writetime',
+
+ # DSE specifics
+ "node", "nodes", "plan", "active", "application", "applications", "java", "executor", "executors", "std_out", "std_err",
+ "renew", "delegation", "no", "redact", "token", "lowercasestring", "cluster", "authentication", "schemes", "scheme",
+ "internal", "ldap", "kerberos", "remote", "object", "method", "call", "calls", "search", "schema", "config", "rows",
+ "columns", "profiles", "commit", "reload", "rebuild", "field", "workpool", "any", "submission", "indices",
+ "restrict", "unrestrict"
))
"""
Set of keywords in CQL.
Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g
"""
cql_keywords_unreserved = set((
'aggregate', 'all', 'as', 'ascii', 'bigint', 'blob', 'boolean', 'called', 'clustering', 'compact', 'contains',
- 'count', 'counter', 'custom', 'date', 'decimal', 'distinct', 'double', 'exists', 'filtering', 'finalfunc', 'float',
+ 'count', 'counter', 'custom', 'date', 'decimal', 'deterministic', 'distinct', 'double', 'exists', 'filtering', 'finalfunc', 'float',
'frozen', 'function', 'functions', 'inet', 'initcond', 'input', 'int', 'json', 'key', 'keys', 'keyspaces',
- 'language', 'list', 'login', 'map', 'nologin', 'nosuperuser', 'options', 'password', 'permission', 'permissions',
+ 'language', 'list', 'login', 'map', 'monotonic', 'nologin', 'nosuperuser', 'options', 'password', 'permission', 'permissions',
'returns', 'role', 'roles', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'text', 'time',
'timestamp', 'timeuuid', 'tinyint', 'trigger', 'ttl', 'tuple', 'type', 'user', 'users', 'uuid', 'values', 'varchar',
'varint', 'writetime'
))
"""
Set of unreserved keywords in CQL.
Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g
"""
cql_keywords_reserved = cql_keywords - cql_keywords_unreserved
"""
Set of reserved keywords in CQL.
"""
_encoder = Encoder()
class Metadata(object):
"""
Holds a representation of the cluster schema and topology.
"""
cluster_name = None
""" The string name of the cluster. """
keyspaces = None
"""
A map from keyspace names to matching :class:`~.KeyspaceMetadata` instances.
"""
partitioner = None
"""
The string name of the partitioner for the cluster.
"""
token_map = None
""" A :class:`~.TokenMap` instance describing the ring topology. """
dbaas = False
""" A boolean indicating if connected to a DBaaS cluster """
def __init__(self):
self.keyspaces = {}
self.dbaas = False
self._hosts = {}
self._hosts_lock = RLock()
def export_schema_as_string(self):
"""
Returns a string that can be executed as a query in order to recreate
the entire schema. The string is formatted to be human readable.
"""
return "\n\n".join(ks.export_as_string() for ks in self.keyspaces.values())
def refresh(self, connection, timeout, target_type=None, change_type=None, **kwargs):
server_version = self.get_host(connection.endpoint).release_version
- parser = get_schema_parser(connection, server_version, timeout)
+ dse_version = self.get_host(connection.endpoint).dse_version
+ parser = get_schema_parser(connection, server_version, dse_version, timeout)
if not target_type:
self._rebuild_all(parser)
return
tt_lower = target_type.lower()
try:
parse_method = getattr(parser, 'get_' + tt_lower)
meta = parse_method(self.keyspaces, **kwargs)
if meta:
update_method = getattr(self, '_update_' + tt_lower)
if tt_lower == 'keyspace' and connection.protocol_version < 3:
# we didn't have 'type' target in legacy protocol versions, so we need to query those too
user_types = parser.get_types_map(self.keyspaces, **kwargs)
self._update_keyspace(meta, user_types)
else:
update_method(meta)
else:
drop_method = getattr(self, '_drop_' + tt_lower)
drop_method(**kwargs)
except AttributeError:
raise ValueError("Unknown schema target_type: '%s'" % target_type)
def _rebuild_all(self, parser):
current_keyspaces = set()
for keyspace_meta in parser.get_all_keyspaces():
current_keyspaces.add(keyspace_meta.name)
old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None)
self.keyspaces[keyspace_meta.name] = keyspace_meta
if old_keyspace_meta:
self._keyspace_updated(keyspace_meta.name)
else:
self._keyspace_added(keyspace_meta.name)
# remove not-just-added keyspaces
removed_keyspaces = [name for name in self.keyspaces.keys()
if name not in current_keyspaces]
self.keyspaces = dict((name, meta) for name, meta in self.keyspaces.items()
if name in current_keyspaces)
for ksname in removed_keyspaces:
self._keyspace_removed(ksname)
def _update_keyspace(self, keyspace_meta, new_user_types=None):
ks_name = keyspace_meta.name
old_keyspace_meta = self.keyspaces.get(ks_name, None)
self.keyspaces[ks_name] = keyspace_meta
if old_keyspace_meta:
keyspace_meta.tables = old_keyspace_meta.tables
keyspace_meta.user_types = new_user_types if new_user_types is not None else old_keyspace_meta.user_types
keyspace_meta.indexes = old_keyspace_meta.indexes
keyspace_meta.functions = old_keyspace_meta.functions
keyspace_meta.aggregates = old_keyspace_meta.aggregates
keyspace_meta.views = old_keyspace_meta.views
if (keyspace_meta.replication_strategy != old_keyspace_meta.replication_strategy):
self._keyspace_updated(ks_name)
else:
self._keyspace_added(ks_name)
def _drop_keyspace(self, keyspace):
if self.keyspaces.pop(keyspace, None):
self._keyspace_removed(keyspace)
def _update_table(self, meta):
try:
keyspace_meta = self.keyspaces[meta.keyspace_name]
# this is unfortunate, but protocol v4 does not differentiate
# between events for tables and views. .get_table will
# return one or the other based on the query results.
# Here we deal with that.
if isinstance(meta, TableMetadata):
keyspace_meta._add_table_metadata(meta)
else:
keyspace_meta._add_view_metadata(meta)
except KeyError:
# can happen if keyspace disappears while processing async event
pass
def _drop_table(self, keyspace, table):
try:
keyspace_meta = self.keyspaces[keyspace]
keyspace_meta._drop_table_metadata(table) # handles either table or view
except KeyError:
# can happen if keyspace disappears while processing async event
pass
def _update_type(self, type_meta):
try:
self.keyspaces[type_meta.keyspace].user_types[type_meta.name] = type_meta
except KeyError:
# can happen if keyspace disappears while processing async event
pass
def _drop_type(self, keyspace, type):
try:
self.keyspaces[keyspace].user_types.pop(type, None)
except KeyError:
# can happen if keyspace disappears while processing async event
pass
def _update_function(self, function_meta):
try:
self.keyspaces[function_meta.keyspace].functions[function_meta.signature] = function_meta
except KeyError:
# can happen if keyspace disappears while processing async event
pass
def _drop_function(self, keyspace, function):
try:
self.keyspaces[keyspace].functions.pop(function.signature, None)
except KeyError:
pass
def _update_aggregate(self, aggregate_meta):
try:
self.keyspaces[aggregate_meta.keyspace].aggregates[aggregate_meta.signature] = aggregate_meta
except KeyError:
pass
def _drop_aggregate(self, keyspace, aggregate):
try:
self.keyspaces[keyspace].aggregates.pop(aggregate.signature, None)
except KeyError:
pass
def _keyspace_added(self, ksname):
if self.token_map:
self.token_map.rebuild_keyspace(ksname, build_if_absent=False)
def _keyspace_updated(self, ksname):
if self.token_map:
self.token_map.rebuild_keyspace(ksname, build_if_absent=False)
def _keyspace_removed(self, ksname):
if self.token_map:
self.token_map.remove_keyspace(ksname)
def rebuild_token_map(self, partitioner, token_map):
"""
Rebuild our view of the topology from fresh rows from the
system topology tables.
For internal use only.
"""
self.partitioner = partitioner
if partitioner.endswith('RandomPartitioner'):
token_class = MD5Token
elif partitioner.endswith('Murmur3Partitioner'):
token_class = Murmur3Token
elif partitioner.endswith('ByteOrderedPartitioner'):
token_class = BytesToken
else:
self.token_map = None
return
token_to_host_owner = {}
ring = []
for host, token_strings in six.iteritems(token_map):
for token_string in token_strings:
token = token_class.from_string(token_string)
ring.append(token)
token_to_host_owner[token] = host
all_tokens = sorted(ring)
self.token_map = TokenMap(
token_class, token_to_host_owner, all_tokens, self)
def get_replicas(self, keyspace, key):
"""
Returns a list of :class:`.Host` instances that are replicas for a given
partition key.
"""
t = self.token_map
if not t:
return []
try:
return t.get_replicas(keyspace, t.token_class.from_key(key))
except NoMurmur3:
return []
def can_support_partitioner(self):
if self.partitioner.endswith('Murmur3Partitioner') and murmur3 is None:
return False
else:
return True
def add_or_return_host(self, host):
"""
Returns a tuple (host, new), where ``host`` is a Host
instance, and ``new`` is a bool indicating whether
the host was newly added.
"""
with self._hosts_lock:
try:
return self._hosts[host.endpoint], False
except KeyError:
self._hosts[host.endpoint] = host
return host, True
def remove_host(self, host):
with self._hosts_lock:
return bool(self._hosts.pop(host.endpoint, False))
- def get_host(self, endpoint_or_address):
+ def get_host(self, endpoint_or_address, port=None):
"""
- Find a host in the metadata for a specific endpoint. If a string inet address is passed,
- iterate all hosts to match the :attr:`~.pool.Host.broadcast_rpc_address` attribute.
+ Find a host in the metadata for a specific endpoint. If a string inet address and port are passed,
+ iterate all hosts to match the :attr:`~.pool.Host.broadcast_rpc_address` and
+ :attr:`~.pool.Host.broadcast_rpc_port`attributes.
"""
if not isinstance(endpoint_or_address, EndPoint):
- return self._get_host_by_address(endpoint_or_address)
+ return self._get_host_by_address(endpoint_or_address, port)
return self._hosts.get(endpoint_or_address)
- def _get_host_by_address(self, address):
+ def _get_host_by_address(self, address, port=None):
for host in six.itervalues(self._hosts):
- if host.broadcast_rpc_address == address:
+ if (host.broadcast_rpc_address == address and
+ (port is None or host.broadcast_rpc_port is None or host.broadcast_rpc_port == port)):
return host
+
return None
def all_hosts(self):
"""
Returns a list of all known :class:`.Host` instances in the cluster.
"""
with self._hosts_lock:
return list(self._hosts.values())
REPLICATION_STRATEGY_CLASS_PREFIX = "org.apache.cassandra.locator."
def trim_if_startswith(s, prefix):
if s.startswith(prefix):
return s[len(prefix):]
return s
_replication_strategies = {}
class ReplicationStrategyTypeType(type):
def __new__(metacls, name, bases, dct):
dct.setdefault('name', name)
cls = type.__new__(metacls, name, bases, dct)
if not name.startswith('_'):
_replication_strategies[name] = cls
return cls
+
@six.add_metaclass(ReplicationStrategyTypeType)
class _ReplicationStrategy(object):
options_map = None
@classmethod
def create(cls, strategy_class, options_map):
if not strategy_class:
return None
strategy_name = trim_if_startswith(strategy_class, REPLICATION_STRATEGY_CLASS_PREFIX)
rs_class = _replication_strategies.get(strategy_name, None)
if rs_class is None:
rs_class = _UnknownStrategyBuilder(strategy_name)
_replication_strategies[strategy_name] = rs_class
try:
rs_instance = rs_class(options_map)
except Exception as exc:
log.warning("Failed creating %s with options %s: %s", strategy_name, options_map, exc)
return None
return rs_instance
def make_token_replica_map(self, token_to_host_owner, ring):
raise NotImplementedError()
def export_for_schema(self):
raise NotImplementedError()
ReplicationStrategy = _ReplicationStrategy
class _UnknownStrategyBuilder(object):
def __init__(self, name):
self.name = name
def __call__(self, options_map):
strategy_instance = _UnknownStrategy(self.name, options_map)
return strategy_instance
class _UnknownStrategy(ReplicationStrategy):
def __init__(self, name, options_map):
self.name = name
self.options_map = options_map.copy() if options_map is not None else dict()
self.options_map['class'] = self.name
def __eq__(self, other):
return (isinstance(other, _UnknownStrategy) and
self.name == other.name and
self.options_map == other.options_map)
def export_for_schema(self):
"""
Returns a string version of these replication options which are
suitable for use in a CREATE KEYSPACE statement.
"""
if self.options_map:
return dict((str(key), str(value)) for key, value in self.options_map.items())
return "{'class': '%s'}" % (self.name, )
def make_token_replica_map(self, token_to_host_owner, ring):
return {}
+class ReplicationFactor(object):
+ """
+ Represent the replication factor of a keyspace.
+ """
+
+ all_replicas = None
+ """
+ The number of total replicas.
+ """
+
+ full_replicas = None
+ """
+ The number of replicas that own a full copy of the data. This is the same
+ than `all_replicas` when transient replication is not enabled.
+ """
+
+ transient_replicas = None
+ """
+ The number of transient replicas.
+
+ Only set if the keyspace has transient replication enabled.
+ """
+
+ def __init__(self, all_replicas, transient_replicas=None):
+ self.all_replicas = all_replicas
+ self.transient_replicas = transient_replicas
+ self.full_replicas = (all_replicas - transient_replicas) if transient_replicas else all_replicas
+
+ @staticmethod
+ def create(rf):
+ """
+ Given the inputted replication factor string, parse and return the ReplicationFactor instance.
+ """
+ transient_replicas = None
+ try:
+ all_replicas = int(rf)
+ except ValueError:
+ try:
+ rf = rf.split('/')
+ all_replicas, transient_replicas = int(rf[0]), int(rf[1])
+ except Exception:
+ raise ValueError("Unable to determine replication factor from: {}".format(rf))
+
+ return ReplicationFactor(all_replicas, transient_replicas)
+
+ def __str__(self):
+ return ("%d/%d" % (self.all_replicas, self.transient_replicas) if self.transient_replicas
+ else "%d" % self.all_replicas)
+
+ def __eq__(self, other):
+ if not isinstance(other, ReplicationFactor):
+ return False
+
+ return self.all_replicas == other.all_replicas and self.full_replicas == other.full_replicas
+
+
class SimpleStrategy(ReplicationStrategy):
- replication_factor = None
+ replication_factor_info = None
"""
- The replication factor for this keyspace.
+ A :class:`cassandra.metadata.ReplicationFactor` instance.
"""
+ @property
+ def replication_factor(self):
+ """
+ The replication factor for this keyspace.
+
+ For backward compatibility, this returns the
+ :attr:`cassandra.metadata.ReplicationFactor.full_replicas` value of
+ :attr:`cassandra.metadata.SimpleStrategy.replication_factor_info`.
+ """
+ return self.replication_factor_info.full_replicas
+
def __init__(self, options_map):
- try:
- self.replication_factor = int(options_map['replication_factor'])
- except Exception:
- raise ValueError("SimpleStrategy requires an integer 'replication_factor' option")
+ self.replication_factor_info = ReplicationFactor.create(options_map['replication_factor'])
def make_token_replica_map(self, token_to_host_owner, ring):
replica_map = {}
for i in range(len(ring)):
j, hosts = 0, list()
while len(hosts) < self.replication_factor and j < len(ring):
token = ring[(i + j) % len(ring)]
host = token_to_host_owner[token]
if host not in hosts:
hosts.append(host)
j += 1
replica_map[ring[i]] = hosts
return replica_map
def export_for_schema(self):
"""
Returns a string version of these replication options which are
suitable for use in a CREATE KEYSPACE statement.
"""
- return "{'class': 'SimpleStrategy', 'replication_factor': '%d'}" \
- % (self.replication_factor,)
+ return "{'class': 'SimpleStrategy', 'replication_factor': '%s'}" \
+ % (str(self.replication_factor_info),)
def __eq__(self, other):
if not isinstance(other, SimpleStrategy):
return False
- return self.replication_factor == other.replication_factor
+ return str(self.replication_factor_info) == str(other.replication_factor_info)
class NetworkTopologyStrategy(ReplicationStrategy):
+ dc_replication_factors_info = None
+ """
+ A map of datacenter names to the :class:`cassandra.metadata.ReplicationFactor` instance for that DC.
+ """
+
dc_replication_factors = None
"""
A map of datacenter names to the replication factor for that DC.
+
+ For backward compatibility, this maps to the :attr:`cassandra.metadata.ReplicationFactor.full_replicas`
+ value of the :attr:`cassandra.metadata.NetworkTopologyStrategy.dc_replication_factors_info` dict.
"""
def __init__(self, dc_replication_factors):
+ self.dc_replication_factors_info = dict(
+ (str(k), ReplicationFactor.create(v)) for k, v in dc_replication_factors.items())
self.dc_replication_factors = dict(
- (str(k), int(v)) for k, v in dc_replication_factors.items())
+ (dc, rf.full_replicas) for dc, rf in self.dc_replication_factors_info.items())
def make_token_replica_map(self, token_to_host_owner, ring):
- dc_rf_map = dict((dc, int(rf))
- for dc, rf in self.dc_replication_factors.items() if rf > 0)
+ dc_rf_map = dict(
+ (dc, full_replicas) for dc, full_replicas in self.dc_replication_factors.items()
+ if full_replicas > 0)
# build a map of DCs to lists of indexes into `ring` for tokens that
# belong to that DC
dc_to_token_offset = defaultdict(list)
dc_racks = defaultdict(set)
hosts_per_dc = defaultdict(set)
for i, token in enumerate(ring):
host = token_to_host_owner[token]
dc_to_token_offset[host.datacenter].append(i)
if host.datacenter and host.rack:
dc_racks[host.datacenter].add(host.rack)
hosts_per_dc[host.datacenter].add(host)
# A map of DCs to an index into the dc_to_token_offset value for that dc.
# This is how we keep track of advancing around the ring for each DC.
dc_to_current_index = defaultdict(int)
replica_map = defaultdict(list)
for i in range(len(ring)):
replicas = replica_map[ring[i]]
# go through each DC and find the replicas in that DC
for dc in dc_to_token_offset.keys():
if dc not in dc_rf_map:
continue
# advance our per-DC index until we're up to at least the
# current token in the ring
token_offsets = dc_to_token_offset[dc]
index = dc_to_current_index[dc]
num_tokens = len(token_offsets)
while index < num_tokens and token_offsets[index] < i:
index += 1
dc_to_current_index[dc] = index
replicas_remaining = dc_rf_map[dc]
replicas_this_dc = 0
skipped_hosts = []
racks_placed = set()
racks_this_dc = dc_racks[dc]
hosts_this_dc = len(hosts_per_dc[dc])
for token_offset_index in six.moves.range(index, index+num_tokens):
if token_offset_index >= len(token_offsets):
token_offset_index = token_offset_index - len(token_offsets)
token_offset = token_offsets[token_offset_index]
host = token_to_host_owner[ring[token_offset]]
if replicas_remaining == 0 or replicas_this_dc == hosts_this_dc:
break
if host in replicas:
continue
if host.rack in racks_placed and len(racks_placed) < len(racks_this_dc):
skipped_hosts.append(host)
continue
replicas.append(host)
replicas_this_dc += 1
replicas_remaining -= 1
racks_placed.add(host.rack)
if len(racks_placed) == len(racks_this_dc):
for host in skipped_hosts:
if replicas_remaining == 0:
break
replicas.append(host)
replicas_remaining -= 1
del skipped_hosts[:]
return replica_map
def export_for_schema(self):
"""
Returns a string version of these replication options which are
suitable for use in a CREATE KEYSPACE statement.
"""
ret = "{'class': 'NetworkTopologyStrategy'"
- for dc, repl_factor in sorted(self.dc_replication_factors.items()):
- ret += ", '%s': '%d'" % (dc, repl_factor)
+ for dc, rf in sorted(self.dc_replication_factors_info.items()):
+ ret += ", '%s': '%s'" % (dc, str(rf))
return ret + "}"
def __eq__(self, other):
if not isinstance(other, NetworkTopologyStrategy):
return False
- return self.dc_replication_factors == other.dc_replication_factors
+ return self.dc_replication_factors_info == other.dc_replication_factors_info
class LocalStrategy(ReplicationStrategy):
def __init__(self, options_map):
pass
def make_token_replica_map(self, token_to_host_owner, ring):
return {}
def export_for_schema(self):
"""
Returns a string version of these replication options which are
suitable for use in a CREATE KEYSPACE statement.
"""
return "{'class': 'LocalStrategy'}"
def __eq__(self, other):
return isinstance(other, LocalStrategy)
class KeyspaceMetadata(object):
"""
A representation of the schema for a single keyspace.
"""
name = None
""" The string name of the keyspace. """
durable_writes = True
"""
A boolean indicating whether durable writes are enabled for this keyspace
or not.
"""
replication_strategy = None
"""
A :class:`.ReplicationStrategy` subclass object.
"""
tables = None
"""
A map from table names to instances of :class:`~.TableMetadata`.
"""
indexes = None
"""
A dict mapping index names to :class:`.IndexMetadata` instances.
"""
user_types = None
"""
A map from user-defined type names to instances of :class:`~cassandra.metadata.UserType`.
.. versionadded:: 2.1.0
"""
functions = None
"""
A map from user-defined function signatures to instances of :class:`~cassandra.metadata.Function`.
.. versionadded:: 2.6.0
"""
aggregates = None
"""
A map from user-defined aggregate signatures to instances of :class:`~cassandra.metadata.Aggregate`.
.. versionadded:: 2.6.0
"""
views = None
"""
A dict mapping view names to :class:`.MaterializedViewMetadata` instances.
"""
virtual = False
"""
A boolean indicating if this is a virtual keyspace or not. Always ``False``
- for clusters running pre-4.0 versions of Cassandra.
+ for clusters running Cassandra pre-4.0 and DSE pre-6.7 versions.
.. versionadded:: 3.15
"""
+ graph_engine = None
+ """
+ A string indicating whether a graph engine is enabled for this keyspace (Core/Classic).
+ """
+
_exc_info = None
""" set if metadata parsing failed """
- def __init__(self, name, durable_writes, strategy_class, strategy_options):
+ def __init__(self, name, durable_writes, strategy_class, strategy_options, graph_engine=None):
self.name = name
self.durable_writes = durable_writes
self.replication_strategy = ReplicationStrategy.create(strategy_class, strategy_options)
self.tables = {}
self.indexes = {}
self.user_types = {}
self.functions = {}
self.aggregates = {}
self.views = {}
+ self.graph_engine = graph_engine
+
+ @property
+ def is_graph_enabled(self):
+ return self.graph_engine is not None
def export_as_string(self):
"""
Returns a CQL query string that can be used to recreate the entire keyspace,
including user-defined types and tables.
"""
- cql = "\n\n".join([self.as_cql_query() + ';'] +
- self.user_type_strings() +
- [f.export_as_string() for f in self.functions.values()] +
- [a.export_as_string() for a in self.aggregates.values()] +
- [t.export_as_string() for t in self.tables.values()])
+ # Make sure tables with vertex are exported before tables with edges
+ tables_with_vertex = [t for t in self.tables.values() if hasattr(t, 'vertex') and t.vertex]
+ other_tables = [t for t in self.tables.values() if t not in tables_with_vertex]
+
+ cql = "\n\n".join(
+ [self.as_cql_query() + ';'] +
+ self.user_type_strings() +
+ [f.export_as_string() for f in self.functions.values()] +
+ [a.export_as_string() for a in self.aggregates.values()] +
+ [t.export_as_string() for t in tables_with_vertex + other_tables])
+
if self._exc_info:
import traceback
ret = "/*\nWarning: Keyspace %s is incomplete because of an error processing metadata.\n" % \
(self.name)
for line in traceback.format_exception(*self._exc_info):
ret += line
ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % cql
return ret
if self.virtual:
return ("/*\nWarning: Keyspace {ks} is a virtual keyspace and cannot be recreated with CQL.\n"
"Structure, for reference:*/\n"
"{cql}\n"
"").format(ks=self.name, cql=cql)
return cql
def as_cql_query(self):
"""
Returns a CQL query string that can be used to recreate just this keyspace,
not including user-defined types and tables.
"""
if self.virtual:
return "// VIRTUAL KEYSPACE {}".format(protect_name(self.name))
ret = "CREATE KEYSPACE %s WITH replication = %s " % (
protect_name(self.name),
self.replication_strategy.export_for_schema())
- return ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false"))
+ ret = ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false"))
+ if self.graph_engine is not None:
+ ret = ret + (" AND graph_engine = '%s'" % self.graph_engine)
+ return ret
def user_type_strings(self):
user_type_strings = []
user_types = self.user_types.copy()
keys = sorted(user_types.keys())
for k in keys:
if k in user_types:
self.resolve_user_types(k, user_types, user_type_strings)
return user_type_strings
def resolve_user_types(self, key, user_types, user_type_strings):
user_type = user_types.pop(key)
for type_name in user_type.field_types:
for sub_type in types.cql_types_from_string(type_name):
if sub_type in user_types:
self.resolve_user_types(sub_type, user_types, user_type_strings)
user_type_strings.append(user_type.export_as_string())
def _add_table_metadata(self, table_metadata):
old_indexes = {}
old_meta = self.tables.get(table_metadata.name, None)
if old_meta:
# views are not queried with table, so they must be transferred to new
table_metadata.views = old_meta.views
# indexes will be updated with what is on the new metadata
old_indexes = old_meta.indexes
# note the intentional order of add before remove
# this makes sure the maps are never absent something that existed before this update
for index_name, index_metadata in six.iteritems(table_metadata.indexes):
self.indexes[index_name] = index_metadata
for index_name in (n for n in old_indexes if n not in table_metadata.indexes):
self.indexes.pop(index_name, None)
self.tables[table_metadata.name] = table_metadata
def _drop_table_metadata(self, table_name):
table_meta = self.tables.pop(table_name, None)
if table_meta:
for index_name in table_meta.indexes:
self.indexes.pop(index_name, None)
for view_name in table_meta.views:
self.views.pop(view_name, None)
return
# we can't tell table drops from views, so drop both
# (name is unique among them, within a keyspace)
view_meta = self.views.pop(table_name, None)
if view_meta:
try:
self.tables[view_meta.base_table_name].views.pop(table_name, None)
except KeyError:
pass
def _add_view_metadata(self, view_metadata):
try:
self.tables[view_metadata.base_table_name].views[view_metadata.name] = view_metadata
self.views[view_metadata.name] = view_metadata
except KeyError:
pass
class UserType(object):
"""
A user defined type, as created by ``CREATE TYPE`` statements.
User-defined types were introduced in Cassandra 2.1.
.. versionadded:: 2.1.0
"""
keyspace = None
"""
The string name of the keyspace in which this type is defined.
"""
name = None
"""
The name of this type.
"""
field_names = None
"""
An ordered list of the names for each field in this user-defined type.
"""
field_types = None
"""
An ordered list of the types for each field in this user-defined type.
"""
def __init__(self, keyspace, name, field_names, field_types):
self.keyspace = keyspace
self.name = name
# non-frozen collections can return None
self.field_names = field_names or []
self.field_types = field_types or []
def as_cql_query(self, formatted=False):
"""
Returns a CQL query that can be used to recreate this type.
If `formatted` is set to :const:`True`, extra whitespace will
be added to make the query more readable.
"""
ret = "CREATE TYPE %s.%s (%s" % (
protect_name(self.keyspace),
protect_name(self.name),
"\n" if formatted else "")
if formatted:
field_join = ",\n"
padding = " "
else:
field_join = ", "
padding = ""
fields = []
for field_name, field_type in zip(self.field_names, self.field_types):
fields.append("%s %s" % (protect_name(field_name), field_type))
ret += field_join.join("%s%s" % (padding, field) for field in fields)
ret += "\n)" if formatted else ")"
return ret
def export_as_string(self):
return self.as_cql_query(formatted=True) + ';'
class Aggregate(object):
"""
A user defined aggregate function, as created by ``CREATE AGGREGATE`` statements.
Aggregate functions were introduced in Cassandra 2.2
.. versionadded:: 2.6.0
"""
keyspace = None
"""
The string name of the keyspace in which this aggregate is defined
"""
name = None
"""
The name of this aggregate
"""
argument_types = None
"""
An ordered list of the types for each argument to the aggregate
"""
final_func = None
"""
Name of a final function
"""
initial_condition = None
"""
Initial condition of the aggregate
"""
return_type = None
"""
Return type of the aggregate
"""
state_func = None
"""
Name of a state function
"""
state_type = None
"""
Type of the aggregate state
"""
+ deterministic = None
+ """
+ Flag indicating if this function is guaranteed to produce the same result
+ for a particular input and state. This is available only with DSE >=6.0.
+ """
+
def __init__(self, keyspace, name, argument_types, state_func,
- state_type, final_func, initial_condition, return_type):
+ state_type, final_func, initial_condition, return_type,
+ deterministic):
self.keyspace = keyspace
self.name = name
self.argument_types = argument_types
self.state_func = state_func
self.state_type = state_type
self.final_func = final_func
self.initial_condition = initial_condition
self.return_type = return_type
+ self.deterministic = deterministic
def as_cql_query(self, formatted=False):
"""
Returns a CQL query that can be used to recreate this aggregate.
If `formatted` is set to :const:`True`, extra whitespace will
be added to make the query more readable.
"""
sep = '\n ' if formatted else ' '
keyspace = protect_name(self.keyspace)
name = protect_name(self.name)
type_list = ', '.join([types.strip_frozen(arg_type) for arg_type in self.argument_types])
state_func = protect_name(self.state_func)
state_type = types.strip_frozen(self.state_type)
ret = "CREATE AGGREGATE %(keyspace)s.%(name)s(%(type_list)s)%(sep)s" \
"SFUNC %(state_func)s%(sep)s" \
"STYPE %(state_type)s" % locals()
ret += ''.join((sep, 'FINALFUNC ', protect_name(self.final_func))) if self.final_func else ''
ret += ''.join((sep, 'INITCOND ', self.initial_condition)) if self.initial_condition is not None else ''
+ ret += '{}DETERMINISTIC'.format(sep) if self.deterministic else ''
return ret
def export_as_string(self):
return self.as_cql_query(formatted=True) + ';'
@property
def signature(self):
return SignatureDescriptor.format_signature(self.name, self.argument_types)
class Function(object):
"""
A user defined function, as created by ``CREATE FUNCTION`` statements.
User-defined functions were introduced in Cassandra 2.2
.. versionadded:: 2.6.0
"""
keyspace = None
"""
The string name of the keyspace in which this function is defined
"""
name = None
"""
The name of this function
"""
argument_types = None
"""
An ordered list of the types for each argument to the function
"""
argument_names = None
"""
An ordered list of the names of each argument to the function
"""
return_type = None
"""
Return type of the function
"""
language = None
"""
Language of the function body
"""
body = None
"""
Function body string
"""
called_on_null_input = None
"""
Flag indicating whether this function should be called for rows with null values
(convenience function to avoid handling nulls explicitly if the result will just be null)
"""
+ deterministic = None
+ """
+ Flag indicating if this function is guaranteed to produce the same result
+ for a particular input. This is available only for DSE >=6.0.
+ """
+
+ monotonic = None
+ """
+ Flag indicating if this function is guaranteed to increase or decrease
+ monotonically on any of its arguments. This is available only for DSE >=6.0.
+ """
+
+ monotonic_on = None
+ """
+ A list containing the argument or arguments over which this function is
+ monotonic. This is available only for DSE >=6.0.
+ """
+
def __init__(self, keyspace, name, argument_types, argument_names,
- return_type, language, body, called_on_null_input):
+ return_type, language, body, called_on_null_input,
+ deterministic, monotonic, monotonic_on):
self.keyspace = keyspace
self.name = name
self.argument_types = argument_types
# argument_types (frozen>) will always be a list
# argument_name is not frozen in C* < 3.0 and may return None
self.argument_names = argument_names or []
self.return_type = return_type
self.language = language
self.body = body
self.called_on_null_input = called_on_null_input
+ self.deterministic = deterministic
+ self.monotonic = monotonic
+ self.monotonic_on = monotonic_on
def as_cql_query(self, formatted=False):
"""
Returns a CQL query that can be used to recreate this function.
If `formatted` is set to :const:`True`, extra whitespace will
be added to make the query more readable.
"""
sep = '\n ' if formatted else ' '
keyspace = protect_name(self.keyspace)
name = protect_name(self.name)
arg_list = ', '.join(["%s %s" % (protect_name(n), types.strip_frozen(t))
for n, t in zip(self.argument_names, self.argument_types)])
typ = self.return_type
lang = self.language
body = self.body
on_null = "CALLED" if self.called_on_null_input else "RETURNS NULL"
+ deterministic_token = ('DETERMINISTIC{}'.format(sep)
+ if self.deterministic else
+ '')
+ monotonic_tokens = '' # default for nonmonotonic function
+ if self.monotonic:
+ # monotonic on all arguments; ignore self.monotonic_on
+ monotonic_tokens = 'MONOTONIC{}'.format(sep)
+ elif self.monotonic_on:
+ # if monotonic == False and monotonic_on is nonempty, we know that
+ # monotonicity was specified with MONOTONIC ON , so there's
+ # exactly 1 value there
+ monotonic_tokens = 'MONOTONIC ON {}{}'.format(self.monotonic_on[0],
+ sep)
return "CREATE FUNCTION %(keyspace)s.%(name)s(%(arg_list)s)%(sep)s" \
"%(on_null)s ON NULL INPUT%(sep)s" \
"RETURNS %(typ)s%(sep)s" \
+ "%(deterministic_token)s" \
+ "%(monotonic_tokens)s" \
"LANGUAGE %(lang)s%(sep)s" \
"AS $$%(body)s$$" % locals()
def export_as_string(self):
return self.as_cql_query(formatted=True) + ';'
@property
def signature(self):
return SignatureDescriptor.format_signature(self.name, self.argument_types)
class TableMetadata(object):
"""
A representation of the schema for a single table.
"""
keyspace_name = None
""" String name of this Table's keyspace """
name = None
""" The string name of the table. """
partition_key = None
"""
A list of :class:`.ColumnMetadata` instances representing the columns in
the partition key for this table. This will always hold at least one
column.
"""
clustering_key = None
"""
A list of :class:`.ColumnMetadata` instances representing the columns
in the clustering key for this table. These are all of the
:attr:`.primary_key` columns that are not in the :attr:`.partition_key`.
Note that a table may have no clustering keys, in which case this will
be an empty list.
"""
@property
def primary_key(self):
"""
A list of :class:`.ColumnMetadata` representing the components of
the primary key for this table.
"""
return self.partition_key + self.clustering_key
columns = None
"""
A dict mapping column names to :class:`.ColumnMetadata` instances.
"""
indexes = None
"""
A dict mapping index names to :class:`.IndexMetadata` instances.
"""
is_compact_storage = False
options = None
"""
A dict mapping table option names to their specific settings for this
table.
"""
compaction_options = {
"min_compaction_threshold": "min_threshold",
"max_compaction_threshold": "max_threshold",
"compaction_strategy_class": "class"}
triggers = None
"""
A dict mapping trigger names to :class:`.TriggerMetadata` instances.
"""
views = None
"""
A dict mapping view names to :class:`.MaterializedViewMetadata` instances.
"""
_exc_info = None
""" set if metadata parsing failed """
virtual = False
"""
A boolean indicating if this is a virtual table or not. Always ``False``
- for clusters running pre-4.0 versions of Cassandra.
+ for clusters running Cassandra pre-4.0 and DSE pre-6.7 versions.
.. versionadded:: 3.15
"""
@property
def is_cql_compatible(self):
"""
A boolean indicating if this table can be represented as CQL in export
"""
if self.virtual:
return False
comparator = getattr(self, 'comparator', None)
if comparator:
# no compact storage with more than one column beyond PK if there
# are clustering columns
incompatible = (self.is_compact_storage and
len(self.columns) > len(self.primary_key) + 1 and
len(self.clustering_key) >= 1)
return not incompatible
return True
extensions = None
"""
Metadata describing configuration for table extensions
"""
def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None, virtual=False):
self.keyspace_name = keyspace_name
self.name = name
self.partition_key = [] if partition_key is None else partition_key
self.clustering_key = [] if clustering_key is None else clustering_key
self.columns = OrderedDict() if columns is None else columns
self.indexes = {}
self.options = {} if options is None else options
self.comparator = None
self.triggers = OrderedDict() if triggers is None else triggers
self.views = {}
self.virtual = virtual
def export_as_string(self):
"""
Returns a string of CQL queries that can be used to recreate this table
along with all indexes on it. The returned string is formatted to
be human readable.
"""
if self._exc_info:
import traceback
ret = "/*\nWarning: Table %s.%s is incomplete because of an error processing metadata.\n" % \
(self.keyspace_name, self.name)
for line in traceback.format_exception(*self._exc_info):
ret += line
ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql()
elif not self.is_cql_compatible:
# If we can't produce this table with CQL, comment inline
ret = "/*\nWarning: Table %s.%s omitted because it has constructs not compatible with CQL (was created via legacy API).\n" % \
(self.keyspace_name, self.name)
ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql()
elif self.virtual:
ret = ('/*\nWarning: Table {ks}.{tab} is a virtual table and cannot be recreated with CQL.\n'
'Structure, for reference:\n'
'{cql}\n*/').format(ks=self.keyspace_name, tab=self.name, cql=self._all_as_cql())
else:
ret = self._all_as_cql()
return ret
def _all_as_cql(self):
ret = self.as_cql_query(formatted=True)
ret += ";"
for index in self.indexes.values():
ret += "\n%s;" % index.as_cql_query()
for trigger_meta in self.triggers.values():
ret += "\n%s;" % (trigger_meta.as_cql_query(),)
for view_meta in self.views.values():
ret += "\n\n%s;" % (view_meta.as_cql_query(formatted=True),)
if self.extensions:
registry = _RegisteredExtensionType._extension_registry
for k in six.viewkeys(registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey
ext = registry[k]
cql = ext.after_table_cql(self, k, self.extensions[k])
if cql:
ret += "\n\n%s" % (cql,)
return ret
def as_cql_query(self, formatted=False):
"""
Returns a CQL query that can be used to recreate this table (index
creations are not included). If `formatted` is set to :const:`True`,
extra whitespace will be added to make the query human readable.
"""
ret = "%s TABLE %s.%s (%s" % (
('VIRTUAL' if self.virtual else 'CREATE'),
protect_name(self.keyspace_name),
protect_name(self.name),
"\n" if formatted else "")
if formatted:
column_join = ",\n"
padding = " "
else:
column_join = ", "
padding = ""
columns = []
for col in self.columns.values():
columns.append("%s %s%s" % (protect_name(col.name), col.cql_type, ' static' if col.is_static else ''))
if len(self.partition_key) == 1 and not self.clustering_key:
columns[0] += " PRIMARY KEY"
ret += column_join.join("%s%s" % (padding, col) for col in columns)
# primary key
if len(self.partition_key) > 1 or self.clustering_key:
ret += "%s%sPRIMARY KEY (" % (column_join, padding)
if len(self.partition_key) > 1:
ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key)
else:
ret += protect_name(self.partition_key[0].name)
if self.clustering_key:
ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key)
ret += ")"
# properties
ret += "%s) WITH " % ("\n" if formatted else "")
ret += self._property_string(formatted, self.clustering_key, self.options, self.is_compact_storage)
return ret
@classmethod
def _property_string(cls, formatted, clustering_key, options_map, is_compact_storage=False):
properties = []
if is_compact_storage:
properties.append("COMPACT STORAGE")
if clustering_key:
cluster_str = "CLUSTERING ORDER BY "
inner = []
for col in clustering_key:
ordering = "DESC" if col.is_reversed else "ASC"
inner.append("%s %s" % (protect_name(col.name), ordering))
cluster_str += "(%s)" % ", ".join(inner)
properties.append(cluster_str)
properties.extend(cls._make_option_strings(options_map))
join_str = "\n AND " if formatted else " AND "
return join_str.join(properties)
@classmethod
def _make_option_strings(cls, options_map):
ret = []
options_copy = dict(options_map.items())
actual_options = json.loads(options_copy.pop('compaction_strategy_options', '{}'))
value = options_copy.pop("compaction_strategy_class", None)
actual_options.setdefault("class", value)
compaction_option_strings = ["'%s': '%s'" % (k, v) for k, v in actual_options.items()]
ret.append('compaction = {%s}' % ', '.join(compaction_option_strings))
for system_table_name in cls.compaction_options.keys():
options_copy.pop(system_table_name, None) # delete if present
options_copy.pop('compaction_strategy_option', None)
if not options_copy.get('compression'):
params = json.loads(options_copy.pop('compression_parameters', '{}'))
param_strings = ["'%s': '%s'" % (k, v) for k, v in params.items()]
ret.append('compression = {%s}' % ', '.join(param_strings))
for name, value in options_copy.items():
if value is not None:
if name == "comment":
value = value or ""
ret.append("%s = %s" % (name, protect_value(value)))
return list(sorted(ret))
+class TableMetadataV3(TableMetadata):
+ """
+ For C* 3.0+. `option_maps` take a superset of map names, so if nothing
+ changes structurally, new option maps can just be appended to the list.
+ """
+ compaction_options = {}
+
+ option_maps = [
+ 'compaction', 'compression', 'caching',
+ 'nodesync' # added DSE 6.0
+ ]
+
+ @property
+ def is_cql_compatible(self):
+ return True
+
+ @classmethod
+ def _make_option_strings(cls, options_map):
+ ret = []
+ options_copy = dict(options_map.items())
+
+ for option in cls.option_maps:
+ value = options_copy.get(option)
+ if isinstance(value, Mapping):
+ del options_copy[option]
+ params = ("'%s': '%s'" % (k, v) for k, v in value.items())
+ ret.append("%s = {%s}" % (option, ', '.join(params)))
+
+ for name, value in options_copy.items():
+ if value is not None:
+ if name == "comment":
+ value = value or ""
+ ret.append("%s = %s" % (name, protect_value(value)))
+
+ return list(sorted(ret))
+
+
+class TableMetadataDSE68(TableMetadataV3):
+
+ vertex = None
+ """A :class:`.VertexMetadata` instance, if graph enabled"""
+
+ edge = None
+ """A :class:`.EdgeMetadata` instance, if graph enabled"""
+
+ def as_cql_query(self, formatted=False):
+ ret = super(TableMetadataDSE68, self).as_cql_query(formatted)
+
+ if self.vertex:
+ ret += " AND VERTEX LABEL %s" % protect_name(self.vertex.label_name)
+
+ if self.edge:
+ ret += " AND EDGE LABEL %s" % protect_name(self.edge.label_name)
+
+ ret += self._export_edge_as_cql(
+ self.edge.from_label,
+ self.edge.from_partition_key_columns,
+ self.edge.from_clustering_columns, "FROM")
+
+ ret += self._export_edge_as_cql(
+ self.edge.to_label,
+ self.edge.to_partition_key_columns,
+ self.edge.to_clustering_columns, "TO")
+
+ return ret
+
+ @staticmethod
+ def _export_edge_as_cql(label_name, partition_keys,
+ clustering_columns, keyword):
+ ret = " %s %s(" % (keyword, protect_name(label_name))
+
+ if len(partition_keys) == 1:
+ ret += protect_name(partition_keys[0])
+ else:
+ ret += "(%s)" % ", ".join([protect_name(k) for k in partition_keys])
+
+ if clustering_columns:
+ ret += ", %s" % ", ".join([protect_name(k) for k in clustering_columns])
+ ret += ")"
+
+ return ret
+
+
class TableExtensionInterface(object):
"""
Defines CQL/DDL for Cassandra table extensions.
"""
# limited API for now. Could be expanded as new extension types materialize -- "extend_option_strings", for example
@classmethod
def after_table_cql(cls, ext_key, ext_blob):
"""
Called to produce CQL/DDL to follow the table definition.
Should contain requisite terminating semicolon(s).
"""
pass
class _RegisteredExtensionType(type):
_extension_registry = {}
def __new__(mcs, name, bases, dct):
cls = super(_RegisteredExtensionType, mcs).__new__(mcs, name, bases, dct)
if name != 'RegisteredTableExtension':
mcs._extension_registry[cls.name] = cls
return cls
@six.add_metaclass(_RegisteredExtensionType)
class RegisteredTableExtension(TableExtensionInterface):
"""
Extending this class registers it by name (associated by key in the `system_schema.tables.extensions` map).
"""
name = None
"""
Name of the extension (key in the map)
"""
def protect_name(name):
return maybe_escape_name(name)
def protect_names(names):
return [protect_name(n) for n in names]
def protect_value(value):
if value is None:
return 'NULL'
if isinstance(value, (int, float, bool)):
return str(value).lower()
return "'%s'" % value.replace("'", "''")
valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$')
def is_valid_name(name):
if name is None:
return False
if name.lower() in cql_keywords_reserved:
return False
return valid_cql3_word_re.match(name) is not None
def maybe_escape_name(name):
if is_valid_name(name):
return name
return escape_name(name)
def escape_name(name):
return '"%s"' % (name.replace('"', '""'),)
class ColumnMetadata(object):
"""
A representation of a single column in a table.
"""
table = None
""" The :class:`.TableMetadata` this column belongs to. """
name = None
""" The string name of this column. """
cql_type = None
"""
The CQL type for the column.
"""
is_static = False
"""
If this column is static (available in Cassandra 2.1+), this will
be :const:`True`, otherwise :const:`False`.
"""
is_reversed = False
"""
If this column is reversed (DESC) as in clustering order
"""
_cass_type = None
def __init__(self, table_metadata, column_name, cql_type, is_static=False, is_reversed=False):
self.table = table_metadata
self.name = column_name
self.cql_type = cql_type
self.is_static = is_static
self.is_reversed = is_reversed
def __str__(self):
return "%s %s" % (self.name, self.cql_type)
class IndexMetadata(object):
"""
A representation of a secondary index on a column.
"""
keyspace_name = None
""" A string name of the keyspace. """
table_name = None
""" A string name of the table this index is on. """
name = None
""" A string name for the index. """
kind = None
""" A string representing the kind of index (COMPOSITE, CUSTOM,...). """
index_options = {}
""" A dict of index options. """
def __init__(self, keyspace_name, table_name, index_name, kind, index_options):
self.keyspace_name = keyspace_name
self.table_name = table_name
self.name = index_name
self.kind = kind
self.index_options = index_options
def as_cql_query(self):
"""
Returns a CQL query that can be used to recreate this index.
"""
options = dict(self.index_options)
index_target = options.pop("target")
if self.kind != "CUSTOM":
return "CREATE INDEX %s ON %s.%s (%s)" % (
protect_name(self.name),
protect_name(self.keyspace_name),
protect_name(self.table_name),
index_target)
else:
class_name = options.pop("class_name")
ret = "CREATE CUSTOM INDEX %s ON %s.%s (%s) USING '%s'" % (
protect_name(self.name),
protect_name(self.keyspace_name),
protect_name(self.table_name),
index_target,
class_name)
if options:
# PYTHON-1008: `ret` will always be a unicode
opts_cql_encoded = _encoder.cql_encode_all_types(options, as_text_type=True)
ret += " WITH OPTIONS = %s" % opts_cql_encoded
return ret
def export_as_string(self):
"""
Returns a CQL query string that can be used to recreate this index.
"""
return self.as_cql_query() + ';'
class TokenMap(object):
"""
Information about the layout of the ring.
"""
token_class = None
"""
A subclass of :class:`.Token`, depending on what partitioner the cluster uses.
"""
token_to_host_owner = None
"""
A map of :class:`.Token` objects to the :class:`.Host` that owns that token.
"""
tokens_to_hosts_by_ks = None
"""
A map of keyspace names to a nested map of :class:`.Token` objects to
sets of :class:`.Host` objects.
"""
ring = None
"""
An ordered list of :class:`.Token` instances in the ring.
"""
_metadata = None
def __init__(self, token_class, token_to_host_owner, all_tokens, metadata):
self.token_class = token_class
self.ring = all_tokens
self.token_to_host_owner = token_to_host_owner
self.tokens_to_hosts_by_ks = {}
self._metadata = metadata
self._rebuild_lock = RLock()
def rebuild_keyspace(self, keyspace, build_if_absent=False):
with self._rebuild_lock:
try:
current = self.tokens_to_hosts_by_ks.get(keyspace, None)
if (build_if_absent and current is None) or (not build_if_absent and current is not None):
ks_meta = self._metadata.keyspaces.get(keyspace)
if ks_meta:
replica_map = self.replica_map_for_keyspace(self._metadata.keyspaces[keyspace])
self.tokens_to_hosts_by_ks[keyspace] = replica_map
except Exception:
# should not happen normally, but we don't want to blow up queries because of unexpected meta state
# bypass until new map is generated
self.tokens_to_hosts_by_ks[keyspace] = {}
log.exception("Failed creating a token map for keyspace '%s' with %s. PLEASE REPORT THIS: https://datastax-oss.atlassian.net/projects/PYTHON", keyspace, self.token_to_host_owner)
def replica_map_for_keyspace(self, ks_metadata):
strategy = ks_metadata.replication_strategy
if strategy:
return strategy.make_token_replica_map(self.token_to_host_owner, self.ring)
else:
return None
def remove_keyspace(self, keyspace):
self.tokens_to_hosts_by_ks.pop(keyspace, None)
def get_replicas(self, keyspace, token):
"""
Get a set of :class:`.Host` instances representing all of the
replica nodes for a given :class:`.Token`.
"""
tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None)
if tokens_to_hosts is None:
self.rebuild_keyspace(keyspace, build_if_absent=True)
tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None)
if tokens_to_hosts:
# The values in self.ring correspond to the end of the
# token range up to and including the value listed.
point = bisect_left(self.ring, token)
if point == len(self.ring):
return tokens_to_hosts[self.ring[0]]
else:
return tokens_to_hosts[self.ring[point]]
return []
@total_ordering
class Token(object):
"""
Abstract class representing a token.
"""
def __init__(self, token):
self.value = token
@classmethod
def hash_fn(cls, key):
return key
@classmethod
def from_key(cls, key):
return cls(cls.hash_fn(key))
@classmethod
def from_string(cls, token_string):
raise NotImplementedError()
def __eq__(self, other):
return self.value == other.value
def __lt__(self, other):
return self.value < other.value
def __hash__(self):
return hash(self.value)
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self.value)
__str__ = __repr__
MIN_LONG = -(2 ** 63)
MAX_LONG = (2 ** 63) - 1
class NoMurmur3(Exception):
pass
class HashToken(Token):
@classmethod
def from_string(cls, token_string):
""" `token_string` should be the string representation from the server. """
# The hash partitioners just store the deciman value
return cls(int(token_string))
class Murmur3Token(HashToken):
"""
A token for ``Murmur3Partitioner``.
"""
@classmethod
def hash_fn(cls, key):
if murmur3 is not None:
h = int(murmur3(key))
return h if h != MIN_LONG else MAX_LONG
else:
raise NoMurmur3()
def __init__(self, token):
""" `token` is an int or string representing the token. """
self.value = int(token)
class MD5Token(HashToken):
"""
A token for ``RandomPartitioner``.
"""
@classmethod
def hash_fn(cls, key):
if isinstance(key, six.text_type):
key = key.encode('UTF-8')
return abs(varint_unpack(md5(key).digest()))
class BytesToken(Token):
"""
A token for ``ByteOrderedPartitioner``.
"""
@classmethod
def from_string(cls, token_string):
""" `token_string` should be the string representation from the server. """
# unhexlify works fine with unicode input in everythin but pypy3, where it Raises "TypeError: 'str' does not support the buffer interface"
if isinstance(token_string, six.text_type):
token_string = token_string.encode('ascii')
# The BOP stores a hex string
return cls(unhexlify(token_string))
class TriggerMetadata(object):
"""
A representation of a trigger for a table.
"""
table = None
""" The :class:`.TableMetadata` this trigger belongs to. """
name = None
""" The string name of this trigger. """
options = None
"""
A dict mapping trigger option names to their specific settings for this
table.
"""
def __init__(self, table_metadata, trigger_name, options=None):
self.table = table_metadata
self.name = trigger_name
self.options = options
def as_cql_query(self):
ret = "CREATE TRIGGER %s ON %s.%s USING %s" % (
protect_name(self.name),
protect_name(self.table.keyspace_name),
protect_name(self.table.name),
protect_value(self.options['class'])
)
return ret
def export_as_string(self):
return self.as_cql_query() + ';'
class _SchemaParser(object):
def __init__(self, connection, timeout):
self.connection = connection
self.timeout = timeout
def _handle_results(self, success, result, expected_failures=tuple()):
"""
Given a bool and a ResultSet (the form returned per result from
Connection.wait_for_responses), return a dictionary containing the
results. Used to process results from asynchronous queries to system
tables.
``expected_failures`` will usually be used to allow callers to ignore
``InvalidRequest`` errors caused by a missing system keyspace. For
example, some DSE versions report a 4.X server version, but do not have
virtual tables. Thus, running against 4.X servers, SchemaParserV4 uses
expected_failures to make a best-effort attempt to read those
keyspaces, but treat them as empty if they're not found.
:param success: A boolean representing whether or not the query
succeeded
:param result: The resultset in question.
:expected_failures: An Exception class or an iterable thereof. If the
query failed, but raised an instance of an expected failure class, this
will ignore the failure and return an empty list.
"""
if not success and isinstance(result, expected_failures):
return []
elif success:
- return dict_factory(*result.results) if result else []
+ return dict_factory(result.column_names, result.parsed_rows) if result else []
else:
raise result
def _query_build_row(self, query_string, build_func):
result = self._query_build_rows(query_string, build_func)
return result[0] if result else None
def _query_build_rows(self, query_string, build_func):
query = QueryMessage(query=query_string, consistency_level=ConsistencyLevel.ONE)
responses = self.connection.wait_for_responses((query), timeout=self.timeout, fail_on_error=False)
(success, response) = responses[0]
if success:
- result = dict_factory(*response.results)
+ result = dict_factory(response.column_names, response.parsed_rows)
return [build_func(row) for row in result]
elif isinstance(response, InvalidRequest):
log.debug("user types table not found")
return []
else:
raise response
class SchemaParserV22(_SchemaParser):
+ """
+ For C* 2.2+
+ """
_SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces"
_SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies"
_SELECT_COLUMNS = "SELECT * FROM system.schema_columns"
_SELECT_TRIGGERS = "SELECT * FROM system.schema_triggers"
_SELECT_TYPES = "SELECT * FROM system.schema_usertypes"
_SELECT_FUNCTIONS = "SELECT * FROM system.schema_functions"
_SELECT_AGGREGATES = "SELECT * FROM system.schema_aggregates"
_table_name_col = 'columnfamily_name'
_function_agg_arument_type_col = 'signature'
recognized_table_options = (
"comment",
"read_repair_chance",
"dclocal_read_repair_chance", # kept to be safe, but see _build_table_options()
"local_read_repair_chance",
"replicate_on_write",
"gc_grace_seconds",
"bloom_filter_fp_chance",
"caching",
"compaction_strategy_class",
"compaction_strategy_options",
"min_compaction_threshold",
"max_compaction_threshold",
"compression_parameters",
"min_index_interval",
"max_index_interval",
"index_interval",
"speculative_retry",
"rows_per_partition_to_cache",
"memtable_flush_period_in_ms",
"populate_io_cache_on_flush",
"compression",
"default_time_to_live")
def __init__(self, connection, timeout):
super(SchemaParserV22, self).__init__(connection, timeout)
self.keyspaces_result = []
self.tables_result = []
self.columns_result = []
self.triggers_result = []
self.types_result = []
self.functions_result = []
self.aggregates_result = []
self.keyspace_table_rows = defaultdict(list)
self.keyspace_table_col_rows = defaultdict(lambda: defaultdict(list))
self.keyspace_type_rows = defaultdict(list)
self.keyspace_func_rows = defaultdict(list)
self.keyspace_agg_rows = defaultdict(list)
self.keyspace_table_trigger_rows = defaultdict(lambda: defaultdict(list))
def get_all_keyspaces(self):
self._query_all()
for row in self.keyspaces_result:
keyspace_meta = self._build_keyspace_metadata(row)
try:
for table_row in self.keyspace_table_rows.get(keyspace_meta.name, []):
table_meta = self._build_table_metadata(table_row)
keyspace_meta._add_table_metadata(table_meta)
for usertype_row in self.keyspace_type_rows.get(keyspace_meta.name, []):
usertype = self._build_user_type(usertype_row)
keyspace_meta.user_types[usertype.name] = usertype
for fn_row in self.keyspace_func_rows.get(keyspace_meta.name, []):
fn = self._build_function(fn_row)
keyspace_meta.functions[fn.signature] = fn
for agg_row in self.keyspace_agg_rows.get(keyspace_meta.name, []):
agg = self._build_aggregate(agg_row)
keyspace_meta.aggregates[agg.signature] = agg
except Exception:
log.exception("Error while parsing metadata for keyspace %s. Metadata model will be incomplete.", keyspace_meta.name)
keyspace_meta._exc_info = sys.exc_info()
yield keyspace_meta
def get_table(self, keyspaces, keyspace, table):
cl = ConsistencyLevel.ONE
where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col,), (keyspace, table), _encoder)
cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl)
col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl)
triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl)
(cf_success, cf_result), (col_success, col_result), (triggers_success, triggers_result) \
= self.connection.wait_for_responses(cf_query, col_query, triggers_query, timeout=self.timeout, fail_on_error=False)
table_result = self._handle_results(cf_success, cf_result)
col_result = self._handle_results(col_success, col_result)
# the triggers table doesn't exist in C* 1.2
triggers_result = self._handle_results(triggers_success, triggers_result,
expected_failures=InvalidRequest)
if table_result:
return self._build_table_metadata(table_result[0], col_result, triggers_result)
def get_type(self, keyspaces, keyspace, type):
where_clause = bind_params(" WHERE keyspace_name = %s AND type_name = %s", (keyspace, type), _encoder)
return self._query_build_row(self._SELECT_TYPES + where_clause, self._build_user_type)
def get_types_map(self, keyspaces, keyspace):
where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder)
types = self._query_build_rows(self._SELECT_TYPES + where_clause, self._build_user_type)
return dict((t.name, t) for t in types)
def get_function(self, keyspaces, keyspace, function):
where_clause = bind_params(" WHERE keyspace_name = %%s AND function_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,),
(keyspace, function.name, function.argument_types), _encoder)
return self._query_build_row(self._SELECT_FUNCTIONS + where_clause, self._build_function)
def get_aggregate(self, keyspaces, keyspace, aggregate):
where_clause = bind_params(" WHERE keyspace_name = %%s AND aggregate_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,),
(keyspace, aggregate.name, aggregate.argument_types), _encoder)
return self._query_build_row(self._SELECT_AGGREGATES + where_clause, self._build_aggregate)
def get_keyspace(self, keyspaces, keyspace):
where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder)
return self._query_build_row(self._SELECT_KEYSPACES + where_clause, self._build_keyspace_metadata)
@classmethod
def _build_keyspace_metadata(cls, row):
try:
ksm = cls._build_keyspace_metadata_internal(row)
except Exception:
name = row["keyspace_name"]
ksm = KeyspaceMetadata(name, False, 'UNKNOWN', {})
ksm._exc_info = sys.exc_info() # capture exc_info before log because nose (test) logging clears it in certain circumstances
log.exception("Error while parsing metadata for keyspace %s row(%s)", name, row)
return ksm
@staticmethod
def _build_keyspace_metadata_internal(row):
name = row["keyspace_name"]
durable_writes = row["durable_writes"]
strategy_class = row["strategy_class"]
strategy_options = json.loads(row["strategy_options"])
return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options)
@classmethod
def _build_user_type(cls, usertype_row):
field_types = list(map(cls._schema_type_to_cql, usertype_row['field_types']))
return UserType(usertype_row['keyspace_name'], usertype_row['type_name'],
usertype_row['field_names'], field_types)
@classmethod
def _build_function(cls, function_row):
return_type = cls._schema_type_to_cql(function_row['return_type'])
+ deterministic = function_row.get('deterministic', False)
+ monotonic = function_row.get('monotonic', False)
+ monotonic_on = function_row.get('monotonic_on', ())
return Function(function_row['keyspace_name'], function_row['function_name'],
function_row[cls._function_agg_arument_type_col], function_row['argument_names'],
return_type, function_row['language'], function_row['body'],
- function_row['called_on_null_input'])
+ function_row['called_on_null_input'],
+ deterministic, monotonic, monotonic_on)
@classmethod
def _build_aggregate(cls, aggregate_row):
cass_state_type = types.lookup_casstype(aggregate_row['state_type'])
initial_condition = aggregate_row['initcond']
if initial_condition is not None:
initial_condition = _encoder.cql_encode_all_types(cass_state_type.deserialize(initial_condition, 3))
state_type = _cql_from_cass_type(cass_state_type)
return_type = cls._schema_type_to_cql(aggregate_row['return_type'])
return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'],
aggregate_row['signature'], aggregate_row['state_func'], state_type,
- aggregate_row['final_func'], initial_condition, return_type)
+ aggregate_row['final_func'], initial_condition, return_type,
+ aggregate_row.get('deterministic', False))
def _build_table_metadata(self, row, col_rows=None, trigger_rows=None):
keyspace_name = row["keyspace_name"]
cfname = row[self._table_name_col]
col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][cfname]
trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][cfname]
if not col_rows: # CASSANDRA-8487
log.warning("Building table metadata with no column meta for %s.%s",
keyspace_name, cfname)
table_meta = TableMetadata(keyspace_name, cfname)
try:
comparator = types.lookup_casstype(row["comparator"])
table_meta.comparator = comparator
is_dct_comparator = issubclass(comparator, types.DynamicCompositeType)
is_composite_comparator = issubclass(comparator, types.CompositeType)
column_name_types = comparator.subtypes if is_composite_comparator else (comparator,)
num_column_name_components = len(column_name_types)
last_col = column_name_types[-1]
column_aliases = row.get("column_aliases", None)
clustering_rows = [r for r in col_rows
if r.get('type', None) == "clustering_key"]
if len(clustering_rows) > 1:
clustering_rows = sorted(clustering_rows, key=lambda row: row.get('component_index'))
if column_aliases is not None:
column_aliases = json.loads(column_aliases)
if not column_aliases: # json load failed or column_aliases empty PYTHON-562
column_aliases = [r.get('column_name') for r in clustering_rows]
if is_composite_comparator:
if issubclass(last_col, types.ColumnToCollectionType):
# collections
is_compact = False
has_value = False
clustering_size = num_column_name_components - 2
elif (len(column_aliases) == num_column_name_components - 1 and
issubclass(last_col, types.UTF8Type)):
# aliases?
is_compact = False
has_value = False
clustering_size = num_column_name_components - 1
else:
# compact table
is_compact = True
has_value = column_aliases or not col_rows
clustering_size = num_column_name_components
# Some thrift tables define names in composite types (see PYTHON-192)
if not column_aliases and hasattr(comparator, 'fieldnames'):
column_aliases = filter(None, comparator.fieldnames)
else:
is_compact = True
if column_aliases or not col_rows or is_dct_comparator:
has_value = True
clustering_size = num_column_name_components
else:
has_value = False
clustering_size = 0
# partition key
partition_rows = [r for r in col_rows
if r.get('type', None) == "partition_key"]
if len(partition_rows) > 1:
partition_rows = sorted(partition_rows, key=lambda row: row.get('component_index'))
key_aliases = row.get("key_aliases")
if key_aliases is not None:
key_aliases = json.loads(key_aliases) if key_aliases else []
else:
# In 2.0+, we can use the 'type' column. In 3.0+, we have to use it.
key_aliases = [r.get('column_name') for r in partition_rows]
key_validator = row.get("key_validator")
if key_validator is not None:
key_type = types.lookup_casstype(key_validator)
key_types = key_type.subtypes if issubclass(key_type, types.CompositeType) else [key_type]
else:
key_types = [types.lookup_casstype(r.get('validator')) for r in partition_rows]
for i, col_type in enumerate(key_types):
if len(key_aliases) > i:
column_name = key_aliases[i]
elif i == 0:
column_name = "key"
else:
column_name = "key%d" % i
col = ColumnMetadata(table_meta, column_name, col_type.cql_parameterized_type())
table_meta.columns[column_name] = col
table_meta.partition_key.append(col)
# clustering key
for i in range(clustering_size):
if len(column_aliases) > i:
column_name = column_aliases[i]
else:
column_name = "column%d" % (i + 1)
data_type = column_name_types[i]
cql_type = _cql_from_cass_type(data_type)
is_reversed = types.is_reversed_casstype(data_type)
col = ColumnMetadata(table_meta, column_name, cql_type, is_reversed=is_reversed)
table_meta.columns[column_name] = col
table_meta.clustering_key.append(col)
# value alias (if present)
if has_value:
value_alias_rows = [r for r in col_rows
if r.get('type', None) == "compact_value"]
if not key_aliases: # TODO are we checking the right thing here?
value_alias = "value"
else:
value_alias = row.get("value_alias", None)
if value_alias is None and value_alias_rows: # CASSANDRA-8487
# In 2.0+, we can use the 'type' column. In 3.0+, we have to use it.
value_alias = value_alias_rows[0].get('column_name')
default_validator = row.get("default_validator")
if default_validator:
validator = types.lookup_casstype(default_validator)
else:
if value_alias_rows: # CASSANDRA-8487
validator = types.lookup_casstype(value_alias_rows[0].get('validator'))
cql_type = _cql_from_cass_type(validator)
col = ColumnMetadata(table_meta, value_alias, cql_type)
if value_alias: # CASSANDRA-8487
table_meta.columns[value_alias] = col
# other normal columns
for col_row in col_rows:
column_meta = self._build_column_metadata(table_meta, col_row)
if column_meta.name is not None:
table_meta.columns[column_meta.name] = column_meta
index_meta = self._build_index_metadata(column_meta, col_row)
if index_meta:
table_meta.indexes[index_meta.name] = index_meta
for trigger_row in trigger_rows:
trigger_meta = self._build_trigger_metadata(table_meta, trigger_row)
table_meta.triggers[trigger_meta.name] = trigger_meta
table_meta.options = self._build_table_options(row)
table_meta.is_compact_storage = is_compact
except Exception:
table_meta._exc_info = sys.exc_info()
log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, cfname, row, col_rows)
return table_meta
def _build_table_options(self, row):
""" Setup the mostly-non-schema table options, like caching settings """
options = dict((o, row.get(o)) for o in self.recognized_table_options if o in row)
# the option name when creating tables is "dclocal_read_repair_chance",
# but the column name in system.schema_columnfamilies is
# "local_read_repair_chance". We'll store this as dclocal_read_repair_chance,
# since that's probably what users are expecting (and we need it for the
# CREATE TABLE statement anyway).
if "local_read_repair_chance" in options:
val = options.pop("local_read_repair_chance")
options["dclocal_read_repair_chance"] = val
return options
@classmethod
def _build_column_metadata(cls, table_metadata, row):
name = row["column_name"]
type_string = row["validator"]
data_type = types.lookup_casstype(type_string)
cql_type = _cql_from_cass_type(data_type)
is_static = row.get("type", None) == "static"
is_reversed = types.is_reversed_casstype(data_type)
column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed)
column_meta._cass_type = data_type
return column_meta
@staticmethod
def _build_index_metadata(column_metadata, row):
index_name = row.get("index_name")
kind = row.get("index_type")
if index_name or kind:
options = row.get("index_options")
options = json.loads(options) if options else {}
options = options or {} # if the json parsed to None, init empty dict
# generate a CQL index identity string
target = protect_name(column_metadata.name)
if kind != "CUSTOM":
if "index_keys" in options:
target = 'keys(%s)' % (target,)
elif "index_values" in options:
# don't use any "function" for collection values
pass
else:
# it might be a "full" index on a frozen collection, but
# we need to check the data type to verify that, because
# there is no special index option for full-collection
# indexes.
data_type = column_metadata._cass_type
collection_types = ('map', 'set', 'list')
if data_type.typename == "frozen" and data_type.subtypes[0].typename in collection_types:
# no index option for full-collection index
target = 'full(%s)' % (target,)
options['target'] = target
return IndexMetadata(column_metadata.table.keyspace_name, column_metadata.table.name, index_name, kind, options)
@staticmethod
def _build_trigger_metadata(table_metadata, row):
name = row["trigger_name"]
options = row["trigger_options"]
trigger_meta = TriggerMetadata(table_metadata, name, options)
return trigger_meta
def _query_all(self):
cl = ConsistencyLevel.ONE
queries = [
QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl),
QueryMessage(query=self._SELECT_COLUMN_FAMILIES, consistency_level=cl),
QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl),
QueryMessage(query=self._SELECT_TYPES, consistency_level=cl),
QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl),
QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl),
QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl)
]
((ks_success, ks_result),
(table_success, table_result),
(col_success, col_result),
(types_success, types_result),
(functions_success, functions_result),
(aggregates_success, aggregates_result),
(triggers_success, triggers_result)) = (
self.connection.wait_for_responses(*queries, timeout=self.timeout,
fail_on_error=False)
)
self.keyspaces_result = self._handle_results(ks_success, ks_result)
self.tables_result = self._handle_results(table_success, table_result)
self.columns_result = self._handle_results(col_success, col_result)
# if we're connected to Cassandra < 2.0, the triggers table will not exist
if triggers_success:
- self.triggers_result = dict_factory(*triggers_result.results)
+ self.triggers_result = dict_factory(triggers_result.column_names, triggers_result.parsed_rows)
else:
if isinstance(triggers_result, InvalidRequest):
log.debug("triggers table not found")
elif isinstance(triggers_result, Unauthorized):
log.warning("this version of Cassandra does not allow access to schema_triggers metadata with authorization enabled (CASSANDRA-7967); "
"The driver will operate normally, but will not reflect triggers in the local metadata model, or schema strings.")
else:
raise triggers_result
# if we're connected to Cassandra < 2.1, the usertypes table will not exist
if types_success:
- self.types_result = dict_factory(*types_result.results)
+ self.types_result = dict_factory(types_result.column_names, types_result.parsed_rows)
else:
if isinstance(types_result, InvalidRequest):
log.debug("user types table not found")
self.types_result = {}
else:
raise types_result
# functions were introduced in Cassandra 2.2
if functions_success:
- self.functions_result = dict_factory(*functions_result.results)
+ self.functions_result = dict_factory(functions_result.column_names, functions_result.parsed_rows)
else:
if isinstance(functions_result, InvalidRequest):
log.debug("user functions table not found")
else:
raise functions_result
# aggregates were introduced in Cassandra 2.2
if aggregates_success:
- self.aggregates_result = dict_factory(*aggregates_result.results)
+ self.aggregates_result = dict_factory(aggregates_result.column_names, aggregates_result.parsed_rows)
else:
if isinstance(aggregates_result, InvalidRequest):
log.debug("user aggregates table not found")
else:
raise aggregates_result
self._aggregate_results()
def _aggregate_results(self):
m = self.keyspace_table_rows
for row in self.tables_result:
m[row["keyspace_name"]].append(row)
m = self.keyspace_table_col_rows
for row in self.columns_result:
ksname = row["keyspace_name"]
cfname = row[self._table_name_col]
m[ksname][cfname].append(row)
m = self.keyspace_type_rows
for row in self.types_result:
m[row["keyspace_name"]].append(row)
m = self.keyspace_func_rows
for row in self.functions_result:
m[row["keyspace_name"]].append(row)
m = self.keyspace_agg_rows
for row in self.aggregates_result:
m[row["keyspace_name"]].append(row)
m = self.keyspace_table_trigger_rows
for row in self.triggers_result:
ksname = row["keyspace_name"]
cfname = row[self._table_name_col]
m[ksname][cfname].append(row)
@staticmethod
def _schema_type_to_cql(type_string):
cass_type = types.lookup_casstype(type_string)
return _cql_from_cass_type(cass_type)
class SchemaParserV3(SchemaParserV22):
+ """
+ For C* 3.0+
+ """
_SELECT_KEYSPACES = "SELECT * FROM system_schema.keyspaces"
_SELECT_TABLES = "SELECT * FROM system_schema.tables"
_SELECT_COLUMNS = "SELECT * FROM system_schema.columns"
_SELECT_INDEXES = "SELECT * FROM system_schema.indexes"
_SELECT_TRIGGERS = "SELECT * FROM system_schema.triggers"
_SELECT_TYPES = "SELECT * FROM system_schema.types"
_SELECT_FUNCTIONS = "SELECT * FROM system_schema.functions"
_SELECT_AGGREGATES = "SELECT * FROM system_schema.aggregates"
_SELECT_VIEWS = "SELECT * FROM system_schema.views"
_table_name_col = 'table_name'
_function_agg_arument_type_col = 'argument_types'
+ _table_metadata_class = TableMetadataV3
+
recognized_table_options = (
'bloom_filter_fp_chance',
'caching',
'cdc',
'comment',
'compaction',
'compression',
'crc_check_chance',
'dclocal_read_repair_chance',
'default_time_to_live',
'gc_grace_seconds',
'max_index_interval',
'memtable_flush_period_in_ms',
'min_index_interval',
'read_repair_chance',
'speculative_retry')
def __init__(self, connection, timeout):
super(SchemaParserV3, self).__init__(connection, timeout)
self.indexes_result = []
self.keyspace_table_index_rows = defaultdict(lambda: defaultdict(list))
self.keyspace_view_rows = defaultdict(list)
def get_all_keyspaces(self):
for keyspace_meta in super(SchemaParserV3, self).get_all_keyspaces():
for row in self.keyspace_view_rows[keyspace_meta.name]:
view_meta = self._build_view_metadata(row)
keyspace_meta._add_view_metadata(view_meta)
yield keyspace_meta
def get_table(self, keyspaces, keyspace, table):
cl = ConsistencyLevel.ONE
where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder)
cf_query = QueryMessage(query=self._SELECT_TABLES + where_clause, consistency_level=cl)
col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl)
indexes_query = QueryMessage(query=self._SELECT_INDEXES + where_clause, consistency_level=cl)
triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl)
# in protocol v4 we don't know if this event is a view or a table, so we look for both
where_clause = bind_params(" WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder)
view_query = QueryMessage(query=self._SELECT_VIEWS + where_clause,
consistency_level=cl)
((cf_success, cf_result), (col_success, col_result),
(indexes_sucess, indexes_result), (triggers_success, triggers_result),
(view_success, view_result)) = (
self.connection.wait_for_responses(
cf_query, col_query, indexes_query, triggers_query,
view_query, timeout=self.timeout, fail_on_error=False)
)
table_result = self._handle_results(cf_success, cf_result)
col_result = self._handle_results(col_success, col_result)
if table_result:
indexes_result = self._handle_results(indexes_sucess, indexes_result)
triggers_result = self._handle_results(triggers_success, triggers_result)
return self._build_table_metadata(table_result[0], col_result, triggers_result, indexes_result)
view_result = self._handle_results(view_success, view_result)
if view_result:
return self._build_view_metadata(view_result[0], col_result)
@staticmethod
def _build_keyspace_metadata_internal(row):
name = row["keyspace_name"]
durable_writes = row["durable_writes"]
strategy_options = dict(row["replication"])
strategy_class = strategy_options.pop("class")
return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options)
@staticmethod
def _build_aggregate(aggregate_row):
return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'],
aggregate_row['argument_types'], aggregate_row['state_func'], aggregate_row['state_type'],
- aggregate_row['final_func'], aggregate_row['initcond'], aggregate_row['return_type'])
+ aggregate_row['final_func'], aggregate_row['initcond'], aggregate_row['return_type'],
+ aggregate_row.get('deterministic', False))
def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_rows=None, virtual=False):
keyspace_name = row["keyspace_name"]
table_name = row[self._table_name_col]
col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][table_name]
trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][table_name]
index_rows = index_rows or self.keyspace_table_index_rows[keyspace_name][table_name]
- table_meta = TableMetadataV3(keyspace_name, table_name, virtual=virtual)
+ table_meta = self._table_metadata_class(keyspace_name, table_name, virtual=virtual)
try:
table_meta.options = self._build_table_options(row)
flags = row.get('flags', set())
if flags:
is_dense = 'dense' in flags
compact_static = not is_dense and 'super' not in flags and 'compound' not in flags
table_meta.is_compact_storage = is_dense or 'super' in flags or 'compound' not in flags
elif virtual:
compact_static = False
table_meta.is_compact_storage = False
is_dense = False
else:
compact_static = True
table_meta.is_compact_storage = True
is_dense = False
self._build_table_columns(table_meta, col_rows, compact_static, is_dense, virtual)
for trigger_row in trigger_rows:
trigger_meta = self._build_trigger_metadata(table_meta, trigger_row)
table_meta.triggers[trigger_meta.name] = trigger_meta
for index_row in index_rows:
index_meta = self._build_index_metadata(table_meta, index_row)
if index_meta:
table_meta.indexes[index_meta.name] = index_meta
table_meta.extensions = row.get('extensions', {})
except Exception:
table_meta._exc_info = sys.exc_info()
log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, table_name, row, col_rows)
return table_meta
def _build_table_options(self, row):
""" Setup the mostly-non-schema table options, like caching settings """
return dict((o, row.get(o)) for o in self.recognized_table_options if o in row)
def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False, virtual=False):
# partition key
partition_rows = [r for r in col_rows
if r.get('kind', None) == "partition_key"]
if len(partition_rows) > 1:
partition_rows = sorted(partition_rows, key=lambda row: row.get('position'))
for r in partition_rows:
# we have to add meta here (and not in the later loop) because TableMetadata.columns is an
# OrderedDict, and it assumes keys are inserted first, in order, when exporting CQL
column_meta = self._build_column_metadata(meta, r)
meta.columns[column_meta.name] = column_meta
meta.partition_key.append(meta.columns[r.get('column_name')])
# clustering key
if not compact_static:
clustering_rows = [r for r in col_rows
if r.get('kind', None) == "clustering"]
if len(clustering_rows) > 1:
clustering_rows = sorted(clustering_rows, key=lambda row: row.get('position'))
for r in clustering_rows:
column_meta = self._build_column_metadata(meta, r)
meta.columns[column_meta.name] = column_meta
meta.clustering_key.append(meta.columns[r.get('column_name')])
for col_row in (r for r in col_rows
if r.get('kind', None) not in ('partition_key', 'clustering_key')):
column_meta = self._build_column_metadata(meta, col_row)
if is_dense and column_meta.cql_type == types.cql_empty_type:
continue
if compact_static and not column_meta.is_static:
# for compact static tables, we omit the clustering key and value, and only add the logical columns.
# They are marked not static so that it generates appropriate CQL
continue
if compact_static:
column_meta.is_static = False
meta.columns[column_meta.name] = column_meta
def _build_view_metadata(self, row, col_rows=None):
keyspace_name = row["keyspace_name"]
view_name = row["view_name"]
base_table_name = row["base_table_name"]
include_all_columns = row["include_all_columns"]
where_clause = row["where_clause"]
col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][view_name]
view_meta = MaterializedViewMetadata(keyspace_name, view_name, base_table_name,
include_all_columns, where_clause, self._build_table_options(row))
self._build_table_columns(view_meta, col_rows)
view_meta.extensions = row.get('extensions', {})
return view_meta
@staticmethod
def _build_column_metadata(table_metadata, row):
name = row["column_name"]
cql_type = row["type"]
is_static = row.get("kind", None) == "static"
is_reversed = row["clustering_order"].upper() == "DESC"
column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed)
return column_meta
@staticmethod
def _build_index_metadata(table_metadata, row):
index_name = row.get("index_name")
kind = row.get("kind")
if index_name or kind:
index_options = row.get("options")
return IndexMetadata(table_metadata.keyspace_name, table_metadata.name, index_name, kind, index_options)
else:
return None
@staticmethod
def _build_trigger_metadata(table_metadata, row):
name = row["trigger_name"]
options = row["options"]
trigger_meta = TriggerMetadata(table_metadata, name, options)
return trigger_meta
def _query_all(self):
cl = ConsistencyLevel.ONE
queries = [
QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl),
QueryMessage(query=self._SELECT_TABLES, consistency_level=cl),
QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl),
QueryMessage(query=self._SELECT_TYPES, consistency_level=cl),
QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl),
QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl),
QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl),
QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl),
QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl)
]
((ks_success, ks_result),
(table_success, table_result),
(col_success, col_result),
(types_success, types_result),
(functions_success, functions_result),
(aggregates_success, aggregates_result),
(triggers_success, triggers_result),
(indexes_success, indexes_result),
(views_success, views_result)) = self.connection.wait_for_responses(
*queries, timeout=self.timeout, fail_on_error=False
)
self.keyspaces_result = self._handle_results(ks_success, ks_result)
self.tables_result = self._handle_results(table_success, table_result)
self.columns_result = self._handle_results(col_success, col_result)
self.triggers_result = self._handle_results(triggers_success, triggers_result)
self.types_result = self._handle_results(types_success, types_result)
self.functions_result = self._handle_results(functions_success, functions_result)
self.aggregates_result = self._handle_results(aggregates_success, aggregates_result)
self.indexes_result = self._handle_results(indexes_success, indexes_result)
self.views_result = self._handle_results(views_success, views_result)
self._aggregate_results()
def _aggregate_results(self):
super(SchemaParserV3, self)._aggregate_results()
m = self.keyspace_table_index_rows
for row in self.indexes_result:
ksname = row["keyspace_name"]
cfname = row[self._table_name_col]
m[ksname][cfname].append(row)
m = self.keyspace_view_rows
for row in self.views_result:
m[row["keyspace_name"]].append(row)
@staticmethod
def _schema_type_to_cql(type_string):
return type_string
+class SchemaParserDSE60(SchemaParserV3):
+ """
+ For DSE 6.0+
+ """
+ recognized_table_options = (SchemaParserV3.recognized_table_options +
+ ("nodesync",))
+
+
class SchemaParserV4(SchemaParserV3):
- recognized_table_options = tuple(
- opt for opt in
- SchemaParserV3.recognized_table_options
- if opt not in (
- # removed in V4: CASSANDRA-13910
- 'dclocal_read_repair_chance', 'read_repair_chance'
- )
- )
+ recognized_table_options = (
+ 'additional_write_policy',
+ 'bloom_filter_fp_chance',
+ 'caching',
+ 'cdc',
+ 'comment',
+ 'compaction',
+ 'compression',
+ 'crc_check_chance',
+ 'default_time_to_live',
+ 'gc_grace_seconds',
+ 'max_index_interval',
+ 'memtable_flush_period_in_ms',
+ 'min_index_interval',
+ 'read_repair',
+ 'speculative_retry')
_SELECT_VIRTUAL_KEYSPACES = 'SELECT * from system_virtual_schema.keyspaces'
_SELECT_VIRTUAL_TABLES = 'SELECT * from system_virtual_schema.tables'
_SELECT_VIRTUAL_COLUMNS = 'SELECT * from system_virtual_schema.columns'
def __init__(self, connection, timeout):
super(SchemaParserV4, self).__init__(connection, timeout)
self.virtual_keyspaces_rows = defaultdict(list)
self.virtual_tables_rows = defaultdict(list)
self.virtual_columns_rows = defaultdict(lambda: defaultdict(list))
def _query_all(self):
cl = ConsistencyLevel.ONE
# todo: this duplicates V3; we should find a way for _query_all methods
# to extend each other.
queries = [
# copied from V3
QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl),
QueryMessage(query=self._SELECT_TABLES, consistency_level=cl),
QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl),
QueryMessage(query=self._SELECT_TYPES, consistency_level=cl),
QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl),
QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl),
QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl),
QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl),
QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl),
# V4-only queries
QueryMessage(query=self._SELECT_VIRTUAL_KEYSPACES, consistency_level=cl),
QueryMessage(query=self._SELECT_VIRTUAL_TABLES, consistency_level=cl),
QueryMessage(query=self._SELECT_VIRTUAL_COLUMNS, consistency_level=cl)
]
responses = self.connection.wait_for_responses(
*queries, timeout=self.timeout, fail_on_error=False)
(
# copied from V3
(ks_success, ks_result),
(table_success, table_result),
(col_success, col_result),
(types_success, types_result),
(functions_success, functions_result),
(aggregates_success, aggregates_result),
(triggers_success, triggers_result),
(indexes_success, indexes_result),
(views_success, views_result),
# V4-only responses
(virtual_ks_success, virtual_ks_result),
(virtual_table_success, virtual_table_result),
(virtual_column_success, virtual_column_result)
) = responses
# copied from V3
self.keyspaces_result = self._handle_results(ks_success, ks_result)
self.tables_result = self._handle_results(table_success, table_result)
self.columns_result = self._handle_results(col_success, col_result)
self.triggers_result = self._handle_results(triggers_success, triggers_result)
self.types_result = self._handle_results(types_success, types_result)
self.functions_result = self._handle_results(functions_success, functions_result)
self.aggregates_result = self._handle_results(aggregates_success, aggregates_result)
self.indexes_result = self._handle_results(indexes_success, indexes_result)
self.views_result = self._handle_results(views_success, views_result)
# V4-only results
# These tables don't exist in some DSE versions reporting 4.X so we can
# ignore them if we got an error
self.virtual_keyspaces_result = self._handle_results(
virtual_ks_success, virtual_ks_result,
- expected_failures=InvalidRequest
+ expected_failures=(InvalidRequest,)
)
self.virtual_tables_result = self._handle_results(
virtual_table_success, virtual_table_result,
- expected_failures=InvalidRequest
+ expected_failures=(InvalidRequest,)
)
self.virtual_columns_result = self._handle_results(
virtual_column_success, virtual_column_result,
- expected_failures=InvalidRequest
+ expected_failures=(InvalidRequest,)
)
self._aggregate_results()
def _aggregate_results(self):
super(SchemaParserV4, self)._aggregate_results()
m = self.virtual_tables_rows
for row in self.virtual_tables_result:
m[row["keyspace_name"]].append(row)
m = self.virtual_columns_rows
for row in self.virtual_columns_result:
ks_name = row['keyspace_name']
tab_name = row[self._table_name_col]
m[ks_name][tab_name].append(row)
def get_all_keyspaces(self):
for x in super(SchemaParserV4, self).get_all_keyspaces():
yield x
for row in self.virtual_keyspaces_result:
ks_name = row['keyspace_name']
keyspace_meta = self._build_keyspace_metadata(row)
keyspace_meta.virtual = True
for table_row in self.virtual_tables_rows.get(ks_name, []):
table_name = table_row[self._table_name_col]
col_rows = self.virtual_columns_rows[ks_name][table_name]
keyspace_meta._add_table_metadata(
self._build_table_metadata(table_row,
col_rows=col_rows,
virtual=True)
)
yield keyspace_meta
@staticmethod
def _build_keyspace_metadata_internal(row):
# necessary fields that aren't int virtual ks
row["durable_writes"] = row.get("durable_writes", None)
row["replication"] = row.get("replication", {})
row["replication"]["class"] = row["replication"].get("class", None)
return super(SchemaParserV4, SchemaParserV4)._build_keyspace_metadata_internal(row)
-class TableMetadataV3(TableMetadata):
- compaction_options = {}
+class SchemaParserDSE67(SchemaParserV4):
+ """
+ For DSE 6.7+
+ """
+ recognized_table_options = (SchemaParserV4.recognized_table_options +
+ ("nodesync",))
- option_maps = ['compaction', 'compression', 'caching']
- @property
- def is_cql_compatible(self):
- return True
+class SchemaParserDSE68(SchemaParserDSE67):
+ """
+ For DSE 6.8+
+ """
- @classmethod
- def _make_option_strings(cls, options_map):
- ret = []
- options_copy = dict(options_map.items())
+ _SELECT_VERTICES = "SELECT * FROM system_schema.vertices"
+ _SELECT_EDGES = "SELECT * FROM system_schema.edges"
- for option in cls.option_maps:
- value = options_copy.get(option)
- if isinstance(value, Mapping):
- del options_copy[option]
- params = ("'%s': '%s'" % (k, v) for k, v in value.items())
- ret.append("%s = {%s}" % (option, ', '.join(params)))
+ _table_metadata_class = TableMetadataDSE68
- for name, value in options_copy.items():
- if value is not None:
- if name == "comment":
- value = value or ""
- ret.append("%s = %s" % (name, protect_value(value)))
+ def __init__(self, connection, timeout):
+ super(SchemaParserDSE68, self).__init__(connection, timeout)
+ self.keyspace_table_vertex_rows = defaultdict(lambda: defaultdict(list))
+ self.keyspace_table_edge_rows = defaultdict(lambda: defaultdict(list))
- return list(sorted(ret))
+ def get_all_keyspaces(self):
+ for keyspace_meta in super(SchemaParserDSE68, self).get_all_keyspaces():
+ self._build_graph_metadata(keyspace_meta)
+ yield keyspace_meta
+
+ def get_table(self, keyspaces, keyspace, table):
+ table_meta = super(SchemaParserDSE68, self).get_table(keyspaces, keyspace, table)
+ cl = ConsistencyLevel.ONE
+ where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder)
+ vertices_query = QueryMessage(query=self._SELECT_VERTICES + where_clause, consistency_level=cl)
+ edges_query = QueryMessage(query=self._SELECT_EDGES + where_clause, consistency_level=cl)
+
+ (vertices_success, vertices_result), (edges_success, edges_result) \
+ = self.connection.wait_for_responses(vertices_query, edges_query, timeout=self.timeout, fail_on_error=False)
+ vertices_result = self._handle_results(vertices_success, vertices_result)
+ edges_result = self._handle_results(edges_success, edges_result)
+
+ try:
+ if vertices_result:
+ table_meta.vertex = self._build_table_vertex_metadata(vertices_result[0])
+ elif edges_result:
+ table_meta.edge = self._build_table_edge_metadata(keyspaces[keyspace], edges_result[0])
+ except Exception:
+ table_meta.vertex = None
+ table_meta.edge = None
+ table_meta._exc_info = sys.exc_info()
+ log.exception("Error while parsing graph metadata for table %s.%s.", keyspace, table)
+
+ return table_meta
+
+ @staticmethod
+ def _build_keyspace_metadata_internal(row):
+ name = row["keyspace_name"]
+ durable_writes = row.get("durable_writes", None)
+ replication = dict(row.get("replication")) if 'replication' in row else {}
+ replication_class = replication.pop("class") if 'class' in replication else None
+ graph_engine = row.get("graph_engine", None)
+ return KeyspaceMetadata(name, durable_writes, replication_class, replication, graph_engine)
+
+ def _build_graph_metadata(self, keyspace_meta):
+
+ def _build_table_graph_metadata(table_meta):
+ for row in self.keyspace_table_vertex_rows[keyspace_meta.name][table_meta.name]:
+ table_meta.vertex = self._build_table_vertex_metadata(row)
+
+ for row in self.keyspace_table_edge_rows[keyspace_meta.name][table_meta.name]:
+ table_meta.edge = self._build_table_edge_metadata(keyspace_meta, row)
+
+ try:
+ # Make sure we process vertices before edges
+ for table_meta in [t for t in six.itervalues(keyspace_meta.tables)
+ if t.name in self.keyspace_table_vertex_rows[keyspace_meta.name]]:
+ _build_table_graph_metadata(table_meta)
+
+ # all other tables...
+ for table_meta in [t for t in six.itervalues(keyspace_meta.tables)
+ if t.name not in self.keyspace_table_vertex_rows[keyspace_meta.name]]:
+ _build_table_graph_metadata(table_meta)
+ except Exception:
+ # schema error, remove all graph metadata for this keyspace
+ for t in six.itervalues(keyspace_meta.tables):
+ t.edge = t.vertex = None
+ keyspace_meta._exc_info = sys.exc_info()
+ log.exception("Error while parsing graph metadata for keyspace %s", keyspace_meta.name)
+
+ @staticmethod
+ def _build_table_vertex_metadata(row):
+ return VertexMetadata(row.get("keyspace_name"), row.get("table_name"),
+ row.get("label_name"))
+
+ @staticmethod
+ def _build_table_edge_metadata(keyspace_meta, row):
+ from_table = row.get("from_table")
+ from_table_meta = keyspace_meta.tables.get(from_table)
+ from_label = from_table_meta.vertex.label_name
+ to_table = row.get("to_table")
+ to_table_meta = keyspace_meta.tables.get(to_table)
+ to_label = to_table_meta.vertex.label_name
+
+ return EdgeMetadata(
+ row.get("keyspace_name"), row.get("table_name"),
+ row.get("label_name"), from_table, from_label,
+ row.get("from_partition_key_columns"),
+ row.get("from_clustering_columns"), to_table, to_label,
+ row.get("to_partition_key_columns"),
+ row.get("to_clustering_columns"))
+
+ def _query_all(self):
+ cl = ConsistencyLevel.ONE
+ queries = [
+ # copied from v4
+ QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_TABLES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl),
+ QueryMessage(query=self._SELECT_TYPES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl),
+ QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl),
+ QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl),
+ QueryMessage(query=self._SELECT_VIRTUAL_KEYSPACES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_VIRTUAL_TABLES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_VIRTUAL_COLUMNS, consistency_level=cl),
+ # dse6.8 only
+ QueryMessage(query=self._SELECT_VERTICES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_EDGES, consistency_level=cl)
+ ]
+
+ responses = self.connection.wait_for_responses(
+ *queries, timeout=self.timeout, fail_on_error=False)
+ (
+ # copied from V4
+ (ks_success, ks_result),
+ (table_success, table_result),
+ (col_success, col_result),
+ (types_success, types_result),
+ (functions_success, functions_result),
+ (aggregates_success, aggregates_result),
+ (triggers_success, triggers_result),
+ (indexes_success, indexes_result),
+ (views_success, views_result),
+ (virtual_ks_success, virtual_ks_result),
+ (virtual_table_success, virtual_table_result),
+ (virtual_column_success, virtual_column_result),
+ # dse6.8 responses
+ (vertices_success, vertices_result),
+ (edges_success, edges_result)
+ ) = responses
+
+ # copied from V4
+ self.keyspaces_result = self._handle_results(ks_success, ks_result)
+ self.tables_result = self._handle_results(table_success, table_result)
+ self.columns_result = self._handle_results(col_success, col_result)
+ self.triggers_result = self._handle_results(triggers_success, triggers_result)
+ self.types_result = self._handle_results(types_success, types_result)
+ self.functions_result = self._handle_results(functions_success, functions_result)
+ self.aggregates_result = self._handle_results(aggregates_success, aggregates_result)
+ self.indexes_result = self._handle_results(indexes_success, indexes_result)
+ self.views_result = self._handle_results(views_success, views_result)
+
+ # These tables don't exist in some DSE versions reporting 4.X so we can
+ # ignore them if we got an error
+ self.virtual_keyspaces_result = self._handle_results(
+ virtual_ks_success, virtual_ks_result,
+ expected_failures=(InvalidRequest,)
+ )
+ self.virtual_tables_result = self._handle_results(
+ virtual_table_success, virtual_table_result,
+ expected_failures=(InvalidRequest,)
+ )
+ self.virtual_columns_result = self._handle_results(
+ virtual_column_success, virtual_column_result,
+ expected_failures=(InvalidRequest,)
+ )
+
+ # dse6.8-only results
+ self.vertices_result = self._handle_results(vertices_success, vertices_result)
+ self.edges_result = self._handle_results(edges_success, edges_result)
+
+ self._aggregate_results()
+
+ def _aggregate_results(self):
+ super(SchemaParserDSE68, self)._aggregate_results()
+
+ m = self.keyspace_table_vertex_rows
+ for row in self.vertices_result:
+ ksname = row["keyspace_name"]
+ cfname = row['table_name']
+ m[ksname][cfname].append(row)
+
+ m = self.keyspace_table_edge_rows
+ for row in self.edges_result:
+ ksname = row["keyspace_name"]
+ cfname = row['table_name']
+ m[ksname][cfname].append(row)
class MaterializedViewMetadata(object):
"""
A representation of a materialized view on a table
"""
keyspace_name = None
-
- """ A string name of the view."""
+ """ A string name of the keyspace of this view."""
name = None
""" A string name of the view."""
base_table_name = None
""" A string name of the base table for this view."""
partition_key = None
"""
A list of :class:`.ColumnMetadata` instances representing the columns in
the partition key for this view. This will always hold at least one
column.
"""
clustering_key = None
"""
A list of :class:`.ColumnMetadata` instances representing the columns
in the clustering key for this view.
Note that a table may have no clustering keys, in which case this will
be an empty list.
"""
columns = None
"""
A dict mapping column names to :class:`.ColumnMetadata` instances.
"""
include_all_columns = None
""" A flag indicating whether the view was created AS SELECT * """
where_clause = None
""" String WHERE clause for the view select statement. From server metadata """
options = None
"""
A dict mapping table option names to their specific settings for this
view.
"""
extensions = None
"""
Metadata describing configuration for table extensions
"""
def __init__(self, keyspace_name, view_name, base_table_name, include_all_columns, where_clause, options):
self.keyspace_name = keyspace_name
self.name = view_name
self.base_table_name = base_table_name
self.partition_key = []
self.clustering_key = []
self.columns = OrderedDict()
self.include_all_columns = include_all_columns
self.where_clause = where_clause
self.options = options or {}
def as_cql_query(self, formatted=False):
"""
Returns a CQL query that can be used to recreate this function.
If `formatted` is set to :const:`True`, extra whitespace will
be added to make the query more readable.
"""
sep = '\n ' if formatted else ' '
keyspace = protect_name(self.keyspace_name)
name = protect_name(self.name)
selected_cols = '*' if self.include_all_columns else ', '.join(protect_name(col.name) for col in self.columns.values())
base_table = protect_name(self.base_table_name)
where_clause = self.where_clause
part_key = ', '.join(protect_name(col.name) for col in self.partition_key)
if len(self.partition_key) > 1:
pk = "((%s)" % part_key
else:
pk = "(%s" % part_key
if self.clustering_key:
pk += ", %s" % ', '.join(protect_name(col.name) for col in self.clustering_key)
pk += ")"
properties = TableMetadataV3._property_string(formatted, self.clustering_key, self.options)
ret = ("CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s"
"SELECT %(selected_cols)s%(sep)s"
"FROM %(keyspace)s.%(base_table)s%(sep)s"
"WHERE %(where_clause)s%(sep)s"
"PRIMARY KEY %(pk)s%(sep)s"
"WITH %(properties)s") % locals()
if self.extensions:
registry = _RegisteredExtensionType._extension_registry
for k in six.viewkeys(registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey
ext = registry[k]
cql = ext.after_table_cql(self, k, self.extensions[k])
if cql:
ret += "\n\n%s" % (cql,)
return ret
def export_as_string(self):
return self.as_cql_query(formatted=True) + ";"
-def get_schema_parser(connection, server_version, timeout):
+class VertexMetadata(object):
+ """
+ A representation of a vertex on a table
+ """
+
+ keyspace_name = None
+ """ A string name of the keyspace. """
+
+ table_name = None
+ """ A string name of the table this vertex is on. """
+
+ label_name = None
+ """ A string name of the label of this vertex."""
+
+ def __init__(self, keyspace_name, table_name, label_name):
+ self.keyspace_name = keyspace_name
+ self.table_name = table_name
+ self.label_name = label_name
+
+
+class EdgeMetadata(object):
+ """
+ A representation of an edge on a table
+ """
+
+ keyspace_name = None
+ """A string name of the keyspace """
+
+ table_name = None
+ """A string name of the table this edge is on"""
+
+ label_name = None
+ """A string name of the label of this edge"""
+
+ from_table = None
+ """A string name of the from table of this edge (incoming vertex)"""
+
+ from_label = None
+ """A string name of the from table label of this edge (incoming vertex)"""
+
+ from_partition_key_columns = None
+ """The columns that match the partition key of the incoming vertex table."""
+
+ from_clustering_columns = None
+ """The columns that match the clustering columns of the incoming vertex table."""
+
+ to_table = None
+ """A string name of the to table of this edge (outgoing vertex)"""
+
+ to_label = None
+ """A string name of the to table label of this edge (outgoing vertex)"""
+
+ to_partition_key_columns = None
+ """The columns that match the partition key of the outgoing vertex table."""
+
+ to_clustering_columns = None
+ """The columns that match the clustering columns of the outgoing vertex table."""
+
+ def __init__(
+ self, keyspace_name, table_name, label_name, from_table,
+ from_label, from_partition_key_columns, from_clustering_columns,
+ to_table, to_label, to_partition_key_columns,
+ to_clustering_columns):
+ self.keyspace_name = keyspace_name
+ self.table_name = table_name
+ self.label_name = label_name
+ self.from_table = from_table
+ self.from_label = from_label
+ self.from_partition_key_columns = from_partition_key_columns
+ self.from_clustering_columns = from_clustering_columns
+ self.to_table = to_table
+ self.to_label = to_label
+ self.to_partition_key_columns = to_partition_key_columns
+ self.to_clustering_columns = to_clustering_columns
+
+
+def get_schema_parser(connection, server_version, dse_version, timeout):
version = Version(server_version)
+ if dse_version:
+ v = Version(dse_version)
+ if v >= Version('6.8.0'):
+ return SchemaParserDSE68(connection, timeout)
+ elif v >= Version('6.7.0'):
+ return SchemaParserDSE67(connection, timeout)
+ elif v >= Version('6.0.0'):
+ return SchemaParserDSE60(connection, timeout)
+
if version >= Version('4-a'):
return SchemaParserV4(connection, timeout)
- if version >= Version('3.0.0'):
+ elif version >= Version('3.0.0'):
return SchemaParserV3(connection, timeout)
else:
# we could further specialize by version. Right now just refactoring the
# multi-version parser we have as of C* 2.2.0rc1.
return SchemaParserV22(connection, timeout)
def _cql_from_cass_type(cass_type):
"""
A string representation of the type for this column, such as "varchar"
or "map".
"""
if issubclass(cass_type, types.ReversedType):
return cass_type.subtypes[0].cql_parameterized_type()
else:
return cass_type.cql_parameterized_type()
+class RLACTableExtension(RegisteredTableExtension):
+ name = "DSE_RLACA"
+
+ @classmethod
+ def after_table_cql(cls, table_meta, ext_key, ext_blob):
+ return "RESTRICT ROWS ON %s.%s USING %s;" % (protect_name(table_meta.keyspace_name),
+ protect_name(table_meta.name),
+ protect_name(ext_blob.decode('utf-8')))
NO_VALID_REPLICA = object()
def group_keys_by_replica(session, keyspace, table, keys):
"""
Returns a :class:`dict` with the keys grouped per host. This can be
used to more accurately group by IN clause or to batch the keys per host.
If a valid replica is not found for a particular key it will be grouped under
:class:`~.NO_VALID_REPLICA`
Example usage::
result = group_keys_by_replica(
session, "system", "peers",
(("127.0.0.1", ), ("127.0.0.2", ))
)
"""
cluster = session.cluster
partition_keys = cluster.metadata.keyspaces[keyspace].tables[table].partition_key
serializers = list(types._cqltypes[partition_key.cql_type] for partition_key in partition_keys)
keys_per_host = defaultdict(list)
distance = cluster._default_load_balancing_policy.distance
for key in keys:
serialized_key = [serializer.serialize(pk, cluster.protocol_version)
for serializer, pk in zip(serializers, key)]
if len(serialized_key) == 1:
routing_key = serialized_key[0]
else:
routing_key = b"".join(struct.pack(">H%dsB" % len(p), len(p), p, 0) for p in serialized_key)
all_replicas = cluster.metadata.get_replicas(keyspace, routing_key)
# First check if there are local replicas
valid_replicas = [host for host in all_replicas if
host.is_up and distance(host) == HostDistance.LOCAL]
if not valid_replicas:
valid_replicas = [host for host in all_replicas if host.is_up]
if valid_replicas:
keys_per_host[random.choice(valid_replicas)].append(key)
else:
# We will group under this statement all the keys for which
# we haven't found a valid replica
keys_per_host[NO_VALID_REPLICA].append(key)
return dict(keys_per_host)
+
+# TODO next major reorg
+class _NodeInfo(object):
+ """
+ Internal utility functions to determine the different host addresses/ports
+ from a local or peers row.
+ """
+
+ @staticmethod
+ def get_broadcast_rpc_address(row):
+ # TODO next major, change the parsing logic to avoid any
+ # overriding of a non-null value
+ addr = row.get("rpc_address")
+ if "native_address" in row:
+ addr = row.get("native_address")
+ if "native_transport_address" in row:
+ addr = row.get("native_transport_address")
+ if not addr or addr in ["0.0.0.0", "::"]:
+ addr = row.get("peer")
+
+ return addr
+
+ @staticmethod
+ def get_broadcast_rpc_port(row):
+ port = row.get("rpc_port")
+ if port is None or port == 0:
+ port = row.get("native_port")
+
+ return port if port and port > 0 else None
+
+ @staticmethod
+ def get_broadcast_address(row):
+ addr = row.get("broadcast_address")
+ if addr is None:
+ addr = row.get("peer")
+
+ return addr
+
+ @staticmethod
+ def get_broadcast_port(row):
+ port = row.get("broadcast_port")
+ if port is None or port == 0:
+ port = row.get("peer_port")
+
+ return port if port and port > 0 else None
diff --git a/cassandra/policies.py b/cassandra/policies.py
index d610666..fa1e8cf 100644
--- a/cassandra/policies.py
+++ b/cassandra/policies.py
@@ -1,1104 +1,1183 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import islice, cycle, groupby, repeat
import logging
from random import randint, shuffle
from threading import Lock
import socket
import warnings
from cassandra import WriteType as WT
# This is done this way because WriteType was originally
# defined here and in order not to break the API.
# It may removed in the next mayor.
WriteType = WT
from cassandra import ConsistencyLevel, OperationTimedOut
log = logging.getLogger(__name__)
class HostDistance(object):
"""
A measure of how "distant" a node is from the client, which
may influence how the load balancer distributes requests
and how many connections are opened to the node.
"""
IGNORED = -1
"""
A node with this distance should never be queried or have
connections opened to it.
"""
LOCAL = 0
"""
Nodes with ``LOCAL`` distance will be preferred for operations
under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
and will have a greater number of connections opened against
them by default.
This distance is typically used for nodes within the same
datacenter as the client.
"""
REMOTE = 1
"""
Nodes with ``REMOTE`` distance will be treated as a last resort
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
and will have a smaller number of connections opened against
them by default.
This distance is typically used for nodes outside of the
datacenter that the client is running in.
"""
class HostStateListener(object):
def on_up(self, host):
""" Called when a node is marked up. """
raise NotImplementedError()
def on_down(self, host):
""" Called when a node is marked down. """
raise NotImplementedError()
def on_add(self, host):
"""
Called when a node is added to the cluster. The newly added node
should be considered up.
"""
raise NotImplementedError()
def on_remove(self, host):
""" Called when a node is removed from the cluster. """
raise NotImplementedError()
class LoadBalancingPolicy(HostStateListener):
"""
Load balancing policies are used to decide how to distribute
requests among all possible coordinator nodes in the cluster.
In particular, they may focus on querying "near" nodes (those
in a local datacenter) or on querying nodes who happen to
be replicas for the requested data.
You may also use subclasses of :class:`.LoadBalancingPolicy` for
custom behavior.
"""
_hosts_lock = None
def __init__(self):
self._hosts_lock = Lock()
def distance(self, host):
"""
Returns a measure of how remote a :class:`~.pool.Host` is in
terms of the :class:`.HostDistance` enums.
"""
raise NotImplementedError()
def populate(self, cluster, hosts):
"""
This method is called to initialize the load balancing
policy with a set of :class:`.Host` instances before its
first use. The `cluster` parameter is an instance of
:class:`.Cluster`.
"""
raise NotImplementedError()
def make_query_plan(self, working_keyspace=None, query=None):
"""
Given a :class:`~.query.Statement` instance, return a iterable
of :class:`.Host` instances which should be queried in that
order. A generator may work well for custom implementations
of this method.
Note that the `query` argument may be :const:`None` when preparing
statements.
`working_keyspace` should be the string name of the current keyspace,
as set through :meth:`.Session.set_keyspace()` or with a ``USE``
statement.
"""
raise NotImplementedError()
def check_supported(self):
"""
This will be called after the cluster Metadata has been initialized.
If the load balancing policy implementation cannot be supported for
some reason (such as a missing C extension), this is the point at
which it should raise an exception.
"""
pass
class RoundRobinPolicy(LoadBalancingPolicy):
"""
A subclass of :class:`.LoadBalancingPolicy` which evenly
distributes queries across all nodes in the cluster,
regardless of what datacenter the nodes may be in.
"""
_live_hosts = frozenset(())
_position = 0
def populate(self, cluster, hosts):
self._live_hosts = frozenset(hosts)
if len(hosts) > 1:
self._position = randint(0, len(hosts) - 1)
def distance(self, host):
return HostDistance.LOCAL
def make_query_plan(self, working_keyspace=None, query=None):
# not thread-safe, but we don't care much about lost increments
# for the purposes of load balancing
pos = self._position
self._position += 1
hosts = self._live_hosts
length = len(hosts)
if length:
pos %= length
return islice(cycle(hosts), pos, pos + length)
else:
return []
def on_up(self, host):
with self._hosts_lock:
self._live_hosts = self._live_hosts.union((host, ))
def on_down(self, host):
with self._hosts_lock:
self._live_hosts = self._live_hosts.difference((host, ))
def on_add(self, host):
with self._hosts_lock:
self._live_hosts = self._live_hosts.union((host, ))
def on_remove(self, host):
with self._hosts_lock:
self._live_hosts = self._live_hosts.difference((host, ))
class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
"""
Similar to :class:`.RoundRobinPolicy`, but prefers hosts
in the local datacenter and only uses nodes in remote
datacenters as a last resort.
"""
local_dc = None
used_hosts_per_remote_dc = 0
def __init__(self, local_dc='', used_hosts_per_remote_dc=0):
"""
The `local_dc` parameter should be the name of the datacenter
(such as is reported by ``nodetool ring``) that should
be considered local. If not specified, the driver will choose
a local_dc based on the first host among :attr:`.Cluster.contact_points`
having a valid DC. If relying on this mechanism, all specified
contact points should be nodes in a single, local DC.
`used_hosts_per_remote_dc` controls how many nodes in
each remote datacenter will have connections opened
against them. In other words, `used_hosts_per_remote_dc` hosts
will be considered :attr:`~.HostDistance.REMOTE` and the
rest will be considered :attr:`~.HostDistance.IGNORED`.
By default, all remote hosts are ignored.
"""
self.local_dc = local_dc
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
self._dc_live_hosts = {}
self._position = 0
self._endpoints = []
LoadBalancingPolicy.__init__(self)
def _dc(self, host):
return host.datacenter or self.local_dc
def populate(self, cluster, hosts):
for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)):
self._dc_live_hosts[dc] = tuple(set(dc_hosts))
if not self.local_dc:
self._endpoints = [
endpoint
for endpoint in cluster.endpoints_resolved]
self._position = randint(0, len(hosts) - 1) if hosts else 0
def distance(self, host):
dc = self._dc(host)
if dc == self.local_dc:
return HostDistance.LOCAL
if not self.used_hosts_per_remote_dc:
return HostDistance.IGNORED
else:
dc_hosts = self._dc_live_hosts.get(dc)
if not dc_hosts:
return HostDistance.IGNORED
if host in list(dc_hosts)[:self.used_hosts_per_remote_dc]:
return HostDistance.REMOTE
else:
return HostDistance.IGNORED
def make_query_plan(self, working_keyspace=None, query=None):
# not thread-safe, but we don't care much about lost increments
# for the purposes of load balancing
pos = self._position
self._position += 1
local_live = self._dc_live_hosts.get(self.local_dc, ())
pos = (pos % len(local_live)) if local_live else 0
for host in islice(cycle(local_live), pos, pos + len(local_live)):
yield host
# the dict can change, so get candidate DCs iterating over keys of a copy
other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc]
for dc in other_dcs:
remote_live = self._dc_live_hosts.get(dc, ())
for host in remote_live[:self.used_hosts_per_remote_dc]:
yield host
def on_up(self, host):
# not worrying about threads because this will happen during
# control connection startup/refresh
if not self.local_dc and host.datacenter:
if host.endpoint in self._endpoints:
self.local_dc = host.datacenter
log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); "
"if incorrect, please specify a local_dc to the constructor, "
"or limit contact points to local cluster nodes" %
(self.local_dc, host.endpoint))
del self._endpoints
dc = self._dc(host)
with self._hosts_lock:
current_hosts = self._dc_live_hosts.get(dc, ())
if host not in current_hosts:
self._dc_live_hosts[dc] = current_hosts + (host, )
def on_down(self, host):
dc = self._dc(host)
with self._hosts_lock:
current_hosts = self._dc_live_hosts.get(dc, ())
if host in current_hosts:
hosts = tuple(h for h in current_hosts if h != host)
if hosts:
self._dc_live_hosts[dc] = hosts
else:
del self._dc_live_hosts[dc]
def on_add(self, host):
self.on_up(host)
def on_remove(self, host):
self.on_down(host)
class TokenAwarePolicy(LoadBalancingPolicy):
"""
A :class:`.LoadBalancingPolicy` wrapper that adds token awareness to
a child policy.
This alters the child policy's behavior so that it first attempts to
send queries to :attr:`~.HostDistance.LOCAL` replicas (as determined
by the child policy) based on the :class:`.Statement`'s
:attr:`~.Statement.routing_key`. If :attr:`.shuffle_replicas` is
truthy, these replicas will be yielded in a random order. Once those
hosts are exhausted, the remaining hosts in the child policy's query
plan will be used in the order provided by the child policy.
If no :attr:`~.Statement.routing_key` is set on the query, the child
policy's query plan will be used as is.
"""
_child_policy = None
_cluster_metadata = None
shuffle_replicas = False
"""
Yield local replicas in a random order.
"""
def __init__(self, child_policy, shuffle_replicas=False):
self._child_policy = child_policy
self.shuffle_replicas = shuffle_replicas
def populate(self, cluster, hosts):
self._cluster_metadata = cluster.metadata
self._child_policy.populate(cluster, hosts)
def check_supported(self):
if not self._cluster_metadata.can_support_partitioner():
raise RuntimeError(
'%s cannot be used with the cluster partitioner (%s) because '
'the relevant C extension for this driver was not compiled. '
'See the installation instructions for details on building '
'and installing the C extensions.' %
(self.__class__.__name__, self._cluster_metadata.partitioner))
def distance(self, *args, **kwargs):
return self._child_policy.distance(*args, **kwargs)
def make_query_plan(self, working_keyspace=None, query=None):
if query and query.keyspace:
keyspace = query.keyspace
else:
keyspace = working_keyspace
child = self._child_policy
if query is None:
for host in child.make_query_plan(keyspace, query):
yield host
else:
routing_key = query.routing_key
if routing_key is None or keyspace is None:
for host in child.make_query_plan(keyspace, query):
yield host
else:
replicas = self._cluster_metadata.get_replicas(keyspace, routing_key)
if self.shuffle_replicas:
shuffle(replicas)
for replica in replicas:
if replica.is_up and \
child.distance(replica) == HostDistance.LOCAL:
yield replica
for host in child.make_query_plan(keyspace, query):
# skip if we've already listed this host
if host not in replicas or \
child.distance(host) == HostDistance.REMOTE:
yield host
def on_up(self, *args, **kwargs):
return self._child_policy.on_up(*args, **kwargs)
def on_down(self, *args, **kwargs):
return self._child_policy.on_down(*args, **kwargs)
def on_add(self, *args, **kwargs):
return self._child_policy.on_add(*args, **kwargs)
def on_remove(self, *args, **kwargs):
return self._child_policy.on_remove(*args, **kwargs)
class WhiteListRoundRobinPolicy(RoundRobinPolicy):
"""
A subclass of :class:`.RoundRobinPolicy` which evenly
distributes queries across all nodes in the cluster,
regardless of what datacenter the nodes may be in, but
only if that node exists in the list of allowed nodes
This policy is addresses the issue described in
https://datastax-oss.atlassian.net/browse/JAVA-145
Where connection errors occur when connection
attempts are made to private IP addresses remotely
"""
def __init__(self, hosts):
"""
The `hosts` parameter should be a sequence of hosts to permit
connections to.
"""
- self._allowed_hosts = hosts
+ self._allowed_hosts = tuple(hosts)
self._allowed_hosts_resolved = [endpoint[4][0] for a in self._allowed_hosts
for endpoint in socket.getaddrinfo(a, None, socket.AF_UNSPEC, socket.SOCK_STREAM)]
RoundRobinPolicy.__init__(self)
def populate(self, cluster, hosts):
self._live_hosts = frozenset(h for h in hosts if h.address in self._allowed_hosts_resolved)
if len(hosts) <= 1:
self._position = 0
else:
self._position = randint(0, len(hosts) - 1)
def distance(self, host):
if host.address in self._allowed_hosts_resolved:
return HostDistance.LOCAL
else:
return HostDistance.IGNORED
def on_up(self, host):
if host.address in self._allowed_hosts_resolved:
RoundRobinPolicy.on_up(self, host)
def on_add(self, host):
if host.address in self._allowed_hosts_resolved:
RoundRobinPolicy.on_add(self, host)
class HostFilterPolicy(LoadBalancingPolicy):
"""
A :class:`.LoadBalancingPolicy` subclass configured with a child policy,
and a single-argument predicate. This policy defers to the child policy for
hosts where ``predicate(host)`` is truthy. Hosts for which
``predicate(host)`` is falsey will be considered :attr:`.IGNORED`, and will
not be used in a query plan.
This can be used in the cases where you need a whitelist or blacklist
policy, e.g. to prepare for decommissioning nodes or for testing:
.. code-block:: python
def address_is_ignored(host):
return host.address in [ignored_address0, ignored_address1]
blacklist_filter_policy = HostFilterPolicy(
child_policy=RoundRobinPolicy(),
predicate=address_is_ignored
)
cluster = Cluster(
primary_host,
load_balancing_policy=blacklist_filter_policy,
)
See the note in the :meth:`.make_query_plan` documentation for a caveat on
how wrapping ordering polices (e.g. :class:`.RoundRobinPolicy`) may break
desirable properties of the wrapped policy.
Please note that whitelist and blacklist policies are not recommended for
general, day-to-day use. You probably want something like
:class:`.DCAwareRoundRobinPolicy`, which prefers a local DC but has
fallbacks, over a brute-force method like whitelisting or blacklisting.
"""
def __init__(self, child_policy, predicate):
"""
:param child_policy: an instantiated :class:`.LoadBalancingPolicy`
that this one will defer to.
:param predicate: a one-parameter function that takes a :class:`.Host`.
If it returns a falsey value, the :class:`.Host` will
be :attr:`.IGNORED` and not returned in query plans.
"""
super(HostFilterPolicy, self).__init__()
self._child_policy = child_policy
self._predicate = predicate
def on_up(self, host, *args, **kwargs):
return self._child_policy.on_up(host, *args, **kwargs)
def on_down(self, host, *args, **kwargs):
return self._child_policy.on_down(host, *args, **kwargs)
def on_add(self, host, *args, **kwargs):
return self._child_policy.on_add(host, *args, **kwargs)
def on_remove(self, host, *args, **kwargs):
return self._child_policy.on_remove(host, *args, **kwargs)
@property
def predicate(self):
"""
A predicate, set on object initialization, that takes a :class:`.Host`
and returns a value. If the value is falsy, the :class:`.Host` is
:class:`~HostDistance.IGNORED`. If the value is truthy,
:class:`.HostFilterPolicy` defers to the child policy to determine the
host's distance.
This is a read-only value set in ``__init__``, implemented as a
``property``.
"""
return self._predicate
def distance(self, host):
"""
Checks if ``predicate(host)``, then returns
:attr:`~HostDistance.IGNORED` if falsey, and defers to the child policy
otherwise.
"""
if self.predicate(host):
return self._child_policy.distance(host)
else:
return HostDistance.IGNORED
def populate(self, cluster, hosts):
self._child_policy.populate(cluster=cluster, hosts=hosts)
def make_query_plan(self, working_keyspace=None, query=None):
"""
Defers to the child policy's
:meth:`.LoadBalancingPolicy.make_query_plan` and filters the results.
Note that this filtering may break desirable properties of the wrapped
policy in some cases. For instance, imagine if you configure this
policy to filter out ``host2``, and to wrap a round-robin policy that
rotates through three hosts in the order ``host1, host2, host3``,
``host2, host3, host1``, ``host3, host1, host2``, repeating. This
policy will yield ``host1, host3``, ``host3, host1``, ``host3, host1``,
disproportionately favoring ``host3``.
"""
child_qp = self._child_policy.make_query_plan(
working_keyspace=working_keyspace, query=query
)
for host in child_qp:
if self.predicate(host):
yield host
def check_supported(self):
return self._child_policy.check_supported()
class ConvictionPolicy(object):
"""
A policy which decides when hosts should be considered down
based on the types of failures and the number of failures.
If custom behavior is needed, this class may be subclassed.
"""
def __init__(self, host):
"""
`host` is an instance of :class:`.Host`.
"""
self.host = host
def add_failure(self, connection_exc):
"""
Implementations should return :const:`True` if the host should be
convicted, :const:`False` otherwise.
"""
raise NotImplementedError()
def reset(self):
"""
Implementations should clear out any convictions or state regarding
the host.
"""
raise NotImplementedError()
class SimpleConvictionPolicy(ConvictionPolicy):
"""
The default implementation of :class:`ConvictionPolicy`,
which simply marks a host as down after the first failure
of any kind.
"""
def add_failure(self, connection_exc):
return not isinstance(connection_exc, OperationTimedOut)
def reset(self):
pass
class ReconnectionPolicy(object):
"""
This class and its subclasses govern how frequently an attempt is made
to reconnect to nodes that are marked as dead.
If custom behavior is needed, this class may be subclassed.
"""
def new_schedule(self):
"""
This should return a finite or infinite iterable of delays (each as a
floating point number of seconds) inbetween each failed reconnection
attempt. Note that if the iterable is finite, reconnection attempts
will cease once the iterable is exhausted.
"""
raise NotImplementedError()
class ConstantReconnectionPolicy(ReconnectionPolicy):
"""
A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay
inbetween each reconnection attempt.
"""
def __init__(self, delay, max_attempts=64):
"""
`delay` should be a floating point number of seconds to wait inbetween
each attempt.
`max_attempts` should be a total number of attempts to be made before
giving up, or :const:`None` to continue reconnection attempts forever.
The default is 64.
"""
if delay < 0:
raise ValueError("delay must not be negative")
if max_attempts is not None and max_attempts < 0:
raise ValueError("max_attempts must not be negative")
self.delay = delay
self.max_attempts = max_attempts
def new_schedule(self):
if self.max_attempts:
return repeat(self.delay, self.max_attempts)
return repeat(self.delay)
class ExponentialReconnectionPolicy(ReconnectionPolicy):
"""
A :class:`.ReconnectionPolicy` subclass which exponentially increases
the length of the delay inbetween each reconnection attempt up to
a set maximum delay.
A random amount of jitter (+/- 15%) will be added to the pure exponential
delay value to avoid the situations where many reconnection handlers are
trying to reconnect at exactly the same time.
"""
# TODO: max_attempts is 64 to preserve legacy default behavior
# consider changing to None in major release to prevent the policy
# giving up forever
def __init__(self, base_delay, max_delay, max_attempts=64):
"""
`base_delay` and `max_delay` should be in floating point units of
seconds.
`max_attempts` should be a total number of attempts to be made before
giving up, or :const:`None` to continue reconnection attempts forever.
The default is 64.
"""
if base_delay < 0 or max_delay < 0:
raise ValueError("Delays may not be negative")
if max_delay < base_delay:
raise ValueError("Max delay must be greater than base delay")
if max_attempts is not None and max_attempts < 0:
raise ValueError("max_attempts must not be negative")
self.base_delay = base_delay
self.max_delay = max_delay
self.max_attempts = max_attempts
def new_schedule(self):
i, overflowed = 0, False
while self.max_attempts is None or i < self.max_attempts:
if overflowed:
yield self.max_delay
else:
try:
yield self._add_jitter(min(self.base_delay * (2 ** i), self.max_delay))
except OverflowError:
overflowed = True
yield self.max_delay
i += 1
# Adds -+ 15% to the delay provided
def _add_jitter(self, value):
jitter = randint(85, 115)
delay = (jitter * value) / 100
return min(max(self.base_delay, delay), self.max_delay)
class RetryPolicy(object):
"""
A policy that describes whether to retry, rethrow, or ignore coordinator
timeout and unavailable failures. These are failures reported from the
server side. Timeouts are configured by
`settings in cassandra.yaml `_.
Unavailable failures occur when the coordinator cannot acheive the consistency
level for a request. For further information see the method descriptions
below.
To specify a default retry policy, set the
:attr:`.Cluster.default_retry_policy` attribute to an instance of this
class or one of its subclasses.
To specify a retry policy per query, set the :attr:`.Statement.retry_policy`
attribute to an instance of this class or one of its subclasses.
If custom behavior is needed for retrying certain operations,
this class may be subclassed.
"""
RETRY = 0
"""
This should be returned from the below methods if the operation
should be retried on the same connection.
"""
RETHROW = 1
"""
This should be returned from the below methods if the failure
should be propagated and no more retries attempted.
"""
IGNORE = 2
"""
This should be returned from the below methods if the failure
should be ignored but no more retries should be attempted.
"""
RETRY_NEXT_HOST = 3
"""
This should be returned from the below methods if the operation
should be retried on another connection.
"""
def on_read_timeout(self, query, consistency, required_responses,
received_responses, data_retrieved, retry_num):
"""
This is called when a read operation times out from the coordinator's
perspective (i.e. a replica did not respond to the coordinator in time).
It should return a tuple with two items: one of the class enums (such
as :attr:`.RETRY`) and a :class:`.ConsistencyLevel` to retry the
operation at or :const:`None` to keep the same consistency level.
`query` is the :class:`.Statement` that timed out.
`consistency` is the :class:`.ConsistencyLevel` that the operation was
attempted at.
The `required_responses` and `received_responses` parameters describe
how many replicas needed to respond to meet the requested consistency
level and how many actually did respond before the coordinator timed
out the request. `data_retrieved` is a boolean indicating whether
any of those responses contained data (as opposed to just a digest).
`retry_num` counts how many times the operation has been retried, so
the first time this method is called, `retry_num` will be 0.
By default, operations will be retried at most once, and only if
a sufficient number of replicas responded (with data digests).
"""
if retry_num != 0:
return self.RETHROW, None
elif received_responses >= required_responses and not data_retrieved:
return self.RETRY, consistency
else:
return self.RETHROW, None
def on_write_timeout(self, query, consistency, write_type,
required_responses, received_responses, retry_num):
"""
This is called when a write operation times out from the coordinator's
perspective (i.e. a replica did not respond to the coordinator in time).
`query` is the :class:`.Statement` that timed out.
`consistency` is the :class:`.ConsistencyLevel` that the operation was
attempted at.
`write_type` is one of the :class:`.WriteType` enums describing the
type of write operation.
The `required_responses` and `received_responses` parameters describe
how many replicas needed to acknowledge the write to meet the requested
consistency level and how many replicas actually did acknowledge the
write before the coordinator timed out the request.
`retry_num` counts how many times the operation has been retried, so
the first time this method is called, `retry_num` will be 0.
By default, failed write operations will retried at most once, and
they will only be retried if the `write_type` was
:attr:`~.WriteType.BATCH_LOG`.
"""
if retry_num != 0:
return self.RETHROW, None
elif write_type == WriteType.BATCH_LOG:
return self.RETRY, consistency
else:
return self.RETHROW, None
def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num):
"""
This is called when the coordinator node determines that a read or
write operation cannot be successful because the number of live
replicas are too low to meet the requested :class:`.ConsistencyLevel`.
This means that the read or write operation was never forwarded to
any replicas.
`query` is the :class:`.Statement` that failed.
`consistency` is the :class:`.ConsistencyLevel` that the operation was
attempted at.
`required_replicas` is the number of replicas that would have needed to
acknowledge the operation to meet the requested consistency level.
`alive_replicas` is the number of replicas that the coordinator
considered alive at the time of the request.
`retry_num` counts how many times the operation has been retried, so
the first time this method is called, `retry_num` will be 0.
By default, if this is the first retry, it triggers a retry on the next
host in the query plan with the same consistency level. If this is not the
first retry, no retries will be attempted and the error will be re-raised.
"""
return (self.RETRY_NEXT_HOST, None) if retry_num == 0 else (self.RETHROW, None)
def on_request_error(self, query, consistency, error, retry_num):
"""
This is called when an unexpected error happens. This can be in the
following situations:
* On a connection error
* On server errors: overloaded, isBootstrapping, serverError, etc.
`query` is the :class:`.Statement` that timed out.
`consistency` is the :class:`.ConsistencyLevel` that the operation was
attempted at.
`error` the instance of the exception.
`retry_num` counts how many times the operation has been retried, so
the first time this method is called, `retry_num` will be 0.
The default, it triggers a retry on the next host in the query plan
with the same consistency level.
"""
# TODO revisit this for the next major
# To preserve the same behavior than before, we don't take retry_num into account
return self.RETRY_NEXT_HOST, None
class FallthroughRetryPolicy(RetryPolicy):
"""
A retry policy that never retries and always propagates failures to
the application.
"""
def on_read_timeout(self, *args, **kwargs):
return self.RETHROW, None
def on_write_timeout(self, *args, **kwargs):
return self.RETHROW, None
def on_unavailable(self, *args, **kwargs):
return self.RETHROW, None
def on_request_error(self, *args, **kwargs):
return self.RETHROW, None
class DowngradingConsistencyRetryPolicy(RetryPolicy):
"""
*Deprecated:* This retry policy will be removed in the next major release.
A retry policy that sometimes retries with a lower consistency level than
the one initially requested.
**BEWARE**: This policy may retry queries using a lower consistency
level than the one initially requested. By doing so, it may break
consistency guarantees. In other words, if you use this retry policy,
there are cases (documented below) where a read at :attr:`~.QUORUM`
*may not* see a preceding write at :attr:`~.QUORUM`. Do not use this
policy unless you have understood the cases where this can happen and
are ok with that. It is also recommended to subclass this class so
that queries that required a consistency level downgrade can be
recorded (so that repairs can be made later, etc).
This policy implements the same retries as :class:`.RetryPolicy`,
but on top of that, it also retries in the following cases:
* On a read timeout: if the number of replicas that responded is
greater than one but lower than is required by the requested
consistency level, the operation is retried at a lower consistency
level.
* On a write timeout: if the operation is an :attr:`~.UNLOGGED_BATCH`
and at least one replica acknowledged the write, the operation is
retried at a lower consistency level. Furthermore, for other
write types, if at least one replica acknowledged the write, the
timeout is ignored.
* On an unavailable exception: if at least one replica is alive, the
operation is retried at a lower consistency level.
The reasoning behind this retry policy is as follows: if, based
on the information the Cassandra coordinator node returns, retrying the
operation with the initially requested consistency has a chance to
succeed, do it. Otherwise, if based on that information we know the
initially requested consistency level cannot be achieved currently, then:
* For writes, ignore the exception (thus silently failing the
consistency requirement) if we know the write has been persisted on at
least one replica.
* For reads, try reading at a lower consistency level (thus silently
failing the consistency requirement).
In other words, this policy implements the idea that if the requested
consistency level cannot be achieved, the next best thing for writes is
to make sure the data is persisted, and that reading something is better
than reading nothing, even if there is a risk of reading stale data.
"""
def __init__(self, *args, **kwargs):
super(DowngradingConsistencyRetryPolicy, self).__init__(*args, **kwargs)
warnings.warn('DowngradingConsistencyRetryPolicy is deprecated '
'and will be removed in the next major release.',
DeprecationWarning)
def _pick_consistency(self, num_responses):
if num_responses >= 3:
return self.RETRY, ConsistencyLevel.THREE
elif num_responses >= 2:
return self.RETRY, ConsistencyLevel.TWO
elif num_responses >= 1:
return self.RETRY, ConsistencyLevel.ONE
else:
return self.RETHROW, None
def on_read_timeout(self, query, consistency, required_responses,
received_responses, data_retrieved, retry_num):
if retry_num != 0:
return self.RETHROW, None
elif ConsistencyLevel.is_serial(consistency):
# Downgrading does not make sense for a CAS read query
return self.RETHROW, None
elif received_responses < required_responses:
return self._pick_consistency(received_responses)
elif not data_retrieved:
return self.RETRY, consistency
else:
return self.RETHROW, None
def on_write_timeout(self, query, consistency, write_type,
required_responses, received_responses, retry_num):
if retry_num != 0:
return self.RETHROW, None
if write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER):
if received_responses > 0:
# persisted on at least one replica
return self.IGNORE, None
else:
return self.RETHROW, None
elif write_type == WriteType.UNLOGGED_BATCH:
return self._pick_consistency(received_responses)
elif write_type == WriteType.BATCH_LOG:
return self.RETRY, consistency
return self.RETHROW, None
def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num):
if retry_num != 0:
return self.RETHROW, None
elif ConsistencyLevel.is_serial(consistency):
# failed at the paxos phase of a LWT, retry on the next host
return self.RETRY_NEXT_HOST, None
else:
return self._pick_consistency(alive_replicas)
class AddressTranslator(object):
"""
Interface for translating cluster-defined endpoints.
The driver discovers nodes using server metadata and topology change events. Normally,
the endpoint defined by the server is the right way to connect to a node. In some environments,
these addresses may not be reachable, or not preferred (public vs. private IPs in cloud environments,
suboptimal routing, etc). This interface allows for translating from server defined endpoints to
preferred addresses for driver connections.
*Note:* :attr:`~Cluster.contact_points` provided while creating the :class:`~.Cluster` instance are not
translated using this mechanism -- only addresses received from Cassandra nodes are.
"""
def translate(self, addr):
"""
Accepts the node ip address, and returns a translated address to be used connecting to this node.
"""
raise NotImplementedError()
class IdentityTranslator(AddressTranslator):
"""
Returns the endpoint with no translation
"""
def translate(self, addr):
return addr
class EC2MultiRegionTranslator(AddressTranslator):
"""
Resolves private ips of the hosts in the same datacenter as the client, and public ips of hosts in other datacenters.
"""
def translate(self, addr):
"""
Reverse DNS the public broadcast_address, then lookup that hostname to get the AWS-resolved IP, which
will point to the private IP address within the same datacenter.
"""
# get family of this address so we translate to the same
family = socket.getaddrinfo(addr, 0, socket.AF_UNSPEC, socket.SOCK_STREAM)[0][0]
host = socket.getfqdn(addr)
for a in socket.getaddrinfo(host, 0, family, socket.SOCK_STREAM):
try:
return a[4][0]
except Exception:
pass
return addr
class SpeculativeExecutionPolicy(object):
"""
Interface for specifying speculative execution plans
"""
def new_plan(self, keyspace, statement):
"""
Returns
:param keyspace:
:param statement:
:return:
"""
raise NotImplementedError()
class SpeculativeExecutionPlan(object):
def next_execution(self, host):
raise NotImplementedError()
class NoSpeculativeExecutionPlan(SpeculativeExecutionPlan):
def next_execution(self, host):
return -1
class NoSpeculativeExecutionPolicy(SpeculativeExecutionPolicy):
def new_plan(self, keyspace, statement):
return NoSpeculativeExecutionPlan()
class ConstantSpeculativeExecutionPolicy(SpeculativeExecutionPolicy):
"""
A speculative execution policy that sends a new query every X seconds (**delay**) for a maximum of Y attempts (**max_attempts**).
"""
def __init__(self, delay, max_attempts):
self.delay = delay
self.max_attempts = max_attempts
class ConstantSpeculativeExecutionPlan(SpeculativeExecutionPlan):
def __init__(self, delay, max_attempts):
self.delay = delay
self.remaining = max_attempts
def next_execution(self, host):
if self.remaining > 0:
self.remaining -= 1
return self.delay
else:
return -1
def new_plan(self, keyspace, statement):
return self.ConstantSpeculativeExecutionPlan(self.delay, self.max_attempts)
+
+
+class WrapperPolicy(LoadBalancingPolicy):
+
+ def __init__(self, child_policy):
+ self._child_policy = child_policy
+
+ def distance(self, *args, **kwargs):
+ return self._child_policy.distance(*args, **kwargs)
+
+ def populate(self, cluster, hosts):
+ self._child_policy.populate(cluster, hosts)
+
+ def on_up(self, *args, **kwargs):
+ return self._child_policy.on_up(*args, **kwargs)
+
+ def on_down(self, *args, **kwargs):
+ return self._child_policy.on_down(*args, **kwargs)
+
+ def on_add(self, *args, **kwargs):
+ return self._child_policy.on_add(*args, **kwargs)
+
+ def on_remove(self, *args, **kwargs):
+ return self._child_policy.on_remove(*args, **kwargs)
+
+
+class DefaultLoadBalancingPolicy(WrapperPolicy):
+ """
+ A :class:`.LoadBalancingPolicy` wrapper that adds the ability to target a specific host first.
+
+ If no host is set on the query, the child policy's query plan will be used as is.
+ """
+
+ _cluster_metadata = None
+
+ def populate(self, cluster, hosts):
+ self._cluster_metadata = cluster.metadata
+ self._child_policy.populate(cluster, hosts)
+
+ def make_query_plan(self, working_keyspace=None, query=None):
+ if query and query.keyspace:
+ keyspace = query.keyspace
+ else:
+ keyspace = working_keyspace
+
+ # TODO remove next major since execute(..., host=XXX) is now available
+ addr = getattr(query, 'target_host', None) if query else None
+ target_host = self._cluster_metadata.get_host(addr)
+
+ child = self._child_policy
+ if target_host and target_host.is_up:
+ yield target_host
+ for h in child.make_query_plan(keyspace, query):
+ if h != target_host:
+ yield h
+ else:
+ for h in child.make_query_plan(keyspace, query):
+ yield h
+
+
+# TODO for backward compatibility, remove in next major
+class DSELoadBalancingPolicy(DefaultLoadBalancingPolicy):
+ """
+ *Deprecated:* This will be removed in the next major release,
+ consider using :class:`.DefaultLoadBalancingPolicy`.
+ """
+ def __init__(self, *args, **kwargs):
+ super(DSELoadBalancingPolicy, self).__init__(*args, **kwargs)
+ warnings.warn("DSELoadBalancingPolicy will be removed in 4.0. Consider using "
+ "DefaultLoadBalancingPolicy.", DeprecationWarning)
+
+
+class NeverRetryPolicy(RetryPolicy):
+ def _rethrow(self, *args, **kwargs):
+ return self.RETHROW, None
+
+ on_read_timeout = _rethrow
+ on_write_timeout = _rethrow
+ on_unavailable = _rethrow
diff --git a/cassandra/pool.py b/cassandra/pool.py
index cd814ef..cd27656 100644
--- a/cassandra/pool.py
+++ b/cassandra/pool.py
@@ -1,818 +1,866 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Connection pooling and host management.
"""
from functools import total_ordering
import logging
import socket
import time
from threading import Lock, RLock, Condition
import weakref
try:
from weakref import WeakSet
except ImportError:
from cassandra.util import WeakSet # NOQA
from cassandra import AuthenticationFailed
from cassandra.connection import ConnectionException, EndPoint, DefaultEndPoint
from cassandra.policies import HostDistance
log = logging.getLogger(__name__)
class NoConnectionsAvailable(Exception):
"""
All existing connections to a given host are busy, or there are
no open connections.
"""
pass
@total_ordering
class Host(object):
"""
Represents a single Cassandra node.
"""
endpoint = None
"""
The :class:`~.connection.EndPoint` to connect to the node.
"""
broadcast_address = None
"""
- broadcast address configured for the node, *if available* ('peer' in system.peers table).
- This is not present in the ``system.local`` table for older versions of Cassandra. It is also not queried if
- :attr:`~.Cluster.token_metadata_enabled` is ``False``.
+ broadcast address configured for the node, *if available*:
+
+ 'system.local.broadcast_address' or 'system.peers.peer' (Cassandra 2-3)
+ 'system.local.broadcast_address' or 'system.peers_v2.peer' (Cassandra 4)
+
+ This is not present in the ``system.local`` table for older versions of Cassandra. It
+ is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
+ """
+
+ broadcast_port = None
+ """
+ broadcast port configured for the node, *if available*:
+
+ 'system.local.broadcast_port' or 'system.peers_v2.peer_port' (Cassandra 4)
+
+ It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
"""
broadcast_rpc_address = None
"""
- The broadcast rpc address of the node (`native_address` or `rpc_address`).
+ The broadcast rpc address of the node:
+
+ 'system.local.rpc_address' or 'system.peers.rpc_address' (Cassandra 3)
+ 'system.local.rpc_address' or 'system.peers.native_transport_address (DSE 6+)'
+ 'system.local.rpc_address' or 'system.peers_v2.native_address (Cassandra 4)'
+ """
+
+ broadcast_rpc_port = None
+ """
+ The broadcast rpc port of the node, *if available*:
+
+ 'system.local.rpc_port' or 'system.peers.native_transport_port' (DSE 6+)
+ 'system.local.rpc_port' or 'system.peers_v2.native_port' (Cassandra 4)
"""
listen_address = None
"""
- listen address configured for the node, *if available*. This is only available in the ``system.local`` table for newer
- versions of Cassandra. It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
- Usually the same as ``broadcast_address`` unless configured differently in cassandra.yaml.
+ listen address configured for the node, *if available*:
+
+ 'system.local.listen_address'
+
+ This is only available in the ``system.local`` table for newer versions of Cassandra. It is also not
+ queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. Usually the same as ``broadcast_address``
+ unless configured differently in cassandra.yaml.
+ """
+
+ listen_port = None
+ """
+ listen port configured for the node, *if available*:
+
+ 'system.local.listen_port'
+
+ This is only available in the ``system.local`` table for newer versions of Cassandra. It is also not
+ queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
"""
conviction_policy = None
"""
A :class:`~.ConvictionPolicy` instance for determining when this node should
be marked up or down.
"""
is_up = None
"""
:const:`True` if the node is considered up, :const:`False` if it is
considered down, and :const:`None` if it is not known if the node is
up or down.
"""
release_version = None
"""
release_version as queried from the control connection system tables
"""
host_id = None
"""
The unique identifier of the cassandra node
"""
dse_version = None
"""
dse_version as queried from the control connection system tables. Only populated when connecting to
DSE with this property available. Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
"""
dse_workload = None
"""
DSE workload queried from the control connection system tables. Only populated when connecting to
DSE with this property available. Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
+ This is a legacy attribute that does not portray multiple workloads in a uniform fashion.
+ See also :attr:`~.Host.dse_workloads`.
+ """
+
+ dse_workloads = None
+ """
+ DSE workloads set, queried from the control connection system tables. Only populated when connecting to
+ DSE with this property available (added in DSE 5.1).
+ Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
"""
_datacenter = None
_rack = None
_reconnection_handler = None
lock = None
_currently_handling_node_up = False
def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None):
if endpoint is None:
raise ValueError("endpoint may not be None")
if conviction_policy_factory is None:
raise ValueError("conviction_policy_factory may not be None")
self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint)
self.conviction_policy = conviction_policy_factory(self)
self.host_id = host_id
self.set_location_info(datacenter, rack)
self.lock = RLock()
@property
def address(self):
"""
The IP address of the endpoint. This is the RPC address the driver uses when connecting to the node.
"""
# backward compatibility
return self.endpoint.address
@property
def datacenter(self):
""" The datacenter the node is in. """
return self._datacenter
@property
def rack(self):
""" The rack the node is in. """
return self._rack
def set_location_info(self, datacenter, rack):
"""
Sets the datacenter and rack for this node. Intended for internal
use (by the control connection, which periodically checks the
ring topology) only.
"""
self._datacenter = datacenter
self._rack = rack
def set_up(self):
if not self.is_up:
log.debug("Host %s is now marked up", self.endpoint)
self.conviction_policy.reset()
self.is_up = True
def set_down(self):
self.is_up = False
def signal_connection_failure(self, connection_exc):
return self.conviction_policy.add_failure(connection_exc)
def is_currently_reconnecting(self):
return self._reconnection_handler is not None
def get_and_set_reconnection_handler(self, new_handler):
"""
Atomically replaces the reconnection handler for this
host. Intended for internal use only.
"""
with self.lock:
old = self._reconnection_handler
self._reconnection_handler = new_handler
return old
def __eq__(self, other):
if isinstance(other, Host):
return self.endpoint == other.endpoint
else: # TODO Backward compatibility, remove next major
return self.endpoint.address == other
def __hash__(self):
return hash(self.endpoint)
def __lt__(self, other):
return self.endpoint < other.endpoint
def __str__(self):
return str(self.endpoint)
def __repr__(self):
dc = (" %s" % (self._datacenter,)) if self._datacenter else ""
return "<%s: %s%s>" % (self.__class__.__name__, self.endpoint, dc)
class _ReconnectionHandler(object):
"""
Abstract class for attempting reconnections with a given
schedule and scheduler.
"""
_cancelled = False
def __init__(self, scheduler, schedule, callback, *callback_args, **callback_kwargs):
self.scheduler = scheduler
self.schedule = schedule
self.callback = callback
self.callback_args = callback_args
self.callback_kwargs = callback_kwargs
def start(self):
if self._cancelled:
log.debug("Reconnection handler was cancelled before starting")
return
first_delay = next(self.schedule)
self.scheduler.schedule(first_delay, self.run)
def run(self):
if self._cancelled:
return
conn = None
try:
conn = self.try_reconnect()
except Exception as exc:
try:
next_delay = next(self.schedule)
except StopIteration:
# the schedule has been exhausted
next_delay = None
# call on_exception for logging purposes even if next_delay is None
if self.on_exception(exc, next_delay):
if next_delay is None:
log.warning(
"Will not continue to retry reconnection attempts "
"due to an exhausted retry schedule")
else:
self.scheduler.schedule(next_delay, self.run)
else:
if not self._cancelled:
self.on_reconnection(conn)
self.callback(*(self.callback_args), **(self.callback_kwargs))
finally:
if conn:
conn.close()
def cancel(self):
self._cancelled = True
def try_reconnect(self):
"""
Subclasses must implement this method. It should attempt to
open a new Connection and return it; if a failure occurs, an
Exception should be raised.
"""
raise NotImplementedError()
def on_reconnection(self, connection):
"""
Called when a new Connection is successfully opened. Nothing is
done by default.
"""
pass
def on_exception(self, exc, next_delay):
"""
Called when an Exception is raised when trying to connect.
`exc` is the Exception that was raised and `next_delay` is the
number of seconds (as a float) that the handler will wait before
attempting to connect again.
Subclasses should return :const:`False` if no more attempts to
connection should be made, :const:`True` otherwise. The default
behavior is to always retry unless the error is an
:exc:`.AuthenticationFailed` instance.
"""
if isinstance(exc, AuthenticationFailed):
return False
else:
return True
class _HostReconnectionHandler(_ReconnectionHandler):
def __init__(self, host, connection_factory, is_host_addition, on_add, on_up, *args, **kwargs):
_ReconnectionHandler.__init__(self, *args, **kwargs)
self.is_host_addition = is_host_addition
self.on_add = on_add
self.on_up = on_up
self.host = host
self.connection_factory = connection_factory
def try_reconnect(self):
return self.connection_factory()
def on_reconnection(self, connection):
log.info("Successful reconnection to %s, marking node up if it isn't already", self.host)
if self.is_host_addition:
self.on_add(self.host)
else:
self.on_up(self.host)
def on_exception(self, exc, next_delay):
if isinstance(exc, AuthenticationFailed):
return False
else:
log.warning("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s",
self.host, next_delay, exc)
log.debug("Reconnection error details", exc_info=True)
return True
class HostConnection(object):
"""
When using v3 of the native protocol, this is used instead of a connection
pool per host (HostConnectionPool) due to the increased in-flight capacity
of individual connections.
"""
host = None
host_distance = None
is_shutdown = False
shutdown_on_error = False
_session = None
_connection = None
_lock = None
_keyspace = None
def __init__(self, host, host_distance, session):
self.host = host
self.host_distance = host_distance
self._session = weakref.proxy(session)
self._lock = Lock()
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
self._stream_available_condition = Condition(self._lock)
self._is_replacing = False
if host_distance == HostDistance.IGNORED:
log.debug("Not opening connection to ignored host %s", self.host)
return
elif host_distance == HostDistance.REMOTE and not session.cluster.connect_to_remote_hosts:
log.debug("Not opening connection to remote host %s", self.host)
return
log.debug("Initializing connection for host %s", self.host)
self._connection = session.cluster.connection_factory(host.endpoint)
self._keyspace = session.keyspace
if self._keyspace:
self._connection.set_keyspace_blocking(self._keyspace)
log.debug("Finished initializing connection for host %s", self.host)
def borrow_connection(self, timeout):
if self.is_shutdown:
raise ConnectionException(
"Pool for %s is shutdown" % (self.host,), self.host)
conn = self._connection
if not conn:
raise NoConnectionsAvailable()
start = time.time()
remaining = timeout
while True:
with conn.lock:
- if conn.in_flight <= conn.max_request_id:
+ if conn.in_flight < conn.max_request_id:
conn.in_flight += 1
return conn, conn.get_request_id()
if timeout is not None:
remaining = timeout - time.time() + start
if remaining < 0:
break
with self._stream_available_condition:
self._stream_available_condition.wait(remaining)
raise NoConnectionsAvailable("All request IDs are currently in use")
def return_connection(self, connection):
with connection.lock:
connection.in_flight -= 1
with self._stream_available_condition:
self._stream_available_condition.notify()
if connection.is_defunct or connection.is_closed:
if connection.signaled_error and not self.shutdown_on_error:
return
is_down = False
if not connection.signaled_error:
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
"marking host %s as down", id(connection), self.host)
is_down = self._session.cluster.signal_connection_failure(
self.host, connection.last_error, is_host_addition=False)
connection.signaled_error = True
if self.shutdown_on_error and not is_down:
is_down = True
self._session.cluster.on_down(self.host, is_host_addition=False)
if is_down:
self.shutdown()
else:
self._connection = None
with self._lock:
if self._is_replacing:
return
self._is_replacing = True
self._session.submit(self._replace, connection)
def _replace(self, connection):
with self._lock:
if self.is_shutdown:
return
log.debug("Replacing connection (%s) to %s", id(connection), self.host)
try:
- conn = self._session.cluster.connection_factory(self.host)
+ conn = self._session.cluster.connection_factory(self.host.endpoint)
if self._keyspace:
conn.set_keyspace_blocking(self._keyspace)
self._connection = conn
except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
self._session.submit(self._replace, connection)
else:
with self._lock:
self._is_replacing = False
self._stream_available_condition.notify()
def shutdown(self):
with self._lock:
if self.is_shutdown:
return
else:
self.is_shutdown = True
self._stream_available_condition.notify_all()
if self._connection:
self._connection.close()
self._connection = None
def _set_keyspace_for_all_conns(self, keyspace, callback):
if self.is_shutdown or not self._connection:
return
def connection_finished_setting_keyspace(conn, error):
self.return_connection(conn)
errors = [] if not error else [error]
callback(self, errors)
self._keyspace = keyspace
self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
def get_connections(self):
c = self._connection
return [c] if c else []
def get_state(self):
connection = self._connection
open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0
in_flights = [connection.in_flight] if connection else []
return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights}
@property
def open_count(self):
connection = self._connection
return 1 if connection and not (connection.is_closed or connection.is_defunct) else 0
_MAX_SIMULTANEOUS_CREATION = 1
_MIN_TRASH_INTERVAL = 10
class HostConnectionPool(object):
"""
Used to pool connections to a host for v1 and v2 native protocol.
"""
host = None
host_distance = None
is_shutdown = False
open_count = 0
_scheduled_for_creation = 0
_next_trash_allowed_at = 0
_keyspace = None
def __init__(self, host, host_distance, session):
self.host = host
self.host_distance = host_distance
self._session = weakref.proxy(session)
self._lock = RLock()
self._conn_available_condition = Condition()
log.debug("Initializing new connection pool for host %s", self.host)
core_conns = session.cluster.get_core_connections_per_host(host_distance)
self._connections = [session.cluster.connection_factory(host.endpoint)
for i in range(core_conns)]
self._keyspace = session.keyspace
if self._keyspace:
for conn in self._connections:
conn.set_keyspace_blocking(self._keyspace)
self._trash = set()
self._next_trash_allowed_at = time.time()
self.open_count = core_conns
log.debug("Finished initializing new connection pool for host %s", self.host)
def borrow_connection(self, timeout):
if self.is_shutdown:
raise ConnectionException(
"Pool for %s is shutdown" % (self.host,), self.host)
conns = self._connections
if not conns:
# handled specially just for simpler code
log.debug("Detected empty pool, opening core conns to %s", self.host)
core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
with self._lock:
# we check the length of self._connections again
# along with self._scheduled_for_creation while holding the lock
# in case multiple threads hit this condition at the same time
to_create = core_conns - (len(self._connections) + self._scheduled_for_creation)
for i in range(to_create):
self._scheduled_for_creation += 1
self._session.submit(self._create_new_connection)
# in_flight is incremented by wait_for_conn
conn = self._wait_for_conn(timeout)
return conn
else:
# note: it would be nice to push changes to these config settings
# to pools instead of doing a new lookup on every
# borrow_connection() call
max_reqs = self._session.cluster.get_max_requests_per_connection(self.host_distance)
max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance)
least_busy = min(conns, key=lambda c: c.in_flight)
request_id = None
# to avoid another thread closing this connection while
# trashing it (through the return_connection process), hold
# the connection lock from this point until we've incremented
# its in_flight count
need_to_wait = False
with least_busy.lock:
if least_busy.in_flight < least_busy.max_request_id:
least_busy.in_flight += 1
request_id = least_busy.get_request_id()
else:
# once we release the lock, wait for another connection
need_to_wait = True
if need_to_wait:
# wait_for_conn will increment in_flight on the conn
least_busy, request_id = self._wait_for_conn(timeout)
# if we have too many requests on this connection but we still
# have space to open a new connection against this host, go ahead
# and schedule the creation of a new connection
if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns:
self._maybe_spawn_new_connection()
return least_busy, request_id
def _maybe_spawn_new_connection(self):
with self._lock:
if self._scheduled_for_creation >= _MAX_SIMULTANEOUS_CREATION:
return
if self.open_count >= self._session.cluster.get_max_connections_per_host(self.host_distance):
return
self._scheduled_for_creation += 1
log.debug("Submitting task for creation of new Connection to %s", self.host)
self._session.submit(self._create_new_connection)
def _create_new_connection(self):
try:
self._add_conn_if_under_max()
except (ConnectionException, socket.error) as exc:
log.warning("Failed to create new connection to %s: %s", self.host, exc)
except Exception:
log.exception("Unexpectedly failed to create new connection")
finally:
with self._lock:
self._scheduled_for_creation -= 1
def _add_conn_if_under_max(self):
max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance)
with self._lock:
if self.is_shutdown:
return True
if self.open_count >= max_conns:
return True
self.open_count += 1
log.debug("Going to open new connection to host %s", self.host)
try:
conn = self._session.cluster.connection_factory(self.host.endpoint)
if self._keyspace:
conn.set_keyspace_blocking(self._session.keyspace)
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
with self._lock:
new_connections = self._connections[:] + [conn]
self._connections = new_connections
log.debug("Added new connection (%s) to pool for host %s, signaling availablility",
id(conn), self.host)
self._signal_available_conn()
return True
except (ConnectionException, socket.error) as exc:
log.warning("Failed to add new connection to pool for host %s: %s", self.host, exc)
with self._lock:
self.open_count -= 1
if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False):
self.shutdown()
return False
except AuthenticationFailed:
with self._lock:
self.open_count -= 1
return False
def _await_available_conn(self, timeout):
with self._conn_available_condition:
self._conn_available_condition.wait(timeout)
def _signal_available_conn(self):
with self._conn_available_condition:
self._conn_available_condition.notify()
def _signal_all_available_conn(self):
with self._conn_available_condition:
self._conn_available_condition.notify_all()
def _wait_for_conn(self, timeout):
start = time.time()
remaining = timeout
while remaining > 0:
# wait on our condition for the possibility that a connection
# is useable
self._await_available_conn(remaining)
# self.shutdown() may trigger the above Condition
if self.is_shutdown:
raise ConnectionException("Pool is shutdown")
conns = self._connections
if conns:
least_busy = min(conns, key=lambda c: c.in_flight)
with least_busy.lock:
if least_busy.in_flight < least_busy.max_request_id:
least_busy.in_flight += 1
return least_busy, least_busy.get_request_id()
remaining = timeout - (time.time() - start)
raise NoConnectionsAvailable()
def return_connection(self, connection):
with connection.lock:
connection.in_flight -= 1
in_flight = connection.in_flight
if connection.is_defunct or connection.is_closed:
if not connection.signaled_error:
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
"marking host %s as down", id(connection), self.host)
is_down = self._session.cluster.signal_connection_failure(
self.host, connection.last_error, is_host_addition=False)
connection.signaled_error = True
if is_down:
self.shutdown()
else:
self._replace(connection)
else:
if connection in self._trash:
with connection.lock:
if connection.in_flight == 0:
with self._lock:
if connection in self._trash:
self._trash.remove(connection)
log.debug("Closing trashed connection (%s) to %s", id(connection), self.host)
connection.close()
return
core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
min_reqs = self._session.cluster.get_min_requests_per_connection(self.host_distance)
# we can use in_flight here without holding the connection lock
# because the fact that in_flight dipped below the min at some
# point is enough to start the trashing procedure
if len(self._connections) > core_conns and in_flight <= min_reqs and \
time.time() >= self._next_trash_allowed_at:
self._maybe_trash_connection(connection)
else:
self._signal_available_conn()
def _maybe_trash_connection(self, connection):
core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
did_trash = False
with self._lock:
if connection not in self._connections:
return
if self.open_count > core_conns:
did_trash = True
self.open_count -= 1
new_connections = self._connections[:]
new_connections.remove(connection)
self._connections = new_connections
with connection.lock:
if connection.in_flight == 0:
log.debug("Skipping trash and closing unused connection (%s) to %s", id(connection), self.host)
connection.close()
# skip adding it to the trash if we're already closing it
return
self._trash.add(connection)
if did_trash:
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
log.debug("Trashed connection (%s) to %s", id(connection), self.host)
def _replace(self, connection):
should_replace = False
with self._lock:
if connection in self._connections:
new_connections = self._connections[:]
new_connections.remove(connection)
self._connections = new_connections
self.open_count -= 1
should_replace = True
if should_replace:
log.debug("Replacing connection (%s) to %s", id(connection), self.host)
connection.close()
self._session.submit(self._retrying_replace)
else:
log.debug("Closing connection (%s) to %s", id(connection), self.host)
connection.close()
def _retrying_replace(self):
replaced = False
try:
replaced = self._add_conn_if_under_max()
except Exception:
log.exception("Failed replacing connection to %s", self.host)
if not replaced:
log.debug("Failed replacing connection to %s. Retrying.", self.host)
self._session.submit(self._retrying_replace)
def shutdown(self):
with self._lock:
if self.is_shutdown:
return
else:
self.is_shutdown = True
self._signal_all_available_conn()
for conn in self._connections:
conn.close()
self.open_count -= 1
for conn in self._trash:
conn.close()
def ensure_core_connections(self):
if self.is_shutdown:
return
core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
with self._lock:
to_create = core_conns - (len(self._connections) + self._scheduled_for_creation)
for i in range(to_create):
self._scheduled_for_creation += 1
self._session.submit(self._create_new_connection)
def _set_keyspace_for_all_conns(self, keyspace, callback):
"""
Asynchronously sets the keyspace for all connections. When all
connections have been set, `callback` will be called with two
arguments: this pool, and a list of any errors that occurred.
"""
remaining_callbacks = set(self._connections)
errors = []
if not remaining_callbacks:
callback(self, errors)
return
def connection_finished_setting_keyspace(conn, error):
self.return_connection(conn)
remaining_callbacks.remove(conn)
if error:
errors.append(error)
if not remaining_callbacks:
callback(self, errors)
self._keyspace = keyspace
for conn in self._connections:
conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
def get_connections(self):
return self._connections
def get_state(self):
in_flights = [c.in_flight for c in self._connections]
return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights}
diff --git a/cassandra/protocol.py b/cassandra/protocol.py
index 7e11779..ed92a76 100644
--- a/cassandra/protocol.py
+++ b/cassandra/protocol.py
@@ -1,1409 +1,1484 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import # to enable import io from stdlib
from collections import namedtuple
import logging
import socket
from uuid import UUID
import six
from six.moves import range
import io
from cassandra import ProtocolVersion
from cassandra import type_codes, DriverException
from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
WriteFailure, ReadFailure, FunctionFailure,
AlreadyExists, InvalidRequest, Unauthorized,
UnsupportedOperation, UserFunctionDescriptor,
UserAggregateDescriptor, SchemaTargetType)
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
- int8_pack, int8_unpack, uint64_pack, header_pack,
- v3_header_pack, uint32_pack)
+ uint8_pack, int8_unpack, uint64_pack, header_pack,
+ v3_header_pack, uint32_pack, uint32_le_unpack, uint32_le_pack)
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
CounterColumnType, DateType, DecimalType,
DoubleType, FloatType, Int32Type,
InetAddressType, IntegerType, ListType,
LongType, MapType, SetType, TimeUUIDType,
UTF8Type, VarcharType, UUIDType, UserType,
TupleType, lookup_casstype, SimpleDateType,
TimeType, ByteType, ShortType, DurationType)
from cassandra import WriteType
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
from cassandra import util
log = logging.getLogger(__name__)
class NotSupportedError(Exception):
pass
class InternalError(Exception):
pass
ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type'])
HEADER_DIRECTION_TO_CLIENT = 0x80
HEADER_DIRECTION_MASK = 0x80
COMPRESSED_FLAG = 0x01
TRACING_FLAG = 0x02
CUSTOM_PAYLOAD_FLAG = 0x04
WARNING_FLAG = 0x08
USE_BETA_FLAG = 0x10
USE_BETA_MASK = ~USE_BETA_FLAG
_message_types_by_opcode = {}
_UNSET_VALUE = object()
def register_class(cls):
_message_types_by_opcode[cls.opcode] = cls
def get_registered_classes():
return _message_types_by_opcode.copy()
class _RegisterMessageType(type):
def __init__(cls, name, bases, dct):
if not name.startswith('_'):
register_class(cls)
@six.add_metaclass(_RegisterMessageType)
class _MessageType(object):
tracing = False
custom_payload = None
warnings = None
def update_custom_payload(self, other):
if other:
if not self.custom_payload:
self.custom_payload = {}
self.custom_payload.update(other)
if len(self.custom_payload) > 65535:
raise ValueError("Custom payload map exceeds max count allowed by protocol (65535)")
def __repr__(self):
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
def _get_params(message_obj):
base_attrs = dir(_MessageType)
return (
(n, a) for n, a in message_obj.__dict__.items()
if n not in base_attrs and not n.startswith('_') and not callable(a)
)
error_classes = {}
class ErrorMessage(_MessageType, Exception):
opcode = 0x00
name = 'ERROR'
summary = 'Unknown'
def __init__(self, code, message, info):
self.code = code
self.message = message
self.info = info
@classmethod
def recv_body(cls, f, protocol_version, *args):
code = read_int(f)
msg = read_string(f)
subcls = error_classes.get(code, cls)
extra_info = subcls.recv_error_info(f, protocol_version)
return subcls(code=code, message=msg, info=extra_info)
def summary_msg(self):
msg = 'Error from server: code=%04x [%s] message="%s"' \
% (self.code, self.summary, self.message)
if six.PY2 and isinstance(msg, six.text_type):
msg = msg.encode('utf-8')
return msg
def __str__(self):
return '<%s>' % self.summary_msg()
__repr__ = __str__
@staticmethod
def recv_error_info(f, protocol_version):
pass
def to_exception(self):
return self
class ErrorMessageSubclass(_RegisterMessageType):
def __init__(cls, name, bases, dct):
if cls.error_code is not None: # Server has an error code of 0.
error_classes[cls.error_code] = cls
@six.add_metaclass(ErrorMessageSubclass)
class ErrorMessageSub(ErrorMessage):
error_code = None
class RequestExecutionException(ErrorMessageSub):
pass
class RequestValidationException(ErrorMessageSub):
pass
class ServerError(ErrorMessageSub):
summary = 'Server error'
error_code = 0x0000
class ProtocolException(ErrorMessageSub):
summary = 'Protocol error'
error_code = 0x000A
+ @property
+ def is_beta_protocol_error(self):
+ return 'USE_BETA flag is unset' in str(self)
+
class BadCredentials(ErrorMessageSub):
summary = 'Bad credentials'
error_code = 0x0100
class UnavailableErrorMessage(RequestExecutionException):
summary = 'Unavailable exception'
error_code = 0x1000
@staticmethod
def recv_error_info(f, protocol_version):
return {
'consistency': read_consistency_level(f),
'required_replicas': read_int(f),
'alive_replicas': read_int(f),
}
def to_exception(self):
return Unavailable(self.summary_msg(), **self.info)
class OverloadedErrorMessage(RequestExecutionException):
summary = 'Coordinator node overloaded'
error_code = 0x1001
class IsBootstrappingErrorMessage(RequestExecutionException):
summary = 'Coordinator node is bootstrapping'
error_code = 0x1002
class TruncateError(RequestExecutionException):
summary = 'Error during truncate'
error_code = 0x1003
class WriteTimeoutErrorMessage(RequestExecutionException):
summary = "Coordinator node timed out waiting for replica nodes' responses"
error_code = 0x1100
@staticmethod
def recv_error_info(f, protocol_version):
return {
'consistency': read_consistency_level(f),
'received_responses': read_int(f),
'required_responses': read_int(f),
'write_type': WriteType.name_to_value[read_string(f)],
}
def to_exception(self):
return WriteTimeout(self.summary_msg(), **self.info)
class ReadTimeoutErrorMessage(RequestExecutionException):
summary = "Coordinator node timed out waiting for replica nodes' responses"
error_code = 0x1200
@staticmethod
def recv_error_info(f, protocol_version):
return {
'consistency': read_consistency_level(f),
'received_responses': read_int(f),
'required_responses': read_int(f),
'data_retrieved': bool(read_byte(f)),
}
def to_exception(self):
return ReadTimeout(self.summary_msg(), **self.info)
class ReadFailureMessage(RequestExecutionException):
summary = "Replica(s) failed to execute read"
error_code = 0x1300
@staticmethod
def recv_error_info(f, protocol_version):
consistency = read_consistency_level(f)
received_responses = read_int(f)
required_responses = read_int(f)
if ProtocolVersion.uses_error_code_map(protocol_version):
error_code_map = read_error_code_map(f)
failures = len(error_code_map)
else:
error_code_map = None
failures = read_int(f)
data_retrieved = bool(read_byte(f))
return {
'consistency': consistency,
'received_responses': received_responses,
'required_responses': required_responses,
'failures': failures,
'error_code_map': error_code_map,
'data_retrieved': data_retrieved
}
def to_exception(self):
return ReadFailure(self.summary_msg(), **self.info)
class FunctionFailureMessage(RequestExecutionException):
summary = "User Defined Function failure"
error_code = 0x1400
@staticmethod
def recv_error_info(f, protocol_version):
return {
'keyspace': read_string(f),
'function': read_string(f),
'arg_types': [read_string(f) for _ in range(read_short(f))],
}
def to_exception(self):
return FunctionFailure(self.summary_msg(), **self.info)
class WriteFailureMessage(RequestExecutionException):
summary = "Replica(s) failed to execute write"
error_code = 0x1500
@staticmethod
def recv_error_info(f, protocol_version):
consistency = read_consistency_level(f)
received_responses = read_int(f)
required_responses = read_int(f)
if ProtocolVersion.uses_error_code_map(protocol_version):
error_code_map = read_error_code_map(f)
failures = len(error_code_map)
else:
error_code_map = None
failures = read_int(f)
write_type = WriteType.name_to_value[read_string(f)]
return {
'consistency': consistency,
'received_responses': received_responses,
'required_responses': required_responses,
'failures': failures,
'error_code_map': error_code_map,
'write_type': write_type
}
def to_exception(self):
return WriteFailure(self.summary_msg(), **self.info)
class CDCWriteException(RequestExecutionException):
summary = 'Failed to execute write due to CDC space exhaustion.'
error_code = 0x1600
class SyntaxException(RequestValidationException):
summary = 'Syntax error in CQL query'
error_code = 0x2000
class UnauthorizedErrorMessage(RequestValidationException):
summary = 'Unauthorized'
error_code = 0x2100
def to_exception(self):
return Unauthorized(self.summary_msg())
class InvalidRequestException(RequestValidationException):
summary = 'Invalid query'
error_code = 0x2200
def to_exception(self):
return InvalidRequest(self.summary_msg())
class ConfigurationException(RequestValidationException):
summary = 'Query invalid because of configuration issue'
error_code = 0x2300
class PreparedQueryNotFound(RequestValidationException):
summary = 'Matching prepared statement not found on this node'
error_code = 0x2500
@staticmethod
def recv_error_info(f, protocol_version):
# return the query ID
return read_binary_string(f)
class AlreadyExistsException(ConfigurationException):
summary = 'Item already exists'
error_code = 0x2400
@staticmethod
def recv_error_info(f, protocol_version):
return {
'keyspace': read_string(f),
'table': read_string(f),
}
def to_exception(self):
return AlreadyExists(**self.info)
+class ClientWriteError(RequestExecutionException):
+ summary = 'Client write failure.'
+ error_code = 0x8000
+
+
class StartupMessage(_MessageType):
opcode = 0x01
name = 'STARTUP'
KNOWN_OPTION_KEYS = set((
'CQL_VERSION',
'COMPRESSION',
'NO_COMPACT'
))
def __init__(self, cqlversion, options):
self.cqlversion = cqlversion
self.options = options
def send_body(self, f, protocol_version):
optmap = self.options.copy()
optmap['CQL_VERSION'] = self.cqlversion
write_stringmap(f, optmap)
class ReadyMessage(_MessageType):
opcode = 0x02
name = 'READY'
@classmethod
def recv_body(cls, *args):
return cls()
class AuthenticateMessage(_MessageType):
opcode = 0x03
name = 'AUTHENTICATE'
def __init__(self, authenticator):
self.authenticator = authenticator
@classmethod
def recv_body(cls, f, *args):
authname = read_string(f)
return cls(authenticator=authname)
class CredentialsMessage(_MessageType):
opcode = 0x04
name = 'CREDENTIALS'
def __init__(self, creds):
self.creds = creds
def send_body(self, f, protocol_version):
if protocol_version > 1:
raise UnsupportedOperation(
"Credentials-based authentication is not supported with "
"protocol version 2 or higher. Use the SASL authentication "
"mechanism instead.")
write_short(f, len(self.creds))
for credkey, credval in self.creds.items():
write_string(f, credkey)
write_string(f, credval)
class AuthChallengeMessage(_MessageType):
opcode = 0x0E
name = 'AUTH_CHALLENGE'
def __init__(self, challenge):
self.challenge = challenge
@classmethod
def recv_body(cls, f, *args):
return cls(read_binary_longstring(f))
class AuthResponseMessage(_MessageType):
opcode = 0x0F
name = 'AUTH_RESPONSE'
def __init__(self, response):
self.response = response
def send_body(self, f, protocol_version):
write_longstring(f, self.response)
class AuthSuccessMessage(_MessageType):
opcode = 0x10
name = 'AUTH_SUCCESS'
def __init__(self, token):
self.token = token
@classmethod
def recv_body(cls, f, *args):
return cls(read_longstring(f))
class OptionsMessage(_MessageType):
opcode = 0x05
name = 'OPTIONS'
def send_body(self, f, protocol_version):
pass
class SupportedMessage(_MessageType):
opcode = 0x06
name = 'SUPPORTED'
def __init__(self, cql_versions, options):
self.cql_versions = cql_versions
self.options = options
@classmethod
def recv_body(cls, f, *args):
options = read_stringmultimap(f)
cql_versions = options.pop('CQL_VERSION')
return cls(cql_versions=cql_versions, options=options)
# used for QueryMessage and ExecuteMessage
_VALUES_FLAG = 0x01
_SKIP_METADATA_FLAG = 0x02
_PAGE_SIZE_FLAG = 0x04
_WITH_PAGING_STATE_FLAG = 0x08
_WITH_SERIAL_CONSISTENCY_FLAG = 0x10
-_PROTOCOL_TIMESTAMP = 0x20
+_PROTOCOL_TIMESTAMP_FLAG = 0x20
+_NAMES_FOR_VALUES_FLAG = 0x40 # not used here
_WITH_KEYSPACE_FLAG = 0x80
_PREPARED_WITH_KEYSPACE_FLAG = 0x01
+_PAGE_SIZE_BYTES_FLAG = 0x40000000
+_PAGING_OPTIONS_FLAG = 0x80000000
-class QueryMessage(_MessageType):
- opcode = 0x07
- name = 'QUERY'
+class _QueryMessage(_MessageType):
- def __init__(self, query, consistency_level, serial_consistency_level=None,
- fetch_size=None, paging_state=None, timestamp=None, keyspace=None):
- self.query = query
+ def __init__(self, query_params, consistency_level,
+ serial_consistency_level=None, fetch_size=None,
+ paging_state=None, timestamp=None, skip_meta=False,
+ continuous_paging_options=None, keyspace=None):
+ self.query_params = query_params
self.consistency_level = consistency_level
self.serial_consistency_level = serial_consistency_level
self.fetch_size = fetch_size
self.paging_state = paging_state
self.timestamp = timestamp
+ self.skip_meta = skip_meta
+ self.continuous_paging_options = continuous_paging_options
self.keyspace = keyspace
- self._query_params = None # only used internally. May be set to a list of native-encoded values to have them sent with the request.
- def send_body(self, f, protocol_version):
- write_longstring(f, self.query)
+ def _write_query_params(self, f, protocol_version):
write_consistency_level(f, self.consistency_level)
flags = 0x00
- if self._query_params is not None:
+ if self.query_params is not None:
flags |= _VALUES_FLAG # also v2+, but we're only setting params internally right now
if self.serial_consistency_level:
if protocol_version >= 2:
flags |= _WITH_SERIAL_CONSISTENCY_FLAG
else:
raise UnsupportedOperation(
"Serial consistency levels require the use of protocol version "
"2 or higher. Consider setting Cluster.protocol_version to 2 "
"to support serial consistency levels.")
if self.fetch_size:
if protocol_version >= 2:
flags |= _PAGE_SIZE_FLAG
else:
raise UnsupportedOperation(
"Automatic query paging may only be used with protocol version "
"2 or higher. Consider setting Cluster.protocol_version to 2.")
if self.paging_state:
if protocol_version >= 2:
flags |= _WITH_PAGING_STATE_FLAG
else:
raise UnsupportedOperation(
"Automatic query paging may only be used with protocol version "
"2 or higher. Consider setting Cluster.protocol_version to 2.")
if self.timestamp is not None:
- flags |= _PROTOCOL_TIMESTAMP
+ flags |= _PROTOCOL_TIMESTAMP_FLAG
+
+ if self.continuous_paging_options:
+ if ProtocolVersion.has_continuous_paging_support(protocol_version):
+ flags |= _PAGING_OPTIONS_FLAG
+ else:
+ raise UnsupportedOperation(
+ "Continuous paging may only be used with protocol version "
+ "ProtocolVersion.DSE_V1 or higher. Consider setting Cluster.protocol_version to ProtocolVersion.DSE_V1.")
if self.keyspace is not None:
if ProtocolVersion.uses_keyspace_flag(protocol_version):
flags |= _WITH_KEYSPACE_FLAG
else:
raise UnsupportedOperation(
"Keyspaces may only be set on queries with protocol version "
- "5 or higher. Consider setting Cluster.protocol_version to 5.")
+ "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.")
if ProtocolVersion.uses_int_query_flags(protocol_version):
write_uint(f, flags)
else:
write_byte(f, flags)
- if self._query_params is not None:
- write_short(f, len(self._query_params))
- for param in self._query_params:
+ if self.query_params is not None:
+ write_short(f, len(self.query_params))
+ for param in self.query_params:
write_value(f, param)
-
if self.fetch_size:
write_int(f, self.fetch_size)
if self.paging_state:
write_longstring(f, self.paging_state)
if self.serial_consistency_level:
write_consistency_level(f, self.serial_consistency_level)
if self.timestamp is not None:
write_long(f, self.timestamp)
if self.keyspace is not None:
write_string(f, self.keyspace)
+ if self.continuous_paging_options:
+ self._write_paging_options(f, self.continuous_paging_options, protocol_version)
+
+ def _write_paging_options(self, f, paging_options, protocol_version):
+ write_int(f, paging_options.max_pages)
+ write_int(f, paging_options.max_pages_per_second)
+ if ProtocolVersion.has_continuous_paging_next_pages(protocol_version):
+ write_int(f, paging_options.max_queue_size)
+
+
+class QueryMessage(_QueryMessage):
+ opcode = 0x07
+ name = 'QUERY'
+
+ def __init__(self, query, consistency_level, serial_consistency_level=None,
+ fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None):
+ self.query = query
+ super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size,
+ paging_state, timestamp, False, continuous_paging_options, keyspace)
+
+ def send_body(self, f, protocol_version):
+ write_longstring(f, self.query)
+ self._write_query_params(f, protocol_version)
+
+
+class ExecuteMessage(_QueryMessage):
+ opcode = 0x0A
+ name = 'EXECUTE'
+
+ def __init__(self, query_id, query_params, consistency_level,
+ serial_consistency_level=None, fetch_size=None,
+ paging_state=None, timestamp=None, skip_meta=False,
+ continuous_paging_options=None, result_metadata_id=None):
+ self.query_id = query_id
+ self.result_metadata_id = result_metadata_id
+ super(ExecuteMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size,
+ paging_state, timestamp, skip_meta, continuous_paging_options)
+
+ def _write_query_params(self, f, protocol_version):
+ if protocol_version == 1:
+ if self.serial_consistency_level:
+ raise UnsupportedOperation(
+ "Serial consistency levels require the use of protocol version "
+ "2 or higher. Consider setting Cluster.protocol_version to 2 "
+ "to support serial consistency levels.")
+ if self.fetch_size or self.paging_state:
+ raise UnsupportedOperation(
+ "Automatic query paging may only be used with protocol version "
+ "2 or higher. Consider setting Cluster.protocol_version to 2.")
+ write_short(f, len(self.query_params))
+ for param in self.query_params:
+ write_value(f, param)
+ write_consistency_level(f, self.consistency_level)
+ else:
+ super(ExecuteMessage, self)._write_query_params(f, protocol_version)
+
+ def send_body(self, f, protocol_version):
+ write_string(f, self.query_id)
+ if ProtocolVersion.uses_prepared_metadata(protocol_version):
+ write_string(f, self.result_metadata_id)
+ self._write_query_params(f, protocol_version)
CUSTOM_TYPE = object()
RESULT_KIND_VOID = 0x0001
RESULT_KIND_ROWS = 0x0002
RESULT_KIND_SET_KEYSPACE = 0x0003
RESULT_KIND_PREPARED = 0x0004
RESULT_KIND_SCHEMA_CHANGE = 0x0005
class ResultMessage(_MessageType):
opcode = 0x08
name = 'RESULT'
kind = None
results = None
paging_state = None
# Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE)
type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_'))
_FLAGS_GLOBAL_TABLES_SPEC = 0x0001
_HAS_MORE_PAGES_FLAG = 0x0002
_NO_METADATA_FLAG = 0x0004
+ _CONTINUOUS_PAGING_FLAG = 0x40000000
+ _CONTINUOUS_PAGING_LAST_FLAG = 0x80000000
_METADATA_ID_FLAG = 0x0008
- def __init__(self, kind, results, paging_state=None, col_types=None):
+ kind = None
+
+ # These are all the things a result message might contain. They are populated according to 'kind'
+ column_names = None
+ column_types = None
+ parsed_rows = None
+ paging_state = None
+ continuous_paging_seq = None
+ continuous_paging_last = None
+ new_keyspace = None
+ column_metadata = None
+ query_id = None
+ bind_metadata = None
+ pk_indexes = None
+ schema_change_event = None
+
+ def __init__(self, kind):
self.kind = kind
- self.results = results
- self.paging_state = paging_state
- self.col_types = col_types
+
+ def recv(self, f, protocol_version, user_type_map, result_metadata):
+ if self.kind == RESULT_KIND_VOID:
+ return
+ elif self.kind == RESULT_KIND_ROWS:
+ self.recv_results_rows(f, protocol_version, user_type_map, result_metadata)
+ elif self.kind == RESULT_KIND_SET_KEYSPACE:
+ self.new_keyspace = read_string(f)
+ elif self.kind == RESULT_KIND_PREPARED:
+ self.recv_results_prepared(f, protocol_version, user_type_map)
+ elif self.kind == RESULT_KIND_SCHEMA_CHANGE:
+ self.recv_results_schema_change(f, protocol_version)
+ else:
+ raise DriverException("Unknown RESULT kind: %d" % self.kind)
@classmethod
def recv_body(cls, f, protocol_version, user_type_map, result_metadata):
kind = read_int(f)
- paging_state = None
- col_types = None
- if kind == RESULT_KIND_VOID:
- results = None
- elif kind == RESULT_KIND_ROWS:
- paging_state, col_types, results, result_metadata_id = cls.recv_results_rows(
- f, protocol_version, user_type_map, result_metadata)
- elif kind == RESULT_KIND_SET_KEYSPACE:
- ksname = read_string(f)
- results = ksname
- elif kind == RESULT_KIND_PREPARED:
- results = cls.recv_results_prepared(f, protocol_version, user_type_map)
- elif kind == RESULT_KIND_SCHEMA_CHANGE:
- results = cls.recv_results_schema_change(f, protocol_version)
- else:
- raise DriverException("Unknown RESULT kind: %d" % kind)
- return cls(kind, results, paging_state, col_types)
+ msg = cls(kind)
+ msg.recv(f, protocol_version, user_type_map, result_metadata)
+ return msg
- @classmethod
- def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata):
- paging_state, column_metadata, result_metadata_id = cls.recv_results_metadata(f, user_type_map)
- column_metadata = column_metadata or result_metadata
+ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata):
+ self.recv_results_metadata(f, user_type_map)
+ column_metadata = self.column_metadata or result_metadata
rowcount = read_int(f)
- rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
- colnames = [c[2] for c in column_metadata]
- coltypes = [c[3] for c in column_metadata]
+ rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
+ self.column_names = [c[2] for c in column_metadata]
+ self.column_types = [c[3] for c in column_metadata]
try:
- parsed_rows = [
+ self.parsed_rows = [
tuple(ctype.from_binary(val, protocol_version)
- for ctype, val in zip(coltypes, row))
+ for ctype, val in zip(self.column_types, row))
for row in rows]
except Exception:
for row in rows:
for i in range(len(row)):
try:
- coltypes[i].from_binary(row[i], protocol_version)
+ self.column_types[i].from_binary(row[i], protocol_version)
except Exception as e:
- raise DriverException('Failed decoding result column "%s" of type %s: %s' % (colnames[i],
- coltypes[i].cql_parameterized_type(),
+ raise DriverException('Failed decoding result column "%s" of type %s: %s' % (self.column_names[i],
+ self.column_types[i].cql_parameterized_type(),
str(e)))
- return paging_state, coltypes, (colnames, parsed_rows), result_metadata_id
- @classmethod
- def recv_results_prepared(cls, f, protocol_version, user_type_map):
- query_id = read_binary_string(f)
+ def recv_results_prepared(self, f, protocol_version, user_type_map):
+ self.query_id = read_binary_string(f)
if ProtocolVersion.uses_prepared_metadata(protocol_version):
- result_metadata_id = read_binary_string(f)
+ self.result_metadata_id = read_binary_string(f)
else:
- result_metadata_id = None
- bind_metadata, pk_indexes, result_metadata, _ = cls.recv_prepared_metadata(f, protocol_version, user_type_map)
- return query_id, bind_metadata, pk_indexes, result_metadata, result_metadata_id
+ self.result_metadata_id = None
+ self.recv_prepared_metadata(f, protocol_version, user_type_map)
- @classmethod
- def recv_results_metadata(cls, f, user_type_map):
+ def recv_results_metadata(self, f, user_type_map):
flags = read_int(f)
colcount = read_int(f)
- if flags & cls._HAS_MORE_PAGES_FLAG:
- paging_state = read_binary_longstring(f)
- else:
- paging_state = None
+ if flags & self._HAS_MORE_PAGES_FLAG:
+ self.paging_state = read_binary_longstring(f)
- if flags & cls._METADATA_ID_FLAG:
- result_metadata_id = read_binary_string(f)
- else:
- result_metadata_id = None
-
- no_meta = bool(flags & cls._NO_METADATA_FLAG)
+ no_meta = bool(flags & self._NO_METADATA_FLAG)
if no_meta:
- return paging_state, [], result_metadata_id
+ return
+
+ if flags & self._CONTINUOUS_PAGING_FLAG:
+ self.continuous_paging_seq = read_int(f)
+ self.continuous_paging_last = flags & self._CONTINUOUS_PAGING_LAST_FLAG
- glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC)
+ if flags & self._METADATA_ID_FLAG:
+ self.result_metadata_id = read_binary_string(f)
+
+ glob_tblspec = bool(flags & self._FLAGS_GLOBAL_TABLES_SPEC)
if glob_tblspec:
ksname = read_string(f)
cfname = read_string(f)
column_metadata = []
for _ in range(colcount):
if glob_tblspec:
colksname = ksname
colcfname = cfname
else:
colksname = read_string(f)
colcfname = read_string(f)
colname = read_string(f)
- coltype = cls.read_type(f, user_type_map)
+ coltype = self.read_type(f, user_type_map)
column_metadata.append((colksname, colcfname, colname, coltype))
- return paging_state, column_metadata, result_metadata_id
- @classmethod
- def recv_prepared_metadata(cls, f, protocol_version, user_type_map):
+ self.column_metadata = column_metadata
+
+ def recv_prepared_metadata(self, f, protocol_version, user_type_map):
flags = read_int(f)
colcount = read_int(f)
pk_indexes = None
if protocol_version >= 4:
num_pk_indexes = read_int(f)
pk_indexes = [read_short(f) for _ in range(num_pk_indexes)]
- glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC)
+ glob_tblspec = bool(flags & self._FLAGS_GLOBAL_TABLES_SPEC)
if glob_tblspec:
ksname = read_string(f)
cfname = read_string(f)
bind_metadata = []
for _ in range(colcount):
if glob_tblspec:
colksname = ksname
colcfname = cfname
else:
colksname = read_string(f)
colcfname = read_string(f)
colname = read_string(f)
- coltype = cls.read_type(f, user_type_map)
+ coltype = self.read_type(f, user_type_map)
bind_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype))
if protocol_version >= 2:
- _, result_metadata, result_metadata_id = cls.recv_results_metadata(f, user_type_map)
- return bind_metadata, pk_indexes, result_metadata, result_metadata_id
- else:
- return bind_metadata, pk_indexes, None, None
+ self.recv_results_metadata(f, user_type_map)
- @classmethod
- def recv_results_schema_change(cls, f, protocol_version):
- return EventMessage.recv_schema_change(f, protocol_version)
+ self.bind_metadata = bind_metadata
+ self.pk_indexes = pk_indexes
+
+ def recv_results_schema_change(self, f, protocol_version):
+ self.schema_change_event = EventMessage.recv_schema_change(f, protocol_version)
@classmethod
def read_type(cls, f, user_type_map):
optid = read_short(f)
try:
typeclass = cls.type_codes[optid]
except KeyError:
raise NotSupportedError("Unknown data type code 0x%04x. Have to skip"
" entire result set." % (optid,))
if typeclass in (ListType, SetType):
subtype = cls.read_type(f, user_type_map)
typeclass = typeclass.apply_parameters((subtype,))
elif typeclass == MapType:
keysubtype = cls.read_type(f, user_type_map)
valsubtype = cls.read_type(f, user_type_map)
typeclass = typeclass.apply_parameters((keysubtype, valsubtype))
elif typeclass == TupleType:
num_items = read_short(f)
types = tuple(cls.read_type(f, user_type_map) for _ in range(num_items))
typeclass = typeclass.apply_parameters(types)
elif typeclass == UserType:
ks = read_string(f)
udt_name = read_string(f)
num_fields = read_short(f)
names, types = zip(*((read_string(f), cls.read_type(f, user_type_map))
for _ in range(num_fields)))
specialized_type = typeclass.make_udt_class(ks, udt_name, names, types)
specialized_type.mapped_class = user_type_map.get(ks, {}).get(udt_name)
typeclass = specialized_type
elif typeclass == CUSTOM_TYPE:
classname = read_string(f)
typeclass = lookup_casstype(classname)
return typeclass
@staticmethod
def recv_row(f, colcount):
return [read_value(f) for _ in range(colcount)]
class PrepareMessage(_MessageType):
opcode = 0x09
name = 'PREPARE'
def __init__(self, query, keyspace=None):
self.query = query
self.keyspace = keyspace
def send_body(self, f, protocol_version):
write_longstring(f, self.query)
flags = 0x00
if self.keyspace is not None:
if ProtocolVersion.uses_keyspace_flag(protocol_version):
flags |= _PREPARED_WITH_KEYSPACE_FLAG
else:
raise UnsupportedOperation(
"Keyspaces may only be set on queries with protocol version "
- "5 or higher. Consider setting Cluster.protocol_version to 5.")
+ "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.")
if ProtocolVersion.uses_prepare_flags(protocol_version):
write_uint(f, flags)
else:
# checks above should prevent this, but just to be safe...
if flags:
raise UnsupportedOperation(
"Attempted to set flags with value {flags:0=#8x} on"
"protocol version {pv}, which doesn't support flags"
"in prepared statements."
- "Consider setting Cluster.protocol_version to 5."
+ "Consider setting Cluster.protocol_version to 5 or DSE_V2."
"".format(flags=flags, pv=protocol_version))
if ProtocolVersion.uses_keyspace_flag(protocol_version):
if self.keyspace:
write_string(f, self.keyspace)
-class ExecuteMessage(_MessageType):
- opcode = 0x0A
- name = 'EXECUTE'
- def __init__(self, query_id, query_params, consistency_level,
- serial_consistency_level=None, fetch_size=None,
- paging_state=None, timestamp=None, skip_meta=False,
- result_metadata_id=None):
- self.query_id = query_id
- self.query_params = query_params
- self.consistency_level = consistency_level
- self.serial_consistency_level = serial_consistency_level
- self.fetch_size = fetch_size
- self.paging_state = paging_state
- self.timestamp = timestamp
- self.skip_meta = skip_meta
- self.result_metadata_id = result_metadata_id
-
- def send_body(self, f, protocol_version):
- write_string(f, self.query_id)
- if ProtocolVersion.uses_prepared_metadata(protocol_version):
- write_string(f, self.result_metadata_id)
- if protocol_version == 1:
- if self.serial_consistency_level:
- raise UnsupportedOperation(
- "Serial consistency levels require the use of protocol version "
- "2 or higher. Consider setting Cluster.protocol_version to 2 "
- "to support serial consistency levels.")
- if self.fetch_size or self.paging_state:
- raise UnsupportedOperation(
- "Automatic query paging may only be used with protocol version "
- "2 or higher. Consider setting Cluster.protocol_version to 2.")
- write_short(f, len(self.query_params))
- for param in self.query_params:
- write_value(f, param)
- write_consistency_level(f, self.consistency_level)
- else:
- write_consistency_level(f, self.consistency_level)
- flags = _VALUES_FLAG
- if self.serial_consistency_level:
- flags |= _WITH_SERIAL_CONSISTENCY_FLAG
- if self.fetch_size:
- flags |= _PAGE_SIZE_FLAG
- if self.paging_state:
- flags |= _WITH_PAGING_STATE_FLAG
- if self.timestamp is not None:
- if protocol_version >= 3:
- flags |= _PROTOCOL_TIMESTAMP
- else:
- raise UnsupportedOperation(
- "Protocol-level timestamps may only be used with protocol version "
- "3 or higher. Consider setting Cluster.protocol_version to 3.")
- if self.skip_meta:
- flags |= _SKIP_METADATA_FLAG
-
- if ProtocolVersion.uses_int_query_flags(protocol_version):
- write_uint(f, flags)
- else:
- write_byte(f, flags)
-
- write_short(f, len(self.query_params))
- for param in self.query_params:
- write_value(f, param)
- if self.fetch_size:
- write_int(f, self.fetch_size)
- if self.paging_state:
- write_longstring(f, self.paging_state)
- if self.serial_consistency_level:
- write_consistency_level(f, self.serial_consistency_level)
- if self.timestamp is not None:
- write_long(f, self.timestamp)
-
-
-
class BatchMessage(_MessageType):
opcode = 0x0D
name = 'BATCH'
def __init__(self, batch_type, queries, consistency_level,
serial_consistency_level=None, timestamp=None,
keyspace=None):
self.batch_type = batch_type
self.queries = queries
self.consistency_level = consistency_level
self.serial_consistency_level = serial_consistency_level
self.timestamp = timestamp
self.keyspace = keyspace
def send_body(self, f, protocol_version):
write_byte(f, self.batch_type.value)
write_short(f, len(self.queries))
for prepared, string_or_query_id, params in self.queries:
if not prepared:
write_byte(f, 0)
write_longstring(f, string_or_query_id)
else:
write_byte(f, 1)
write_short(f, len(string_or_query_id))
f.write(string_or_query_id)
write_short(f, len(params))
for param in params:
write_value(f, param)
write_consistency_level(f, self.consistency_level)
if protocol_version >= 3:
flags = 0
if self.serial_consistency_level:
flags |= _WITH_SERIAL_CONSISTENCY_FLAG
if self.timestamp is not None:
- flags |= _PROTOCOL_TIMESTAMP
+ flags |= _PROTOCOL_TIMESTAMP_FLAG
if self.keyspace:
if ProtocolVersion.uses_keyspace_flag(protocol_version):
flags |= _WITH_KEYSPACE_FLAG
else:
raise UnsupportedOperation(
"Keyspaces may only be set on queries with protocol version "
"5 or higher. Consider setting Cluster.protocol_version to 5.")
if ProtocolVersion.uses_int_query_flags(protocol_version):
write_int(f, flags)
else:
write_byte(f, flags)
if self.serial_consistency_level:
write_consistency_level(f, self.serial_consistency_level)
if self.timestamp is not None:
write_long(f, self.timestamp)
if ProtocolVersion.uses_keyspace_flag(protocol_version):
if self.keyspace is not None:
write_string(f, self.keyspace)
known_event_types = frozenset((
'TOPOLOGY_CHANGE',
'STATUS_CHANGE',
'SCHEMA_CHANGE'
))
class RegisterMessage(_MessageType):
opcode = 0x0B
name = 'REGISTER'
def __init__(self, event_list):
self.event_list = event_list
def send_body(self, f, protocol_version):
write_stringlist(f, self.event_list)
class EventMessage(_MessageType):
opcode = 0x0C
name = 'EVENT'
def __init__(self, event_type, event_args):
self.event_type = event_type
self.event_args = event_args
@classmethod
def recv_body(cls, f, protocol_version, *args):
event_type = read_string(f).upper()
if event_type in known_event_types:
read_method = getattr(cls, 'recv_' + event_type.lower())
return cls(event_type=event_type, event_args=read_method(f, protocol_version))
raise NotSupportedError('Unknown event type %r' % event_type)
@classmethod
def recv_topology_change(cls, f, protocol_version):
# "NEW_NODE" or "REMOVED_NODE"
change_type = read_string(f)
address = read_inet(f)
return dict(change_type=change_type, address=address)
@classmethod
def recv_status_change(cls, f, protocol_version):
# "UP" or "DOWN"
change_type = read_string(f)
address = read_inet(f)
return dict(change_type=change_type, address=address)
@classmethod
def recv_schema_change(cls, f, protocol_version):
# "CREATED", "DROPPED", or "UPDATED"
change_type = read_string(f)
if protocol_version >= 3:
target = read_string(f)
keyspace = read_string(f)
event = {'target_type': target, 'change_type': change_type, 'keyspace': keyspace}
if target != SchemaTargetType.KEYSPACE:
target_name = read_string(f)
if target == SchemaTargetType.FUNCTION:
event['function'] = UserFunctionDescriptor(target_name, [read_string(f) for _ in range(read_short(f))])
elif target == SchemaTargetType.AGGREGATE:
event['aggregate'] = UserAggregateDescriptor(target_name, [read_string(f) for _ in range(read_short(f))])
else:
event[target.lower()] = target_name
else:
keyspace = read_string(f)
table = read_string(f)
if table:
event = {'target_type': SchemaTargetType.TABLE, 'change_type': change_type, 'keyspace': keyspace, 'table': table}
else:
event = {'target_type': SchemaTargetType.KEYSPACE, 'change_type': change_type, 'keyspace': keyspace}
return event
+class ReviseRequestMessage(_MessageType):
+
+ class RevisionType(object):
+ PAGING_CANCEL = 1
+ PAGING_BACKPRESSURE = 2
+
+ opcode = 0xFF
+ name = 'REVISE_REQUEST'
+
+ def __init__(self, op_type, op_id, next_pages=0):
+ self.op_type = op_type
+ self.op_id = op_id
+ self.next_pages = next_pages
+
+ def send_body(self, f, protocol_version):
+ write_int(f, self.op_type)
+ write_int(f, self.op_id)
+ if self.op_type == ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE:
+ if self.next_pages <= 0:
+ raise UnsupportedOperation("Continuous paging backpressure requires next_pages > 0")
+ elif not ProtocolVersion.has_continuous_paging_next_pages(protocol_version):
+ raise UnsupportedOperation(
+ "Continuous paging backpressure may only be used with protocol version "
+ "ProtocolVersion.DSE_V2 or higher. Consider setting Cluster.protocol_version to ProtocolVersion.DSE_V2.")
+ else:
+ write_int(f, self.next_pages)
+
+
class _ProtocolHandler(object):
"""
_ProtocolHander handles encoding and decoding messages.
This class can be specialized to compose Handlers which implement alternative
result decoding or type deserialization. Class definitions are passed to :class:`cassandra.cluster.Cluster`
on initialization.
Contracted class methods are :meth:`_ProtocolHandler.encode_message` and :meth:`_ProtocolHandler.decode_message`.
"""
message_types_by_opcode = _message_types_by_opcode.copy()
"""
Default mapping of opcode to Message implementation. The default ``decode_message`` implementation uses
this to instantiate a message and populate using ``recv_body``. This mapping can be updated to inject specialized
result decoding implementations.
"""
@classmethod
def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version):
"""
Encodes a message using the specified frame parameters, and compressor
:param msg: the message, typically of cassandra.protocol._MessageType, generated by the driver
:param stream_id: protocol stream id for the frame header
:param protocol_version: version for the frame header, and used encoding contents
:param compressor: optional compression function to be used on the body
"""
flags = 0
body = io.BytesIO()
if msg.custom_payload:
if protocol_version < 4:
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
flags |= CUSTOM_PAYLOAD_FLAG
write_bytesmap(body, msg.custom_payload)
msg.send_body(body, protocol_version)
body = body.getvalue()
- if compressor and len(body) > 0:
+ # With checksumming, the compression is done at the segment frame encoding
+ if (not ProtocolVersion.has_checksumming_support(protocol_version)
+ and compressor and len(body) > 0):
body = compressor(body)
flags |= COMPRESSED_FLAG
if msg.tracing:
flags |= TRACING_FLAG
if allow_beta_protocol_version:
flags |= USE_BETA_FLAG
buff = io.BytesIO()
cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body))
buff.write(body)
return buff.getvalue()
@staticmethod
def _write_header(f, version, flags, stream_id, opcode, length):
"""
Write a CQL protocol frame header.
"""
pack = v3_header_pack if version >= 3 else header_pack
f.write(pack(version, flags, stream_id, opcode))
write_int(f, length)
@classmethod
def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body,
decompressor, result_metadata):
"""
Decodes a native protocol message body
:param protocol_version: version to use decoding contents
:param user_type_map: map[keyspace name] = map[type name] = custom type to instantiate when deserializing this type
:param stream_id: native protocol stream id from the frame header
:param flags: native protocol flags bitmap from the header
:param opcode: native protocol opcode from the header
:param body: frame body
:param decompressor: optional decompression function to inflate the body
:return: a message decoded from the body and frame attributes
"""
- if flags & COMPRESSED_FLAG:
+ if (not ProtocolVersion.has_checksumming_support(protocol_version) and
+ flags & COMPRESSED_FLAG):
if decompressor is None:
raise RuntimeError("No de-compressor available for compressed frame!")
body = decompressor(body)
flags ^= COMPRESSED_FLAG
body = io.BytesIO(body)
if flags & TRACING_FLAG:
trace_id = UUID(bytes=body.read(16))
flags ^= TRACING_FLAG
else:
trace_id = None
if flags & WARNING_FLAG:
warnings = read_stringlist(body)
flags ^= WARNING_FLAG
else:
warnings = None
if flags & CUSTOM_PAYLOAD_FLAG:
custom_payload = read_bytesmap(body)
flags ^= CUSTOM_PAYLOAD_FLAG
else:
custom_payload = None
- flags &= USE_BETA_MASK # will only be set if we asserted it in connection estabishment
+ flags &= USE_BETA_MASK # will only be set if we asserted it in connection estabishment
if flags:
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
msg_class = cls.message_types_by_opcode[opcode]
msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata)
msg.stream_id = stream_id
msg.trace_id = trace_id
msg.custom_payload = custom_payload
msg.warnings = warnings
if msg.warnings:
for w in msg.warnings:
log.warning("Server warning: %s", w)
return msg
+
def cython_protocol_handler(colparser):
"""
Given a column parser to deserialize ResultMessages, return a suitable
Cython-based protocol handler.
There are three Cython-based protocol handlers:
- obj_parser.ListParser
decodes result messages into a list of tuples
- obj_parser.LazyParser
decodes result messages lazily by returning an iterator
- numpy_parser.NumPyParser
decodes result messages into NumPy arrays
The default is to use obj_parser.ListParser
"""
from cassandra.row_parser import make_recv_results_rows
class FastResultMessage(ResultMessage):
"""
Cython version of Result Message that has a faster implementation of
recv_results_row.
"""
# type_codes = ResultMessage.type_codes.copy()
code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items())
- recv_results_rows = classmethod(make_recv_results_rows(colparser))
+ recv_results_rows = make_recv_results_rows(colparser)
class CythonProtocolHandler(_ProtocolHandler):
"""
Use FastResultMessage to decode query result message messages.
"""
my_opcodes = _ProtocolHandler.message_types_by_opcode.copy()
my_opcodes[FastResultMessage.opcode] = FastResultMessage
message_types_by_opcode = my_opcodes
col_parser = colparser
return CythonProtocolHandler
if HAVE_CYTHON:
from cassandra.obj_parser import ListParser, LazyParser
ProtocolHandler = cython_protocol_handler(ListParser())
LazyProtocolHandler = cython_protocol_handler(LazyParser())
else:
# Use Python-based ProtocolHandler
ProtocolHandler = _ProtocolHandler
LazyProtocolHandler = None
if HAVE_CYTHON and HAVE_NUMPY:
from cassandra.numpy_parser import NumpyParser
NumpyProtocolHandler = cython_protocol_handler(NumpyParser())
else:
NumpyProtocolHandler = None
def read_byte(f):
return int8_unpack(f.read(1))
def write_byte(f, b):
- f.write(int8_pack(b))
+ f.write(uint8_pack(b))
def read_int(f):
return int32_unpack(f.read(4))
+def read_uint_le(f, size=4):
+ """
+ Read a sequence of little endian bytes and return an unsigned integer.
+ """
+
+ if size == 4:
+ value = uint32_le_unpack(f.read(4))
+ else:
+ value = 0
+ for i in range(size):
+ value |= (read_byte(f) & 0xFF) << 8 * i
+
+ return value
+
+
+def write_uint_le(f, i, size=4):
+ """
+ Write an unsigned integer on a sequence of little endian bytes.
+ """
+ if size == 4:
+ f.write(uint32_le_pack(i))
+ else:
+ for j in range(size):
+ shift = j * 8
+ write_byte(f, i >> shift & 0xFF)
+
+
def write_int(f, i):
f.write(int32_pack(i))
def write_uint(f, i):
f.write(uint32_pack(i))
def write_long(f, i):
f.write(uint64_pack(i))
def read_short(f):
return uint16_unpack(f.read(2))
def write_short(f, s):
f.write(uint16_pack(s))
def read_consistency_level(f):
return read_short(f)
def write_consistency_level(f, cl):
write_short(f, cl)
def read_string(f):
size = read_short(f)
contents = f.read(size)
return contents.decode('utf8')
def read_binary_string(f):
size = read_short(f)
contents = f.read(size)
return contents
def write_string(f, s):
if isinstance(s, six.text_type):
s = s.encode('utf8')
write_short(f, len(s))
f.write(s)
def read_binary_longstring(f):
size = read_int(f)
contents = f.read(size)
return contents
def read_longstring(f):
return read_binary_longstring(f).decode('utf8')
def write_longstring(f, s):
if isinstance(s, six.text_type):
s = s.encode('utf8')
write_int(f, len(s))
f.write(s)
def read_stringlist(f):
numstrs = read_short(f)
return [read_string(f) for _ in range(numstrs)]
def write_stringlist(f, stringlist):
write_short(f, len(stringlist))
for s in stringlist:
write_string(f, s)
def read_stringmap(f):
numpairs = read_short(f)
strmap = {}
for _ in range(numpairs):
k = read_string(f)
strmap[k] = read_string(f)
return strmap
def write_stringmap(f, strmap):
write_short(f, len(strmap))
for k, v in strmap.items():
write_string(f, k)
write_string(f, v)
def read_bytesmap(f):
numpairs = read_short(f)
bytesmap = {}
for _ in range(numpairs):
k = read_string(f)
bytesmap[k] = read_value(f)
return bytesmap
def write_bytesmap(f, bytesmap):
write_short(f, len(bytesmap))
for k, v in bytesmap.items():
write_string(f, k)
write_value(f, v)
def read_stringmultimap(f):
numkeys = read_short(f)
strmmap = {}
for _ in range(numkeys):
k = read_string(f)
strmmap[k] = read_stringlist(f)
return strmmap
def write_stringmultimap(f, strmmap):
write_short(f, len(strmmap))
for k, v in strmmap.items():
write_string(f, k)
write_stringlist(f, v)
def read_error_code_map(f):
numpairs = read_int(f)
error_code_map = {}
for _ in range(numpairs):
endpoint = read_inet_addr_only(f)
error_code_map[endpoint] = read_short(f)
return error_code_map
def read_value(f):
size = read_int(f)
if size < 0:
return None
return f.read(size)
def write_value(f, v):
if v is None:
write_int(f, -1)
elif v is _UNSET_VALUE:
write_int(f, -2)
else:
write_int(f, len(v))
f.write(v)
def read_inet_addr_only(f):
size = read_byte(f)
addrbytes = f.read(size)
if size == 4:
addrfam = socket.AF_INET
elif size == 16:
addrfam = socket.AF_INET6
else:
raise InternalError("bad inet address: %r" % (addrbytes,))
return util.inet_ntop(addrfam, addrbytes)
def read_inet(f):
addr = read_inet_addr_only(f)
port = read_int(f)
return (addr, port)
def write_inet(f, addrtuple):
addr, port = addrtuple
if ':' in addr:
addrfam = socket.AF_INET6
else:
addrfam = socket.AF_INET
addrbytes = util.inet_pton(addrfam, addr)
write_byte(f, len(addrbytes))
f.write(addrbytes)
write_int(f, port)
diff --git a/cassandra/query.py b/cassandra/query.py
index 74a9896..0e7a41d 100644
--- a/cassandra/query.py
+++ b/cassandra/query.py
@@ -1,1089 +1,1102 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module holds classes for working with prepared statements and
specifying consistency levels and retry policies for individual
queries.
"""
from collections import namedtuple
from datetime import datetime, timedelta
import re
import struct
import time
import six
from six.moves import range, zip
import warnings
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.util import unix_time_from_uuid1
from cassandra.encoder import Encoder
import cassandra.encoder
from cassandra.protocol import _UNSET_VALUE
from cassandra.util import OrderedDict, _sanitize_identifiers
import logging
log = logging.getLogger(__name__)
UNSET_VALUE = _UNSET_VALUE
"""
Specifies an unset value when binding a prepared statement.
Unset values are ignored, allowing prepared statements to be used without specify
See https://issues.apache.org/jira/browse/CASSANDRA-7304 for further details on semantics.
.. versionadded:: 2.6.0
Only valid when using native protocol v4+
"""
NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]')
START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*')
END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$')
_clean_name_cache = {}
def _clean_column_name(name):
try:
return _clean_name_cache[name]
except KeyError:
clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name)))
_clean_name_cache[name] = clean
return clean
def tuple_factory(colnames, rows):
"""
Returns each row as a tuple
Example::
>>> from cassandra.query import tuple_factory
>>> session = cluster.connect('mykeyspace')
>>> session.row_factory = tuple_factory
>>> rows = session.execute("SELECT name, age FROM users LIMIT 1")
>>> print rows[0]
('Bob', 42)
.. versionchanged:: 2.0.0
moved from ``cassandra.decoder`` to ``cassandra.query``
"""
return rows
class PseudoNamedTupleRow(object):
"""
Helper class for pseudo_named_tuple_factory. These objects provide an
__iter__ interface, as well as index- and attribute-based access to values,
but otherwise do not attempt to implement the full namedtuple or iterable
interface.
"""
def __init__(self, ordered_dict):
self._dict = ordered_dict
self._tuple = tuple(ordered_dict.values())
def __getattr__(self, name):
return self._dict[name]
def __getitem__(self, idx):
return self._tuple[idx]
def __iter__(self):
return iter(self._tuple)
def __repr__(self):
return '{t}({od})'.format(t=self.__class__.__name__,
od=self._dict)
def pseudo_namedtuple_factory(colnames, rows):
"""
Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback
factory for cases where :meth:`.named_tuple_factory` fails to create rows.
"""
return [PseudoNamedTupleRow(od)
for od in ordered_dict_factory(colnames, rows)]
def named_tuple_factory(colnames, rows):
"""
Returns each row as a `namedtuple `_.
This is the default row factory.
Example::
>>> from cassandra.query import named_tuple_factory
>>> session = cluster.connect('mykeyspace')
>>> session.row_factory = named_tuple_factory
>>> rows = session.execute("SELECT name, age FROM users LIMIT 1")
>>> user = rows[0]
>>> # you can access field by their name:
>>> print "name: %s, age: %d" % (user.name, user.age)
name: Bob, age: 42
>>> # or you can access fields by their position (like a tuple)
>>> name, age = user
>>> print "name: %s, age: %d" % (name, age)
name: Bob, age: 42
>>> name = user[0]
>>> age = user[1]
>>> print "name: %s, age: %d" % (name, age)
name: Bob, age: 42
.. versionchanged:: 2.0.0
moved from ``cassandra.decoder`` to ``cassandra.query``
"""
clean_column_names = map(_clean_column_name, colnames)
try:
Row = namedtuple('Row', clean_column_names)
except SyntaxError:
warnings.warn(
"Failed creating namedtuple for a result because there were too "
"many columns. This is due to a Python limitation that affects "
"namedtuple in Python 3.0-3.6 (see issue18896). The row will be "
"created with {substitute_factory_name}, which lacks some namedtuple "
"features and is slower. To avoid slower performance accessing "
"values on row objects, Upgrade to Python 3.7, or use a different "
"row factory. (column names: {colnames})".format(
substitute_factory_name=pseudo_namedtuple_factory.__name__,
colnames=colnames
)
)
return pseudo_namedtuple_factory(colnames, rows)
except Exception:
clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt
log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) "
"(see Python 'namedtuple' documentation for details on name rules). "
"Results will be returned with positional names. "
"Avoid this by choosing different names, using SELECT \"
\" AS aliases, "
"or specifying a different row_factory on your Session" %
(colnames, clean_column_names))
Row = namedtuple('Row', _sanitize_identifiers(clean_column_names))
return [Row(*row) for row in rows]
def dict_factory(colnames, rows):
"""
Returns each row as a dict.
Example::
>>> from cassandra.query import dict_factory
>>> session = cluster.connect('mykeyspace')
>>> session.row_factory = dict_factory
>>> rows = session.execute("SELECT name, age FROM users LIMIT 1")
>>> print rows[0]
{u'age': 42, u'name': u'Bob'}
.. versionchanged:: 2.0.0
moved from ``cassandra.decoder`` to ``cassandra.query``
"""
return [dict(zip(colnames, row)) for row in rows]
def ordered_dict_factory(colnames, rows):
"""
Like :meth:`~cassandra.query.dict_factory`, but returns each row as an OrderedDict,
so the order of the columns is preserved.
.. versionchanged:: 2.0.0
moved from ``cassandra.decoder`` to ``cassandra.query``
"""
return [OrderedDict(zip(colnames, row)) for row in rows]
FETCH_SIZE_UNSET = object()
class Statement(object):
"""
An abstract class representing a single query. There are three subclasses:
:class:`.SimpleStatement`, :class:`.BoundStatement`, and :class:`.BatchStatement`.
These can be passed to :meth:`.Session.execute()`.
"""
retry_policy = None
"""
An instance of a :class:`cassandra.policies.RetryPolicy` or one of its
subclasses. This controls when a query will be retried and how it
will be retried.
"""
consistency_level = None
"""
The :class:`.ConsistencyLevel` to be used for this operation. Defaults
to :const:`None`, which means that the default consistency level for
the Session this is executed in will be used.
"""
fetch_size = FETCH_SIZE_UNSET
"""
How many rows will be fetched at a time. This overrides the default
of :attr:`.Session.default_fetch_size`
This only takes effect when protocol version 2 or higher is used.
See :attr:`.Cluster.protocol_version` for details.
.. versionadded:: 2.0.0
"""
keyspace = None
"""
The string name of the keyspace this query acts on. This is used when
- :class:`~.TokenAwarePolicy` is configured for
- :attr:`.Cluster.load_balancing_policy`
+ :class:`~.TokenAwarePolicy` is configured in the profile load balancing policy.
It is set implicitly on :class:`.BoundStatement`, and :class:`.BatchStatement`,
but must be set explicitly on :class:`.SimpleStatement`.
.. versionadded:: 2.1.3
"""
custom_payload = None
"""
:ref:`custom_payload` to be passed to the server.
These are only allowed when using protocol version 4 or higher.
.. versionadded:: 2.6.0
"""
is_idempotent = False
"""
Flag indicating whether this statement is safe to run multiple times in speculative execution.
"""
_serial_consistency_level = None
_routing_key = None
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None,
is_idempotent=False):
if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors
raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy')
if retry_policy is not None:
self.retry_policy = retry_policy
if consistency_level is not None:
self.consistency_level = consistency_level
self._routing_key = routing_key
if serial_consistency_level is not None:
self.serial_consistency_level = serial_consistency_level
if fetch_size is not FETCH_SIZE_UNSET:
self.fetch_size = fetch_size
if keyspace is not None:
self.keyspace = keyspace
if custom_payload is not None:
self.custom_payload = custom_payload
self.is_idempotent = is_idempotent
def _key_parts_packed(self, parts):
for p in parts:
l = len(p)
yield struct.pack(">H%dsB" % l, l, p, 0)
def _get_routing_key(self):
return self._routing_key
def _set_routing_key(self, key):
if isinstance(key, (list, tuple)):
if len(key) == 1:
self._routing_key = key[0]
else:
self._routing_key = b"".join(self._key_parts_packed(key))
else:
self._routing_key = key
def _del_routing_key(self):
self._routing_key = None
routing_key = property(
_get_routing_key,
_set_routing_key,
_del_routing_key,
"""
The :attr:`~.TableMetadata.partition_key` portion of the primary key,
which can be used to determine which nodes are replicas for the query.
If the partition key is a composite, a list or tuple must be passed in.
Each key component should be in its packed (binary) format, so all
components should be strings.
""")
def _get_serial_consistency_level(self):
return self._serial_consistency_level
def _set_serial_consistency_level(self, serial_consistency_level):
if (serial_consistency_level is not None and
not ConsistencyLevel.is_serial(serial_consistency_level)):
raise ValueError(
"serial_consistency_level must be either ConsistencyLevel.SERIAL "
"or ConsistencyLevel.LOCAL_SERIAL")
self._serial_consistency_level = serial_consistency_level
def _del_serial_consistency_level(self):
self._serial_consistency_level = None
serial_consistency_level = property(
_get_serial_consistency_level,
_set_serial_consistency_level,
_del_serial_consistency_level,
"""
The serial consistency level is only used by conditional updates
(``INSERT``, ``UPDATE`` and ``DELETE`` with an ``IF`` condition). For
those, the ``serial_consistency_level`` defines the consistency level of
the serial phase (or "paxos" phase) while the normal
:attr:`~.consistency_level` defines the consistency for the "learn" phase,
i.e. what type of reads will be guaranteed to see the update right away.
For example, if a conditional write has a :attr:`~.consistency_level` of
:attr:`~.ConsistencyLevel.QUORUM` (and is successful), then a
:attr:`~.ConsistencyLevel.QUORUM` read is guaranteed to see that write.
But if the regular :attr:`~.consistency_level` of that write is
:attr:`~.ConsistencyLevel.ANY`, then only a read with a
:attr:`~.consistency_level` of :attr:`~.ConsistencyLevel.SERIAL` is
guaranteed to see it (even a read with consistency
:attr:`~.ConsistencyLevel.ALL` is not guaranteed to be enough).
The serial consistency can only be one of :attr:`~.ConsistencyLevel.SERIAL`
or :attr:`~.ConsistencyLevel.LOCAL_SERIAL`. While ``SERIAL`` guarantees full
linearizability (with other ``SERIAL`` updates), ``LOCAL_SERIAL`` only
guarantees it in the local data center.
The serial consistency level is ignored for any query that is not a
conditional update. Serial reads should use the regular
:attr:`consistency_level`.
Serial consistency levels may only be used against Cassandra 2.0+
and the :attr:`~.Cluster.protocol_version` must be set to 2 or higher.
See :doc:`/lwt` for a discussion on how to work with results returned from
conditional statements.
.. versionadded:: 2.0.0
""")
class SimpleStatement(Statement):
"""
A simple, un-prepared query.
"""
def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None,
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
custom_payload=None, is_idempotent=False):
"""
`query_string` should be a literal CQL statement with the exception
of parameter placeholders that will be filled through the
`parameters` argument of :meth:`.Session.execute()`.
See :class:`Statement` attributes for a description of the other parameters.
"""
Statement.__init__(self, retry_policy, consistency_level, routing_key,
serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent)
self._query_string = query_string
@property
def query_string(self):
return self._query_string
def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'' %
(self.query_string, consistency))
__repr__ = __str__
class PreparedStatement(object):
"""
A statement that has been prepared against at least one Cassandra node.
Instances of this class should not be created directly, but through
:meth:`.Session.prepare()`.
A :class:`.PreparedStatement` should be prepared only once. Re-preparing a statement
may affect performance (as the operation requires a network roundtrip).
|prepared_stmt_head|: Do not use ``*`` in prepared statements if you might
change the schema of the table being queried. The driver and server each
maintain a map between metadata for a schema and statements that were
prepared against that schema. When a user changes a schema, e.g. by adding
or removing a column, the server invalidates its mappings involving that
schema. However, there is currently no way to propagate that invalidation
to drivers. Thus, after a schema change, the driver will incorrectly
interpret the results of ``SELECT *`` queries prepared before the schema
change. This is currently being addressed in `CASSANDRA-10786
`_.
.. |prepared_stmt_head| raw:: html
A note about * in prepared statements
"""
column_metadata = None #TODO: make this bind_metadata in next major
retry_policy = None
consistency_level = None
custom_payload = None
fetch_size = FETCH_SIZE_UNSET
keyspace = None # change to prepared_keyspace in major release
protocol_version = None
query_id = None
query_string = None
result_metadata = None
result_metadata_id = None
routing_key_indexes = None
_routing_key_index_set = None
serial_consistency_level = None # TODO never used?
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
keyspace, protocol_version, result_metadata, result_metadata_id):
self.column_metadata = column_metadata
self.query_id = query_id
self.routing_key_indexes = routing_key_indexes
self.query_string = query
self.keyspace = keyspace
self.protocol_version = protocol_version
self.result_metadata = result_metadata
self.result_metadata_id = result_metadata_id
self.is_idempotent = False
@classmethod
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
query, prepared_keyspace, protocol_version, result_metadata,
result_metadata_id):
if not column_metadata:
return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version, result_metadata, result_metadata_id)
if pk_indexes:
routing_key_indexes = pk_indexes
else:
routing_key_indexes = None
first_col = column_metadata[0]
ks_meta = cluster_metadata.keyspaces.get(first_col.keyspace_name)
if ks_meta:
table_meta = ks_meta.tables.get(first_col.table_name)
if table_meta:
partition_key_columns = table_meta.partition_key
# make a map of {column_name: index} for each column in the statement
statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata))
# a list of which indexes in the statement correspond to partition key items
try:
routing_key_indexes = [statement_indexes[c.name]
for c in partition_key_columns]
except KeyError: # we're missing a partition key component in the prepared
pass # statement; just leave routing_key_indexes as None
return PreparedStatement(column_metadata, query_id, routing_key_indexes,
query, prepared_keyspace, protocol_version, result_metadata,
result_metadata_id)
def bind(self, values):
"""
Creates and returns a :class:`BoundStatement` instance using `values`.
See :meth:`BoundStatement.bind` for rules on input ``values``.
"""
return BoundStatement(self).bind(values)
def is_routing_key_index(self, i):
if self._routing_key_index_set is None:
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
return i in self._routing_key_index_set
def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'' %
(self.query_string, consistency))
__repr__ = __str__
class BoundStatement(Statement):
"""
A prepared statement that has been bound to a particular set of values.
These may be created directly or through :meth:`.PreparedStatement.bind()`.
"""
prepared_statement = None
"""
The :class:`PreparedStatement` instance that this was created from.
"""
values = None
"""
The sequence of values that were bound to the prepared statement.
"""
def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None,
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
custom_payload=None):
"""
`prepared_statement` should be an instance of :class:`PreparedStatement`.
See :class:`Statement` attributes for a description of the other parameters.
"""
self.prepared_statement = prepared_statement
self.retry_policy = prepared_statement.retry_policy
self.consistency_level = prepared_statement.consistency_level
self.serial_consistency_level = prepared_statement.serial_consistency_level
self.fetch_size = prepared_statement.fetch_size
self.custom_payload = prepared_statement.custom_payload
self.is_idempotent = prepared_statement.is_idempotent
self.values = []
meta = prepared_statement.column_metadata
if meta:
self.keyspace = meta[0].keyspace_name
Statement.__init__(self, retry_policy, consistency_level, routing_key,
serial_consistency_level, fetch_size, keyspace, custom_payload,
prepared_statement.is_idempotent)
def bind(self, values):
"""
Binds a sequence of values for the prepared statement parameters
and returns this instance. Note that `values` *must* be:
* a sequence, even if you are only binding one value, or
* a dict that relates 1-to-1 between dict keys and columns
.. versionchanged:: 2.6.0
:data:`~.UNSET_VALUE` was introduced. These can be bound as positional parameters
in a sequence, or by name in a dict. Additionally, when using protocol v4+:
* short sequences will be extended to match bind parameters with UNSET_VALUE
* names may be omitted from a dict with UNSET_VALUE implied.
.. versionchanged:: 3.0.0
method will not throw if extra keys are present in bound dict (PYTHON-178)
"""
if values is None:
values = ()
proto_version = self.prepared_statement.protocol_version
col_meta = self.prepared_statement.column_metadata
# special case for binding dicts
if isinstance(values, dict):
values_dict = values
values = []
# sort values accordingly
for col in col_meta:
try:
values.append(values_dict[col.name])
except KeyError:
if proto_version >= 4:
values.append(UNSET_VALUE)
else:
raise KeyError(
'Column name `%s` not found in bound dict.' %
(col.name))
value_len = len(values)
col_meta_len = len(col_meta)
if value_len > col_meta_len:
raise ValueError(
"Too many arguments provided to bind() (got %d, expected %d)" %
(len(values), len(col_meta)))
# this is fail-fast for clarity pre-v4. When v4 can be assumed,
# the error will be better reported when UNSET_VALUE is implicitly added.
if proto_version < 4 and self.prepared_statement.routing_key_indexes and \
value_len < len(self.prepared_statement.routing_key_indexes):
raise ValueError(
"Too few arguments provided to bind() (got %d, required %d for routing key)" %
(value_len, len(self.prepared_statement.routing_key_indexes)))
self.raw_values = values
self.values = []
for value, col_spec in zip(values, col_meta):
if value is None:
self.values.append(None)
elif value is UNSET_VALUE:
if proto_version >= 4:
self._append_unset_value()
else:
raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version)
else:
try:
self.values.append(col_spec.type.serialize(value, proto_version))
except (TypeError, struct.error) as exc:
actual_type = type(value)
message = ('Received an argument of invalid type for column "%s". '
'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc))
raise TypeError(message)
if proto_version >= 4:
diff = col_meta_len - len(self.values)
if diff:
for _ in range(diff):
self._append_unset_value()
return self
def _append_unset_value(self):
next_index = len(self.values)
if self.prepared_statement.is_routing_key_index(next_index):
col_meta = self.prepared_statement.column_metadata[next_index]
raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name)
self.values.append(UNSET_VALUE)
@property
def routing_key(self):
if not self.prepared_statement.routing_key_indexes:
return None
if self._routing_key is not None:
return self._routing_key
routing_indexes = self.prepared_statement.routing_key_indexes
if len(routing_indexes) == 1:
self._routing_key = self.values[routing_indexes[0]]
else:
self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes))
return self._routing_key
def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'' %
(self.prepared_statement.query_string, self.raw_values, consistency))
__repr__ = __str__
class BatchType(object):
"""
A BatchType is used with :class:`.BatchStatement` instances to control
the atomicity of the batch operation.
.. versionadded:: 2.0.0
"""
LOGGED = None
"""
Atomic batch operation.
"""
UNLOGGED = None
"""
Non-atomic batch operation.
"""
COUNTER = None
"""
Batches of counter operations.
"""
def __init__(self, name, value):
self.name = name
self.value = value
def __str__(self):
return self.name
def __repr__(self):
return "BatchType.%s" % (self.name, )
BatchType.LOGGED = BatchType("LOGGED", 0)
BatchType.UNLOGGED = BatchType("UNLOGGED", 1)
BatchType.COUNTER = BatchType("COUNTER", 2)
class BatchStatement(Statement):
"""
A protocol-level batch of operations which are applied atomically
by default.
.. versionadded:: 2.0.0
"""
batch_type = None
"""
The :class:`.BatchType` for the batch operation. Defaults to
:attr:`.BatchType.LOGGED`.
"""
serial_consistency_level = None
"""
The same as :attr:`.Statement.serial_consistency_level`, but is only
supported when using protocol version 3 or higher.
"""
_statements_and_parameters = None
_session = None
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
consistency_level=None, serial_consistency_level=None,
session=None, custom_payload=None):
"""
`batch_type` specifies The :class:`.BatchType` for the batch operation.
Defaults to :attr:`.BatchType.LOGGED`.
`retry_policy` should be a :class:`~.RetryPolicy` instance for
controlling retries on the operation.
`consistency_level` should be a :class:`~.ConsistencyLevel` value
to be used for all operations in the batch.
`custom_payload` is a :ref:`custom_payload` passed to the server.
Note: as Statement objects are added to the batch, this map is
updated with any values found in their custom payloads. These are
only allowed when using protocol version 4 or higher.
Example usage:
.. code-block:: python
insert_user = session.prepare("INSERT INTO users (name, age) VALUES (?, ?)")
batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM)
for (name, age) in users_to_insert:
batch.add(insert_user, (name, age))
session.execute(batch)
You can also mix different types of operations within a batch:
.. code-block:: python
batch = BatchStatement()
batch.add(SimpleStatement("INSERT INTO users (name, age) VALUES (%s, %s)"), (name, age))
batch.add(SimpleStatement("DELETE FROM pending_users WHERE name=%s"), (name,))
session.execute(batch)
.. versionadded:: 2.0.0
.. versionchanged:: 2.1.0
Added `serial_consistency_level` as a parameter
.. versionchanged:: 2.6.0
Added `custom_payload` as a parameter
"""
self.batch_type = batch_type
self._statements_and_parameters = []
self._session = session
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level,
serial_consistency_level=serial_consistency_level, custom_payload=custom_payload)
def clear(self):
"""
This is a convenience method to clear a batch statement for reuse.
*Note:* it should not be used concurrently with uncompleted execution futures executing the same
``BatchStatement``.
"""
del self._statements_and_parameters[:]
self.keyspace = None
self.routing_key = None
if self.custom_payload:
self.custom_payload.clear()
def add(self, statement, parameters=None):
"""
Adds a :class:`.Statement` and optional sequence of parameters
to be used with the statement to the batch.
Like with other statements, parameters must be a sequence, even
if there is only one item.
"""
if isinstance(statement, six.string_types):
if parameters:
encoder = Encoder() if self._session is None else self._session.encoder
statement = bind_params(statement, parameters, encoder)
self._add_statement_and_params(False, statement, ())
elif isinstance(statement, PreparedStatement):
query_id = statement.query_id
bound_statement = statement.bind(() if parameters is None else parameters)
self._update_state(bound_statement)
self._add_statement_and_params(True, query_id, bound_statement.values)
elif isinstance(statement, BoundStatement):
if parameters:
raise ValueError(
"Parameters cannot be passed with a BoundStatement "
"to BatchStatement.add()")
self._update_state(statement)
self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values)
else:
# it must be a SimpleStatement
query_string = statement.query_string
if parameters:
encoder = Encoder() if self._session is None else self._session.encoder
query_string = bind_params(query_string, parameters, encoder)
self._update_state(statement)
self._add_statement_and_params(False, query_string, ())
return self
def add_all(self, statements, parameters):
"""
Adds a sequence of :class:`.Statement` objects and a matching sequence
of parameters to the batch. Statement and parameter sequences must be of equal length or
one will be truncated. :const:`None` can be used in the parameters position where are needed.
"""
for statement, value in zip(statements, parameters):
self.add(statement, value)
def _add_statement_and_params(self, is_prepared, statement, parameters):
if len(self._statements_and_parameters) >= 0xFFFF:
raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF)
self._statements_and_parameters.append((is_prepared, statement, parameters))
def _maybe_set_routing_attributes(self, statement):
if self.routing_key is None:
if statement.keyspace and statement.routing_key:
self.routing_key = statement.routing_key
self.keyspace = statement.keyspace
def _update_custom_payload(self, statement):
if statement.custom_payload:
if self.custom_payload is None:
self.custom_payload = {}
self.custom_payload.update(statement.custom_payload)
def _update_state(self, statement):
self._maybe_set_routing_attributes(statement)
self._update_custom_payload(statement)
def __len__(self):
return len(self._statements_and_parameters)
def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'' %
(self.batch_type, len(self), consistency))
__repr__ = __str__
ValueSequence = cassandra.encoder.ValueSequence
"""
A wrapper class that is used to specify that a sequence of values should
be treated as a CQL list of values instead of a single column collection when used
as part of the `parameters` argument for :meth:`.Session.execute()`.
This is typically needed when supplying a list of keys to select.
For example::
>>> my_user_ids = ('alice', 'bob', 'charles')
>>> query = "SELECT * FROM users WHERE user_id IN %s"
>>> session.execute(query, parameters=[ValueSequence(my_user_ids)])
"""
def bind_params(query, params, encoder):
if six.PY2 and isinstance(query, six.text_type):
query = query.encode('utf-8')
if isinstance(params, dict):
return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params))
else:
return query % tuple(encoder.cql_encode_all_types(v) for v in params)
class TraceUnavailable(Exception):
"""
Raised when complete trace details cannot be fetched from Cassandra.
"""
pass
class QueryTrace(object):
"""
A trace of the duration and events that occurred when executing
an operation.
"""
trace_id = None
"""
:class:`uuid.UUID` unique identifier for this tracing session. Matches
the ``session_id`` column in ``system_traces.sessions`` and
``system_traces.events``.
"""
request_type = None
"""
A string that very generally describes the traced operation.
"""
duration = None
"""
A :class:`datetime.timedelta` measure of the duration of the query.
"""
client = None
"""
The IP address of the client that issued this request
This is only available when using Cassandra 2.2+
"""
coordinator = None
"""
The IP address of the host that acted as coordinator for this request.
"""
parameters = None
"""
A :class:`dict` of parameters for the traced operation, such as the
specific query string.
"""
started_at = None
"""
A UTC :class:`datetime.datetime` object describing when the operation
was started.
"""
events = None
"""
A chronologically sorted list of :class:`.TraceEvent` instances
representing the steps the traced operation went through. This
corresponds to the rows in ``system_traces.events`` for this tracing
session.
"""
_session = None
_SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s"
_SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s"
_BASE_RETRY_SLEEP = 0.003
def __init__(self, trace_id, session):
self.trace_id = trace_id
self._session = session
def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None):
"""
Retrieves the actual tracing details from Cassandra and populates the
attributes of this instance. Because tracing details are stored
asynchronously by Cassandra, this may need to retry the session
detail fetch. If the trace is still not available after `max_wait`
seconds, :exc:`.TraceUnavailable` will be raised; if `max_wait` is
:const:`None`, this will retry forever.
`wait_for_complete=False` bypasses the wait for duration to be populated.
This can be used to query events from partial sessions.
`query_cl` specifies a consistency level to use for polling the trace tables,
if it should be different than the session default.
"""
attempt = 0
start = time.time()
while True:
time_spent = time.time() - start
if max_wait is not None and time_spent >= max_wait:
raise TraceUnavailable(
"Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,))
log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id)
session_results = self._execute(
SimpleStatement(self._SELECT_SESSIONS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait)
# PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries
is_complete = session_results and session_results[0].duration is not None and session_results[0].started_at is not None
if not session_results or (wait_for_complete and not is_complete):
time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt))
attempt += 1
continue
if is_complete:
log.debug("Fetched trace info for trace ID: %s", self.trace_id)
else:
log.debug("Fetching parital trace info for trace ID: %s", self.trace_id)
session_row = session_results[0]
self.request_type = session_row.request
self.duration = timedelta(microseconds=session_row.duration) if is_complete else None
self.started_at = session_row.started_at
self.coordinator = session_row.coordinator
self.parameters = session_row.parameters
# since C* 2.2
self.client = getattr(session_row, 'client', None)
log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id)
time_spent = time.time() - start
event_results = self._execute(
SimpleStatement(self._SELECT_EVENTS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait)
log.debug("Fetched trace events for trace ID: %s", self.trace_id)
self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread)
for r in event_results)
break
def _execute(self, query, parameters, time_spent, max_wait):
timeout = (max_wait - time_spent) if max_wait is not None else None
future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout)
# in case the user switched the row factory, set it to namedtuple for this query
future.row_factory = named_tuple_factory
future.send_request()
try:
return future.result()
except OperationTimedOut:
raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
def __str__(self):
return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \
% (self.request_type, self.trace_id, self.coordinator, self.started_at,
self.duration, self.parameters)
class TraceEvent(object):
"""
Representation of a single event within a query trace.
"""
description = None
"""
A brief description of the event.
"""
datetime = None
"""
A UTC :class:`datetime.datetime` marking when the event occurred.
"""
source = None
"""
The IP address of the node this event occurred on.
"""
source_elapsed = None
"""
A :class:`datetime.timedelta` measuring the amount of time until
this event occurred starting from when :attr:`.source` first
received the query.
"""
thread_name = None
"""
The name of the thread that this event occurred on.
"""
def __init__(self, description, timeuuid, source, source_elapsed, thread_name):
self.description = description
self.datetime = datetime.utcfromtimestamp(unix_time_from_uuid1(timeuuid))
self.source = source
if source_elapsed is not None:
self.source_elapsed = timedelta(microseconds=source_elapsed)
else:
self.source_elapsed = None
self.thread_name = thread_name
def __str__(self):
return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime)
+
+
+# TODO remove next major since we can target using the `host` attribute of session.execute
+class HostTargetingStatement(object):
+ """
+ Wraps any query statement and attaches a target host, making
+ it usable in a targeted LBP without modifying the user's statement.
+ """
+ def __init__(self, inner_statement, target_host):
+ self.__class__ = type(inner_statement.__class__.__name__,
+ (self.__class__, inner_statement.__class__),
+ {})
+ self.__dict__ = inner_statement.__dict__
+ self.target_host = target_host
diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx
index 49fafd3..3a4b2f4 100644
--- a/cassandra/row_parser.pyx
+++ b/cassandra/row_parser.pyx
@@ -1,50 +1,48 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cassandra.parsing cimport ParseDesc, ColumnParser
from cassandra.obj_parser import TupleRowParser
from cassandra.deserializers import make_deserializers
include "ioutils.pyx"
def make_recv_results_rows(ColumnParser colparser):
- def recv_results_rows(cls, f, int protocol_version, user_type_map, result_metadata):
+ def recv_results_rows(self, f, int protocol_version, user_type_map, result_metadata):
"""
Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples)
This is used as the recv_results_rows method of (Fast)ResultMessage
"""
- paging_state, column_metadata, result_metadata_id = cls.recv_results_metadata(f, user_type_map)
+ self.recv_results_metadata(f, user_type_map)
- column_metadata = column_metadata or result_metadata
+ column_metadata = self.column_metadata or result_metadata
- colnames = [c[2] for c in column_metadata]
- coltypes = [c[3] for c in column_metadata]
+ self.column_names = [c[2] for c in column_metadata]
+ self.column_types = [c[3] for c in column_metadata]
- desc = ParseDesc(colnames, coltypes, make_deserializers(coltypes),
+ desc = ParseDesc(self.column_names, self.column_types, make_deserializers(self.column_types),
protocol_version)
reader = BytesIOReader(f.read())
try:
- parsed_rows = colparser.parse_rows(reader, desc)
+ self.parsed_rows = colparser.parse_rows(reader, desc)
except Exception as e:
# Use explicitly the TupleRowParser to display better error messages for column decoding failures
rowparser = TupleRowParser()
reader.buf_ptr = reader.buf
reader.pos = 0
rowcount = read_int(reader)
for i in range(rowcount):
rowparser.unpack_row(reader, desc)
- return (paging_state, coltypes, (colnames, parsed_rows), result_metadata_id)
-
return recv_results_rows
diff --git a/cassandra/segment.py b/cassandra/segment.py
new file mode 100644
index 0000000..e3881c4
--- /dev/null
+++ b/cassandra/segment.py
@@ -0,0 +1,224 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import zlib
+import six
+
+from cassandra import DriverException
+from cassandra.marshal import int32_pack
+from cassandra.protocol import write_uint_le, read_uint_le
+
+CRC24_INIT = 0x875060
+CRC24_POLY = 0x1974F0B
+CRC24_LENGTH = 3
+CRC32_LENGTH = 4
+CRC32_INITIAL = zlib.crc32(b"\xfa\x2d\x55\xca")
+
+
+class CrcException(Exception):
+ """
+ CRC mismatch error.
+
+ TODO: here to avoid import cycles with cassandra.connection. In the next
+ major, the exceptions should be declared in a separated exceptions.py
+ file.
+ """
+ pass
+
+
+def compute_crc24(data, length):
+ crc = CRC24_INIT
+
+ for _ in range(length):
+ crc ^= (data & 0xff) << 16
+ data >>= 8
+
+ for i in range(8):
+ crc <<= 1
+ if crc & 0x1000000 != 0:
+ crc ^= CRC24_POLY
+
+ return crc
+
+
+def compute_crc32(data, value):
+ crc32 = zlib.crc32(data, value)
+ if six.PY2:
+ crc32 &= 0xffffffff
+
+ return crc32
+
+
+class SegmentHeader(object):
+
+ payload_length = None
+ uncompressed_payload_length = None
+ is_self_contained = None
+
+ def __init__(self, payload_length, uncompressed_payload_length, is_self_contained):
+ self.payload_length = payload_length
+ self.uncompressed_payload_length = uncompressed_payload_length
+ self.is_self_contained = is_self_contained
+
+ @property
+ def segment_length(self):
+ """
+ Return the total length of the segment, including the CRC.
+ """
+ hl = SegmentCodec.UNCOMPRESSED_HEADER_LENGTH if self.uncompressed_payload_length < 1 \
+ else SegmentCodec.COMPRESSED_HEADER_LENGTH
+ return hl + CRC24_LENGTH + self.payload_length + CRC32_LENGTH
+
+
+class Segment(object):
+
+ MAX_PAYLOAD_LENGTH = 128 * 1024 - 1
+
+ payload = None
+ is_self_contained = None
+
+ def __init__(self, payload, is_self_contained):
+ self.payload = payload
+ self.is_self_contained = is_self_contained
+
+
+class SegmentCodec(object):
+
+ COMPRESSED_HEADER_LENGTH = 5
+ UNCOMPRESSED_HEADER_LENGTH = 3
+ FLAG_OFFSET = 17
+
+ compressor = None
+ decompressor = None
+
+ def __init__(self, compressor=None, decompressor=None):
+ self.compressor = compressor
+ self.decompressor = decompressor
+
+ @property
+ def header_length(self):
+ return self.COMPRESSED_HEADER_LENGTH if self.compression \
+ else self.UNCOMPRESSED_HEADER_LENGTH
+
+ @property
+ def header_length_with_crc(self):
+ return (self.COMPRESSED_HEADER_LENGTH if self.compression
+ else self.UNCOMPRESSED_HEADER_LENGTH) + CRC24_LENGTH
+
+ @property
+ def compression(self):
+ return self.compressor and self.decompressor
+
+ def compress(self, data):
+ # the uncompressed length is already encoded in the header, so
+ # we remove it here
+ return self.compressor(data)[4:]
+
+ def decompress(self, encoded_data, uncompressed_length):
+ return self.decompressor(int32_pack(uncompressed_length) + encoded_data)
+
+ def encode_header(self, buffer, payload_length, uncompressed_length, is_self_contained):
+ if payload_length > Segment.MAX_PAYLOAD_LENGTH:
+ raise DriverException('Payload length exceed Segment.MAX_PAYLOAD_LENGTH')
+
+ header_data = payload_length
+
+ flag_offset = self.FLAG_OFFSET
+ if self.compression:
+ header_data |= uncompressed_length << flag_offset
+ flag_offset += 17
+
+ if is_self_contained:
+ header_data |= 1 << flag_offset
+
+ write_uint_le(buffer, header_data, size=self.header_length)
+ header_crc = compute_crc24(header_data, self.header_length)
+ write_uint_le(buffer, header_crc, size=CRC24_LENGTH)
+
+ def _encode_segment(self, buffer, payload, is_self_contained):
+ """
+ Encode a message to a single segment.
+ """
+ uncompressed_payload = payload
+ uncompressed_payload_length = len(payload)
+
+ if self.compression:
+ compressed_payload = self.compress(uncompressed_payload)
+ if len(compressed_payload) >= uncompressed_payload_length:
+ encoded_payload = uncompressed_payload
+ uncompressed_payload_length = 0
+ else:
+ encoded_payload = compressed_payload
+ else:
+ encoded_payload = uncompressed_payload
+
+ payload_length = len(encoded_payload)
+ self.encode_header(buffer, payload_length, uncompressed_payload_length, is_self_contained)
+ payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
+ buffer.write(encoded_payload)
+ write_uint_le(buffer, payload_crc)
+
+ def encode(self, buffer, msg):
+ """
+ Encode a message to one of more segments.
+ """
+ msg_length = len(msg)
+
+ if msg_length > Segment.MAX_PAYLOAD_LENGTH:
+ payloads = []
+ for i in range(0, msg_length, Segment.MAX_PAYLOAD_LENGTH):
+ payloads.append(msg[i:i + Segment.MAX_PAYLOAD_LENGTH])
+ else:
+ payloads = [msg]
+
+ is_self_contained = len(payloads) == 1
+ for payload in payloads:
+ self._encode_segment(buffer, payload, is_self_contained)
+
+ def decode_header(self, buffer):
+ header_data = read_uint_le(buffer, self.header_length)
+
+ expected_header_crc = read_uint_le(buffer, CRC24_LENGTH)
+ actual_header_crc = compute_crc24(header_data, self.header_length)
+ if actual_header_crc != expected_header_crc:
+ raise CrcException('CRC mismatch on header {:x}. Received {:x}", computed {:x}.'.format(
+ header_data, expected_header_crc, actual_header_crc))
+
+ payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH
+ header_data >>= 17
+
+ if self.compression:
+ uncompressed_payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH
+ header_data >>= 17
+ else:
+ uncompressed_payload_length = -1
+
+ is_self_contained = (header_data & 1) == 1
+
+ return SegmentHeader(payload_length, uncompressed_payload_length, is_self_contained)
+
+ def decode(self, buffer, header):
+ encoded_payload = buffer.read(header.payload_length)
+ expected_payload_crc = read_uint_le(buffer)
+
+ actual_payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
+ if actual_payload_crc != expected_payload_crc:
+ raise CrcException('CRC mismatch on payload. Received {:x}", computed {:x}.'.format(
+ expected_payload_crc, actual_payload_crc))
+
+ payload = encoded_payload
+ if self.compression and header.uncompressed_payload_length > 0:
+ payload = self.decompress(encoded_payload, header.uncompressed_payload_length)
+
+ return Segment(payload, header.is_self_contained)
diff --git a/cassandra/util.py b/cassandra/util.py
index efb3a95..f896ff4 100644
--- a/cassandra/util.py
+++ b/cassandra/util.py
@@ -1,1328 +1,2039 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import with_statement
import calendar
import datetime
from functools import total_ordering
+import logging
+from itertools import chain
import random
+import re
import six
import uuid
import sys
+_HAS_GEOMET = True
+try:
+ from geomet import wkt
+except:
+ _HAS_GEOMET = False
+
+
+from cassandra import DriverException
+
DATETIME_EPOC = datetime.datetime(1970, 1, 1)
+UTC_DATETIME_EPOC = datetime.datetime.utcfromtimestamp(0)
+
+_nan = float('nan')
+
+log = logging.getLogger(__name__)
assert sys.byteorder in ('little', 'big')
is_little_endian = sys.byteorder == 'little'
+
def datetime_from_timestamp(timestamp):
"""
Creates a timezone-agnostic datetime from timestamp (in seconds) in a consistent manner.
Works around a Windows issue with large negative timestamps (PYTHON-119),
and rounding differences in Python 3.4 (PYTHON-340).
:param timestamp: a unix timestamp, in seconds
"""
dt = DATETIME_EPOC + datetime.timedelta(seconds=timestamp)
return dt
+def utc_datetime_from_ms_timestamp(timestamp):
+ """
+ Creates a UTC datetime from a timestamp in milliseconds. See
+ :meth:`datetime_from_timestamp`.
+
+ Raises an `OverflowError` if the timestamp is out of range for
+ :class:`~datetime.datetime`.
+
+ :param timestamp: timestamp, in milliseconds
+ """
+ return UTC_DATETIME_EPOC + datetime.timedelta(milliseconds=timestamp)
+
+
+def ms_timestamp_from_datetime(dt):
+ """
+ Converts a datetime to a timestamp expressed in milliseconds.
+
+ :param dt: a :class:`datetime.datetime`
+ """
+ return int(round((dt - UTC_DATETIME_EPOC).total_seconds() * 1000))
+
+
def unix_time_from_uuid1(uuid_arg):
"""
Converts a version 1 :class:`uuid.UUID` to a timestamp with the same precision
as :meth:`time.time()` returns. This is useful for examining the
results of queries returning a v1 :class:`~uuid.UUID`.
:param uuid_arg: a version 1 :class:`~uuid.UUID`
"""
return (uuid_arg.time - 0x01B21DD213814000) / 1e7
def datetime_from_uuid1(uuid_arg):
"""
Creates a timezone-agnostic datetime from the timestamp in the
specified type-1 UUID.
:param uuid_arg: a version 1 :class:`~uuid.UUID`
"""
return datetime_from_timestamp(unix_time_from_uuid1(uuid_arg))
def min_uuid_from_time(timestamp):
"""
Generates the minimum TimeUUID (type 1) for a given timestamp, as compared by Cassandra.
See :func:`uuid_from_time` for argument and return types.
"""
return uuid_from_time(timestamp, 0x808080808080, 0x80) # Cassandra does byte-wise comparison; fill with min signed bytes (0x80 = -128)
def max_uuid_from_time(timestamp):
"""
Generates the maximum TimeUUID (type 1) for a given timestamp, as compared by Cassandra.
See :func:`uuid_from_time` for argument and return types.
"""
return uuid_from_time(timestamp, 0x7f7f7f7f7f7f, 0x3f7f) # Max signed bytes (0x7f = 127)
def uuid_from_time(time_arg, node=None, clock_seq=None):
"""
Converts a datetime or timestamp to a type 1 :class:`uuid.UUID`.
:param time_arg:
The time to use for the timestamp portion of the UUID.
This can either be a :class:`datetime` object or a timestamp
in seconds (as returned from :meth:`time.time()`).
:type datetime: :class:`datetime` or timestamp
:param node:
None integer for the UUID (up to 48 bits). If not specified, this
field is randomized.
:type node: long
:param clock_seq:
Clock sequence field for the UUID (up to 14 bits). If not specified,
a random sequence is generated.
:type clock_seq: int
:rtype: :class:`uuid.UUID`
"""
if hasattr(time_arg, 'utctimetuple'):
seconds = int(calendar.timegm(time_arg.utctimetuple()))
microseconds = (seconds * 1e6) + time_arg.time().microsecond
else:
microseconds = int(time_arg * 1e6)
# 0x01b21dd213814000 is the number of 100-ns intervals between the
# UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00.
intervals = int(microseconds * 10) + 0x01b21dd213814000
time_low = intervals & 0xffffffff
time_mid = (intervals >> 32) & 0xffff
time_hi_version = (intervals >> 48) & 0x0fff
if clock_seq is None:
clock_seq = random.getrandbits(14)
else:
if clock_seq > 0x3fff:
raise ValueError('clock_seq is out of range (need a 14-bit value)')
clock_seq_low = clock_seq & 0xff
clock_seq_hi_variant = 0x80 | ((clock_seq >> 8) & 0x3f)
if node is None:
node = random.getrandbits(48)
return uuid.UUID(fields=(time_low, time_mid, time_hi_version,
clock_seq_hi_variant, clock_seq_low, node), version=1)
LOWEST_TIME_UUID = uuid.UUID('00000000-0000-1000-8080-808080808080')
""" The lowest possible TimeUUID, as sorted by Cassandra. """
HIGHEST_TIME_UUID = uuid.UUID('ffffffff-ffff-1fff-bf7f-7f7f7f7f7f7f')
""" The highest possible TimeUUID, as sorted by Cassandra. """
+def _addrinfo_or_none(contact_point, port):
+ """
+ A helper function that wraps socket.getaddrinfo and returns None
+ when it fails to, e.g. resolve one of the hostnames. Used to address
+ PYTHON-895.
+ """
+ try:
+ value = socket.getaddrinfo(contact_point, port,
+ socket.AF_UNSPEC, socket.SOCK_STREAM)
+ return value
+ except socket.gaierror:
+ log.debug('Could not resolve hostname "{}" '
+ 'with port {}'.format(contact_point, port))
+ return None
+
+
+def _addrinfo_to_ip_strings(addrinfo):
+ """
+ Helper function that consumes the data output by socket.getaddrinfo and
+ extracts the IP address from the sockaddr portion of the result.
+
+ Since this is meant to be used in conjunction with _addrinfo_or_none,
+ this will pass None and EndPoint instances through unaffected.
+ """
+ if addrinfo is None:
+ return None
+ return [(entry[4][0], entry[4][1]) for entry in addrinfo]
+
+
+def _resolve_contact_points_to_string_map(contact_points):
+ return OrderedDict(
+ ('{cp}:{port}'.format(cp=cp, port=port), _addrinfo_to_ip_strings(_addrinfo_or_none(cp, port)))
+ for cp, port in contact_points
+ )
+
+
try:
from collections import OrderedDict
except ImportError:
# OrderedDict from Python 2.7+
# Copyright (c) 2009 Raymond Hettinger
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
# OTHER DEALINGS IN THE SOFTWARE.
from UserDict import DictMixin
class OrderedDict(dict, DictMixin): # noqa
""" A dictionary which maintains the insertion order of keys. """
def __init__(self, *args, **kwds):
""" A dictionary which maintains the insertion order of keys. """
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
try:
self.__end
except AttributeError:
self.clear()
self.update(*args, **kwds)
def clear(self):
self.__end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.__map = {} # key --> [key, prev, next]
dict.clear(self)
def __setitem__(self, key, value):
if key not in self:
end = self.__end
curr = end[1]
curr[2] = end[1] = self.__map[key] = [key, curr, end]
dict.__setitem__(self, key, value)
def __delitem__(self, key):
dict.__delitem__(self, key)
key, prev, next = self.__map.pop(key)
prev[2] = next
next[1] = prev
def __iter__(self):
end = self.__end
curr = end[2]
while curr is not end:
yield curr[0]
curr = curr[2]
def __reversed__(self):
end = self.__end
curr = end[1]
while curr is not end:
yield curr[0]
curr = curr[1]
def popitem(self, last=True):
if not self:
raise KeyError('dictionary is empty')
if last:
key = next(reversed(self))
else:
key = next(iter(self))
value = self.pop(key)
return key, value
def __reduce__(self):
items = [[k, self[k]] for k in self]
tmp = self.__map, self.__end
del self.__map, self.__end
inst_dict = vars(self).copy()
self.__map, self.__end = tmp
if inst_dict:
return (self.__class__, (items,), inst_dict)
return self.__class__, (items,)
def keys(self):
return list(self)
setdefault = DictMixin.setdefault
update = DictMixin.update
pop = DictMixin.pop
values = DictMixin.values
items = DictMixin.items
iterkeys = DictMixin.iterkeys
itervalues = DictMixin.itervalues
iteritems = DictMixin.iteritems
def __repr__(self):
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, self.items())
def copy(self):
return self.__class__(self)
@classmethod
def fromkeys(cls, iterable, value=None):
d = cls()
for key in iterable:
d[key] = value
return d
def __eq__(self, other):
if isinstance(other, OrderedDict):
if len(self) != len(other):
return False
for p, q in zip(self.items(), other.items()):
if p != q:
return False
return True
return dict.__eq__(self, other)
def __ne__(self, other):
return not self == other
# WeakSet from Python 2.7+ (https://code.google.com/p/weakrefset)
from _weakref import ref
class _IterationGuard(object):
# This context manager registers itself in the current iterators of the
# weak container, such as to delay all removals until the context manager
# exits.
# This technique should be relatively thread-safe (since sets are).
def __init__(self, weakcontainer):
# Don't create cycles
self.weakcontainer = ref(weakcontainer)
def __enter__(self):
w = self.weakcontainer()
if w is not None:
w._iterating.add(self)
return self
def __exit__(self, e, t, b):
w = self.weakcontainer()
if w is not None:
s = w._iterating
s.remove(self)
if not s:
w._commit_removals()
class WeakSet(object):
def __init__(self, data=None):
self.data = set()
def _remove(item, selfref=ref(self)):
self = selfref()
if self is not None:
if self._iterating:
self._pending_removals.append(item)
else:
self.data.discard(item)
self._remove = _remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
if data is not None:
self.update(data)
def _commit_removals(self):
l = self._pending_removals
discard = self.data.discard
while l:
discard(l.pop())
def __iter__(self):
with _IterationGuard(self):
for itemref in self.data:
item = itemref()
if item is not None:
yield item
def __len__(self):
return sum(x() is not None for x in self.data)
def __contains__(self, item):
return ref(item) in self.data
def __reduce__(self):
return (self.__class__, (list(self),),
getattr(self, '__dict__', None))
__hash__ = None
def add(self, item):
if self._pending_removals:
self._commit_removals()
self.data.add(ref(item, self._remove))
def clear(self):
if self._pending_removals:
self._commit_removals()
self.data.clear()
def copy(self):
return self.__class__(self)
def pop(self):
if self._pending_removals:
self._commit_removals()
while True:
try:
itemref = self.data.pop()
except KeyError:
raise KeyError('pop from empty WeakSet')
item = itemref()
if item is not None:
return item
def remove(self, item):
if self._pending_removals:
self._commit_removals()
self.data.remove(ref(item))
def discard(self, item):
if self._pending_removals:
self._commit_removals()
self.data.discard(ref(item))
def update(self, other):
if self._pending_removals:
self._commit_removals()
if isinstance(other, self.__class__):
self.data.update(other.data)
else:
for element in other:
self.add(element)
def __ior__(self, other):
self.update(other)
return self
# Helper functions for simple delegating methods.
def _apply(self, other, method):
if not isinstance(other, self.__class__):
other = self.__class__(other)
newdata = method(other.data)
newset = self.__class__()
newset.data = newdata
return newset
def difference(self, other):
return self._apply(other, self.data.difference)
__sub__ = difference
def difference_update(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.difference_update(ref(item) for item in other)
def __isub__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.difference_update(ref(item) for item in other)
return self
def intersection(self, other):
return self._apply(other, self.data.intersection)
__and__ = intersection
def intersection_update(self, other):
if self._pending_removals:
self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
def __iand__(self, other):
if self._pending_removals:
self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
return self
def issubset(self, other):
return self.data.issubset(ref(item) for item in other)
__lt__ = issubset
def __le__(self, other):
return self.data <= set(ref(item) for item in other)
def issuperset(self, other):
return self.data.issuperset(ref(item) for item in other)
__gt__ = issuperset
def __ge__(self, other):
return self.data >= set(ref(item) for item in other)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.data == set(ref(item) for item in other)
def symmetric_difference(self, other):
return self._apply(other, self.data.symmetric_difference)
__xor__ = symmetric_difference
def symmetric_difference_update(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.symmetric_difference_update(ref(item) for item in other)
def __ixor__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.symmetric_difference_update(ref(item) for item in other)
return self
def union(self, other):
return self._apply(other, self.data.union)
__or__ = union
def isdisjoint(self, other):
return len(self.intersection(other)) == 0
class SortedSet(object):
'''
A sorted set based on sorted list
A sorted set implementation is used in this case because it does not
require its elements to be immutable/hashable.
#Not implemented: update functions, inplace operators
'''
def __init__(self, iterable=()):
self._items = []
self.update(iterable)
def __len__(self):
return len(self._items)
def __getitem__(self, i):
return self._items[i]
def __iter__(self):
return iter(self._items)
def __reversed__(self):
return reversed(self._items)
def __repr__(self):
return '%s(%r)' % (
self.__class__.__name__,
self._items)
def __reduce__(self):
return self.__class__, (self._items,)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self._items == other._items
else:
try:
return len(other) == len(self._items) and all(item in self for item in other)
except TypeError:
return NotImplemented
def __ne__(self, other):
if isinstance(other, self.__class__):
return self._items != other._items
else:
try:
return len(other) != len(self._items) or any(item not in self for item in other)
except TypeError:
return NotImplemented
def __le__(self, other):
return self.issubset(other)
def __lt__(self, other):
return len(other) > len(self._items) and self.issubset(other)
def __ge__(self, other):
return self.issuperset(other)
def __gt__(self, other):
return len(self._items) > len(other) and self.issuperset(other)
def __and__(self, other):
return self._intersect(other)
__rand__ = __and__
def __iand__(self, other):
isect = self._intersect(other)
self._items = isect._items
return self
def __or__(self, other):
return self.union(other)
__ror__ = __or__
def __ior__(self, other):
union = self.union(other)
self._items = union._items
return self
def __sub__(self, other):
return self._diff(other)
def __rsub__(self, other):
return sortedset(other) - self
def __isub__(self, other):
diff = self._diff(other)
self._items = diff._items
return self
def __xor__(self, other):
return self.symmetric_difference(other)
__rxor__ = __xor__
def __ixor__(self, other):
sym_diff = self.symmetric_difference(other)
self._items = sym_diff._items
return self
def __contains__(self, item):
i = self._find_insertion(item)
return i < len(self._items) and self._items[i] == item
def __delitem__(self, i):
del self._items[i]
def __delslice__(self, i, j):
del self._items[i:j]
def add(self, item):
i = self._find_insertion(item)
if i < len(self._items):
if self._items[i] != item:
self._items.insert(i, item)
else:
self._items.append(item)
def update(self, iterable):
for i in iterable:
self.add(i)
def clear(self):
del self._items[:]
def copy(self):
new = sortedset()
new._items = list(self._items)
return new
def isdisjoint(self, other):
return len(self._intersect(other)) == 0
def issubset(self, other):
return len(self._intersect(other)) == len(self._items)
def issuperset(self, other):
return len(self._intersect(other)) == len(other)
def pop(self):
if not self._items:
raise KeyError("pop from empty set")
return self._items.pop()
def remove(self, item):
i = self._find_insertion(item)
if i < len(self._items):
if self._items[i] == item:
self._items.pop(i)
return
raise KeyError('%r' % item)
def union(self, *others):
union = sortedset()
union._items = list(self._items)
for other in others:
for item in other:
union.add(item)
return union
def intersection(self, *others):
isect = self.copy()
for other in others:
isect = isect._intersect(other)
if not isect:
break
return isect
def difference(self, *others):
diff = self.copy()
for other in others:
diff = diff._diff(other)
if not diff:
break
return diff
def symmetric_difference(self, other):
diff_self_other = self._diff(other)
diff_other_self = other.difference(self)
return diff_self_other.union(diff_other_self)
def _diff(self, other):
diff = sortedset()
for item in self._items:
if item not in other:
diff.add(item)
return diff
def _intersect(self, other):
isect = sortedset()
for item in self._items:
if item in other:
isect.add(item)
return isect
def _find_insertion(self, x):
# this uses bisect_left algorithm unless it has elements it can't compare,
# in which case it defaults to grouping non-comparable items at the beginning or end,
# and scanning sequentially to find an insertion point
a = self._items
lo = 0
hi = len(a)
try:
while lo < hi:
mid = (lo + hi) // 2
if a[mid] < x: lo = mid + 1
else: hi = mid
except TypeError:
# could not compare a[mid] with x
# start scanning to find insertion point while swallowing type errors
lo = 0
compared_one = False # flag is used to determine whether uncomparables are grouped at the front or back
while lo < hi:
try:
if a[lo] == x or a[lo] >= x: break
compared_one = True
except TypeError:
if compared_one: break
lo += 1
return lo
sortedset = SortedSet # backwards-compatibility
from cassandra.compat import Mapping
from six.moves import cPickle
class OrderedMap(Mapping):
'''
An ordered map that accepts non-hashable types for keys. It also maintains the
insertion order of items, behaving as OrderedDict in that regard. These maps
are constructed and read just as normal mapping types, exept that they may
contain arbitrary collections and other non-hashable items as keys::
>>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'),
... ({'three': 3, 'four': 4}, 'value2')])
>>> list(od.keys())
[{'two': 2, 'one': 1}, {'three': 3, 'four': 4}]
>>> list(od.values())
['value', 'value2']
These constructs are needed to support nested collections in Cassandra 2.1.3+,
where frozen collections can be specified as parameters to others::
CREATE TABLE example (
...
value map>, double>
...
)
This class derives from the (immutable) Mapping API. Objects in these maps
are not intended be modified.
-
- Note: Because of the way Cassandra encodes nested types, when using the
- driver with nested collections, :attr:`~.Cluster.protocol_version` must be 3
- or higher.
-
'''
def __init__(self, *args, **kwargs):
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
self._items = []
self._index = {}
if args:
e = args[0]
if callable(getattr(e, 'keys', None)):
for k in e.keys():
self._insert(k, e[k])
else:
for k, v in e:
self._insert(k, v)
for k, v in six.iteritems(kwargs):
self._insert(k, v)
def _insert(self, key, value):
flat_key = self._serialize_key(key)
i = self._index.get(flat_key, -1)
if i >= 0:
self._items[i] = (key, value)
else:
self._items.append((key, value))
self._index[flat_key] = len(self._items) - 1
__setitem__ = _insert
def __getitem__(self, key):
try:
index = self._index[self._serialize_key(key)]
return self._items[index][1]
except KeyError:
raise KeyError(str(key))
def __delitem__(self, key):
# not efficient -- for convenience only
try:
index = self._index.pop(self._serialize_key(key))
self._index = dict((k, i if i < index else i - 1) for k, i in self._index.items())
self._items.pop(index)
except KeyError:
raise KeyError(str(key))
def __iter__(self):
for i in self._items:
yield i[0]
def __len__(self):
return len(self._items)
def __eq__(self, other):
if isinstance(other, OrderedMap):
return self._items == other._items
try:
d = dict(other)
return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items)
except KeyError:
return False
except TypeError:
pass
return NotImplemented
def __repr__(self):
return '%s([%s])' % (
self.__class__.__name__,
', '.join("(%r, %r)" % (k, v) for k, v in self._items))
def __str__(self):
return '{%s}' % ', '.join("%r: %r" % (k, v) for k, v in self._items)
def popitem(self):
try:
kv = self._items.pop()
del self._index[self._serialize_key(kv[0])]
return kv
except IndexError:
raise KeyError()
def _serialize_key(self, key):
return cPickle.dumps(key)
class OrderedMapSerializedKey(OrderedMap):
def __init__(self, cass_type, protocol_version):
super(OrderedMapSerializedKey, self).__init__()
self.cass_key_type = cass_type
self.protocol_version = protocol_version
def _insert_unchecked(self, key, flat_key, value):
self._items.append((key, value))
self._index[flat_key] = len(self._items) - 1
def _serialize_key(self, key):
return self.cass_key_type.serialize(key, self.protocol_version)
import datetime
import time
if six.PY3:
long = int
@total_ordering
class Time(object):
'''
Idealized time, independent of day.
Up to nanosecond resolution
'''
MICRO = 1000
MILLI = 1000 * MICRO
SECOND = 1000 * MILLI
MINUTE = 60 * SECOND
HOUR = 60 * MINUTE
DAY = 24 * HOUR
nanosecond_time = 0
def __init__(self, value):
"""
Initializer value can be:
- integer_type: absolute nanoseconds in the day
- datetime.time: built-in time
- string_type: a string time of the form "HH:MM:SS[.mmmuuunnn]"
"""
if isinstance(value, six.integer_types):
self._from_timestamp(value)
elif isinstance(value, datetime.time):
self._from_time(value)
elif isinstance(value, six.string_types):
self._from_timestring(value)
else:
raise TypeError('Time arguments must be a whole number, datetime.time, or string')
@property
def hour(self):
"""
The hour component of this time (0-23)
"""
return self.nanosecond_time // Time.HOUR
@property
def minute(self):
"""
The minute component of this time (0-59)
"""
minutes = self.nanosecond_time // Time.MINUTE
return minutes % 60
@property
def second(self):
"""
The second component of this time (0-59)
"""
seconds = self.nanosecond_time // Time.SECOND
return seconds % 60
@property
def nanosecond(self):
"""
The fractional seconds component of the time, in nanoseconds
"""
return self.nanosecond_time % Time.SECOND
def time(self):
"""
Return a built-in datetime.time (nanosecond precision truncated to micros).
"""
return datetime.time(hour=self.hour, minute=self.minute, second=self.second,
microsecond=self.nanosecond // Time.MICRO)
def _from_timestamp(self, t):
if t >= Time.DAY:
raise ValueError("value must be less than number of nanoseconds in a day (%d)" % Time.DAY)
self.nanosecond_time = t
def _from_timestring(self, s):
try:
parts = s.split('.')
base_time = time.strptime(parts[0], "%H:%M:%S")
self.nanosecond_time = (base_time.tm_hour * Time.HOUR +
base_time.tm_min * Time.MINUTE +
base_time.tm_sec * Time.SECOND)
if len(parts) > 1:
# right pad to 9 digits
nano_time_str = parts[1] + "0" * (9 - len(parts[1]))
self.nanosecond_time += int(nano_time_str)
except ValueError:
raise ValueError("can't interpret %r as a time" % (s,))
def _from_time(self, t):
self.nanosecond_time = (t.hour * Time.HOUR +
t.minute * Time.MINUTE +
t.second * Time.SECOND +
t.microsecond * Time.MICRO)
def __hash__(self):
return self.nanosecond_time
def __eq__(self, other):
if isinstance(other, Time):
return self.nanosecond_time == other.nanosecond_time
if isinstance(other, six.integer_types):
return self.nanosecond_time == other
return self.nanosecond_time % Time.MICRO == 0 and \
datetime.time(hour=self.hour, minute=self.minute, second=self.second,
microsecond=self.nanosecond // Time.MICRO) == other
def __ne__(self, other):
return not self.__eq__(other)
def __lt__(self, other):
if not isinstance(other, Time):
return NotImplemented
return self.nanosecond_time < other.nanosecond_time
def __repr__(self):
return "Time(%s)" % self.nanosecond_time
def __str__(self):
return "%02d:%02d:%02d.%09d" % (self.hour, self.minute,
self.second, self.nanosecond)
@total_ordering
class Date(object):
'''
Idealized date: year, month, day
Offers wider year range than datetime.date. For Dates that cannot be represented
as a datetime.date (because datetime.MINYEAR, datetime.MAXYEAR), this type falls back
to printing days_from_epoch offset.
'''
MINUTE = 60
HOUR = 60 * MINUTE
DAY = 24 * HOUR
date_format = "%Y-%m-%d"
days_from_epoch = 0
def __init__(self, value):
"""
Initializer value can be:
- integer_type: absolute days from epoch (1970, 1, 1). Can be negative.
- datetime.date: built-in date
- string_type: a string time of the form "yyyy-mm-dd"
"""
if isinstance(value, six.integer_types):
self.days_from_epoch = value
elif isinstance(value, (datetime.date, datetime.datetime)):
self._from_timetuple(value.timetuple())
elif isinstance(value, six.string_types):
self._from_datestring(value)
else:
raise TypeError('Date arguments must be a whole number, datetime.date, or string')
@property
def seconds(self):
"""
Absolute seconds from epoch (can be negative)
"""
return self.days_from_epoch * Date.DAY
def date(self):
"""
Return a built-in datetime.date for Dates falling in the years [datetime.MINYEAR, datetime.MAXYEAR]
ValueError is raised for Dates outside this range.
"""
try:
dt = datetime_from_timestamp(self.seconds)
return datetime.date(dt.year, dt.month, dt.day)
except Exception:
raise ValueError("%r exceeds ranges for built-in datetime.date" % self)
def _from_timetuple(self, t):
self.days_from_epoch = calendar.timegm(t) // Date.DAY
def _from_datestring(self, s):
if s[0] == '+':
s = s[1:]
dt = datetime.datetime.strptime(s, self.date_format)
self._from_timetuple(dt.timetuple())
def __hash__(self):
return self.days_from_epoch
def __eq__(self, other):
if isinstance(other, Date):
return self.days_from_epoch == other.days_from_epoch
if isinstance(other, six.integer_types):
return self.days_from_epoch == other
try:
return self.date() == other
except Exception:
return False
def __ne__(self, other):
return not self.__eq__(other)
def __lt__(self, other):
if not isinstance(other, Date):
return NotImplemented
return self.days_from_epoch < other.days_from_epoch
def __repr__(self):
return "Date(%s)" % self.days_from_epoch
def __str__(self):
try:
dt = datetime_from_timestamp(self.seconds)
return "%04d-%02d-%02d" % (dt.year, dt.month, dt.day)
except:
# If we overflow datetime.[MIN|MAX]
return str(self.days_from_epoch)
import socket
if hasattr(socket, 'inet_pton'):
inet_pton = socket.inet_pton
inet_ntop = socket.inet_ntop
else:
"""
Windows doesn't have socket.inet_pton and socket.inet_ntop until Python 3.4
This is an alternative impl using ctypes, based on this win_inet_pton project:
https://github.com/hickeroar/win_inet_pton
"""
import ctypes
class sockaddr(ctypes.Structure):
"""
Shared struct for ipv4 and ipv6.
https://msdn.microsoft.com/en-us/library/windows/desktop/ms740496(v=vs.85).aspx
``__pad1`` always covers the port.
When being used for ``sockaddr_in6``, ``ipv4_addr`` actually covers ``sin6_flowinfo``, resulting
in proper alignment for ``ipv6_addr``.
"""
_fields_ = [("sa_family", ctypes.c_short),
("__pad1", ctypes.c_ushort),
("ipv4_addr", ctypes.c_byte * 4),
("ipv6_addr", ctypes.c_byte * 16),
("__pad2", ctypes.c_ulong)]
if hasattr(ctypes, 'windll'):
WSAStringToAddressA = ctypes.windll.ws2_32.WSAStringToAddressA
WSAAddressToStringA = ctypes.windll.ws2_32.WSAAddressToStringA
else:
def not_windows(*args):
raise OSError("IPv6 addresses cannot be handled on Windows. "
"Missing ctypes.windll")
WSAStringToAddressA = not_windows
WSAAddressToStringA = not_windows
def inet_pton(address_family, ip_string):
if address_family == socket.AF_INET:
return socket.inet_aton(ip_string)
addr = sockaddr()
addr.sa_family = address_family
addr_size = ctypes.c_int(ctypes.sizeof(addr))
if WSAStringToAddressA(
ip_string,
address_family,
None,
ctypes.byref(addr),
ctypes.byref(addr_size)
) != 0:
raise socket.error(ctypes.FormatError())
if address_family == socket.AF_INET6:
return ctypes.string_at(addr.ipv6_addr, 16)
raise socket.error('unknown address family')
def inet_ntop(address_family, packed_ip):
if address_family == socket.AF_INET:
return socket.inet_ntoa(packed_ip)
addr = sockaddr()
addr.sa_family = address_family
addr_size = ctypes.c_int(ctypes.sizeof(addr))
ip_string = ctypes.create_string_buffer(128)
ip_string_size = ctypes.c_int(ctypes.sizeof(ip_string))
if address_family == socket.AF_INET6:
if len(packed_ip) != ctypes.sizeof(addr.ipv6_addr):
raise socket.error('packed IP wrong length for inet_ntoa')
ctypes.memmove(addr.ipv6_addr, packed_ip, 16)
else:
raise socket.error('unknown address family')
if WSAAddressToStringA(
ctypes.byref(addr),
addr_size,
None,
ip_string,
ctypes.byref(ip_string_size)
) != 0:
raise socket.error(ctypes.FormatError())
return ip_string[:ip_string_size.value - 1]
import keyword
# similar to collections.namedtuple, reproduced here because Python 2.6 did not have the rename logic
def _positional_rename_invalid_identifiers(field_names):
names_out = list(field_names)
for index, name in enumerate(field_names):
if (not all(c.isalnum() or c == '_' for c in name)
or keyword.iskeyword(name)
or not name
or name[0].isdigit()
or name.startswith('_')):
names_out[index] = 'field_%d_' % index
return names_out
def _sanitize_identifiers(field_names):
names_out = _positional_rename_invalid_identifiers(field_names)
if len(names_out) != len(set(names_out)):
observed_names = set()
for index, name in enumerate(names_out):
while names_out[index] in observed_names:
names_out[index] = "%s_" % (names_out[index],)
observed_names.add(names_out[index])
return names_out
+def list_contents_to_tuple(to_convert):
+ if isinstance(to_convert, list):
+ for n, i in enumerate(to_convert):
+ if isinstance(to_convert[n], list):
+ to_convert[n] = tuple(to_convert[n])
+ return tuple(to_convert)
+ else:
+ return to_convert
+
+
+class Point(object):
+ """
+ Represents a point geometry for DSE
+ """
+
+ x = None
+ """
+ x coordinate of the point
+ """
+
+ y = None
+ """
+ y coordinate of the point
+ """
+
+ def __init__(self, x=_nan, y=_nan):
+ self.x = x
+ self.y = y
+
+ def __eq__(self, other):
+ return isinstance(other, Point) and self.x == other.x and self.y == other.y
+
+ def __hash__(self):
+ return hash((self.x, self.y))
+
+ def __str__(self):
+ """
+ Well-known text representation of the point
+ """
+ return "POINT (%r %r)" % (self.x, self.y)
+
+ def __repr__(self):
+ return "%s(%r, %r)" % (self.__class__.__name__, self.x, self.y)
+
+ @staticmethod
+ def from_wkt(s):
+ """
+ Parse a Point geometry from a wkt string and return a new Point object.
+ """
+ if not _HAS_GEOMET:
+ raise DriverException("Geomet is required to deserialize a wkt geometry.")
+
+ try:
+ geom = wkt.loads(s)
+ except ValueError:
+ raise ValueError("Invalid WKT geometry: '{0}'".format(s))
+
+ if geom['type'] != 'Point':
+ raise ValueError("Invalid WKT geometry type. Expected 'Point', got '{0}': '{1}'".format(geom['type'], s))
+
+ coords = geom['coordinates']
+ if len(coords) < 2:
+ x = y = _nan
+ else:
+ x = coords[0]
+ y = coords[1]
+
+ return Point(x=x, y=y)
+
+
+class LineString(object):
+ """
+ Represents a linestring geometry for DSE
+ """
+
+ coords = None
+ """
+ Tuple of (x, y) coordinates in the linestring
+ """
+ def __init__(self, coords=tuple()):
+ """
+ 'coords`: a sequence of (x, y) coordinates of points in the linestring
+ """
+ self.coords = tuple(coords)
+
+ def __eq__(self, other):
+ return isinstance(other, LineString) and self.coords == other.coords
+
+ def __hash__(self):
+ return hash(self.coords)
+
+ def __str__(self):
+ """
+ Well-known text representation of the LineString
+ """
+ if not self.coords:
+ return "LINESTRING EMPTY"
+ return "LINESTRING (%s)" % ', '.join("%r %r" % (x, y) for x, y in self.coords)
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self.coords)
+
+ @staticmethod
+ def from_wkt(s):
+ """
+ Parse a LineString geometry from a wkt string and return a new LineString object.
+ """
+ if not _HAS_GEOMET:
+ raise DriverException("Geomet is required to deserialize a wkt geometry.")
+
+ try:
+ geom = wkt.loads(s)
+ except ValueError:
+ raise ValueError("Invalid WKT geometry: '{0}'".format(s))
+
+ if geom['type'] != 'LineString':
+ raise ValueError("Invalid WKT geometry type. Expected 'LineString', got '{0}': '{1}'".format(geom['type'], s))
+
+ geom['coordinates'] = list_contents_to_tuple(geom['coordinates'])
+
+ return LineString(coords=geom['coordinates'])
+
+
+class _LinearRing(object):
+ # no validation, no implicit closing; just used for poly composition, to
+ # mimic that of shapely.geometry.Polygon
+ def __init__(self, coords=tuple()):
+ self.coords = list_contents_to_tuple(coords)
+
+ def __eq__(self, other):
+ return isinstance(other, _LinearRing) and self.coords == other.coords
+
+ def __hash__(self):
+ return hash(self.coords)
+
+ def __str__(self):
+ if not self.coords:
+ return "LINEARRING EMPTY"
+ return "LINEARRING (%s)" % ', '.join("%r %r" % (x, y) for x, y in self.coords)
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self.coords)
+
+
+class Polygon(object):
+ """
+ Represents a polygon geometry for DSE
+ """
+
+ exterior = None
+ """
+ _LinearRing representing the exterior of the polygon
+ """
+
+ interiors = None
+ """
+ Tuple of _LinearRings representing interior holes in the polygon
+ """
+
+ def __init__(self, exterior=tuple(), interiors=None):
+ """
+ 'exterior`: a sequence of (x, y) coordinates of points in the linestring
+ `interiors`: None, or a sequence of sequences or (x, y) coordinates of points describing interior linear rings
+ """
+ self.exterior = _LinearRing(exterior)
+ self.interiors = tuple(_LinearRing(e) for e in interiors) if interiors else tuple()
+
+ def __eq__(self, other):
+ return isinstance(other, Polygon) and self.exterior == other.exterior and self.interiors == other.interiors
+
+ def __hash__(self):
+ return hash((self.exterior, self.interiors))
+
+ def __str__(self):
+ """
+ Well-known text representation of the polygon
+ """
+ if not self.exterior.coords:
+ return "POLYGON EMPTY"
+ rings = [ring.coords for ring in chain((self.exterior,), self.interiors)]
+ rings = ["(%s)" % ', '.join("%r %r" % (x, y) for x, y in ring) for ring in rings]
+ return "POLYGON (%s)" % ', '.join(rings)
+
+ def __repr__(self):
+ return "%s(%r, %r)" % (self.__class__.__name__, self.exterior.coords, [ring.coords for ring in self.interiors])
+
+ @staticmethod
+ def from_wkt(s):
+ """
+ Parse a Polygon geometry from a wkt string and return a new Polygon object.
+ """
+ if not _HAS_GEOMET:
+ raise DriverException("Geomet is required to deserialize a wkt geometry.")
+
+ try:
+ geom = wkt.loads(s)
+ except ValueError:
+ raise ValueError("Invalid WKT geometry: '{0}'".format(s))
+
+ if geom['type'] != 'Polygon':
+ raise ValueError("Invalid WKT geometry type. Expected 'Polygon', got '{0}': '{1}'".format(geom['type'], s))
+
+ coords = geom['coordinates']
+ exterior = coords[0] if len(coords) > 0 else tuple()
+ interiors = coords[1:] if len(coords) > 1 else None
+
+ return Polygon(exterior=exterior, interiors=interiors)
+
+
+_distance_wkt_pattern = re.compile("distance *\\( *\\( *([\\d\\.-]+) *([\\d+\\.-]+) *\\) *([\\d+\\.-]+) *\\) *$", re.IGNORECASE)
+
+
+class Distance(object):
+ """
+ Represents a Distance geometry for DSE
+ """
+
+ x = None
+ """
+ x coordinate of the center point
+ """
+
+ y = None
+ """
+ y coordinate of the center point
+ """
+
+ radius = None
+ """
+ radius to represent the distance from the center point
+ """
+
+ def __init__(self, x=_nan, y=_nan, radius=_nan):
+ self.x = x
+ self.y = y
+ self.radius = radius
+
+ def __eq__(self, other):
+ return isinstance(other, Distance) and self.x == other.x and self.y == other.y and self.radius == other.radius
+
+ def __hash__(self):
+ return hash((self.x, self.y, self.radius))
+
+ def __str__(self):
+ """
+ Well-known text representation of the point
+ """
+ return "DISTANCE ((%r %r) %r)" % (self.x, self.y, self.radius)
+
+ def __repr__(self):
+ return "%s(%r, %r, %r)" % (self.__class__.__name__, self.x, self.y, self.radius)
+
+ @staticmethod
+ def from_wkt(s):
+ """
+ Parse a Distance geometry from a wkt string and return a new Distance object.
+ """
+
+ distance_match = _distance_wkt_pattern.match(s)
+
+ if distance_match is None:
+ raise ValueError("Invalid WKT geometry: '{0}'".format(s))
+
+ x, y, radius = distance_match.groups()
+ return Distance(x, y, radius)
+
+
class Duration(object):
"""
Cassandra Duration Type
"""
months = 0
+ ""
days = 0
+ ""
nanoseconds = 0
+ ""
def __init__(self, months=0, days=0, nanoseconds=0):
self.months = months
self.days = days
self.nanoseconds = nanoseconds
def __eq__(self, other):
return isinstance(other, self.__class__) and self.months == other.months and self.days == other.days and self.nanoseconds == other.nanoseconds
def __repr__(self):
return "Duration({0}, {1}, {2})".format(self.months, self.days, self.nanoseconds)
def __str__(self):
has_negative_values = self.months < 0 or self.days < 0 or self.nanoseconds < 0
return '%s%dmo%dd%dns' % (
'-' if has_negative_values else '',
abs(self.months),
abs(self.days),
abs(self.nanoseconds)
)
+class DateRangePrecision(object):
+ """
+ An "enum" representing the valid values for :attr:`DateRange.precision`.
+ """
+ YEAR = 'YEAR'
+ """
+ """
+
+ MONTH = 'MONTH'
+ """
+ """
+
+ DAY = 'DAY'
+ """
+ """
+
+ HOUR = 'HOUR'
+ """
+ """
+
+ MINUTE = 'MINUTE'
+ """
+ """
+
+ SECOND = 'SECOND'
+ """
+ """
+
+ MILLISECOND = 'MILLISECOND'
+ """
+ """
+
+ PRECISIONS = (YEAR, MONTH, DAY, HOUR,
+ MINUTE, SECOND, MILLISECOND)
+ """
+ """
+
+ @classmethod
+ def _to_int(cls, precision):
+ return cls.PRECISIONS.index(precision.upper())
+
+ @classmethod
+ def _round_to_precision(cls, ms, precision, default_dt):
+ try:
+ dt = utc_datetime_from_ms_timestamp(ms)
+ except OverflowError:
+ return ms
+ precision_idx = cls._to_int(precision)
+ replace_kwargs = {}
+ if precision_idx <= cls._to_int(DateRangePrecision.YEAR):
+ replace_kwargs['month'] = default_dt.month
+ if precision_idx <= cls._to_int(DateRangePrecision.MONTH):
+ replace_kwargs['day'] = default_dt.day
+ if precision_idx <= cls._to_int(DateRangePrecision.DAY):
+ replace_kwargs['hour'] = default_dt.hour
+ if precision_idx <= cls._to_int(DateRangePrecision.HOUR):
+ replace_kwargs['minute'] = default_dt.minute
+ if precision_idx <= cls._to_int(DateRangePrecision.MINUTE):
+ replace_kwargs['second'] = default_dt.second
+ if precision_idx <= cls._to_int(DateRangePrecision.SECOND):
+ # truncate to nearest 1000 so we deal in ms, not us
+ replace_kwargs['microsecond'] = (default_dt.microsecond // 1000) * 1000
+ if precision_idx == cls._to_int(DateRangePrecision.MILLISECOND):
+ replace_kwargs['microsecond'] = int(round(dt.microsecond, -3))
+ return ms_timestamp_from_datetime(dt.replace(**replace_kwargs))
+
+ @classmethod
+ def round_up_to_precision(cls, ms, precision):
+ # PYTHON-912: this is the only case in which we can't take as upper bound
+ # datetime.datetime.max because the month from ms may be February and we'd
+ # be setting 31 as the month day
+ if precision == cls.MONTH:
+ date_ms = utc_datetime_from_ms_timestamp(ms)
+ upper_date = datetime.datetime.max.replace(year=date_ms.year, month=date_ms.month,
+ day=calendar.monthrange(date_ms.year, date_ms.month)[1])
+ else:
+ upper_date = datetime.datetime.max
+ return cls._round_to_precision(ms, precision, upper_date)
+
+ @classmethod
+ def round_down_to_precision(cls, ms, precision):
+ return cls._round_to_precision(ms, precision, datetime.datetime.min)
+
+
+@total_ordering
+class DateRangeBound(object):
+ """DateRangeBound(value, precision)
+ Represents a single date value and its precision for :class:`DateRange`.
+
+ .. attribute:: milliseconds
+
+ Integer representing milliseconds since the UNIX epoch. May be negative.
+
+ .. attribute:: precision
+
+ String representing the precision of a bound. Must be a valid
+ :class:`DateRangePrecision` member.
+
+ :class:`DateRangeBound` uses a millisecond offset from the UNIX epoch to
+ allow :class:`DateRange` to represent values `datetime.datetime` cannot.
+ For such values, string representions will show this offset rather than the
+ CQL representation.
+ """
+ milliseconds = None
+ precision = None
+
+ def __init__(self, value, precision):
+ """
+ :param value: a value representing ms since the epoch. Accepts an
+ integer or a datetime.
+ :param precision: a string representing precision
+ """
+ if precision is not None:
+ try:
+ self.precision = precision.upper()
+ except AttributeError:
+ raise TypeError('precision must be a string; got %r' % precision)
+
+ if value is None:
+ milliseconds = None
+ elif isinstance(value, six.integer_types):
+ milliseconds = value
+ elif isinstance(value, datetime.datetime):
+ value = value.replace(
+ microsecond=int(round(value.microsecond, -3))
+ )
+ milliseconds = ms_timestamp_from_datetime(value)
+ else:
+ raise ValueError('%r is not a valid value for DateRangeBound' % value)
+
+ self.milliseconds = milliseconds
+ self.validate()
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return (self.milliseconds == other.milliseconds and
+ self.precision == other.precision)
+
+ def __lt__(self, other):
+ return ((str(self.milliseconds), str(self.precision)) <
+ (str(other.milliseconds), str(other.precision)))
+
+ def datetime(self):
+ """
+ Return :attr:`milliseconds` as a :class:`datetime.datetime` if possible.
+ Raises an `OverflowError` if the value is out of range.
+ """
+ return utc_datetime_from_ms_timestamp(self.milliseconds)
+
+ def validate(self):
+ attrs = self.milliseconds, self.precision
+ if attrs == (None, None):
+ return
+ if None in attrs:
+ raise TypeError(
+ ("%s.datetime and %s.precision must not be None unless both "
+ "are None; Got: %r") % (self.__class__.__name__,
+ self.__class__.__name__,
+ self)
+ )
+ if self.precision not in DateRangePrecision.PRECISIONS:
+ raise ValueError(
+ "%s.precision: expected value in %r; got %r" % (
+ self.__class__.__name__,
+ DateRangePrecision.PRECISIONS,
+ self.precision
+ )
+ )
+
+ @classmethod
+ def from_value(cls, value):
+ """
+ Construct a new :class:`DateRangeBound` from a given value. If
+ possible, use the `value['milliseconds']` and `value['precision']` keys
+ of the argument. Otherwise, use the argument as a `(milliseconds,
+ precision)` iterable.
+
+ :param value: a dictlike or iterable object
+ """
+ if isinstance(value, cls):
+ return value
+
+ # if possible, use as a mapping
+ try:
+ milliseconds, precision = value.get('milliseconds'), value.get('precision')
+ except AttributeError:
+ milliseconds = precision = None
+ if milliseconds is not None and precision is not None:
+ return DateRangeBound(value=milliseconds, precision=precision)
+
+ # otherwise, use as an iterable
+ return DateRangeBound(*value)
+
+ def round_up(self):
+ if self.milliseconds is None or self.precision is None:
+ return self
+ self.milliseconds = DateRangePrecision.round_up_to_precision(
+ self.milliseconds, self.precision
+ )
+ return self
+
+ def round_down(self):
+ if self.milliseconds is None or self.precision is None:
+ return self
+ self.milliseconds = DateRangePrecision.round_down_to_precision(
+ self.milliseconds, self.precision
+ )
+ return self
+
+ _formatter_map = {
+ DateRangePrecision.YEAR: '%Y',
+ DateRangePrecision.MONTH: '%Y-%m',
+ DateRangePrecision.DAY: '%Y-%m-%d',
+ DateRangePrecision.HOUR: '%Y-%m-%dT%HZ',
+ DateRangePrecision.MINUTE: '%Y-%m-%dT%H:%MZ',
+ DateRangePrecision.SECOND: '%Y-%m-%dT%H:%M:%SZ',
+ DateRangePrecision.MILLISECOND: '%Y-%m-%dT%H:%M:%S',
+ }
+
+ def __str__(self):
+ if self == OPEN_BOUND:
+ return '*'
+
+ try:
+ dt = self.datetime()
+ except OverflowError:
+ return '%sms' % (self.milliseconds,)
+
+ formatted = dt.strftime(self._formatter_map[self.precision])
+
+ if self.precision == DateRangePrecision.MILLISECOND:
+ # we'd like to just format with '%Y-%m-%dT%H:%M:%S.%fZ', but %f
+ # gives us more precision than we want, so we strftime up to %S and
+ # do the rest ourselves
+ return '%s.%03dZ' % (formatted, dt.microsecond / 1000)
+
+ return formatted
+
+ def __repr__(self):
+ return '%s(milliseconds=%r, precision=%r)' % (
+ self.__class__.__name__, self.milliseconds, self.precision
+ )
+
+
+OPEN_BOUND = DateRangeBound(value=None, precision=None)
+"""
+Represents `*`, an open value or bound for :class:`DateRange`.
+"""
+
+
+@total_ordering
+class DateRange(object):
+ """DateRange(lower_bound=None, upper_bound=None, value=None)
+ DSE DateRange Type
+
+ .. attribute:: lower_bound
+
+ :class:`~DateRangeBound` representing the lower bound of a bounded range.
+
+ .. attribute:: upper_bound
+
+ :class:`~DateRangeBound` representing the upper bound of a bounded range.
+
+ .. attribute:: value
+
+ :class:`~DateRangeBound` representing the value of a single-value range.
+
+ As noted in its documentation, :class:`DateRangeBound` uses a millisecond
+ offset from the UNIX epoch to allow :class:`DateRange` to represent values
+ `datetime.datetime` cannot. For such values, string representions will show
+ this offset rather than the CQL representation.
+ """
+ lower_bound = None
+ upper_bound = None
+ value = None
+
+ def __init__(self, lower_bound=None, upper_bound=None, value=None):
+ """
+ :param lower_bound: a :class:`DateRangeBound` or object accepted by
+ :meth:`DateRangeBound.from_value` to be used as a
+ :attr:`lower_bound`. Mutually exclusive with `value`. If
+ `upper_bound` is specified and this is not, the :attr:`lower_bound`
+ will be open.
+ :param upper_bound: a :class:`DateRangeBound` or object accepted by
+ :meth:`DateRangeBound.from_value` to be used as a
+ :attr:`upper_bound`. Mutually exclusive with `value`. If
+ `lower_bound` is specified and this is not, the :attr:`upper_bound`
+ will be open.
+ :param value: a :class:`DateRangeBound` or object accepted by
+ :meth:`DateRangeBound.from_value` to be used as :attr:`value`. Mutually
+ exclusive with `lower_bound` and `lower_bound`.
+ """
+
+ # if necessary, transform non-None args to DateRangeBounds
+ lower_bound = (DateRangeBound.from_value(lower_bound).round_down()
+ if lower_bound else lower_bound)
+ upper_bound = (DateRangeBound.from_value(upper_bound).round_up()
+ if upper_bound else upper_bound)
+ value = (DateRangeBound.from_value(value).round_down()
+ if value else value)
+
+ # if we're using a 2-ended range but one bound isn't specified, specify
+ # an open bound
+ if lower_bound is None and upper_bound is not None:
+ lower_bound = OPEN_BOUND
+ if upper_bound is None and lower_bound is not None:
+ upper_bound = OPEN_BOUND
+
+ self.lower_bound, self.upper_bound, self.value = (
+ lower_bound, upper_bound, value
+ )
+ self.validate()
+
+ def validate(self):
+ if self.value is None:
+ if self.lower_bound is None or self.upper_bound is None:
+ raise ValueError(
+ '%s instances where value attribute is None must set '
+ 'lower_bound or upper_bound; got %r' % (
+ self.__class__.__name__,
+ self
+ )
+ )
+ else: # self.value is not None
+ if self.lower_bound is not None or self.upper_bound is not None:
+ raise ValueError(
+ '%s instances where value attribute is not None must not '
+ 'set lower_bound or upper_bound; got %r' % (
+ self.__class__.__name__,
+ self
+ )
+ )
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ return (self.lower_bound == other.lower_bound and
+ self.upper_bound == other.upper_bound and
+ self.value == other.value)
+
+ def __lt__(self, other):
+ return ((str(self.lower_bound), str(self.upper_bound), str(self.value)) <
+ (str(other.lower_bound), str(other.upper_bound), str(other.value)))
+
+ def __str__(self):
+ if self.value:
+ return str(self.value)
+ else:
+ return '[%s TO %s]' % (self.lower_bound, self.upper_bound)
+
+ def __repr__(self):
+ return '%s(lower_bound=%r, upper_bound=%r, value=%r)' % (
+ self.__class__.__name__,
+ self.lower_bound, self.upper_bound, self.value
+ )
+
+
@total_ordering
class Version(object):
"""
Internal minimalist class to compare versions.
A valid version is: ....
TODO: when python2 support is removed, use packaging.version.
"""
_version = None
major = None
minor = 0
patch = 0
build = 0
prerelease = 0
def __init__(self, version):
self._version = version
if '-' in version:
version_without_prerelease, self.prerelease = version.split('-', 1)
else:
version_without_prerelease = version
parts = list(reversed(version_without_prerelease.split('.')))
if len(parts) > 4:
- raise ValueError("Invalid version: {}. Only 4 "
- "components plus prerelease are supported".format(version))
+ prerelease_string = "-{}".format(self.prerelease) if self.prerelease else ""
+ log.warning("Unrecognized version: {}. Only 4 components plus prerelease are supported. "
+ "Assuming version as {}{}".format(version, '.'.join(parts[:-5:-1]), prerelease_string))
- self.major = int(parts.pop())
- self.minor = int(parts.pop()) if parts else 0
- self.patch = int(parts.pop()) if parts else 0
+ try:
+ self.major = int(parts.pop())
+ except ValueError:
+ six.reraise(
+ ValueError,
+ ValueError("Couldn't parse version {}. Version should start with a number".format(version)),
+ sys.exc_info()[2]
+ )
+ try:
+ self.minor = int(parts.pop()) if parts else 0
+ self.patch = int(parts.pop()) if parts else 0
- if parts: # we have a build version
- build = parts.pop()
- try:
- self.build = int(build)
- except ValueError:
- self.build = build
+ if parts: # we have a build version
+ build = parts.pop()
+ try:
+ self.build = int(build)
+ except ValueError:
+ self.build = build
+ except ValueError:
+ assumed_version = "{}.{}.{}.{}-{}".format(self.major, self.minor, self.patch, self.build, self.prerelease)
+ log.warning("Unrecognized version {}. Assuming version as {}".format(version, assumed_version))
def __hash__(self):
return self._version
def __repr__(self):
version_string = "Version({0}, {1}, {2}".format(self.major, self.minor, self.patch)
if self.build:
version_string += ", {}".format(self.build)
if self.prerelease:
version_string += ", {}".format(self.prerelease)
version_string += ")"
return version_string
def __str__(self):
return self._version
@staticmethod
def _compare_version_part(version, other_version, cmp):
if not (isinstance(version, six.integer_types) and
isinstance(other_version, six.integer_types)):
version = str(version)
other_version = str(other_version)
return cmp(version, other_version)
def __eq__(self, other):
if not isinstance(other, Version):
return NotImplemented
return (self.major == other.major and
self.minor == other.minor and
self.patch == other.patch and
self._compare_version_part(self.build, other.build, lambda s, o: s == o) and
self._compare_version_part(self.prerelease, other.prerelease, lambda s, o: s == o)
)
def __gt__(self, other):
if not isinstance(other, Version):
return NotImplemented
is_major_ge = self.major >= other.major
is_minor_ge = self.minor >= other.minor
is_patch_ge = self.patch >= other.patch
is_build_gt = self._compare_version_part(self.build, other.build, lambda s, o: s > o)
is_build_ge = self._compare_version_part(self.build, other.build, lambda s, o: s >= o)
# By definition, a prerelease comes BEFORE the actual release, so if a version
# doesn't have a prerelease, it's automatically greater than anything that does
if self.prerelease and not other.prerelease:
is_prerelease_gt = False
elif other.prerelease and not self.prerelease:
is_prerelease_gt = True
else:
is_prerelease_gt = self._compare_version_part(self.prerelease, other.prerelease, lambda s, o: s > o) \
return (self.major > other.major or
(is_major_ge and self.minor > other.minor) or
(is_major_ge and is_minor_ge and self.patch > other.patch) or
(is_major_ge and is_minor_ge and is_patch_ge and is_build_gt) or
(is_major_ge and is_minor_ge and is_patch_ge and is_build_ge and is_prerelease_gt)
)
diff --git a/docs.yaml b/docs.yaml
index 6212699..8e29b94 100644
--- a/docs.yaml
+++ b/docs.yaml
@@ -1,65 +1,75 @@
-title: DataStax Python Driver for Apache Cassandra
-summary: DataStax Python Driver for Apache Cassandra Documentation
+title: DataStax Python Driver
+summary: DataStax Python Driver for Apache Cassandra®
output: docs/_build/
swiftype_drivers: pythondrivers
checks:
external_links:
exclude:
- 'http://aka.ms/vcpython27'
sections:
- title: N/A
prefix: /
type: sphinx
directory: docs
virtualenv_init: |
set -x
- CASS_DRIVER_NO_CYTHON=1 pip install -r test-requirements.txt
+ CASS_DRIVER_NO_CYTHON=1 pip install -r test-datastax-requirements.txt
# for newer versions this is redundant, but in older versions we need to
# install, e.g., the cassandra driver, and those versions don't specify
# the cassandra driver version in requirements files
CASS_DRIVER_NO_CYTHON=1 python setup.py develop
pip install "jinja2==2.8.1;python_version<'3.6'" "sphinx>=1.3,<2" geomet
# build extensions like libev
CASS_DRIVER_NO_CYTHON=1 python setup.py build_ext --inplace --force
versions:
+ - name: '3.25'
+ ref: a83c36a5
+ - name: '3.24'
+ ref: 21cac12b
+ - name: '3.23'
+ ref: a40a2af7
+ - name: '3.22'
+ ref: 1ccd5b99
+ - name: '3.21'
+ ref: 5589d96b
- name: '3.20'
ref: d30d166f
- name: '3.19'
ref: ac2471f9
- name: '3.18'
ref: ec36b957
- name: '3.17'
ref: 38e359e1
- name: '3.16'
ref: '3.16.0'
- name: '3.15'
ref: '2ce0bd97'
- name: '3.14'
ref: '9af8bd19'
- name: '3.13'
ref: '3.13.0'
- name: '3.12'
ref: '43b9c995'
- name: '3.11'
ref: '3.11.0'
- name: '3.10'
ref: 64572368
- name: 3.9
ref: 3.9-doc
- name: 3.8
ref: 3.8-doc
- name: 3.7
ref: 3.7-doc
- name: 3.6
ref: 3.6-doc
- name: 3.5
ref: 3.5-doc
redirects:
- \A\/(.*)/\Z: /\1.html
rewrites:
- search: cassandra.apache.org/doc/cql3/CQL.html
replace: cassandra.apache.org/doc/cql3/CQL-3.0.html
- search: http://www.datastax.com/documentation/cql/3.1/
replace: https://docs.datastax.com/en/archived/cql/3.1/
- search: http://www.datastax.com/docs/1.2/cql_cli/cql/BATCH
replace: https://docs.datastax.com/en/dse/6.7/cql/cql/cql_reference/cql_commands/cqlBatch.html
diff --git a/docs/.nav b/docs/.nav
index 7b39d90..375f058 100644
--- a/docs/.nav
+++ b/docs/.nav
@@ -1,14 +1,18 @@
installation
getting_started
execution_profiles
lwt
object_mapper
+geo_types
+graph
+graph_fluent
+classic_graph
performance
query_paging
security
upgrading
user_defined_types
dates_and_times
cloud
faq
api
diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst
new file mode 100644
index 0000000..592a2c0
--- /dev/null
+++ b/docs/CHANGELOG.rst
@@ -0,0 +1,5 @@
+*********
+CHANGELOG
+*********
+
+.. include:: ../CHANGELOG.rst
diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst
index 81cf1f0..2b3d782 100644
--- a/docs/api/cassandra/cluster.rst
+++ b/docs/api/cassandra/cluster.rst
@@ -1,209 +1,228 @@
``cassandra.cluster`` - Clusters and Sessions
=============================================
.. module:: cassandra.cluster
.. autoclass:: Cluster ([contact_points=('127.0.0.1',)][, port=9042][, executor_threads=2], **attr_kwargs)
.. autoattribute:: contact_points
.. autoattribute:: port
.. autoattribute:: cql_version
.. autoattribute:: protocol_version
.. autoattribute:: compression
.. autoattribute:: auth_provider
.. autoattribute:: load_balancing_policy
.. autoattribute:: reconnection_policy
.. autoattribute:: default_retry_policy
:annotation: =
.. autoattribute:: conviction_policy_factory
.. autoattribute:: address_translator
.. autoattribute:: metrics_enabled
.. autoattribute:: metrics
.. autoattribute:: ssl_context
.. autoattribute:: ssl_options
.. autoattribute:: sockopts
.. autoattribute:: max_schema_agreement_wait
.. autoattribute:: metadata
.. autoattribute:: connection_class
.. autoattribute:: control_connection_timeout
.. autoattribute:: idle_heartbeat_interval
.. autoattribute:: idle_heartbeat_timeout
.. autoattribute:: schema_event_refresh_window
.. autoattribute:: topology_event_refresh_window
.. autoattribute:: status_event_refresh_window
.. autoattribute:: prepare_on_all_hosts
.. autoattribute:: reprepare_on_up
.. autoattribute:: connect_timeout
.. autoattribute:: schema_metadata_enabled
:annotation: = True
.. autoattribute:: token_metadata_enabled
:annotation: = True
.. autoattribute:: timestamp_generator
.. autoattribute:: endpoint_factory
.. autoattribute:: cloud
.. automethod:: connect
.. automethod:: shutdown
.. automethod:: register_user_type
.. automethod:: register_listener
.. automethod:: unregister_listener
.. automethod:: add_execution_profile
.. automethod:: set_max_requests_per_connection
.. automethod:: get_max_requests_per_connection
.. automethod:: set_min_requests_per_connection
.. automethod:: get_min_requests_per_connection
.. automethod:: get_core_connections_per_host
.. automethod:: set_core_connections_per_host
.. automethod:: get_max_connections_per_host
.. automethod:: set_max_connections_per_host
.. automethod:: get_control_connection_host
.. automethod:: refresh_schema_metadata
.. automethod:: refresh_keyspace_metadata
.. automethod:: refresh_table_metadata
.. automethod:: refresh_user_type_metadata
.. automethod:: refresh_user_function_metadata
.. automethod:: refresh_user_aggregate_metadata
.. automethod:: refresh_nodes
.. automethod:: set_meta_refresh_enabled
-.. autoclass:: ExecutionProfile (load_balancing_policy=