diff --git a/.gitignore b/.gitignore
index 5c9cbec957..30ff731f85 100644
--- a/.gitignore
+++ b/.gitignore
@@ -42,3 +42,6 @@ tests/unit/cython/bytesio_testhelper.c
#iPython
*.ipynb
+venv
+docs/venv
+.eggs
\ No newline at end of file
diff --git a/.travis.yml b/.travis.yml
index f1fff4bb63..5a483f9a03 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,14 +1,11 @@
-dist: xenial
+dist: jammy
sudo: false
language: python
python:
- - "2.7"
- - "3.5"
- - "3.6"
- - "3.7"
- - "pypy2.7-6.0"
- - "pypy3.5"
+ - "3.8"
+ - "3.9"
+ - "3.10"
env:
- CASS_DRIVER_NO_CYTHON=1
@@ -17,14 +14,16 @@ addons:
apt:
packages:
- build-essential
- - python-dev
+ - python3-dev
- pypy-dev
- libc-ares-dev
- libev4
- libev-dev
install:
- - pip install tox-travis lz4
+ - pip install --upgrade setuptools importlib-metadata
+ - pip install tox-travis
+ - if [[ $TRAVIS_PYTHON_VERSION != pypy3.5 ]]; then pip install lz4; fi
script:
- tox
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index ae1b50a589..9dce17dcb6 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -1,3 +1,179 @@
+3.29.1
+======
+March 19, 2024
+
+Bug Fixes
+--------
+* cassandra-driver for Python 3.12 Linux is compiled without libev support (PYTHON-1378)
+* Consider moving to native wheel builds for OS X and removing universal2 wheels (PYTHON-1379)
+
+3.29.0
+======
+December 19, 2023
+
+Features
+--------
+* Add support for Python 3.9 through 3.12, drop support for 3.7 (PYTHON-1283)
+* Removal of dependency on six module (PR 1172)
+* Raise explicit exception when deserializing a vector with a subtype that isn’t a constant size (PYTHON-1371)
+
+Others
+------
+* Remove outdated Python pre-3.7 references (PR 1186)
+* Remove backup(.bak) files (PR 1185)
+* Fix doc typo in add_callbacks (PR 1177)
+
+3.28.0
+======
+June 5, 2023
+
+Features
+--------
+* Add support for vector type (PYTHON-1352)
+* Cryptography module is now an optional dependency (PYTHON-1351)
+
+Bug Fixes
+---------
+* Store IV along with encrypted text when using column-level encryption (PYTHON-1350)
+* Create session-specific protocol handlers to contain session-specific CLE policies (PYTHON-1356)
+
+Others
+------
+* Use Cython for smoke builds (PYTHON-1343)
+* Don't fail when inserting UDTs with prepared queries with some missing fields (PR 1151)
+* Convert print statement to function in docs (PR 1157)
+* Update comment for retry policy (DOC-3278)
+* Added error handling blog reference (DOC-2813)
+
+3.27.0
+======
+May 1, 2023
+
+Features
+--------
+* Add support for client-side encryption (PYTHON-1341)
+
+3.26.0
+======
+March 13, 2023
+
+Features
+--------
+* Add support for execution profiles in execute_concurrent (PR 1122)
+
+Bug Fixes
+---------
+* Handle empty non-final result pages (PR 1110)
+* Do not re-use stream IDs for in-flight requests (PR 1114)
+* Asyncore race condition cause logging exception on shutdown (PYTHON-1266)
+
+Others
+------
+* Fix deprecation warning in query tracing (PR 1103)
+* Remove mutable default values from some tests (PR 1116)
+* Remove dependency on unittest2 (PYTHON-1289)
+* Fix deprecation warnings for asyncio.coroutine annotation in asyncioreactor (PYTHON-1290)
+* Fix typos in source files (PR 1126)
+* HostFilterPolicyInitTest fix for Python 3.11 (PR 1131)
+* Fix for DontPrepareOnIgnoredHostsTest (PYTHON-1287)
+* tests.integration.simulacron.test_connection failures (PYTHON-1304)
+* tests.integration.standard.test_single_interface.py appears to be failing for C* 4.0 (PYTHON-1329)
+* Authentication tests appear to be failing fraudulently (PYTHON-1328)
+* PreparedStatementTests.test_fail_if_different_query_id_on_reprepare() failing unexpectedly (PTYHON-1327)
+* Refactor deprecated unittest aliases for Python 3.11 compatibility (PR 1112)
+
+Deprecations
+------------
+* This release removes support for Python 2.7.x as well as Python 3.5.x and 3.6.x
+
+3.25.0
+======
+March 18, 2021
+
+Features
+--------
+* Ensure the driver can connect when invalid peer hosts are in system.peers (PYTHON-1260)
+* Implement protocol v5 checksumming (PYTHON-1258)
+* Fix the default cqlengine connection mechanism to work with Astra (PYTHON-1265)
+
+Bug Fixes
+---------
+* Asyncore race condition cause logging exception on shutdown (PYTHON-1266)
+* Update list of reserved keywords (PYTHON-1269)
+
+Others
+------
+* Drop Python 3.4 support (PYTHON-1220)
+* Update security documentation and examples to use PROTOCOL_TLS (PYTHON-1264)
+
+3.24.0
+======
+June 18, 2020
+
+Features
+--------
+* Make geomet an optional dependency at runtime (PYTHON-1237)
+* Add use_default_tempdir cloud config options (PYTHON-1245)
+* Tcp flow control for libevreactor (PYTHON-1248)
+
+Bug Fixes
+---------
+* Unable to connect to a cloud cluster using Ubuntu 20.04 (PYTHON-1238)
+* PlainTextAuthProvider fails with unicode chars and Python3 (PYTHON-1241)
+* [GRAPH] Graph execution profiles consistency level are not set to LOCAL_QUORUM with a cloud cluster (PYTHON-1240)
+* [GRAPH] Can't write data in a Boolean field using the Fluent API (PYTHON-1239)
+* [GRAPH] Fix elementMap() result deserialization (PYTHON-1233)
+
+Others
+------
+* Bump geomet dependency version to 0.2 (PYTHON-1243)
+* Bump gremlinpython dependency version to 3.4.6 (PYTHON-1212)
+* Improve fluent graph documentation for core graphs (PYTHON-1244)
+
+3.23.0
+======
+April 6, 2020
+
+Features
+--------
+* Transient Replication Support (PYTHON-1207)
+* Support system.peers_v2 and port discovery for C* 4.0 (PYTHON-700)
+
+Bug Fixes
+---------
+* Asyncore logging exception on shutdown (PYTHON-1228)
+
+3.22.0
+======
+February 26, 2020
+
+Features
+--------
+
+* Add all() function to the ResultSet API (PYTHON-1203)
+* Parse new schema metadata in NGDG and generate table edges CQL syntax (PYTHON-996)
+* Add GraphSON3 support (PYTHON-788)
+* Use GraphSON3 as default for Native graphs (PYTHON-1004)
+* Add Tuple and UDT types for native graph (PYTHON-1005)
+* Add Duration type for native graph (PYTHON-1000)
+* Add gx:ByteBuffer graphson type support for Blob field (PYTHON-1027)
+* Enable Paging Through DSE Driver for Gremlin Traversals (PYTHON-1045)
+* Provide numerical wrappers to ensure proper graphson schema definition (PYTHON-1051)
+* Resolve the row_factory automatically for native graphs (PYTHON-1056)
+* Add g:TraversalMetrics/g:Metrics graph deserializers (PYTHON-1057)
+* Add g:BulkSet graph deserializers (PYTHON-1060)
+* Update Graph Engine names and the way to create a Classic/Native Graph (PYTHON-1090)
+* Update Native to Core Graph Engine
+* Add graphson3 and native graph support (PYTHON-1039)
+* Enable Paging Through DSE Driver for Gremlin Traversals (PYTHON-1045)
+* Expose filter predicates for cql collections (PYTHON-1019)
+* Add g:TraversalMetrics/Metrics deserializers (PYTHON-1057)
+* Make graph metadata handling more robust (PYTHON-1204)
+
+Bug Fixes
+---------
+* Make sure to only query the native_transport_address column with DSE (PYTHON-1205)
+
3.21.0
======
January 15, 2020
@@ -31,6 +207,7 @@ Others
* Remove *read_repair_chance table options (PYTHON-1140)
* Avoid warnings about unspecified load balancing policy when connecting to a cloud cluster (PYTHON-1177)
* Add new DSE CQL keywords (PYTHON-1122)
+* Publish binary wheel distributions (PYTHON-1013)
Deprecations
------------
@@ -114,7 +291,7 @@ October 28, 2019
Features
--------
-* DataStax Apollo Support (PYTHON-1074)
+* DataStax Astra Support (PYTHON-1074)
* Use 4.0 schema parser in 4 alpha and snapshot builds (PYTHON-1158)
Bug Fixes
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index cdd742c063..e5da81d74f 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -26,7 +26,6 @@ To protect the community, all contributors are required to `sign the DataStax Co
Design and Implementation Guidelines
------------------------------------
-- We support Python 2.7+, so any changes must work in any of these runtimes (we use ``six``, ``futures``, and some internal backports for compatability)
- We have integrations (notably Cassandra cqlsh) that require pure Python and minimal external dependencies. We try to avoid new external dependencies. Where compiled extensions are concerned, there should always be a pure Python fallback implementation.
- This project follows `semantic versioning `_, so breaking API changes will only be introduced in major versions.
- Legacy ``cqlengine`` has varying degrees of overreaching client-side validation. Going forward, we will avoid client validation where server feedback is adequate and not overly expensive.
diff --git a/Jenkinsfile b/Jenkinsfile
new file mode 100644
index 0000000000..fdc5e74269
--- /dev/null
+++ b/Jenkinsfile
@@ -0,0 +1,658 @@
+#!groovy
+/*
+
+There are multiple combinations to test the python driver.
+
+Test Profiles:
+
+ Full: Execute all unit and integration tests, including long tests.
+ Standard: Execute unit and integration tests.
+ Smoke Tests: Execute a small subset of tests.
+ EVENT_LOOP: Execute a small subset of tests selected to test EVENT_LOOPs.
+
+Matrix Types:
+
+ Full: All server versions, python runtimes tested with and without Cython.
+ Cassandra: All cassandra server versions.
+ Dse: All dse server versions.
+ Smoke: CI-friendly configurations. Currently-supported Python version + modern Cassandra/DSE instances.
+ We also avoid cython since it's tested as part of the nightlies
+
+Parameters:
+
+ EVENT_LOOP: 'LIBEV' (Default), 'GEVENT', 'EVENTLET', 'ASYNCIO', 'ASYNCORE', 'TWISTED'
+ CYTHON: Default, 'True', 'False'
+
+*/
+
+@Library('dsdrivers-pipeline-lib@develop')
+import com.datastax.jenkins.drivers.python.Slack
+
+slack = new Slack()
+
+DEFAULT_CASSANDRA = ['3.0', '3.11', '4.0']
+DEFAULT_DSE = ['dse-5.1.35', 'dse-6.8.30']
+DEFAULT_RUNTIME = ['3.8.16', '3.9.16', '3.10.11', '3.11.3', '3.12.0']
+DEFAULT_CYTHON = ["True", "False"]
+matrices = [
+ "FULL": [
+ "SERVER": DEFAULT_CASSANDRA + DEFAULT_DSE,
+ "RUNTIME": DEFAULT_RUNTIME,
+ "CYTHON": DEFAULT_CYTHON
+ ],
+ "CASSANDRA": [
+ "SERVER": DEFAULT_CASSANDRA,
+ "RUNTIME": DEFAULT_RUNTIME,
+ "CYTHON": DEFAULT_CYTHON
+ ],
+ "DSE": [
+ "SERVER": DEFAULT_DSE,
+ "RUNTIME": DEFAULT_RUNTIME,
+ "CYTHON": DEFAULT_CYTHON
+ ],
+ "SMOKE": [
+ "SERVER": DEFAULT_CASSANDRA.takeRight(2) + DEFAULT_DSE.takeRight(1),
+ "RUNTIME": DEFAULT_RUNTIME.take(1) + DEFAULT_RUNTIME.takeRight(1),
+ "CYTHON": ["True"]
+ ]
+]
+
+def initializeSlackContext() {
+ /*
+ Based on git branch/commit, configure the build context and env vars.
+ */
+
+ def driver_display_name = 'Cassandra Python Driver'
+ if (env.GIT_URL.contains('riptano/python-driver')) {
+ driver_display_name = 'private ' + driver_display_name
+ } else if (env.GIT_URL.contains('python-dse-driver')) {
+ driver_display_name = 'DSE Python Driver'
+ }
+ env.DRIVER_DISPLAY_NAME = driver_display_name
+ env.GIT_SHA = "${env.GIT_COMMIT.take(7)}"
+ env.GITHUB_PROJECT_URL = "https://${GIT_URL.replaceFirst(/(git@|http:\/\/|https:\/\/)/, '').replace(':', '/').replace('.git', '')}"
+ env.GITHUB_BRANCH_URL = "${env.GITHUB_PROJECT_URL}/tree/${env.BRANCH_NAME}"
+ env.GITHUB_COMMIT_URL = "${env.GITHUB_PROJECT_URL}/commit/${env.GIT_COMMIT}"
+}
+
+def getBuildContext() {
+ /*
+ Based on schedule and parameters, configure the build context and env vars.
+ */
+
+ def PROFILE = "${params.PROFILE}"
+ def EVENT_LOOP = "${params.EVENT_LOOP.toLowerCase()}"
+
+ matrixType = params.MATRIX != "DEFAULT" ? params.MATRIX : "SMOKE"
+ matrix = matrices[matrixType].clone()
+
+ // Check if parameters were set explicitly
+ if (params.CYTHON != "DEFAULT") {
+ matrix["CYTHON"] = [params.CYTHON]
+ }
+
+ if (params.SERVER_VERSION != "DEFAULT") {
+ matrix["SERVER"] = [params.SERVER_VERSION]
+ }
+
+ if (params.PYTHON_VERSION != "DEFAULT") {
+ matrix["RUNTIME"] = [params.PYTHON_VERSION]
+ }
+
+ if (params.CI_SCHEDULE == "WEEKNIGHTS") {
+ matrix["SERVER"] = params.CI_SCHEDULE_SERVER_VERSION.split(' ')
+ matrix["RUNTIME"] = params.CI_SCHEDULE_PYTHON_VERSION.split(' ')
+ }
+
+ context = [
+ vars: [
+ "PROFILE=${PROFILE}",
+ "EVENT_LOOP=${EVENT_LOOP}"
+ ],
+ matrix: matrix
+ ]
+
+ return context
+}
+
+def buildAndTest(context) {
+ initializeEnvironment()
+ installDriverAndCompileExtensions()
+
+ try {
+ executeTests()
+ } finally {
+ junit testResults: '*_results.xml'
+ }
+}
+
+def getMatrixBuilds(buildContext) {
+ def tasks = [:]
+ matrix = buildContext.matrix
+
+ matrix["SERVER"].each { serverVersion ->
+ matrix["RUNTIME"].each { runtimeVersion ->
+ matrix["CYTHON"].each { cythonFlag ->
+ def taskVars = [
+ "CASSANDRA_VERSION=${serverVersion}",
+ "PYTHON_VERSION=${runtimeVersion}",
+ "CYTHON_ENABLED=${cythonFlag}"
+ ]
+ def cythonDesc = cythonFlag == "True" ? ", Cython": ""
+ tasks["${serverVersion}, py${runtimeVersion}${cythonDesc}"] = {
+ node("${OS_VERSION}") {
+ scm_variables = checkout scm
+ env.GIT_COMMIT = scm_variables.get('GIT_COMMIT')
+ env.GIT_URL = scm_variables.get('GIT_URL')
+ initializeSlackContext()
+
+ if (env.BUILD_STATED_SLACK_NOTIFIED != 'true') {
+ slack.notifyChannel()
+ }
+
+ withEnv(taskVars) {
+ buildAndTest(context)
+ }
+ }
+ }
+ }
+ }
+ }
+ return tasks
+}
+
+def initializeEnvironment() {
+ sh label: 'Initialize the environment', script: '''#!/bin/bash -lex
+ pyenv global ${PYTHON_VERSION}
+ sudo apt-get install socat
+ pip install --upgrade pip
+ pip install -U setuptools
+
+ # install a version of pyyaml<6.0 compatible with ccm-3.1.5 as of Aug 2023
+ # this works around the python-3.10+ compatibility problem as described in DSP-23524
+ pip install wheel
+ pip install "Cython<3.0" "pyyaml<6.0" --no-build-isolation
+ pip install ${HOME}/ccm
+ '''
+
+ // Determine if server version is Apache CassandraⓇ or DataStax Enterprise
+ if (env.CASSANDRA_VERSION.split('-')[0] == 'dse') {
+ if (env.PYTHON_VERSION =~ /3\.12\.\d+/) {
+ echo "Cannot install DSE dependencies for Python 3.12.x; installing Apache CassandraⓇ requirements only. See PYTHON-1368 for more detail."
+ sh label: 'Install Apache CassandraⓇ requirements', script: '''#!/bin/bash -lex
+ pip install -r test-requirements.txt
+ '''
+ }
+ else {
+ sh label: 'Install DataStax Enterprise requirements', script: '''#!/bin/bash -lex
+ pip install -r test-datastax-requirements.txt
+ '''
+ }
+ } else {
+ sh label: 'Install Apache CassandraⓇ requirements', script: '''#!/bin/bash -lex
+ pip install -r test-requirements.txt
+ '''
+
+ sh label: 'Uninstall the geomet dependency since it is not required for Cassandra', script: '''#!/bin/bash -lex
+ pip uninstall -y geomet
+ '''
+ }
+
+ sh label: 'Install unit test modules', script: '''#!/bin/bash -lex
+ pip install --no-deps nose-ignore-docstring nose-exclude
+ pip install service_identity
+ '''
+
+ if (env.CYTHON_ENABLED == 'True') {
+ sh label: 'Install cython modules', script: '''#!/bin/bash -lex
+ pip install cython numpy
+ '''
+ }
+
+ sh label: 'Download Apache CassandraⓇ or DataStax Enterprise', script: '''#!/bin/bash -lex
+ . ${CCM_ENVIRONMENT_SHELL} ${CASSANDRA_VERSION}
+ '''
+
+ if (env.CASSANDRA_VERSION.split('-')[0] == 'dse') {
+ env.DSE_FIXED_VERSION = env.CASSANDRA_VERSION.split('-')[1]
+ sh label: 'Update environment for DataStax Enterprise', script: '''#!/bin/bash -le
+ cat >> ${HOME}/environment.txt << ENVIRONMENT_EOF
+CCM_CASSANDRA_VERSION=${DSE_FIXED_VERSION} # maintain for backwards compatibility
+CCM_VERSION=${DSE_FIXED_VERSION}
+CCM_SERVER_TYPE=dse
+DSE_VERSION=${DSE_FIXED_VERSION}
+CCM_IS_DSE=true
+CCM_BRANCH=${DSE_FIXED_VERSION}
+DSE_BRANCH=${DSE_FIXED_VERSION}
+ENVIRONMENT_EOF
+ '''
+ }
+
+ sh label: 'Display Python and environment information', script: '''#!/bin/bash -le
+ # Load CCM environment variables
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ python --version
+ pip --version
+ pip freeze
+ printenv | sort
+ '''
+}
+
+def installDriverAndCompileExtensions() {
+ if (env.CYTHON_ENABLED == 'True') {
+ sh label: 'Install the driver and compile with C extensions with Cython', script: '''#!/bin/bash -lex
+ python setup.py build_ext --inplace
+ '''
+ } else {
+ sh label: 'Install the driver and compile with C extensions without Cython', script: '''#!/bin/bash -lex
+ python setup.py build_ext --inplace --no-cython
+ '''
+ }
+}
+
+def executeStandardTests() {
+
+ sh label: 'Execute unit tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variables
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_results.xml tests/unit/ || true
+ EVENT_LOOP=eventlet VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_eventlet_results.xml tests/unit/io/test_eventletreactor.py || true
+ EVENT_LOOP=gevent VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_gevent_results.xml tests/unit/io/test_geventreactor.py || true
+ '''
+
+ sh label: 'Execute Simulacron integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variables
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ SIMULACRON_JAR="${HOME}/simulacron.jar"
+ SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --exclude test_backpressure.py --xunit-file=simulacron_results.xml tests/integration/simulacron/ || true
+
+ # Run backpressure tests separately to avoid memory issue
+ SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --exclude test_backpressure.py --xunit-file=simulacron_backpressure_1_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_paused_connections || true
+ SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --exclude test_backpressure.py --xunit-file=simulacron_backpressure_2_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_queued_requests_timeout || true
+ SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --exclude test_backpressure.py --xunit-file=simulacron_backpressure_3_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_cluster_busy || true
+ SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --exclude test_backpressure.py --xunit-file=simulacron_backpressure_4_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_node_busy || true
+ '''
+
+ sh label: 'Execute CQL engine integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variables
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=cqle_results.xml tests/integration/cqlengine/ || true
+ '''
+
+ sh label: 'Execute Apache CassandraⓇ integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variables
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/ || true
+ '''
+
+ if (env.CASSANDRA_VERSION.split('-')[0] == 'dse' && env.CASSANDRA_VERSION.split('-')[1] != '4.8') {
+ if (env.PYTHON_VERSION =~ /3\.12\.\d+/) {
+ echo "Cannot install DSE dependencies for Python 3.12.x. See PYTHON-1368 for more detail."
+ }
+ else {
+ sh label: 'Execute DataStax Enterprise integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variable
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} DSE_VERSION=${DSE_VERSION} ADS_HOME="${HOME}/" VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=dse_results.xml tests/integration/advanced/ || true
+ '''
+ }
+ }
+
+ sh label: 'Execute DataStax Astra integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variable
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CLOUD_PROXY_PATH="${HOME}/proxy/" CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=advanced_results.xml tests/integration/cloud/ || true
+ '''
+
+ if (env.PROFILE == 'FULL') {
+ sh label: 'Execute long running integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variable
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --exclude-dir=tests/integration/long/upgrade --with-ignore-docstrings --with-xunit --xunit-file=long_results.xml tests/integration/long/ || true
+ '''
+ }
+}
+
+def executeDseSmokeTests() {
+ sh label: 'Execute profile DataStax Enterprise smoke test integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variable
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} DSE_VERSION=${DSE_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/test_dse.py || true
+ '''
+}
+
+def executeEventLoopTests() {
+ sh label: 'Execute profile event loop manager integration tests', script: '''#!/bin/bash -lex
+ # Load CCM environment variable
+ set -o allexport
+ . ${HOME}/environment.txt
+ set +o allexport
+
+ EVENT_LOOP_TESTS=(
+ "tests/integration/standard/test_cluster.py"
+ "tests/integration/standard/test_concurrent.py"
+ "tests/integration/standard/test_connection.py"
+ "tests/integration/standard/test_control_connection.py"
+ "tests/integration/standard/test_metrics.py"
+ "tests/integration/standard/test_query.py"
+ "tests/integration/simulacron/test_endpoint.py"
+ "tests/integration/long/test_ssl.py"
+ )
+ EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} pynose -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml ${EVENT_LOOP_TESTS[@]} || true
+ '''
+}
+
+def executeTests() {
+ switch(env.PROFILE) {
+ case 'DSE-SMOKE-TEST':
+ executeDseSmokeTests()
+ break
+ case 'EVENT_LOOP':
+ executeEventLoopTests()
+ break
+ default:
+ executeStandardTests()
+ break
+ }
+}
+
+
+// TODO move this in the shared lib
+def getDriverMetricType() {
+ metric_type = 'oss'
+ if (env.GIT_URL.contains('riptano/python-driver')) {
+ metric_type = 'oss-private'
+ } else if (env.GIT_URL.contains('python-dse-driver')) {
+ metric_type = 'dse'
+ }
+ return metric_type
+}
+
+def describeBuild(buildContext) {
+ script {
+ def runtimes = buildContext.matrix["RUNTIME"]
+ def serverVersions = buildContext.matrix["SERVER"]
+ def numBuilds = runtimes.size() * serverVersions.size() * buildContext.matrix["CYTHON"].size()
+ currentBuild.displayName = "${env.PROFILE} (${env.EVENT_LOOP} | ${numBuilds} builds)"
+ currentBuild.description = "${env.PROFILE} build testing servers (${serverVersions.join(', ')}) against Python (${runtimes.join(', ')}) using ${env.EVENT_LOOP} event loop manager"
+ }
+}
+
+// branch pattern for cron
+def branchPatternCron() {
+ ~"(master)"
+}
+
+pipeline {
+ agent none
+
+ // Global pipeline timeout
+ options {
+ disableConcurrentBuilds()
+ timeout(time: 10, unit: 'HOURS') // TODO timeout should be per build
+ buildDiscarder(logRotator(artifactNumToKeepStr: '10', // Keep only the last 10 artifacts
+ numToKeepStr: '50')) // Keep only the last 50 build records
+ }
+
+ parameters {
+ choice(
+ name: 'ADHOC_BUILD_TYPE',
+ choices: ['BUILD', 'BUILD-AND-EXECUTE-TESTS'],
+ description: '''
Perform a adhoc build operation
+
+
+
+
+ | Choice |
+ Description |
+
+
+ | BUILD |
+ Performs a Per-Commit build |
+
+
+ | BUILD-AND-EXECUTE-TESTS |
+ Performs a build and executes the integration and unit tests |
+
+
''')
+ choice(
+ name: 'PROFILE',
+ choices: ['STANDARD', 'FULL', 'DSE-SMOKE-TEST', 'EVENT_LOOP'],
+ description: '''Profile to utilize for scheduled or adhoc builds
+
+
+
+
+ | Choice |
+ Description |
+
+
+ | STANDARD |
+ Execute the standard tests for the driver |
+
+
+ | FULL |
+ Execute all tests for the driver, including long tests. |
+
+
+ | DSE-SMOKE-TEST |
+ Execute only the DataStax Enterprise smoke tests |
+
+
+ | EVENT_LOOP |
+ Execute only the event loop tests for the specified event loop manager (see: EVENT_LOOP) |
+
+
''')
+ choice(
+ name: 'MATRIX',
+ choices: ['DEFAULT', 'SMOKE', 'FULL', 'CASSANDRA', 'DSE'],
+ description: '''The matrix for the build.
+
+
+
+
+ | Choice |
+ Description |
+
+
+ | DEFAULT |
+ Default to the build context. |
+
+
+ | SMOKE |
+ Basic smoke tests for current Python runtimes + C*/DSE versions, no Cython |
+
+
+ | FULL |
+ All server versions, python runtimes tested with and without Cython. |
+
+
+ | CASSANDRA |
+ All cassandra server versions. |
+
+
+ | DSE |
+ All dse server versions. |
+
+
''')
+ choice(
+ name: 'PYTHON_VERSION',
+ choices: ['DEFAULT'] + DEFAULT_RUNTIME,
+ description: 'Python runtime version. Default to the build context.')
+ choice(
+ name: 'SERVER_VERSION',
+ choices: ['DEFAULT'] + DEFAULT_CASSANDRA + DEFAULT_DSE,
+ description: '''Apache CassandraⓇ and DataStax Enterprise server version to use for adhoc BUILD-AND-EXECUTE-TESTS ONLY!
+
+
+
+
+ | Choice |
+ Description |
+
+
+ | DEFAULT |
+ Default to the build context. |
+
+
+ | 3.0 |
+ Apache CassandraⓇ v3.0.x |
+
+
+ | 3.11 |
+ Apache CassandraⓇ v3.11.x |
+
+
+ | 4.0 |
+ Apache CassandraⓇ v4.0.x |
+
+
+ | dse-5.1.35 |
+ DataStax Enterprise v5.1.x |
+
+
+ | dse-6.8.30 |
+ DataStax Enterprise v6.8.x (CURRENTLY UNDER DEVELOPMENT) |
+
+
''')
+ choice(
+ name: 'CYTHON',
+ choices: ['DEFAULT'] + DEFAULT_CYTHON,
+ description: '''Flag to determine if Cython should be enabled
+
+
+
+
+ | Choice |
+ Description |
+
+
+ | Default |
+ Default to the build context. |
+
+
+ | True |
+ Enable Cython |
+
+
+ | False |
+ Disable Cython |
+
+
''')
+ choice(
+ name: 'EVENT_LOOP',
+ choices: ['LIBEV', 'GEVENT', 'EVENTLET', 'ASYNCIO', 'ASYNCORE', 'TWISTED'],
+ description: '''Event loop manager to utilize for scheduled or adhoc builds
+
+
+
+
+ | Choice |
+ Description |
+
+
+ | LIBEV |
+ A full-featured and high-performance event loop that is loosely modeled after libevent, but without its limitations and bugs |
+
+
+ | GEVENT |
+ A co-routine -based Python networking library that uses greenlet to provide a high-level synchronous API on top of the libev or libuv event loop |
+
+
+ | EVENTLET |
+ A concurrent networking library for Python that allows you to change how you run your code, not how you write it |
+
+
+ | ASYNCIO |
+ A library to write concurrent code using the async/await syntax |
+
+
+ | ASYNCORE |
+ A module provides the basic infrastructure for writing asynchronous socket service clients and servers |
+
+
+ | TWISTED |
+ An event-driven networking engine written in Python and licensed under the open source MIT license |
+
+
''')
+ choice(
+ name: 'CI_SCHEDULE',
+ choices: ['DO-NOT-CHANGE-THIS-SELECTION', 'WEEKNIGHTS', 'WEEKENDS'],
+ description: 'CI testing schedule to execute periodically scheduled builds and tests of the driver (DO NOT CHANGE THIS SELECTION)')
+ string(
+ name: 'CI_SCHEDULE_PYTHON_VERSION',
+ defaultValue: 'DO-NOT-CHANGE-THIS-SELECTION',
+ description: 'CI testing python version to utilize for scheduled test runs of the driver (DO NOT CHANGE THIS SELECTION)')
+ string(
+ name: 'CI_SCHEDULE_SERVER_VERSION',
+ defaultValue: 'DO-NOT-CHANGE-THIS-SELECTION',
+ description: 'CI testing server version to utilize for scheduled test runs of the driver (DO NOT CHANGE THIS SELECTION)')
+ }
+
+ triggers {
+ parameterizedCron(branchPatternCron().matcher(env.BRANCH_NAME).matches() ? """
+ # Every weeknight (Monday - Friday) around 4:00 AM
+ # These schedules will run with and without Cython enabled for Python 3.8.16 and 3.12.0
+ H 4 * * 1-5 %CI_SCHEDULE=WEEKNIGHTS;EVENT_LOOP=LIBEV;CI_SCHEDULE_PYTHON_VERSION=3.8.16 3.12.0;CI_SCHEDULE_SERVER_VERSION=3.11 4.0 dse-5.1.35 dse-6.8.30
+ """ : "")
+ }
+
+ environment {
+ OS_VERSION = 'ubuntu/bionic64/python-driver'
+ CCM_ENVIRONMENT_SHELL = '/usr/local/bin/ccm_environment.sh'
+ CCM_MAX_HEAP_SIZE = '1536M'
+ }
+
+ stages {
+ stage ('Build and Test') {
+ when {
+ beforeAgent true
+ allOf {
+ not { buildingTag() }
+ }
+ }
+
+ steps {
+ script {
+ context = getBuildContext()
+ withEnv(context.vars) {
+ describeBuild(context)
+
+ // build and test all builds
+ parallel getMatrixBuilds(context)
+
+ slack.notifyChannel(currentBuild.currentResult)
+ }
+ }
+ }
+ }
+
+ }
+}
diff --git a/README-dev.rst b/README-dev.rst
index 8294d4efb8..adca510412 100644
--- a/README-dev.rst
+++ b/README-dev.rst
@@ -176,7 +176,7 @@ Use tee to capture logs and see them on your terminal::
Testing Multiple Python Versions
--------------------------------
-If you want to test all of python 2.7, 3.4, 3.5, 3.6, 3.7, and pypy, use tox (this is what
+Use tox to test all of Python 3.8 through 3.12 and pypy (this is what
TravisCI runs)::
tox
@@ -241,11 +241,10 @@ Adding a New Python Runtime Support
* Add the new python version to our jenkins image:
https://github.com/riptano/openstack-jenkins-drivers/
-* Add the new python version in job-creator:
- https://github.com/riptano/job-creator/
+* Add the new python version in the Jenkinsfile and TravisCI configs as appropriate
* Run the tests and ensure they all pass
* also test all event loops
* Update the wheels building repo to support that version:
- https://github.com/riptano/python-dse-driver-wheels
+ https://github.com/datastax/python-driver-wheels
diff --git a/README.rst b/README.rst
index 0b6c1e206d..98884008b0 100644
--- a/README.rst
+++ b/README.rst
@@ -1,20 +1,16 @@
DataStax Driver for Apache Cassandra
====================================
-.. image:: https://travis-ci.org/datastax/python-driver.png?branch=master
- :target: https://travis-ci.org/datastax/python-driver
+.. image:: https://travis-ci.com/datastax/python-driver.png?branch=master
+ :target: https://travis-ci.com/github/datastax/python-driver
A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) and
DataStax Enterprise (4.7+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3.
-The driver supports Python 2.7, 3.4, 3.5, 3.6, 3.7 and 3.8.
+The driver supports Python 3.8 through 3.12.
**Note:** DataStax products do not support big-endian systems.
-Feedback Requested
-------------------
-**Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short).
-
Features
--------
* `Synchronous `_ and `Asynchronous `_ APIs
@@ -26,7 +22,7 @@ Features
* Configurable `load balancing `_ and `retry policies `_
* `Concurrent execution utilities `_
* `Object mapper `_
-* `Connecting to DataStax Apollo database (cloud) `_
+* `Connecting to DataStax Astra database (cloud) `_
* DSE Graph execution API
* DSE Geometric type serialization
* DSE PlainText and GSSAPI authentication
@@ -61,6 +57,10 @@ Contributing
------------
See `CONTRIBUTING.md `_.
+Error Handling
+--------------
+While originally written for the Java driver, users may reference the `Cassandra error handling done right blog `_ for resolving error handling scenarios with Apache Cassandra.
+
Reporting Problems
------------------
Please report any bugs and make any feature requests on the
diff --git a/appveyor.yml b/appveyor.yml
index d1daaa6ec6..f8a3fd7660 100644
--- a/appveyor.yml
+++ b/appveyor.yml
@@ -1,9 +1,6 @@
environment:
matrix:
- - PYTHON: "C:\\Python27-x64"
- cassandra_version: 3.11.2
- ci_type: standard
- - PYTHON: "C:\\Python35-x64"
+ - PYTHON: "C:\\Python37-x64"
cassandra_version: 3.11.2
ci_type: standard
os: Visual Studio 2015
diff --git a/appveyor/appveyor.ps1 b/appveyor/appveyor.ps1
index cc1e6aa76f..5f6840e4e1 100644
--- a/appveyor/appveyor.ps1
+++ b/appveyor/appveyor.ps1
@@ -54,7 +54,7 @@ Start-Process python -ArgumentList "-m pip install psutil pyYaml six numpy" -Wai
# Clone ccm from git and use master.
If (!(Test-Path $env:CCM_PATH)) {
- Start-Process git -ArgumentList "clone https://github.com/pcmanus/ccm.git $($env:CCM_PATH)" -Wait -NoNewWindow
+ Start-Process git -ArgumentList "clone -b cassandra-test https://github.com/pcmanus/ccm.git $($env:CCM_PATH)" -Wait -NoNewWindow
}
diff --git a/benchmarks/callback_full_pipeline.py b/benchmarks/callback_full_pipeline.py
index e3ecfe3be5..a4a4c33315 100644
--- a/benchmarks/callback_full_pipeline.py
+++ b/benchmarks/callback_full_pipeline.py
@@ -18,7 +18,6 @@
from threading import Event
from base import benchmark, BenchmarkThread
-from six.moves import range
log = logging.getLogger(__name__)
diff --git a/benchmarks/future_batches.py b/benchmarks/future_batches.py
index 8cd915ebab..de4484e617 100644
--- a/benchmarks/future_batches.py
+++ b/benchmarks/future_batches.py
@@ -14,7 +14,7 @@
import logging
from base import benchmark, BenchmarkThread
-from six.moves import queue
+import queue
log = logging.getLogger(__name__)
diff --git a/benchmarks/future_full_pipeline.py b/benchmarks/future_full_pipeline.py
index 9a9fcfcd50..901573c18e 100644
--- a/benchmarks/future_full_pipeline.py
+++ b/benchmarks/future_full_pipeline.py
@@ -14,7 +14,7 @@
import logging
from base import benchmark, BenchmarkThread
-from six.moves import queue
+import queue
log = logging.getLogger(__name__)
diff --git a/benchmarks/sync.py b/benchmarks/sync.py
index f2a45fcd7d..96e744f700 100644
--- a/benchmarks/sync.py
+++ b/benchmarks/sync.py
@@ -13,7 +13,6 @@
# limitations under the License.
from base import benchmark, BenchmarkThread
-from six.moves import range
class Runner(BenchmarkThread):
diff --git a/build.yaml b/build.yaml
deleted file mode 100644
index b60c0950c1..0000000000
--- a/build.yaml
+++ /dev/null
@@ -1,272 +0,0 @@
-schedules:
- nightly_master:
- schedule: nightly
- disable_pull_requests: true
- branches:
- include: [master]
- env_vars: |
- EVENT_LOOP_MANAGER='libev'
- matrix:
- exclude:
- - python: [3.4, 3.6, 3.7, 3.8]
- - cassandra: ['2.1', '3.0', 'test-dse']
-
- commit_long_test:
- schedule: per_commit
- disable_pull_requests: true
- branches:
- include: [/long-python.*/]
- env_vars: |
- EVENT_LOOP_MANAGER='libev'
- matrix:
- exclude:
- - python: [3.4, 3.6, 3.7, 3.8]
- - cassandra: ['2.1', '3.0', 'test-dse']
-
- commit_branches:
- schedule: per_commit
- disable_pull_requests: true
- branches:
- include: [/python.*/]
- env_vars: |
- EVENT_LOOP_MANAGER='libev'
- EXCLUDE_LONG=1
- matrix:
- exclude:
- - python: [3.4, 3.6, 3.7, 3.8]
- - cassandra: ['2.1', '3.0', 'test-dse']
-
- commit_branches_dev:
- schedule: per_commit
- disable_pull_requests: true
- branches:
- include: [/dev-python.*/]
- env_vars: |
- EVENT_LOOP_MANAGER='libev'
- EXCLUDE_LONG=1
- matrix:
- exclude:
- - python: [2.7, 3.4, 3.7, 3.8]
- - cassandra: ['2.0', '2.1', '2.2', '3.0', 'test-dse', dse-4.8', 'dse-5.0']
-
- release_test:
- schedule: per_commit
- disable_pull_requests: true
- branches:
- include: [/release-.+/]
- env_vars: |
- EVENT_LOOP_MANAGER='libev'
-
- weekly_master:
- schedule: 0 10 * * 6
- disable_pull_requests: true
- branches:
- include: [master]
- env_vars: |
- EVENT_LOOP_MANAGER='libev'
- matrix:
- exclude:
- - python: [3.5]
- - cassandra: ['2.2', '3.1']
-
- weekly_gevent:
- schedule: 0 14 * * 6
- disable_pull_requests: true
- branches:
- include: [master]
- env_vars: |
- EVENT_LOOP_MANAGER='gevent'
- JUST_EVENT_LOOP=1
- matrix:
- exclude:
- - python: [3.4]
-
- weekly_eventlet:
- schedule: 0 18 * * 6
- disable_pull_requests: true
- branches:
- include: [master]
- env_vars: |
- EVENT_LOOP_MANAGER='eventlet'
- JUST_EVENT_LOOP=1
- matrix:
- exclude:
- - python: [3.4]
-
- weekly_asyncio:
- schedule: 0 22 * * 6
- disable_pull_requests: true
- branches:
- include: [master]
- env_vars: |
- EVENT_LOOP_MANAGER='asyncio'
- JUST_EVENT_LOOP=1
- matrix:
- exclude:
- - python: [2.7]
-
- weekly_async:
- schedule: 0 10 * * 7
- disable_pull_requests: true
- branches:
- include: [master]
- env_vars: |
- EVENT_LOOP_MANAGER='asyncore'
- JUST_EVENT_LOOP=1
- matrix:
- exclude:
- - python: [3.4]
-
- weekly_twister:
- schedule: 0 14 * * 7
- disable_pull_requests: true
- branches:
- include: [master]
- env_vars: |
- EVENT_LOOP_MANAGER='twisted'
- JUST_EVENT_LOOP=1
- matrix:
- exclude:
- - python: [3.4]
-
- upgrade_tests:
- schedule: adhoc
- branches:
- include: [master, python-546]
- env_vars: |
- EVENT_LOOP_MANAGER='libev'
- JUST_UPGRADE=True
- matrix:
- exclude:
- - python: [3.4, 3.6, 3.7, 3.8]
- - cassandra: ['2.0', '2.1', '2.2', '3.0', 'test-dse']
-
-python:
- - 2.7
- - 3.4
- - 3.5
- - 3.6
- - 3.7
- - 3.8
-
-os:
- - ubuntu/bionic64/python-driver
-
-cassandra:
- - '2.1'
- - '2.2'
- - '3.0'
- - '3.11'
- - 'dse-4.8'
- - 'dse-5.0'
- - 'dse-5.1'
- - 'dse-6.0'
- - 'dse-6.7'
-
-env:
- CYTHON:
- - CYTHON
- - NO_CYTHON
-
-build:
- - script: |
- export JAVA_HOME=$CCM_JAVA_HOME
- export PATH=$JAVA_HOME/bin:$PATH
- export PYTHONPATH=""
-
- # Required for unix socket tests
- sudo apt-get install socat
-
- # Install latest setuptools
- pip install --upgrade pip
- pip install -U setuptools
-
- pip install git+ssh://git@github.com/riptano/ccm-private.git
-
- if [ -n "$CCM_IS_DSE" ]; then
- pip install -r test-datastax-requirements.txt
- else
- pip install -r test-requirements.txt
- fi
-
- pip install nose-ignore-docstring
- pip install nose-exclude
- pip install service_identity
-
- FORCE_CYTHON=False
- if [[ $CYTHON == 'CYTHON' ]]; then
- FORCE_CYTHON=True
- pip install cython
- pip install numpy
- # Install the driver & compile C extensions
- python setup.py build_ext --inplace
- else
- # Install the driver & compile C extensions with no cython
- python setup.py build_ext --inplace --no-cython
- fi
-
- echo "JUST_UPGRADE: $JUST_UPGRADE"
- if [[ $JUST_UPGRADE == 'True' ]]; then
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=upgrade_results.xml tests/integration/upgrade || true
- exit 0
- fi
-
- if [[ $JUST_SMOKE == 'true' ]]; then
- # When we ONLY want to run the smoke tests
- echo "JUST_SMOKE: $JUST_SMOKE"
- echo "==========RUNNING SMOKE TESTS==========="
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION DSE_VERSION='6.7.0' MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/test_dse.py || true
- exit 0
- fi
-
- # Run the unit tests, this is not done in travis because
- # it takes too much time for the whole matrix to build with cython
- if [[ $CYTHON == 'CYTHON' ]]; then
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER VERIFY_CYTHON=1 nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_results.xml tests/unit/ || true
- EVENT_LOOP_MANAGER=eventlet VERIFY_CYTHON=1 nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_eventlet_results.xml tests/unit/io/test_eventletreactor.py || true
- EVENT_LOOP_MANAGER=gevent VERIFY_CYTHON=1 nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=unit_gevent_results.xml tests/unit/io/test_geventreactor.py || true
- fi
-
- if [ -n "$JUST_EVENT_LOOP" ]; then
- echo "Running integration event loop subset with $EVENT_LOOP_MANAGER"
- EVENT_LOOP_TESTS=(
- "tests/integration/standard/test_cluster.py"
- "tests/integration/standard/test_concurrent.py"
- "tests/integration/standard/test_connection.py"
- "tests/integration/standard/test_control_connection.py"
- "tests/integration/standard/test_metrics.py"
- "tests/integration/standard/test_query.py"
- "tests/integration/simulacron/test_endpoint.py"
- "tests/integration/long/test_ssl.py"
- )
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" DSE_VERSION=$DSE_VERSION CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml ${EVENT_LOOP_TESTS[@]} || true
- exit 0
- fi
-
- echo "Running with event loop manager: $EVENT_LOOP_MANAGER"
- echo "==========RUNNING SIMULACRON TESTS=========="
- SIMULACRON_JAR="$HOME/simulacron.jar"
- SIMULACRON_JAR=$SIMULACRON_JAR EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CASSANDRA_DIR=$CCM_INSTALL_DIR CCM_ARGS="$CCM_ARGS" DSE_VERSION=$DSE_VERSION CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=simulacron_results.xml tests/integration/simulacron/ || true
-
- echo "Running with event loop manager: $EVENT_LOOP_MANAGER"
- echo "==========RUNNING CQLENGINE TESTS=========="
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" DSE_VERSION=$DSE_VERSION CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=cqle_results.xml tests/integration/cqlengine/ || true
-
- echo "==========RUNNING INTEGRATION TESTS=========="
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" DSE_VERSION=$DSE_VERSION CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/ || true
-
- if [ -n "$DSE_VERSION" ] && ! [[ $DSE_VERSION == "4.8"* ]]; then
- echo "==========RUNNING DSE INTEGRATION TESTS=========="
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CASSANDRA_DIR=$CCM_INSTALL_DIR DSE_VERSION=$DSE_VERSION ADS_HOME=$HOME/ VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=dse_results.xml tests/integration/advanced/ || true
- fi
-
- echo "==========RUNNING CLOUD TESTS=========="
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CLOUD_PROXY_PATH="$HOME/proxy/" CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=advanced_results.xml tests/integration/cloud/ || true
-
- if [ -z "$EXCLUDE_LONG" ]; then
- echo "==========RUNNING LONG INTEGRATION TESTS=========="
- EVENT_LOOP_MANAGER=$EVENT_LOOP_MANAGER CCM_ARGS="$CCM_ARGS" DSE_VERSION=$DSE_VERSION CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION MAPPED_CASSANDRA_VERSION=$MAPPED_CASSANDRA_VERSION VERIFY_CYTHON=$FORCE_CYTHON nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --exclude-dir=tests/integration/long/upgrade --with-ignore-docstrings --with-xunit --xunit-file=long_results.xml tests/integration/long/ || true
- fi
-
- - xunit:
- - "*_results.xml"
diff --git a/cassandra/__init__.py b/cassandra/__init__.py
index ea0a9b7bdd..4a5b8b29a3 100644
--- a/cassandra/__init__.py
+++ b/cassandra/__init__.py
@@ -22,7 +22,7 @@ def emit(self, record):
logging.getLogger('cassandra').addHandler(NullHandler())
-__version_info__ = (3, 21, 0)
+__version_info__ = (3, 29, 1)
__version__ = '.'.join(map(str, __version_info__))
@@ -55,7 +55,7 @@ class ConsistencyLevel(object):
QUORUM = 4
"""
- ``ceil(RF/2)`` replicas must respond to consider the operation a success
+ ``ceil(RF/2) + 1`` replicas must respond to consider the operation a success
"""
ALL = 5
@@ -161,7 +161,12 @@ class ProtocolVersion(object):
V5 = 5
"""
- v5, in beta from 3.x+
+ v5, in beta from 3.x+. Finalised in 4.0-beta5
+ """
+
+ V6 = 6
+ """
+ v6, in beta from 4.0-beta5
"""
DSE_V1 = 0x41
@@ -174,12 +179,12 @@ class ProtocolVersion(object):
DSE private protocol v2, supported in DSE 6.0+
"""
- SUPPORTED_VERSIONS = (DSE_V2, DSE_V1, V5, V4, V3, V2, V1)
+ SUPPORTED_VERSIONS = (DSE_V2, DSE_V1, V6, V5, V4, V3, V2, V1)
"""
A tuple of all supported protocol versions
"""
- BETA_VERSIONS = (V5,)
+ BETA_VERSIONS = (V6,)
"""
A tuple of all beta protocol versions
"""
@@ -235,6 +240,10 @@ def has_continuous_paging_support(cls, version):
def has_continuous_paging_next_pages(cls, version):
return version >= cls.DSE_V2
+ @classmethod
+ def has_checksumming_support(cls, version):
+ return cls.V5 <= version < cls.DSE_V1
+
class WriteType(object):
"""
@@ -719,3 +728,25 @@ class UnresolvableContactPoints(DriverException):
contact points, only when lookup fails for all hosts
"""
pass
+
+class DependencyException(Exception):
+ """
+ Specific exception class for handling issues with driver dependencies
+ """
+
+ excs = []
+ """
+ A sequence of child exceptions
+ """
+
+ def __init__(self, msg, excs=[]):
+ complete_msg = msg
+ if excs:
+ complete_msg += ("\nThe following exceptions were observed: \n - " + '\n - '.join(str(e) for e in excs))
+ Exception.__init__(self, complete_msg)
+
+class VectorDeserializationFailure(DriverException):
+ """
+ The driver was unable to deserialize a given vector
+ """
+ pass
diff --git a/cassandra/auth.py b/cassandra/auth.py
index 910592f7ac..10200aa387 100644
--- a/cassandra/auth.py
+++ b/cassandra/auth.py
@@ -32,8 +32,6 @@
except ImportError:
SASLClient = None
-import six
-
log = logging.getLogger(__name__)
# Custom payload keys related to DSE Unified Auth
@@ -270,14 +268,15 @@ def __init__(self, username, password):
self.password = password
def get_mechanism(self):
- return six.b("PLAIN")
+ return b"PLAIN"
def get_initial_challenge(self):
- return six.b("PLAIN-START")
+ return b"PLAIN-START"
def evaluate_challenge(self, challenge):
- if challenge == six.b('PLAIN-START'):
- return six.b("\x00%s\x00%s" % (self.username, self.password))
+ if challenge == b'PLAIN-START':
+ data = "\x00%s\x00%s" % (self.username, self.password)
+ return data.encode()
raise Exception('Did not receive a valid challenge response from server')
@@ -296,13 +295,13 @@ def __init__(self, host, service, qops, properties):
self.sasl = SASLClient(host, service, 'GSSAPI', qops=qops, **properties)
def get_mechanism(self):
- return six.b("GSSAPI")
+ return b"GSSAPI"
def get_initial_challenge(self):
- return six.b("GSSAPI-START")
+ return b"GSSAPI-START"
def evaluate_challenge(self, challenge):
- if challenge == six.b('GSSAPI-START'):
+ if challenge == b'GSSAPI-START':
return self.sasl.process()
else:
return self.sasl.process(challenge)
diff --git a/cassandra/cluster.py b/cassandra/cluster.py
index c9a8b6d397..d5f80290a9 100644
--- a/cassandra/cluster.py
+++ b/cassandra/cluster.py
@@ -21,16 +21,17 @@
import atexit
from binascii import hexlify
from collections import defaultdict
+from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures
from copy import copy
-from functools import partial, wraps
+from functools import partial, reduce, wraps
from itertools import groupby, count, chain
import json
import logging
from warnings import warn
from random import random
-import six
-from six.moves import filter, range, queue as Queue
+import re
+import queue
import socket
import sys
import time
@@ -43,12 +44,12 @@
from cassandra import (ConsistencyLevel, AuthenticationFailed,
OperationTimedOut, UnsupportedOperation,
SchemaTargetType, DriverException, ProtocolVersion,
- UnresolvableContactPoints)
+ UnresolvableContactPoints, DependencyException)
from cassandra.auth import _proxy_execute_key, PlainTextAuthProvider
from cassandra.connection import (ConnectionException, ConnectionShutdown,
ConnectionHeartbeat, ProtocolVersionUnsupported,
EndPoint, DefaultEndPoint, DefaultEndPointFactory,
- ContinuousPagingState, SniEndPointFactory)
+ ContinuousPagingState, SniEndPointFactory, ConnectionBusy)
from cassandra.cqltypes import UserType
from cassandra.encoder import Encoder
from cassandra.protocol import (QueryMessage, ResultMessage,
@@ -63,8 +64,8 @@
BatchMessage, RESULT_KIND_PREPARED,
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler,
- RESULT_KIND_VOID)
-from cassandra.metadata import Metadata, protect_name, murmur3
+ RESULT_KIND_VOID, ProtocolException)
+from cassandra.metadata import Metadata, protect_name, murmur3, _NodeInfo
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance,
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
@@ -79,15 +80,16 @@
HostTargetingStatement)
from cassandra.marshal import int64_pack
from cassandra.timestamps import MonotonicTimestampGenerator
-from cassandra.compat import Mapping
-from cassandra.util import _resolve_contact_points_to_string_map
+from cassandra.util import _resolve_contact_points_to_string_map, Version
from cassandra.datastax.insights.reporter import MonitorReporter
from cassandra.datastax.insights.util import version_supports_insights
from cassandra.datastax.graph import (graph_object_row_factory, GraphOptions, GraphSON1Serializer,
- GraphProtocol, GraphSON2Serializer, GraphStatement, SimpleGraphStatement)
-from cassandra.datastax.graph.query import _request_timeout_key
+ GraphProtocol, GraphSON2Serializer, GraphStatement, SimpleGraphStatement,
+ graph_graphson2_row_factory, graph_graphson3_row_factory,
+ GraphSON3Serializer)
+from cassandra.datastax.graph.query import _request_timeout_key, _GraphSONContextRowFactory
from cassandra.datastax import cloud as dscloud
try:
@@ -97,7 +99,11 @@
try:
from cassandra.io.eventletreactor import EventletConnection
-except ImportError:
+# PYTHON-1364
+#
+# At the moment eventlet initialization is chucking AttributeErrors due to it's dependence on pyOpenSSL
+# and some changes in Python 3.12 which have some knock-on effects there.
+except (ImportError, AttributeError):
EventletConnection = None
try:
@@ -105,35 +111,67 @@
except ImportError:
from cassandra.util import WeakSet # NOQA
-if six.PY3:
- long = int
-
-def _is_eventlet_monkey_patched():
- if 'eventlet.patcher' not in sys.modules:
- return False
- import eventlet.patcher
- return eventlet.patcher.is_monkey_patched('socket')
-
-
def _is_gevent_monkey_patched():
if 'gevent.monkey' not in sys.modules:
return False
import gevent.socket
return socket.socket is gevent.socket.socket
+def _try_gevent_import():
+ if _is_gevent_monkey_patched():
+ from cassandra.io.geventreactor import GeventConnection
+ return (GeventConnection,None)
+ else:
+ return (None,None)
+
+def _is_eventlet_monkey_patched():
+ if 'eventlet.patcher' not in sys.modules:
+ return False
+ try:
+ import eventlet.patcher
+ return eventlet.patcher.is_monkey_patched('socket')
+ # Another case related to PYTHON-1364
+ except AttributeError:
+ return False
+
+def _try_eventlet_import():
+ if _is_eventlet_monkey_patched():
+ from cassandra.io.eventletreactor import EventletConnection
+ return (EventletConnection,None)
+ else:
+ return (None,None)
+
+def _try_libev_import():
+ try:
+ from cassandra.io.libevreactor import LibevConnection
+ return (LibevConnection,None)
+ except DependencyException as e:
+ return (None, e)
-# default to gevent when we are monkey patched with gevent, eventlet when
-# monkey patched with eventlet, otherwise if libev is available, use that as
-# the default because it's fastest. Otherwise, use asyncore.
-if _is_gevent_monkey_patched():
- from cassandra.io.geventreactor import GeventConnection as DefaultConnection
-elif _is_eventlet_monkey_patched():
- from cassandra.io.eventletreactor import EventletConnection as DefaultConnection
-else:
+def _try_asyncore_import():
try:
- from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA
- except ImportError:
- from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA
+ from cassandra.io.asyncorereactor import AsyncoreConnection
+ return (AsyncoreConnection,None)
+ except DependencyException as e:
+ return (None, e)
+
+def _connection_reduce_fn(val,import_fn):
+ (rv, excs) = val
+ # If we've already found a workable Connection class return immediately
+ if rv:
+ return val
+ (import_result, exc) = import_fn()
+ if exc:
+ excs.append(exc)
+ return (rv or import_result, excs)
+
+log = logging.getLogger(__name__)
+
+conn_fns = (_try_gevent_import, _try_eventlet_import, _try_libev_import, _try_asyncore_import)
+(conn_class, excs) = reduce(_connection_reduce_fn, conn_fns, (None,[]))
+if not conn_class:
+ raise DependencyException("Unable to load a default connection class", excs)
+DefaultConnection = conn_class
# Forces load of utf8 encoding module to avoid deadlock that occurs
# if code that is being imported tries to import the module in a seperate
@@ -141,8 +179,6 @@ def _is_gevent_monkey_patched():
# See http://bugs.python.org/issue10923
"".encode('utf8')
-log = logging.getLogger(__name__)
-
DEFAULT_MIN_REQUESTS = 5
DEFAULT_MAX_REQUESTS = 100
@@ -153,6 +189,7 @@ def _is_gevent_monkey_patched():
DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1
DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2
+_GRAPH_PAGING_MIN_DSE_VERSION = Version('6.8.0')
_NOT_SET = object()
@@ -415,21 +452,24 @@ class GraphExecutionProfile(ExecutionProfile):
"""
def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None,
- consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None,
- request_timeout=30.0, row_factory=graph_object_row_factory,
- graph_options=None):
+ consistency_level=_NOT_SET, serial_consistency_level=None,
+ request_timeout=30.0, row_factory=None,
+ graph_options=None, continuous_paging_options=_NOT_SET):
"""
Default execution profile for graph execution.
- See :class:`.ExecutionProfile`
- for base attributes.
+ See :class:`.ExecutionProfile` for base attributes. Note that if not explicitly set,
+ the row_factory and graph_options.graph_protocol are resolved during the query execution.
+ These options will resolve to graph_graphson3_row_factory and GraphProtocol.GRAPHSON_3_0
+ for the core graph engine (DSE 6.8+), otherwise graph_object_row_factory and GraphProtocol.GRAPHSON_1_0
In addition to default parameters shown in the signature, this profile also defaults ``retry_policy`` to
:class:`cassandra.policies.NeverRetryPolicy`.
"""
retry_policy = retry_policy or NeverRetryPolicy()
super(GraphExecutionProfile, self).__init__(load_balancing_policy, retry_policy, consistency_level,
- serial_consistency_level, request_timeout, row_factory)
+ serial_consistency_level, request_timeout, row_factory,
+ continuous_paging_options=continuous_paging_options)
self.graph_options = graph_options or GraphOptions(graph_source=b'g',
graph_language=b'gremlin-groovy')
@@ -437,8 +477,8 @@ def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None,
class GraphAnalyticsExecutionProfile(GraphExecutionProfile):
def __init__(self, load_balancing_policy=None, retry_policy=None,
- consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None,
- request_timeout=3600. * 24. * 7., row_factory=graph_object_row_factory,
+ consistency_level=_NOT_SET, serial_consistency_level=None,
+ request_timeout=3600. * 24. * 7., row_factory=None,
graph_options=None):
"""
Execution profile with timeout and load balancing appropriate for graph analytics queries.
@@ -575,7 +615,7 @@ class Cluster(object):
contact_points = ['127.0.0.1']
"""
The list of contact points to try connecting for cluster discovery. A
- contact point can be a string (ip, hostname) or a
+ contact point can be a string (ip or hostname), a tuple (ip/hostname, port) or a
:class:`.connection.EndPoint` instance.
Defaults to loopback interface.
@@ -771,15 +811,15 @@ def default_retry_policy(self, policy):
Using ssl_options without ssl_context is deprecated and will be removed in the
next major release.
- An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket`` (or
- ``ssl.wrap_socket()`` if used without ssl_context) when new sockets are created.
- This should be used when client encryption is enabled in Cassandra.
+ An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket``
+ when new sockets are created. This should be used when client encryption is enabled
+ in Cassandra.
The following documentation only applies when ssl_options is used without ssl_context.
By default, a ``ca_certs`` value should be supplied (the value should be
a string pointing to the location of the CA certs file), and you probably
- want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLSv1`` to match
+ want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLS`` to match
Cassandra's default protocol.
.. versionchanged:: 3.3.0
@@ -789,6 +829,12 @@ def default_retry_policy(self, policy):
should almost always require the option ``'cert_reqs': ssl.CERT_REQUIRED``. Note also that this functionality was not built into
Python standard library until (2.7.9, 3.2). To enable this mechanism in earlier versions, patch ``ssl.match_hostname``
with a custom or `back-ported function `_.
+
+ .. versionchanged:: 3.29.0
+
+ ``ssl.match_hostname`` has been deprecated since Python 3.7 (and removed in Python 3.12). This functionality is now implemented
+ via ``ssl.SSLContext.check_hostname``. All options specified above (including ``check_hostname``) should continue to behave in a
+ way that is consistent with prior implementations.
"""
ssl_context = None
@@ -984,16 +1030,25 @@ def default_retry_policy(self, policy):
cloud = None
"""
A dict of the cloud configuration. Example::
-
+
{
# path to the secure connect bundle
- 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip'
+ 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip',
+
+ # optional config options
+ 'use_default_tempdir': True # use the system temp dir for the zip extraction
}
The zip file will be temporarily extracted in the same directory to
load the configuration and certificates.
"""
+ column_encryption_policy = None
+ """
+ An instance of :class:`cassandra.policies.ColumnEncryptionPolicy` specifying encryption materials to be
+ used for columns in this cluster.
+ """
+
@property
def schema_metadata_enabled(self):
"""
@@ -1095,7 +1150,8 @@ def __init__(self,
monitor_reporting_enabled=True,
monitor_reporting_interval=30,
client_id=None,
- cloud=None):
+ cloud=None,
+ column_encryption_policy=None):
"""
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
extablishing connection pools or refreshing metadata.
@@ -1134,7 +1190,7 @@ def __init__(self,
else:
self._contact_points_explicit = True
- if isinstance(contact_points, six.string_types):
+ if isinstance(contact_points, str):
raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings")
if None in contact_points:
@@ -1143,23 +1199,30 @@ def __init__(self,
self.port = port
+ if column_encryption_policy is not None:
+ self.column_encryption_policy = column_encryption_policy
+
self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port)
self.endpoint_factory.configure(self)
- raw_contact_points = [cp for cp in self.contact_points if not isinstance(cp, EndPoint)]
+ raw_contact_points = []
+ for cp in [cp for cp in self.contact_points if not isinstance(cp, EndPoint)]:
+ raw_contact_points.append(cp if isinstance(cp, tuple) else (cp, port))
+
self.endpoints_resolved = [cp for cp in self.contact_points if isinstance(cp, EndPoint)]
self._endpoint_map_for_insights = {repr(ep): '{ip}:{port}'.format(ip=ep.address, port=ep.port)
for ep in self.endpoints_resolved}
- strs_resolved_map = _resolve_contact_points_to_string_map(raw_contact_points, port)
+ strs_resolved_map = _resolve_contact_points_to_string_map(raw_contact_points)
self.endpoints_resolved.extend(list(chain(
*[
- [DefaultEndPoint(x, port) for x in xs if x is not None]
+ [DefaultEndPoint(ip, port) for ip, port in xs if ip is not None]
for xs in strs_resolved_map.values() if xs is not None
]
)))
+
self._endpoint_map_for_insights.update(
- {key: ['{ip}:{port}'.format(ip=ip, port=port) for ip in value]
+ {key: ['{ip}:{port}'.format(ip=ip, port=port) for ip, port in value]
for key, value in strs_resolved_map.items() if value is not None}
)
@@ -1428,7 +1491,7 @@ def __init__(self, street, zipcode):
# results will include Address instances
results = session.execute("SELECT * FROM users")
row = results[0]
- print row.id, row.location.street, row.location.zipcode
+ print(row.id, row.location.street, row.location.zipcode)
"""
if self.protocol_version < 3:
@@ -1557,7 +1620,7 @@ def set_core_connections_per_host(self, host_distance, core_connections):
If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this
is not supported (there is always one connection per host, unless
the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`)
- and using this will result in an :exc:`~.UnsupporteOperation`.
+ and using this will result in an :exc:`~.UnsupportedOperation`.
"""
if self.protocol_version >= 3:
raise UnsupportedOperation(
@@ -1590,7 +1653,7 @@ def set_max_connections_per_host(self, host_distance, max_connections):
If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this
is not supported (there is always one connection per host, unless
the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`)
- and using this will result in an :exc:`~.UnsupporteOperation`.
+ and using this will result in an :exc:`~.UnsupportedOperation`.
"""
if self.protocol_version >= 3:
raise UnsupportedOperation(
@@ -1762,8 +1825,8 @@ def _new_session(self, keyspace):
return session
def _session_register_user_types(self, session):
- for keyspace, type_map in six.iteritems(self._user_types):
- for udt_name, klass in six.iteritems(type_map):
+ for keyspace, type_map in self._user_types.items():
+ for udt_name, klass in type_map.items():
session.user_type_registered(keyspace, udt_name, klass)
def _cleanup_failed_on_up_handling(self, host):
@@ -2378,7 +2441,7 @@ def default_consistency_level(self, cl):
*Deprecated:* use execution profiles instead
"""
warn("Setting the consistency level at the session level will be removed in 4.0. Consider using "
- "execution profiles and setting the desired consitency level to the EXEC_PROFILE_DEFAULT profile."
+ "execution profiles and setting the desired consistency level to the EXEC_PROFILE_DEFAULT profile."
, DeprecationWarning)
self._validate_set_legacy_config('default_consistency_level', cl)
@@ -2506,6 +2569,7 @@ def default_serial_consistency_level(self, cl):
_profile_manager = None
_metrics = None
_request_init_callbacks = None
+ _graph_paging_available = False
def __init__(self, cluster, hosts, keyspace=None):
self.cluster = cluster
@@ -2538,20 +2602,32 @@ def __init__(self, cluster, hosts, keyspace=None):
msg += " using keyspace '%s'" % self.keyspace
raise NoHostAvailable(msg, [h.address for h in hosts])
- cc_host = self.cluster.get_control_connection_host()
- valid_insights_version = (cc_host and version_supports_insights(cc_host.dse_version))
- if self.cluster.monitor_reporting_enabled and valid_insights_version:
- self._monitor_reporter = MonitorReporter(
- interval_sec=self.cluster.monitor_reporting_interval,
- session=self,
- )
- else:
- if cc_host:
- log.debug('Not starting MonitorReporter thread for Insights; '
- 'not supported by server version {v} on '
- 'ControlConnection host {c}'.format(v=cc_host.release_version, c=cc_host))
-
self.session_id = uuid.uuid4()
+ self._graph_paging_available = self._check_graph_paging_available()
+
+ if self.cluster.column_encryption_policy is not None:
+ try:
+ self.client_protocol_handler = type(
+ str(self.session_id) + "-ProtocolHandler",
+ (ProtocolHandler,),
+ {"column_encryption_policy": self.cluster.column_encryption_policy})
+ except AttributeError:
+ log.info("Unable to set column encryption policy for session")
+
+ if self.cluster.monitor_reporting_enabled:
+ cc_host = self.cluster.get_control_connection_host()
+ valid_insights_version = (cc_host and version_supports_insights(cc_host.dse_version))
+ if valid_insights_version:
+ self._monitor_reporter = MonitorReporter(
+ interval_sec=self.cluster.monitor_reporting_interval,
+ session=self,
+ )
+ else:
+ if cc_host:
+ log.debug('Not starting MonitorReporter thread for Insights; '
+ 'not supported by server version {v} on '
+ 'ControlConnection host {c}'.format(v=cc_host.release_version, c=cc_host))
+
log.debug('Started Session with client_id {} and session_id {}'.format(self.cluster.client_id,
self.session_id))
@@ -2639,7 +2715,7 @@ def execute_async(self, query, parameters=None, trace=False, custom_payload=None
"""
custom_payload = custom_payload if custom_payload else {}
if execute_as:
- custom_payload[_proxy_execute_key] = six.b(execute_as)
+ custom_payload[_proxy_execute_key] = execute_as.encode()
future = self._create_response_future(
query, parameters, trace, custom_payload, timeout,
@@ -2677,21 +2753,34 @@ def execute_graph_async(self, query, parameters=None, trace=False, execution_pro
if not isinstance(query, GraphStatement):
query = SimpleGraphStatement(query)
- execution_profile = self._maybe_get_execution_profile(execution_profile) # look up instance here so we can apply the extended attributes
+ # Clone and look up instance here so we can resolve and apply the extended attributes
+ execution_profile = self.execution_profile_clone_update(execution_profile)
+
+ if not hasattr(execution_profile, 'graph_options'):
+ raise ValueError(
+ "Execution profile for graph queries must derive from GraphExecutionProfile, and provide graph_options")
+ self._resolve_execution_profile_options(execution_profile)
+
+ # make sure the graphson context row factory is binded to this cluster
try:
- options = execution_profile.graph_options.copy()
- except AttributeError:
- raise ValueError("Execution profile for graph queries must derive from GraphExecutionProfile, and provide graph_options")
+ if issubclass(execution_profile.row_factory, _GraphSONContextRowFactory):
+ execution_profile.row_factory = execution_profile.row_factory(self.cluster)
+ except TypeError:
+ # issubclass might fail if arg1 is an instance
+ pass
+
+ # set graph paging if needed
+ self._maybe_set_graph_paging(execution_profile)
graph_parameters = None
if parameters:
- graph_parameters = self._transform_params(parameters, graph_options=options)
+ graph_parameters = self._transform_params(parameters, graph_options=execution_profile.graph_options)
- custom_payload = options.get_options_map()
+ custom_payload = execution_profile.graph_options.get_options_map()
if execute_as:
- custom_payload[_proxy_execute_key] = six.b(execute_as)
- custom_payload[_request_timeout_key] = int64_pack(long(execution_profile.request_timeout * 1000))
+ custom_payload[_proxy_execute_key] = execute_as.encode()
+ custom_payload[_request_timeout_key] = int64_pack(int(execution_profile.request_timeout * 1000))
future = self._create_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload,
timeout=_NOT_SET, execution_profile=execution_profile)
@@ -2699,12 +2788,82 @@ def execute_graph_async(self, query, parameters=None, trace=False, execution_pro
future.message.query_params = graph_parameters
future._protocol_handler = self.client_protocol_handler
- if options.is_analytics_source and isinstance(execution_profile.load_balancing_policy, DefaultLoadBalancingPolicy):
+ if execution_profile.graph_options.is_analytics_source and \
+ isinstance(execution_profile.load_balancing_policy, DefaultLoadBalancingPolicy):
self._target_analytics_master(future)
else:
future.send_request()
return future
+ def _maybe_set_graph_paging(self, execution_profile):
+ graph_paging = execution_profile.continuous_paging_options
+ if execution_profile.continuous_paging_options is _NOT_SET:
+ graph_paging = ContinuousPagingOptions() if self._graph_paging_available else None
+
+ execution_profile.continuous_paging_options = graph_paging
+
+ def _check_graph_paging_available(self):
+ """Verify if we can enable graph paging. This executed only once when the session is created."""
+
+ if not ProtocolVersion.has_continuous_paging_next_pages(self._protocol_version):
+ return False
+
+ for host in self.cluster.metadata.all_hosts():
+ if host.dse_version is None:
+ return False
+
+ version = Version(host.dse_version)
+ if version < _GRAPH_PAGING_MIN_DSE_VERSION:
+ return False
+
+ return True
+
+ def _resolve_execution_profile_options(self, execution_profile):
+ """
+ Determine the GraphSON protocol and row factory for a graph query. This is useful
+ to configure automatically the execution profile when executing a query on a
+ core graph.
+
+ If `graph_protocol` is not explicitly specified, the following rules apply:
+ - Default to GraphProtocol.GRAPHSON_1_0, or GRAPHSON_2_0 if the `graph_language` is not gremlin-groovy.
+ - If `graph_options.graph_name` is specified and is a Core graph, set GraphSON_3_0.
+ If `row_factory` is not explicitly specified, the following rules apply:
+ - Default to graph_object_row_factory.
+ - If `graph_options.graph_name` is specified and is a Core graph, set graph_graphson3_row_factory.
+ """
+ if execution_profile.graph_options.graph_protocol is not None and \
+ execution_profile.row_factory is not None:
+ return
+
+ graph_options = execution_profile.graph_options
+
+ is_core_graph = False
+ if graph_options.graph_name:
+ # graph_options.graph_name is bytes ...
+ name = graph_options.graph_name.decode('utf-8')
+ if name in self.cluster.metadata.keyspaces:
+ ks_metadata = self.cluster.metadata.keyspaces[name]
+ if ks_metadata.graph_engine == 'Core':
+ is_core_graph = True
+
+ if is_core_graph:
+ graph_protocol = GraphProtocol.GRAPHSON_3_0
+ row_factory = graph_graphson3_row_factory
+ else:
+ if graph_options.graph_language == GraphOptions.DEFAULT_GRAPH_LANGUAGE:
+ graph_protocol = GraphOptions.DEFAULT_GRAPH_PROTOCOL
+ row_factory = graph_object_row_factory
+ else:
+ # if not gremlin-groovy, GraphSON_2_0
+ graph_protocol = GraphProtocol.GRAPHSON_2_0
+ row_factory = graph_graphson2_row_factory
+
+ # Only apply if not set explicitly
+ if graph_options.graph_protocol is None:
+ graph_options.graph_protocol = graph_protocol
+ if execution_profile.row_factory is None:
+ execution_profile.row_factory = row_factory
+
def _transform_params(self, parameters, graph_options):
if not isinstance(parameters, dict):
raise ValueError('The parameters must be a dictionary. Unnamed parameters are not allowed.')
@@ -2712,12 +2871,16 @@ def _transform_params(self, parameters, graph_options):
# Serialize python types to graphson
serializer = GraphSON1Serializer
if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0:
- serializer = GraphSON2Serializer
-
- serialized_parameters = {
- p: serializer.serialize(v)
- for p, v in six.iteritems(parameters)
- }
+ serializer = GraphSON2Serializer()
+ elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0:
+ # only required for core graphs
+ context = {
+ 'cluster': self.cluster,
+ 'graph_name': graph_options.graph_name.decode('utf-8') if graph_options.graph_name else None
+ }
+ serializer = GraphSON3Serializer(context)
+
+ serialized_parameters = serializer.serialize(parameters)
return [json.dumps(serialized_parameters).encode('utf-8')]
def _target_analytics_master(self, future):
@@ -2754,7 +2917,7 @@ def _create_response_future(self, query, parameters, trace, custom_payload,
prepared_statement = None
- if isinstance(query, six.string_types):
+ if isinstance(query, str):
query = SimpleStatement(query)
elif isinstance(query, PreparedStatement):
query = query.bind(parameters)
@@ -2970,7 +3133,7 @@ def prepare(self, query, custom_payload=None, keyspace=None):
prepared_keyspace = keyspace if keyspace else None
prepared_statement = PreparedStatement.from_message(
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
- self._protocol_version, response.column_metadata, response.result_metadata_id)
+ self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
prepared_statement.custom_payload = future.custom_payload
self.cluster.add_prepared(response.query_id, prepared_statement)
@@ -3222,10 +3385,6 @@ def user_type_registered(self, keyspace, user_type, klass):
'User type %s does not exist in keyspace %s' % (user_type, keyspace))
field_names = type_meta.field_names
- if six.PY2:
- # go from unicode to string to avoid decode errors from implicit
- # decode when formatting non-ascii values
- field_names = [fn.encode('utf-8') for fn in field_names]
def encode(val):
return '{ %s }' % ' , '.join('%s : %s' % (
@@ -3323,7 +3482,16 @@ class ControlConnection(object):
_SELECT_SCHEMA_PEERS_TEMPLATE = "SELECT peer, host_id, {nt_col_name}, schema_version FROM system.peers"
_SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'"
- _MINIMUM_NATIVE_ADDRESS_VERSION = "4.0"
+ _SELECT_PEERS_V2 = "SELECT * FROM system.peers_v2"
+ _SELECT_PEERS_NO_TOKENS_V2 = "SELECT host_id, peer, peer_port, data_center, rack, native_address, native_port, release_version, schema_version FROM system.peers_v2"
+ _SELECT_SCHEMA_PEERS_V2 = "SELECT host_id, peer, peer_port, native_address, native_port, schema_version FROM system.peers_v2"
+
+ _MINIMUM_NATIVE_ADDRESS_DSE_VERSION = Version("6.0.0")
+
+ class PeersQueryType(object):
+ """internal Enum for _peers_query"""
+ PEERS = 0
+ PEERS_SCHEMA = 1
_is_shutdown = False
_timeout = None
@@ -3336,6 +3504,8 @@ class ControlConnection(object):
_schema_meta_enabled = True
_token_meta_enabled = True
+ _uses_peers_v2 = True
+
# for testing purposes
_time = time
@@ -3372,7 +3542,7 @@ def connect(self):
self._protocol_version = self._cluster.protocol_version
self._set_new_connection(self._reconnect_internal())
- self._cluster.metadata.dbaas = self._connection._product_type == dscloud.PRODUCT_APOLLO
+ self._cluster.metadata.dbaas = self._connection._product_type == dscloud.DATASTAX_CLOUD_PRODUCT_TYPE
def _set_new_connection(self, conn):
"""
@@ -3433,6 +3603,14 @@ def _try_connect(self, host):
break
except ProtocolVersionUnsupported as e:
self._cluster.protocol_downgrade(host.endpoint, e.startup_version)
+ except ProtocolException as e:
+ # protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver
+ # protocol version. If the protocol version was not explicitly specified,
+ # and that the server raises a beta protocol error, we should downgrade.
+ if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error:
+ self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version)
+ else:
+ raise
log.debug("[control connection] Established new connection %r, "
"registering watchers and refreshing schema and topology",
@@ -3450,13 +3628,25 @@ def _try_connect(self, host):
"SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change')
}, register_timeout=self._timeout)
- sel_peers = self._peers_query_for_version(connection, self._SELECT_PEERS_NO_TOKENS_TEMPLATE)
+ sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS
peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE)
local_query = QueryMessage(query=sel_local, consistency_level=ConsistencyLevel.ONE)
- shared_results = connection.wait_for_responses(
- peers_query, local_query, timeout=self._timeout)
+ (peers_success, peers_result), (local_success, local_result) = connection.wait_for_responses(
+ peers_query, local_query, timeout=self._timeout, fail_on_error=False)
+
+ if not local_success:
+ raise local_result
+
+ if not peers_success:
+ # error with the peers v2 query, fallback to peers v1
+ self._uses_peers_v2 = False
+ sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
+ peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE)
+ peers_result = connection.wait_for_response(
+ peers_query, timeout=self._timeout)
+ shared_results = (peers_result, local_result)
self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results)
self._refresh_schema(connection, preloaded_results=shared_results, schema_agreement_wait=-1)
except Exception:
@@ -3578,20 +3768,18 @@ def refresh_node_list_and_token_map(self, force_token_rebuild=False):
def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
force_token_rebuild=False):
-
if preloaded_results:
log.debug("[control connection] Refreshing node list and token map using preloaded results")
peers_result = preloaded_results[0]
local_result = preloaded_results[1]
else:
cl = ConsistencyLevel.ONE
+ sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection)
if not self._token_meta_enabled:
log.debug("[control connection] Refreshing node list without token map")
- sel_peers = self._peers_query_for_version(connection, self._SELECT_PEERS_NO_TOKENS_TEMPLATE)
sel_local = self._SELECT_LOCAL_NO_TOKENS
else:
log.debug("[control connection] Refreshing node list and token map")
- sel_peers = self._SELECT_PEERS
sel_local = self._SELECT_LOCAL
peers_query = QueryMessage(query=sel_peers, consistency_level=cl)
local_query = QueryMessage(query=sel_local, consistency_level=cl)
@@ -3621,13 +3809,17 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
self._update_location_info(host, datacenter, rack)
host.host_id = local_row.get("host_id")
host.listen_address = local_row.get("listen_address")
- host.broadcast_address = local_row.get("broadcast_address")
+ host.listen_port = local_row.get("listen_port")
+ host.broadcast_address = _NodeInfo.get_broadcast_address(local_row)
+ host.broadcast_port = _NodeInfo.get_broadcast_port(local_row)
- host.broadcast_rpc_address = self._address_from_row(local_row)
+ host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(local_row)
+ host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(local_row)
if host.broadcast_rpc_address is None:
if self._token_meta_enabled:
# local rpc_address is not available, use the connection endpoint
host.broadcast_rpc_address = connection.endpoint.address
+ host.broadcast_rpc_port = connection.endpoint.port
else:
# local rpc_address has not been queried yet, try to fetch it
# separately, which might fail because C* < 2.1.6 doesn't have rpc_address
@@ -3640,9 +3832,11 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
row = dict_factory(
local_rpc_address_result.column_names,
local_rpc_address_result.parsed_rows)
- host.broadcast_rpc_address = row[0]['rpc_address']
+ host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row[0])
+ host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row[0])
else:
host.broadcast_rpc_address = connection.endpoint.address
+ host.broadcast_rpc_port = connection.endpoint.port
host.release_version = local_row.get("release_version")
host.dse_version = local_row.get("dse_version")
@@ -3657,12 +3851,14 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
# any new nodes, so we need this additional check. (See PYTHON-90)
should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None
for row in peers_result:
+ if not self._is_valid_peer(row):
+ log.warning(
+ "Found an invalid row for peer (%s). Ignoring host." %
+ _NodeInfo.get_broadcast_rpc_address(row))
+ continue
+
endpoint = self._cluster.endpoint_factory.create(row)
- tokens = row.get("tokens", None)
- if 'tokens' in row and not tokens: # it was selected, but empty
- log.warning("Excluding host (%s) with no tokens in system.peers table of %s." % (endpoint, connection.endpoint))
- continue
if endpoint in found_hosts:
log.warning("Found multiple hosts with the same endpoint (%s). Excluding peer %s", endpoint, row.get("peer"))
continue
@@ -3680,13 +3876,16 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
should_rebuild_token_map |= self._update_location_info(host, datacenter, rack)
host.host_id = row.get("host_id")
- host.broadcast_address = row.get("peer")
- host.broadcast_rpc_address = self._address_from_row(row)
+ host.broadcast_address = _NodeInfo.get_broadcast_address(row)
+ host.broadcast_port = _NodeInfo.get_broadcast_port(row)
+ host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row)
+ host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row)
host.release_version = row.get("release_version")
host.dse_version = row.get("dse_version")
host.dse_workload = row.get("workload")
host.dse_workloads = row.get("workloads")
+ tokens = row.get("tokens", None)
if partitioner and tokens and self._token_meta_enabled:
token_map[host] = tokens
@@ -3701,6 +3900,12 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
log.debug("[control connection] Rebuilding token map due to topology changes")
self._cluster.metadata.rebuild_token_map(partitioner, token_map)
+ @staticmethod
+ def _is_valid_peer(row):
+ return bool(_NodeInfo.get_broadcast_rpc_address(row) and row.get("host_id") and
+ row.get("data_center") and row.get("rack") and
+ ('tokens' not in row or row.get('tokens')))
+
def _update_location_info(self, host, datacenter, rack):
if host.datacenter == datacenter and host.rack == rack:
return False
@@ -3737,7 +3942,8 @@ def _refresh_nodes_if_not_up(self, host):
def _handle_topology_change(self, event):
change_type = event["change_type"]
- host = self._cluster.metadata.get_host(event["address"][0])
+ addr, port = event["address"]
+ host = self._cluster.metadata.get_host(addr, port)
if change_type == "NEW_NODE" or change_type == "MOVED_NODE":
if self._topology_event_refresh_window >= 0:
delay = self._delay_for_event_type('topology_change', self._topology_event_refresh_window)
@@ -3747,7 +3953,8 @@ def _handle_topology_change(self, event):
def _handle_status_change(self, event):
change_type = event["change_type"]
- host = self._cluster.metadata.get_host(event["address"][0])
+ addr, port = event["address"]
+ host = self._cluster.metadata.get_host(addr, port)
if change_type == "UP":
delay = self._delay_for_event_type('status_change', self._status_event_refresh_window)
if host is None:
@@ -3801,7 +4008,7 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai
elapsed = 0
cl = ConsistencyLevel.ONE
schema_mismatches = None
- select_peers_query = self._peers_query_for_version(connection, self._SELECT_SCHEMA_PEERS_TEMPLATE)
+ select_peers_query = self._get_peers_query(self.PeersQueryType.PEERS_SCHEMA, connection)
while elapsed < total_timeout:
peers_query = QueryMessage(query=select_peers_query, consistency_level=cl)
@@ -3856,42 +4063,52 @@ def _get_schema_mismatches(self, peers_result, local_result, local_address):
log.debug("[control connection] Schemas match")
return None
- return dict((version, list(nodes)) for version, nodes in six.iteritems(versions))
+ return dict((version, list(nodes)) for version, nodes in versions.items())
- def _address_from_row(self, row):
- """
- Parse the broadcast rpc address from a row and return it untranslated.
+ def _get_peers_query(self, peers_query_type, connection=None):
"""
- addr = None
- if "rpc_address" in row:
- addr = row.get("rpc_address") # peers and local
- if "native_transport_address" in row:
- addr = row.get("native_transport_address")
- if not addr or addr in ["0.0.0.0", "::"]:
- addr = row.get("peer")
- return addr
+ Determine the peers query to use.
+
+ :param peers_query_type: Should be one of PeersQueryType enum.
+
+ If _uses_peers_v2 is True, return the proper peers_v2 query (no templating).
+ Else, apply the logic below to choose the peers v1 address column name:
- def _peers_query_for_version(self, connection, peers_query_template):
- """
Given a connection:
- find the server product version running on the connection's host,
- use that to choose the column name for the transport address (see APOLLO-1130), and
- use that column name in the provided peers query template.
-
- The provided template should be a string with a format replacement
- field named nt_col_name.
"""
- host_release_version = self._cluster.metadata.get_host(connection.endpoint).release_version
- if host_release_version:
- use_native_address_query = host_release_version >= self._MINIMUM_NATIVE_ADDRESS_VERSION
- if use_native_address_query:
- select_peers_query = peers_query_template.format(nt_col_name="native_transport_address")
+ if peers_query_type not in (self.PeersQueryType.PEERS, self.PeersQueryType.PEERS_SCHEMA):
+ raise ValueError("Invalid peers query type: %s" % peers_query_type)
+
+ if self._uses_peers_v2:
+ if peers_query_type == self.PeersQueryType.PEERS:
+ query = self._SELECT_PEERS_V2 if self._token_meta_enabled else self._SELECT_PEERS_NO_TOKENS_V2
else:
- select_peers_query = peers_query_template.format(nt_col_name="rpc_address")
+ query = self._SELECT_SCHEMA_PEERS_V2
else:
- select_peers_query = self._SELECT_PEERS
- return select_peers_query
+ if peers_query_type == self.PeersQueryType.PEERS and self._token_meta_enabled:
+ query = self._SELECT_PEERS
+ else:
+ query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE
+ if peers_query_type == self.PeersQueryType.PEERS_SCHEMA
+ else self._SELECT_PEERS_NO_TOKENS_TEMPLATE)
+
+ host_release_version = self._cluster.metadata.get_host(connection.endpoint).release_version
+ host_dse_version = self._cluster.metadata.get_host(connection.endpoint).dse_version
+ uses_native_address_query = (
+ host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION)
+
+ if uses_native_address_query:
+ query = query_template.format(nt_col_name="native_transport_address")
+ elif host_release_version:
+ query = query_template.format(nt_col_name="rpc_address")
+ else:
+ query = self._SELECT_PEERS
+
+ return query
def _signal_error(self):
with self._lock:
@@ -3966,7 +4183,7 @@ class _Scheduler(Thread):
is_shutdown = False
def __init__(self, executor):
- self._queue = Queue.PriorityQueue()
+ self._queue = queue.PriorityQueue()
self._scheduled_tasks = set()
self._count = count()
self._executor = executor
@@ -4024,7 +4241,7 @@ def run(self):
else:
self._queue.put_nowait((run_at, i, task))
break
- except Queue.Empty:
+ except queue.Empty:
pass
time.sleep(0.1)
@@ -4081,7 +4298,7 @@ class ResponseFuture(object):
coordinator_host = None
"""
- The host from which we recieved a response
+ The host from which we received a response
"""
attempted_hosts = None
@@ -4199,10 +4416,17 @@ def _on_timeout(self, _attempts=0):
pool = self.session._pools.get(self._current_host)
if pool and not pool.is_shutdown:
+ # Do not return the stream ID to the pool yet. We cannot reuse it
+ # because the node might still be processing the query and will
+ # return a late response to that query - if we used such stream
+ # before the response to the previous query has arrived, the new
+ # query could get a response from the old query
with self._connection.lock:
- self._connection.request_ids.append(self._req_id)
+ self._connection.orphaned_request_ids.add(self._req_id)
+ if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold:
+ self._connection.orphaned_threshold_reached = True
- pool.return_connection(self._connection)
+ pool.return_connection(self._connection, stream_was_orphaned=True)
errors = self._errors
if not errors:
@@ -4300,7 +4524,9 @@ def _query(self, host, message=None, cb=None):
except NoConnectionsAvailable as exc:
log.debug("All connections for host %s are at capacity, moving to the next host", host)
self._errors[host] = exc
- return None
+ except ConnectionBusy as exc:
+ log.debug("Connection for host %s is busy, moving to the next host", host)
+ self._errors[host] = exc
except Exception as exc:
log.debug("Error querying host %s", host, exc_info=True)
self._errors[host] = exc
@@ -4308,7 +4534,8 @@ def _query(self, host, message=None, cb=None):
self._metrics.on_connection_error()
if connection:
pool.return_connection(connection)
- return None
+
+ return None
@property
def has_more_pages(self):
@@ -4933,6 +5160,15 @@ def current_rows(self):
"""
return self._current_rows or []
+ def all(self):
+ """
+ Returns all the remaining rows as a list. This is basically
+ a convenient shortcut to `list(result_set)`.
+
+ This function is not recommended for queries that return a large number of elements.
+ """
+ return list(self)
+
def one(self):
"""
Return a single row of the results or None if empty. This is basically
@@ -4968,6 +5204,12 @@ def next(self):
self.fetch_next_page()
self._page_iter = iter(self._current_rows)
+ # Some servers can return empty pages in this case; Scylla is known to do
+ # so in some circumstances. Guard against this by recursing to handle
+ # the next(iter) call. If we have an empty page in that case it will
+ # get handled by the StopIteration handler when we recurse.
+ return self.next()
+
return next(self._page_iter)
__next__ = next
diff --git a/cassandra/column_encryption/_policies.py b/cassandra/column_encryption/_policies.py
new file mode 100644
index 0000000000..ef8097bfbd
--- /dev/null
+++ b/cassandra/column_encryption/_policies.py
@@ -0,0 +1,139 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import namedtuple
+from functools import lru_cache
+
+import logging
+import os
+
+log = logging.getLogger(__name__)
+
+from cassandra.cqltypes import _cqltypes
+from cassandra.policies import ColumnEncryptionPolicy
+
+from cryptography.hazmat.primitives import padding
+from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+
+AES256_BLOCK_SIZE = 128
+AES256_BLOCK_SIZE_BYTES = int(AES256_BLOCK_SIZE / 8)
+AES256_KEY_SIZE = 256
+AES256_KEY_SIZE_BYTES = int(AES256_KEY_SIZE / 8)
+
+ColData = namedtuple('ColData', ['key','type'])
+
+class AES256ColumnEncryptionPolicy(ColumnEncryptionPolicy):
+
+ # Fix block cipher mode for now. IV size is a function of block cipher used
+ # so fixing this avoids (possibly unnecessary) validation logic here.
+ mode = modes.CBC
+
+ # "iv" param here expects a bytearray that's the same size as the block
+ # size for AES-256 (128 bits or 16 bytes). If none is provided a new one
+ # will be randomly generated, but in this case the IV should be recorded and
+ # preserved or else you will not be able to decrypt any data encrypted by this
+ # policy.
+ def __init__(self, iv=None):
+
+ # CBC uses an IV that's the same size as the block size
+ #
+ # Avoid defining IV with a default arg in order to stay away from
+ # any issues around the caching of default args
+ self.iv = iv
+ if self.iv:
+ if not len(self.iv) == AES256_BLOCK_SIZE_BYTES:
+ raise ValueError("This policy uses AES-256 with CBC mode and therefore expects a 128-bit initialization vector")
+ else:
+ self.iv = os.urandom(AES256_BLOCK_SIZE_BYTES)
+
+ # ColData for a given ColDesc is always preserved. We only create a Cipher
+ # when there's an actual need to for a given ColDesc
+ self.coldata = {}
+ self.ciphers = {}
+
+ def encrypt(self, coldesc, obj_bytes):
+
+ # AES256 has a 128-bit block size so if the input bytes don't align perfectly on
+ # those blocks we have to pad them. There's plenty of room for optimization here:
+ #
+ # * Instances of the PKCS7 padder should be managed in a bounded pool
+ # * It would be nice if we could get a flag from encrypted data to indicate
+ # whether it was padded or not
+ # * Might be able to make this happen with a leading block of flags in encrypted data
+ padder = padding.PKCS7(AES256_BLOCK_SIZE).padder()
+ padded_bytes = padder.update(obj_bytes) + padder.finalize()
+
+ cipher = self._get_cipher(coldesc)
+ encryptor = cipher.encryptor()
+ return self.iv + encryptor.update(padded_bytes) + encryptor.finalize()
+
+ def decrypt(self, coldesc, bytes):
+
+ iv = bytes[:AES256_BLOCK_SIZE_BYTES]
+ encrypted_bytes = bytes[AES256_BLOCK_SIZE_BYTES:]
+ cipher = self._get_cipher(coldesc, iv=iv)
+ decryptor = cipher.decryptor()
+ padded_bytes = decryptor.update(encrypted_bytes) + decryptor.finalize()
+
+ unpadder = padding.PKCS7(AES256_BLOCK_SIZE).unpadder()
+ return unpadder.update(padded_bytes) + unpadder.finalize()
+
+ def add_column(self, coldesc, key, type):
+
+ if not coldesc:
+ raise ValueError("ColDesc supplied to add_column cannot be None")
+ if not key:
+ raise ValueError("Key supplied to add_column cannot be None")
+ if not type:
+ raise ValueError("Type supplied to add_column cannot be None")
+ if type not in _cqltypes.keys():
+ raise ValueError("Type %s is not a supported type".format(type))
+ if not len(key) == AES256_KEY_SIZE_BYTES:
+ raise ValueError("AES256 column encryption policy expects a 256-bit encryption key")
+ self.coldata[coldesc] = ColData(key, _cqltypes[type])
+
+ def contains_column(self, coldesc):
+ return coldesc in self.coldata
+
+ def encode_and_encrypt(self, coldesc, obj):
+ if not coldesc:
+ raise ValueError("ColDesc supplied to encode_and_encrypt cannot be None")
+ if not obj:
+ raise ValueError("Object supplied to encode_and_encrypt cannot be None")
+ coldata = self.coldata.get(coldesc)
+ if not coldata:
+ raise ValueError("Could not find ColData for ColDesc %s".format(coldesc))
+ return self.encrypt(coldesc, coldata.type.serialize(obj, None))
+
+ def cache_info(self):
+ return AES256ColumnEncryptionPolicy._build_cipher.cache_info()
+
+ def column_type(self, coldesc):
+ return self.coldata[coldesc].type
+
+ def _get_cipher(self, coldesc, iv=None):
+ """
+ Access relevant state from this instance necessary to create a Cipher and then get one,
+ hopefully returning a cached instance if we've already done so (and it hasn't been evicted)
+ """
+ try:
+ coldata = self.coldata[coldesc]
+ return AES256ColumnEncryptionPolicy._build_cipher(coldata.key, iv or self.iv)
+ except KeyError:
+ raise ValueError("Could not find column {}".format(coldesc))
+
+ # Explicitly use a class method here to avoid caching self
+ @lru_cache(maxsize=128)
+ def _build_cipher(key, iv):
+ return Cipher(algorithms.AES256(key), AES256ColumnEncryptionPolicy.mode(iv))
diff --git a/cassandra/compat.py b/cassandra/column_encryption/policies.py
similarity index 79%
rename from cassandra/compat.py
rename to cassandra/column_encryption/policies.py
index 83c1b104e5..770084bd48 100644
--- a/cassandra/compat.py
+++ b/cassandra/column_encryption/policies.py
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import six
-
-if six.PY2:
- from collections import Mapping
-elif six.PY3:
- from collections.abc import Mapping
+try:
+ import cryptography
+ from cassandra.column_encryption._policies import *
+except ImportError:
+ # Cryptography is not installed
+ pass
diff --git a/cassandra/concurrent.py b/cassandra/concurrent.py
index a8bddcbdab..fb8f26e1cc 100644
--- a/cassandra/concurrent.py
+++ b/cassandra/concurrent.py
@@ -16,12 +16,10 @@
from collections import namedtuple
from heapq import heappush, heappop
from itertools import cycle
-import six
-from six.moves import xrange, zip
from threading import Condition
import sys
-from cassandra.cluster import ResultSet
+from cassandra.cluster import ResultSet, EXEC_PROFILE_DEFAULT
import logging
log = logging.getLogger(__name__)
@@ -29,7 +27,7 @@
ExecutionResult = namedtuple('ExecutionResult', ['success', 'result_or_exc'])
-def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True, results_generator=False):
+def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True, results_generator=False, execution_profile=EXEC_PROFILE_DEFAULT):
"""
Executes a sequence of (statement, parameters) tuples concurrently. Each
``parameters`` item must be a sequence or :const:`None`.
@@ -56,6 +54,9 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais
footprint is marginal CPU overhead (more thread coordination and sorting out-of-order results
on-the-fly).
+ `execution_profile` argument is the execution profile to use for this
+ request, it is passed directly to :meth:`Session.execute_async`.
+
A sequence of ``ExecutionResult(success, result_or_exc)`` namedtuples is returned
in the same order that the statements were passed in. If ``success`` is :const:`False`,
there was an error executing the statement, and ``result_or_exc`` will be
@@ -90,7 +91,8 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais
if not statements_and_parameters:
return []
- executor = ConcurrentExecutorGenResults(session, statements_and_parameters) if results_generator else ConcurrentExecutorListResults(session, statements_and_parameters)
+ executor = ConcurrentExecutorGenResults(session, statements_and_parameters, execution_profile) \
+ if results_generator else ConcurrentExecutorListResults(session, statements_and_parameters, execution_profile)
return executor.execute(concurrency, raise_on_first_error)
@@ -98,9 +100,10 @@ class _ConcurrentExecutor(object):
max_error_recursion = 100
- def __init__(self, session, statements_and_params):
+ def __init__(self, session, statements_and_params, execution_profile):
self.session = session
self._enum_statements = enumerate(iter(statements_and_params))
+ self._execution_profile = execution_profile
self._condition = Condition()
self._fail_fast = False
self._results_queue = []
@@ -114,7 +117,7 @@ def execute(self, concurrency, fail_fast):
self._current = 0
self._exec_count = 0
with self._condition:
- for n in xrange(concurrency):
+ for n in range(concurrency):
if not self._execute_next():
break
return self._results()
@@ -132,23 +135,19 @@ def _execute_next(self):
def _execute(self, idx, statement, params):
self._exec_depth += 1
try:
- future = self.session.execute_async(statement, params, timeout=None)
+ future = self.session.execute_async(statement, params, timeout=None, execution_profile=self._execution_profile)
args = (future, idx)
future.add_callbacks(
callback=self._on_success, callback_args=args,
errback=self._on_error, errback_args=args)
except Exception as exc:
- # exc_info with fail_fast to preserve stack trace info when raising on the client thread
- # (matches previous behavior -- not sure why we wouldn't want stack trace in the other case)
- e = sys.exc_info() if self._fail_fast and six.PY2 else exc
-
# If we're not failing fast and all executions are raising, there is a chance of recursing
# here as subsequent requests are attempted. If we hit this threshold, schedule this result/retry
# and let the event loop thread return.
if self._exec_depth < self.max_error_recursion:
- self._put_result(e, idx, False)
+ self._put_result(exc, idx, False)
else:
- self.session.submit(self._put_result, e, idx, False)
+ self.session.submit(self._put_result, exc, idx, False)
self._exec_depth -= 1
def _on_success(self, result, future, idx):
@@ -158,14 +157,6 @@ def _on_success(self, result, future, idx):
def _on_error(self, result, future, idx):
self._put_result(result, idx, False)
- @staticmethod
- def _raise(exc):
- if six.PY2 and isinstance(exc, tuple):
- (exc_type, value, traceback) = exc
- six.reraise(exc_type, value, traceback)
- else:
- raise exc
-
class ConcurrentExecutorGenResults(_ConcurrentExecutor):
@@ -185,7 +176,7 @@ def _results(self):
try:
self._condition.release()
if self._fail_fast and not res[0]:
- self._raise(res[1])
+ raise res[1]
yield res
finally:
self._condition.acquire()
@@ -216,9 +207,9 @@ def _results(self):
while self._current < self._exec_count:
self._condition.wait()
if self._exception and self._fail_fast:
- self._raise(self._exception)
+ raise self._exception
if self._exception and self._fail_fast: # raise the exception even if there was no wait
- self._raise(self._exception)
+ raise self._exception
return [r[1] for r in sorted(self._results_queue)]
diff --git a/cassandra/connection.py b/cassandra/connection.py
index 66af1f8521..bfe38fc702 100644
--- a/cassandra/connection.py
+++ b/cassandra/connection.py
@@ -19,19 +19,19 @@
from heapq import heappush, heappop
import io
import logging
-import six
-from six.moves import range
import socket
import struct
import sys
from threading import Thread, Event, RLock, Condition
import time
import ssl
+import weakref
+
if 'gevent.monkey' in sys.modules:
from gevent.queue import Queue, Empty
else:
- from six.moves.queue import Queue, Empty # noqa
+ from queue import Queue, Empty # noqa
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut, ProtocolVersion
from cassandra.marshal import int32_pack
@@ -42,11 +42,15 @@
AuthResponseMessage, AuthChallengeMessage,
AuthSuccessMessage, ProtocolException,
RegisterMessage, ReviseRequestMessage)
+from cassandra.segment import SegmentCodec, CrcException
from cassandra.util import OrderedDict
log = logging.getLogger(__name__)
+segment_codec_no_compression = SegmentCodec()
+segment_codec_lz4 = None
+
# We use an ordered dictionary and specifically add lz4 before
# snappy so that lz4 will be preferred. Changing the order of this
# will change the compression preferences for the driver.
@@ -88,6 +92,7 @@ def lz4_decompress(byts):
return lz4_block.decompress(byts[3::-1] + byts[4:])
locally_supported_compressions['lz4'] = (lz4_compress, lz4_decompress)
+ segment_codec_lz4 = SegmentCodec(lz4_compress, lz4_decompress)
try:
import snappy
@@ -214,25 +219,26 @@ class DefaultEndPointFactory(EndPointFactory):
port = None
"""
- If set, force all endpoints to use this port.
+ If no port is discovered in the row, this is the default port
+ used for endpoint creation.
"""
def __init__(self, port=None):
self.port = port
def create(self, row):
- addr = None
- if "rpc_address" in row:
- addr = row.get("rpc_address")
- if "native_transport_address" in row:
- addr = row.get("native_transport_address")
- if not addr or addr in ["0.0.0.0", "::"]:
- addr = row.get("peer")
+ # TODO next major... move this class so we don't need this kind of hack
+ from cassandra.metadata import _NodeInfo
+ addr = _NodeInfo.get_broadcast_rpc_address(row)
+ port = _NodeInfo.get_broadcast_rpc_port(row)
+ if port is None:
+ port = self.port if self.port else 9042
# create the endpoint with the translated address
+ # TODO next major, create a TranslatedEndPoint type
return DefaultEndPoint(
self.cluster.address_translator.translate(addr),
- self.port if self.port is not None else 9042)
+ port)
@total_ordering
@@ -425,6 +431,10 @@ class ProtocolError(Exception):
pass
+class CrcMismatchException(ConnectionException):
+ pass
+
+
class ContinuousPagingState(object):
"""
A class for specifying continuous paging state, only supported starting with DSE_V2.
@@ -593,11 +603,59 @@ def wrapper(self, *args, **kwargs):
DEFAULT_CQL_VERSION = '3.0.0'
-if six.PY3:
- def int_from_buf_item(i):
- return i
-else:
- int_from_buf_item = ord
+
+class _ConnectionIOBuffer(object):
+ """
+ Abstraction class to ease the use of the different connection io buffers. With
+ protocol V5 and checksumming, the data is read, validated and copied to another
+ cql frame buffer.
+ """
+ _io_buffer = None
+ _cql_frame_buffer = None
+ _connection = None
+ _segment_consumed = False
+
+ def __init__(self, connection):
+ self._io_buffer = io.BytesIO()
+ self._connection = weakref.proxy(connection)
+
+ @property
+ def io_buffer(self):
+ return self._io_buffer
+
+ @property
+ def cql_frame_buffer(self):
+ return self._cql_frame_buffer if self.is_checksumming_enabled else \
+ self._io_buffer
+
+ def set_checksumming_buffer(self):
+ self.reset_io_buffer()
+ self._cql_frame_buffer = io.BytesIO()
+
+ @property
+ def is_checksumming_enabled(self):
+ return self._connection._is_checksumming_enabled
+
+ @property
+ def has_consumed_segment(self):
+ return self._segment_consumed;
+
+ def readable_io_bytes(self):
+ return self.io_buffer.tell()
+
+ def readable_cql_frame_bytes(self):
+ return self.cql_frame_buffer.tell()
+
+ def reset_io_buffer(self):
+ self._io_buffer = io.BytesIO(self._io_buffer.read())
+ self._io_buffer.seek(0, 2) # 2 == SEEK_END
+
+ def reset_cql_frame_buffer(self):
+ if self.is_checksumming_enabled:
+ self._cql_frame_buffer = io.BytesIO(self._cql_frame_buffer.read())
+ self._cql_frame_buffer.seek(0, 2) # 2 == SEEK_END
+ else:
+ self.reset_io_buffer()
class Connection(object):
@@ -624,6 +682,7 @@ class Connection(object):
# The current number of operations that are in flight. More precisely,
# the number of request IDs that are currently in use.
+ # This includes orphaned requests.
in_flight = 0
# Max concurrent requests allowed per connection. This is set optimistically high, allowing
@@ -641,6 +700,20 @@ class Connection(object):
# request_ids set
highest_request_id = 0
+ # Tracks the request IDs which are no longer waited on (timed out), but
+ # cannot be reused yet because the node might still send a response
+ # on this stream
+ orphaned_request_ids = None
+
+ # Set to true if the orphaned stream ID count cross configured threshold
+ # and the connection will be replaced
+ orphaned_threshold_reached = False
+
+ # If the number of orphaned streams reaches this threshold, this connection
+ # will become marked and will be replaced with a new connection by the
+ # owning pool (currently, only HostConnection supports this)
+ orphaned_threshold = 3 * max_in_flight // 4
+
is_defunct = False
is_closed = False
lock = None
@@ -655,28 +728,35 @@ class Connection(object):
allow_beta_protocol_version = False
- _iobuf = None
_current_frame = None
_socket = None
_socket_impl = socket
- _ssl_impl = ssl
_check_hostname = False
_product_type = None
+ _is_checksumming_enabled = False
+
+ _on_orphaned_stream_released = None
+
+ @property
+ def _iobuf(self):
+ # backward compatibility, to avoid any change in the reactors
+ return self._io_buffer.io_buffer
+
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression=True,
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
- ssl_context=None):
+ ssl_context=None, on_orphaned_stream_released=None):
# TODO next major rename host to endpoint and remove port kwarg.
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)
self.authenticator = authenticator
- self.ssl_options = ssl_options.copy() if ssl_options else None
+ self.ssl_options = ssl_options.copy() if ssl_options else {}
self.ssl_context = ssl_context
self.sockopts = sockopts
self.compression = compression
@@ -689,19 +769,27 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
self.no_compact = no_compact
self._push_watchers = defaultdict(set)
self._requests = {}
- self._iobuf = io.BytesIO()
+ self._io_buffer = _ConnectionIOBuffer(self)
self._continuous_paging_sessions = {}
+ self._socket_writable = True
+ self.orphaned_request_ids = set()
+ self._on_orphaned_stream_released = on_orphaned_stream_released
if ssl_options:
- self._check_hostname = bool(self.ssl_options.pop('check_hostname', False))
- if self._check_hostname:
- if not getattr(ssl, 'match_hostname', None):
- raise RuntimeError("ssl_options specify 'check_hostname', but ssl.match_hostname is not provided. "
- "Patch or upgrade Python to use this option.")
self.ssl_options.update(self.endpoint.ssl_options or {})
elif self.endpoint.ssl_options:
self.ssl_options = self.endpoint.ssl_options
+ # PYTHON-1331
+ #
+ # We always use SSLContext.wrap_socket() now but legacy configs may have other params that were passed to ssl.wrap_socket()...
+ # and either could have 'check_hostname'. Remove these params into a separate map and use them to build an SSLContext if
+ # we need to do so.
+ #
+ # Note the use of pop() here; we are very deliberately removing these params from ssl_options if they're present. After this
+ # operation ssl_options should contain only args needed for the ssl_context.wrap_socket() call.
+ if not self.ssl_context and self.ssl_options:
+ self.ssl_context = self._build_ssl_context_from_options()
if protocol_version >= 3:
self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1)
@@ -768,21 +856,57 @@ def factory(cls, endpoint, timeout, *args, **kwargs):
else:
return conn
+ def _build_ssl_context_from_options(self):
+
+ # Extract a subset of names from self.ssl_options which apply to SSLContext creation
+ ssl_context_opt_names = ['ssl_version', 'cert_reqs', 'check_hostname', 'keyfile', 'certfile', 'ca_certs', 'ciphers']
+ opts = {k:self.ssl_options.get(k, None) for k in ssl_context_opt_names if k in self.ssl_options}
+
+ # Python >= 3.10 requires either PROTOCOL_TLS_CLIENT or PROTOCOL_TLS_SERVER so we'll get ahead of things by always
+ # being explicit
+ ssl_version = opts.get('ssl_version', None) or ssl.PROTOCOL_TLS_CLIENT
+ cert_reqs = opts.get('cert_reqs', None) or ssl.CERT_REQUIRED
+ rv = ssl.SSLContext(protocol=int(ssl_version))
+ rv.check_hostname = bool(opts.get('check_hostname', False))
+ rv.options = int(cert_reqs)
+
+ certfile = opts.get('certfile', None)
+ keyfile = opts.get('keyfile', None)
+ if certfile:
+ rv.load_cert_chain(certfile, keyfile)
+ ca_certs = opts.get('ca_certs', None)
+ if ca_certs:
+ rv.load_verify_locations(ca_certs)
+ ciphers = opts.get('ciphers', None)
+ if ciphers:
+ rv.set_ciphers(ciphers)
+
+ return rv
+
def _wrap_socket_from_context(self):
- ssl_options = self.ssl_options or {}
+
+ # Extract a subset of names from self.ssl_options which apply to SSLContext.wrap_socket (or at least the parts
+ # of it that don't involve building an SSLContext under the covers)
+ wrap_socket_opt_names = ['server_side', 'do_handshake_on_connect', 'suppress_ragged_eofs', 'server_hostname']
+ opts = {k:self.ssl_options.get(k, None) for k in wrap_socket_opt_names if k in self.ssl_options}
+
# PYTHON-1186: set the server_hostname only if the SSLContext has
# check_hostname enabled and it is not already provided by the EndPoint ssl options
- if (self.ssl_context.check_hostname and
- 'server_hostname' not in ssl_options):
- ssl_options = ssl_options.copy()
- ssl_options['server_hostname'] = self.endpoint.address
- self._socket = self.ssl_context.wrap_socket(self._socket, **ssl_options)
+ #opts['server_hostname'] = self.endpoint.address
+ if (self.ssl_context.check_hostname and 'server_hostname' not in opts):
+ server_hostname = self.endpoint.address
+ opts['server_hostname'] = server_hostname
+
+ return self.ssl_context.wrap_socket(self._socket, **opts)
def _initiate_connection(self, sockaddr):
self._socket.connect(sockaddr)
- def _match_hostname(self):
- ssl.match_hostname(self._socket.getpeercert(), self.endpoint.address)
+ # PYTHON-1331
+ #
+ # Allow implementations specific to an event loop to add additional behaviours
+ def _validate_hostname(self):
+ pass
def _get_socket_addresses(self):
address, port = self.endpoint.resolve()
@@ -803,16 +927,18 @@ def _connect_socket(self):
try:
self._socket = self._socket_impl.socket(af, socktype, proto)
if self.ssl_context:
- self._wrap_socket_from_context()
- elif self.ssl_options:
- if not self._ssl_impl:
- raise RuntimeError("This version of Python was not compiled with SSL support")
- self._socket = self._ssl_impl.wrap_socket(self._socket, **self.ssl_options)
+ self._socket = self._wrap_socket_from_context()
self._socket.settimeout(self.connect_timeout)
self._initiate_connection(sockaddr)
self._socket.settimeout(None)
+
+ # PYTHON-1331
+ #
+ # Most checking is done via the check_hostname param on the SSLContext.
+ # Subclasses can add additional behaviours via _validate_hostname() so
+ # run that here.
if self._check_hostname:
- self._match_hostname()
+ self._validate_hostname()
sockerr = None
break
except socket.error as err:
@@ -829,6 +955,16 @@ def _connect_socket(self):
for args in self.sockopts:
self._socket.setsockopt(*args)
+ def _enable_compression(self):
+ if self._compressor:
+ self.compressor = self._compressor
+
+ def _enable_checksumming(self):
+ self._io_buffer.set_checksumming_buffer()
+ self._is_checksumming_enabled = True
+ self._segment_codec = segment_codec_lz4 if self.compressor else segment_codec_no_compression
+ log.debug("Enabling protocol checksumming on connection (%s).", id(self))
+
def close(self):
raise NotImplementedError()
@@ -925,11 +1061,20 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
raise ConnectionShutdown("Connection to %s is defunct" % self.endpoint)
elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.endpoint)
+ elif not self._socket_writable:
+ raise ConnectionBusy("Connection %s is overloaded" % self.endpoint)
# queue the decoder function with the request
# this allows us to inject custom functions per request to encode, decode messages
self._requests[request_id] = (cb, decoder, result_metadata)
- msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, allow_beta_protocol_version=self.allow_beta_protocol_version)
+ msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
+ allow_beta_protocol_version=self.allow_beta_protocol_version)
+
+ if self._is_checksumming_enabled:
+ buffer = io.BytesIO()
+ self._segment_codec.encode(buffer, msg)
+ msg = buffer.getvalue()
+
self.push(msg)
return len(msg)
@@ -1008,10 +1153,10 @@ def control_conn_disposed(self):
@defunct_on_error
def _read_frame_header(self):
- buf = self._iobuf.getvalue()
+ buf = self._io_buffer.cql_frame_buffer.getvalue()
pos = len(buf)
if pos:
- version = int_from_buf_item(buf[0]) & PROTOCOL_VERSION_MASK
+ version = buf[0] & PROTOCOL_VERSION_MASK
if version not in ProtocolVersion.SUPPORTED_VERSIONS:
raise ProtocolError("This version of the driver does not support protocol version %d" % version)
frame_header = frame_header_v3 if version >= 3 else frame_header_v1_v2
@@ -1024,29 +1169,62 @@ def _read_frame_header(self):
self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size)
return pos
- def _reset_frame(self):
- self._iobuf = io.BytesIO(self._iobuf.read())
- self._iobuf.seek(0, 2) # io.SEEK_END == 2 (constant not present in 2.6)
- self._current_frame = None
+ @defunct_on_error
+ def _process_segment_buffer(self):
+ readable_bytes = self._io_buffer.readable_io_bytes()
+ if readable_bytes >= self._segment_codec.header_length_with_crc:
+ try:
+ self._io_buffer.io_buffer.seek(0)
+ segment_header = self._segment_codec.decode_header(self._io_buffer.io_buffer)
+
+ if readable_bytes >= segment_header.segment_length:
+ segment = self._segment_codec.decode(self._iobuf, segment_header)
+ self._io_buffer._segment_consumed = True
+ self._io_buffer.cql_frame_buffer.write(segment.payload)
+ else:
+ # not enough data to read the segment. reset the buffer pointer at the
+ # beginning to not lose what we previously read (header).
+ self._io_buffer._segment_consumed = False
+ self._io_buffer.io_buffer.seek(0)
+ except CrcException as exc:
+ # re-raise an exception that inherits from ConnectionException
+ raise CrcMismatchException(str(exc), self.endpoint)
+ else:
+ self._io_buffer._segment_consumed = False
def process_io_buffer(self):
while True:
+ if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes():
+ self._process_segment_buffer()
+ self._io_buffer.reset_io_buffer()
+
+ if self._is_checksumming_enabled and not self._io_buffer.has_consumed_segment:
+ # We couldn't read an entire segment from the io buffer, so return
+ # control to allow more bytes to be read off the wire
+ return
+
if not self._current_frame:
pos = self._read_frame_header()
else:
- pos = self._iobuf.tell()
+ pos = self._io_buffer.readable_cql_frame_bytes()
if not self._current_frame or pos < self._current_frame.end_pos:
+ if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes():
+ # We have a multi-segments message and we need to read more
+ # data to complete the current cql frame
+ continue
+
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
return
else:
frame = self._current_frame
- self._iobuf.seek(frame.body_offset)
- msg = self._iobuf.read(frame.end_pos - frame.body_offset)
+ self._io_buffer.cql_frame_buffer.seek(frame.body_offset)
+ msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset)
self.process_msg(frame, msg)
- self._reset_frame()
+ self._io_buffer.reset_cql_frame_buffer()
+ self._current_frame = None
@defunct_on_error
def process_msg(self, header, body):
@@ -1063,11 +1241,22 @@ def process_msg(self, header, body):
decoder = paging_session.decoder
result_metadata = None
else:
+ need_notify_of_release = False
+ with self.lock:
+ if stream_id in self.orphaned_request_ids:
+ self.in_flight -= 1
+ self.orphaned_request_ids.remove(stream_id)
+ need_notify_of_release = True
+ if need_notify_of_release and self._on_orphaned_stream_released:
+ self._on_orphaned_stream_released()
+
try:
callback, decoder, result_metadata = self._requests.pop(stream_id)
# This can only happen if the stream_id was
# removed due to an OperationTimedOut
except KeyError:
+ with self.lock:
+ self.request_ids.append(stream_id)
return
try:
@@ -1166,7 +1355,7 @@ def _handle_options_response(self, options_response):
remote_supported_compressions)
else:
compression_type = None
- if isinstance(self.compression, six.string_types):
+ if isinstance(self.compression, str):
# the user picked a specific compression type ('snappy' or 'lz4')
if self.compression not in remote_supported_compressions:
raise ProtocolError(
@@ -1181,11 +1370,19 @@ def _handle_options_response(self, options_response):
compression_type = k
break
- # set the decompressor here, but set the compressor only after
- # a successful Ready message
- self._compression_type = compression_type
- self._compressor, self.decompressor = \
- locally_supported_compressions[compression_type]
+ # If snappy compression is selected with v5+checksumming, the connection
+ # will fail with OTO. Only lz4 is supported
+ if (compression_type == 'snappy' and
+ ProtocolVersion.has_checksumming_support(self.protocol_version)):
+ log.debug("Snappy compression is not supported with protocol version %s and "
+ "checksumming. Consider installing lz4. Disabling compression.", self.protocol_version)
+ compression_type = None
+ else:
+ # set the decompressor here, but set the compressor only after
+ # a successful Ready message
+ self._compression_type = compression_type
+ self._compressor, self.decompressor = \
+ locally_supported_compressions[compression_type]
self._send_startup_message(compression_type, no_compact=self.no_compact)
@@ -1206,6 +1403,7 @@ def _send_startup_message(self, compression=None, no_compact=False):
def _handle_startup_response(self, startup_response, did_authenticate=False):
if self.is_defunct:
return
+
if isinstance(startup_response, ReadyMessage):
if self.authenticator:
log.warning("An authentication challenge was not sent, "
@@ -1214,8 +1412,11 @@ def _handle_startup_response(self, startup_response, did_authenticate=False):
self.authenticator.__class__.__name__)
log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.endpoint)
- if self._compressor:
- self.compressor = self._compressor
+ self._enable_compression()
+
+ if ProtocolVersion.has_checksumming_support(self.protocol_version):
+ self._enable_checksumming()
+
self.connected_event.set()
elif isinstance(startup_response, AuthenticateMessage):
log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s",
@@ -1227,6 +1428,10 @@ def _handle_startup_response(self, startup_response, did_authenticate=False):
"if DSE authentication is configured with transitional mode" % (self.host,))
raise AuthenticationFailed('Remote end requires authentication')
+ self._enable_compression()
+ if ProtocolVersion.has_checksumming_support(self.protocol_version):
+ self._enable_checksumming()
+
if isinstance(self.authenticator, dict):
log.debug("Sending credentials-based auth response on %s", self)
cm = CredentialsMessage(creds=self.authenticator)
@@ -1442,7 +1647,7 @@ def __init__(self, connection, owner):
log.debug("Sending options message heartbeat on idle connection (%s) %s",
id(connection), connection.endpoint)
with connection.lock:
- if connection.in_flight <= connection.max_request_id:
+ if connection.in_flight < connection.max_request_id:
connection.in_flight += 1
connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
else:
diff --git a/cassandra/cqlengine/__init__.py b/cassandra/cqlengine/__init__.py
index e2a952d682..b9466e961b 100644
--- a/cassandra/cqlengine/__init__.py
+++ b/cassandra/cqlengine/__init__.py
@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import six
-
-
# Caching constants.
CACHING_ALL = "ALL"
CACHING_KEYS_ONLY = "KEYS_ONLY"
@@ -31,7 +28,4 @@ class ValidationError(CQLEngineException):
class UnicodeMixin(object):
- if six.PY3:
- __str__ = lambda x: x.__unicode__()
- else:
- __str__ = lambda x: six.text_type(x).encode('utf-8')
+ __str__ = lambda x: x.__unicode__()
diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py
index 49116129fc..7c20ec6642 100644
--- a/cassandra/cqlengine/columns.py
+++ b/cassandra/cqlengine/columns.py
@@ -15,7 +15,6 @@
from copy import deepcopy, copy
from datetime import date, datetime, timedelta
import logging
-import six
from uuid import UUID as _UUID
from cassandra import util
@@ -327,7 +326,7 @@ class Blob(Column):
def to_database(self, value):
- if not isinstance(value, (six.binary_type, bytearray)):
+ if not isinstance(value, (bytes, bytearray)):
raise Exception("expecting a binary, got a %s" % type(value))
val = super(Bytes, self).to_database(value)
@@ -381,7 +380,7 @@ def __init__(self, min_length=None, max_length=None, **kwargs):
def validate(self, value):
value = super(Text, self).validate(value)
- if not isinstance(value, (six.string_types, bytearray)) and value is not None:
+ if not isinstance(value, (str, bytearray)) and value is not None:
raise ValidationError('{0} {1} is not a string'.format(self.column_name, type(value)))
if self.max_length is not None:
if value and len(value) > self.max_length:
@@ -655,7 +654,7 @@ def validate(self, value):
return
if isinstance(val, _UUID):
return val
- if isinstance(val, six.string_types):
+ if isinstance(val, str):
try:
return _UUID(val)
except ValueError:
diff --git a/cassandra/cqlengine/connection.py b/cassandra/cqlengine/connection.py
index 884e04ed74..588e512a2d 100644
--- a/cassandra/cqlengine/connection.py
+++ b/cassandra/cqlengine/connection.py
@@ -14,7 +14,6 @@
from collections import defaultdict
import logging
-import six
import threading
from cassandra.cluster import Cluster, _ConfigMode, _NOT_SET, NoHostAvailable, UserTypeDoesNotExist, ConsistencyLevel
@@ -98,7 +97,13 @@ def setup(self):
if self.lazy_connect:
return
- self.cluster = Cluster(self.hosts, **self.cluster_options)
+ if 'cloud' in self.cluster_options:
+ if self.hosts:
+ log.warning("Ignoring hosts %s because a cloud config was provided.", self.hosts)
+ self.cluster = Cluster(**self.cluster_options)
+ else:
+ self.cluster = Cluster(self.hosts, **self.cluster_options)
+
try:
self.session = self.cluster.connect()
log.debug(format_log_context("connection initialized with internally created session", connection=self.name))
@@ -301,6 +306,8 @@ def set_session(s):
log.debug("cqlengine default connection initialized with %s", s)
+# TODO next major: if a cloud config is specified in kwargs, hosts will be ignored.
+# This function should be refactored to reflect this change. PYTHON-1265
def setup(
hosts,
default_keyspace,
@@ -338,7 +345,7 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connect
elif isinstance(query, BaseCQLStatement):
params = query.get_context()
query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size)
- elif isinstance(query, six.string_types):
+ elif isinstance(query, str):
query = SimpleStatement(query, consistency_level=consistency_level)
log.debug(format_log_context('Query: {}, Params: {}'.format(query.query_string, params), connection=connection))
diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py
index 536bde6349..6790a117c7 100644
--- a/cassandra/cqlengine/management.py
+++ b/cassandra/cqlengine/management.py
@@ -16,7 +16,6 @@
import json
import logging
import os
-import six
import warnings
from itertools import product
@@ -232,7 +231,7 @@ def _sync_table(model, connection=None):
except CQLEngineException as ex:
# 1.2 doesn't return cf names, so we have to examine the exception
# and ignore if it says the column family already exists
- if "Cannot add already existing column family" not in six.text_type(ex):
+ if "Cannot add already existing column family" not in str(ex):
raise
else:
log.debug(format_log_context("sync_table checking existing table %s", keyspace=ks_name, connection=connection), cf_name)
@@ -477,7 +476,7 @@ def _update_options(model, connection=None):
except KeyError:
msg = format_log_context("Invalid table option: '%s'; known options: %s", keyspace=ks_name, connection=connection)
raise KeyError(msg % (name, existing_options.keys()))
- if isinstance(existing_value, six.string_types):
+ if isinstance(existing_value, str):
if value != existing_value:
update_options[name] = value
else:
diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py
index b3c7c9e37f..bc00001666 100644
--- a/cassandra/cqlengine/models.py
+++ b/cassandra/cqlengine/models.py
@@ -14,7 +14,6 @@
import logging
import re
-import six
from warnings import warn
from cassandra.cqlengine import CQLEngineException, ValidationError
@@ -614,7 +613,7 @@ def __iter__(self):
def __getitem__(self, key):
""" Returns column's value. """
- if not isinstance(key, six.string_types):
+ if not isinstance(key, str):
raise TypeError
if key not in self._columns.keys():
raise KeyError
@@ -622,7 +621,7 @@ def __getitem__(self, key):
def __setitem__(self, key, val):
""" Sets a column's value. """
- if not isinstance(key, six.string_types):
+ if not isinstance(key, str):
raise TypeError
if key not in self._columns.keys():
raise KeyError
@@ -1042,8 +1041,7 @@ def _transform_column(col_name, col_obj):
return klass
-@six.add_metaclass(ModelMetaClass)
-class Model(BaseModel):
+class Model(BaseModel, metaclass=ModelMetaClass):
__abstract__ = True
"""
*Optional.* Indicates that this model is only intended to be used as a base class for other models.
diff --git a/cassandra/cqlengine/operators.py b/cassandra/cqlengine/operators.py
index bba505583c..2adf51758d 100644
--- a/cassandra/cqlengine/operators.py
+++ b/cassandra/cqlengine/operators.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import six
from cassandra.cqlengine import UnicodeMixin
@@ -44,8 +43,7 @@ def __init__(cls, name, bases, dct):
super(OpMapMeta, cls).__init__(name, bases, dct)
-@six.add_metaclass(OpMapMeta)
-class BaseWhereOperator(BaseQueryOperator):
+class BaseWhereOperator(BaseQueryOperator, metaclass=OpMapMeta):
""" base operator used for where clauses """
@classmethod
def get_operator(cls, symbol):
diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py
index 11f664ec02..73f48a5928 100644
--- a/cassandra/cqlengine/query.py
+++ b/cassandra/cqlengine/query.py
@@ -16,7 +16,6 @@
from datetime import datetime, timedelta
from functools import partial
import time
-import six
from warnings import warn
from cassandra.query import SimpleStatement, BatchType as CBatchType, BatchStatement
@@ -103,29 +102,29 @@ def in_(self, item):
used where you'd typically want to use python's `in` operator
"""
- return WhereClause(six.text_type(self), InOperator(), item)
+ return WhereClause(str(self), InOperator(), item)
def contains_(self, item):
"""
Returns a CONTAINS operator
"""
- return WhereClause(six.text_type(self), ContainsOperator(), item)
+ return WhereClause(str(self), ContainsOperator(), item)
def __eq__(self, other):
- return WhereClause(six.text_type(self), EqualsOperator(), self._to_database(other))
+ return WhereClause(str(self), EqualsOperator(), self._to_database(other))
def __gt__(self, other):
- return WhereClause(six.text_type(self), GreaterThanOperator(), self._to_database(other))
+ return WhereClause(str(self), GreaterThanOperator(), self._to_database(other))
def __ge__(self, other):
- return WhereClause(six.text_type(self), GreaterThanOrEqualOperator(), self._to_database(other))
+ return WhereClause(str(self), GreaterThanOrEqualOperator(), self._to_database(other))
def __lt__(self, other):
- return WhereClause(six.text_type(self), LessThanOperator(), self._to_database(other))
+ return WhereClause(str(self), LessThanOperator(), self._to_database(other))
def __le__(self, other):
- return WhereClause(six.text_type(self), LessThanOrEqualOperator(), self._to_database(other))
+ return WhereClause(str(self), LessThanOrEqualOperator(), self._to_database(other))
class BatchType(object):
@@ -231,7 +230,7 @@ def execute(self):
opener = 'BEGIN ' + (str(batch_type) + ' ' if batch_type else '') + ' BATCH'
if self.timestamp:
- if isinstance(self.timestamp, six.integer_types):
+ if isinstance(self.timestamp, int):
ts = self.timestamp
elif isinstance(self.timestamp, (datetime, timedelta)):
ts = self.timestamp
@@ -286,15 +285,15 @@ class ContextQuery(object):
with ContextQuery(Automobile, keyspace='test2') as A:
A.objects.create(manufacturer='honda', year=2008, model='civic')
- print len(A.objects.all()) # 1 result
+ print(len(A.objects.all())) # 1 result
with ContextQuery(Automobile, keyspace='test4') as A:
- print len(A.objects.all()) # 0 result
+ print(len(A.objects.all())) # 0 result
# Multiple models
with ContextQuery(Automobile, Automobile2, connection='cluster2') as (A, A2):
- print len(A.objects.all())
- print len(A2.objects.all())
+ print(len(A.objects.all()))
+ print(len(A2.objects.all()))
"""
@@ -407,7 +406,7 @@ def _execute(self, statement):
return result
def __unicode__(self):
- return six.text_type(self._select_query())
+ return str(self._select_query())
def __str__(self):
return str(self.__unicode__())
@@ -604,7 +603,7 @@ def batch(self, batch_obj):
def first(self):
try:
- return six.next(iter(self))
+ return next(iter(self))
except StopIteration:
return None
@@ -809,11 +808,11 @@ class Comment(Model):
print("Normal")
for comment in Comment.objects(photo_id=u):
- print comment.comment_id
+ print(comment.comment_id)
print("Reversed")
for comment in Comment.objects(photo_id=u).order_by("-comment_id"):
- print comment.comment_id
+ print(comment.comment_id)
"""
if len(colnames) == 0:
clone = copy.deepcopy(self)
@@ -901,7 +900,7 @@ def limit(self, v):
if v is None:
v = 0
- if not isinstance(v, six.integer_types):
+ if not isinstance(v, int):
raise TypeError
if v == self._limit:
return self
@@ -925,7 +924,7 @@ def fetch_size(self, v):
print(user)
"""
- if not isinstance(v, six.integer_types):
+ if not isinstance(v, int):
raise TypeError
if v == self._fetch_size:
return self
diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py
index c6ceb16607..d92d0b2452 100644
--- a/cassandra/cqlengine/statements.py
+++ b/cassandra/cqlengine/statements.py
@@ -14,8 +14,6 @@
from datetime import datetime, timedelta
import time
-import six
-from six.moves import filter
from cassandra.query import FETCH_SIZE_UNSET
from cassandra.cqlengine import columns
@@ -114,7 +112,7 @@ def __init__(self, field, operator, value, quote_field=True):
def __unicode__(self):
field = ('"{0}"' if self.quote_field else '{0}').format(self.field)
- return u'{0} {1} {2}'.format(field, self.operator, six.text_type(self.query_value))
+ return u'{0} {1} {2}'.format(field, self.operator, str(self.query_value))
def __hash__(self):
return super(WhereClause, self).__hash__() ^ hash(self.operator)
@@ -186,8 +184,7 @@ def __init__(cls, name, bases, dct):
super(ContainerUpdateTypeMapMeta, cls).__init__(name, bases, dct)
-@six.add_metaclass(ContainerUpdateTypeMapMeta)
-class ContainerUpdateClause(AssignmentClause):
+class ContainerUpdateClause(AssignmentClause, metaclass=ContainerUpdateTypeMapMeta):
def __init__(self, field, value, operation=None, previous=None):
super(ContainerUpdateClause, self).__init__(field, value)
@@ -563,7 +560,7 @@ def add_conditional_clause(self, clause):
self.conditionals.append(clause)
def _get_conditionals(self):
- return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.conditionals]))
+ return 'IF {0}'.format(' AND '.join([str(c) for c in self.conditionals]))
def get_context_size(self):
return len(self.get_context())
@@ -584,7 +581,7 @@ def timestamp_normalized(self):
if not self.timestamp:
return None
- if isinstance(self.timestamp, six.integer_types):
+ if isinstance(self.timestamp, int):
return self.timestamp
if isinstance(self.timestamp, timedelta):
@@ -602,7 +599,7 @@ def __repr__(self):
@property
def _where(self):
- return 'WHERE {0}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses]))
+ return 'WHERE {0}'.format(' AND '.join([str(c) for c in self.where_clauses]))
class SelectStatement(BaseCQLStatement):
@@ -629,10 +626,10 @@ def __init__(self,
fetch_size=fetch_size
)
- self.fields = [fields] if isinstance(fields, six.string_types) else (fields or [])
+ self.fields = [fields] if isinstance(fields, str) else (fields or [])
self.distinct_fields = distinct_fields
self.count = count
- self.order_by = [order_by] if isinstance(order_by, six.string_types) else order_by
+ self.order_by = [order_by] if isinstance(order_by, str) else order_by
self.limit = limit
self.allow_filtering = allow_filtering
@@ -653,7 +650,7 @@ def __unicode__(self):
qs += [self._where]
if self.order_by and not self.count:
- qs += ['ORDER BY {0}'.format(', '.join(six.text_type(o) for o in self.order_by))]
+ qs += ['ORDER BY {0}'.format(', '.join(str(o) for o in self.order_by))]
if self.limit:
qs += ['LIMIT {0}'.format(self.limit)]
@@ -798,7 +795,7 @@ def __unicode__(self):
qs += ["USING {0}".format(" AND ".join(using_options))]
qs += ['SET']
- qs += [', '.join([six.text_type(c) for c in self.assignments])]
+ qs += [', '.join([str(c) for c in self.assignments])]
if self.where_clauses:
qs += [self._where]
@@ -849,7 +846,7 @@ def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=
conditionals=conditionals
)
self.fields = []
- if isinstance(fields, six.string_types):
+ if isinstance(fields, str):
fields = [fields]
for field in fields or []:
self.add_field(field)
@@ -874,7 +871,7 @@ def get_context(self):
return ctx
def add_field(self, field):
- if isinstance(field, six.string_types):
+ if isinstance(field, str):
field = FieldDeleteClause(field)
if not isinstance(field, BaseClause):
raise StatementException("only instances of AssignmentClause can be added to statements")
diff --git a/cassandra/cqlengine/usertype.py b/cassandra/cqlengine/usertype.py
index 155068d99e..7fa85f1919 100644
--- a/cassandra/cqlengine/usertype.py
+++ b/cassandra/cqlengine/usertype.py
@@ -13,7 +13,6 @@
# limitations under the License.
import re
-import six
from cassandra.util import OrderedDict
from cassandra.cqlengine import CQLEngineException
@@ -72,7 +71,7 @@ def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
- return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in six.iteritems(self._values)))
+ return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in self._values.items()))
def has_changed_fields(self):
return any(v.changed for v in self._values.values())
@@ -93,14 +92,14 @@ def __getattr__(self, attr):
raise AttributeError(attr)
def __getitem__(self, key):
- if not isinstance(key, six.string_types):
+ if not isinstance(key, str):
raise TypeError
if key not in self._fields.keys():
raise KeyError
return getattr(self, key)
def __setitem__(self, key, val):
- if not isinstance(key, six.string_types):
+ if not isinstance(key, str):
raise TypeError
if key not in self._fields.keys():
raise KeyError
@@ -198,8 +197,7 @@ def _transform_column(field_name, field_obj):
return klass
-@six.add_metaclass(UserTypeMetaClass)
-class UserType(BaseUserType):
+class UserType(BaseUserType, metaclass=UserTypeMetaClass):
"""
This class is used to model User Defined Types. To define a type, declare a class inheriting from this,
and assign field types as class attributes:
diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py
index 7946a63af8..b413b1c9e5 100644
--- a/cassandra/cqltypes.py
+++ b/cassandra/cqltypes.py
@@ -39,8 +39,6 @@
import re
import socket
import time
-import six
-from six.moves import range
import struct
import sys
from uuid import UUID
@@ -51,13 +49,10 @@
float_pack, float_unpack, double_pack, double_unpack,
varint_pack, varint_unpack, point_be, point_le,
vints_pack, vints_unpack)
-from cassandra import util
+from cassandra import util, VectorDeserializationFailure
_little_endian_flag = 1 # we always serialize LE
-if six.PY3:
- import ipaddress
-
-_ord = ord if six.PY2 else lambda x: x
+import ipaddress
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
@@ -66,16 +61,12 @@
log = logging.getLogger(__name__)
-if six.PY3:
- _number_types = frozenset((int, float))
- long = int
+_number_types = frozenset((int, float))
+
- def _name_from_hex_string(encoded_name):
- bin_str = unhexlify(encoded_name)
- return bin_str.decode('ascii')
-else:
- _number_types = frozenset((int, long, float))
- _name_from_hex_string = unhexlify
+def _name_from_hex_string(encoded_name):
+ bin_str = unhexlify(encoded_name)
+ return bin_str.decode('ascii')
def trim_if_startswith(s, prefix):
@@ -235,13 +226,15 @@ def parse_casstype_args(typestring):
else:
names.append(None)
- ctype = lookup_casstype_simple(tok)
+ try:
+ ctype = int(tok)
+ except ValueError:
+ ctype = lookup_casstype_simple(tok)
types.append(ctype)
# return the first (outer) type, which will have all parameters applied
return args[0][0][0]
-
def lookup_casstype(casstype):
"""
Given a Cassandra type as a string (possibly including parameters), hand
@@ -259,6 +252,7 @@ def lookup_casstype(casstype):
try:
return parse_casstype_args(casstype)
except (ValueError, AssertionError, IndexError) as e:
+ log.debug("Exception in parse_casstype_args: %s" % e)
raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e))
@@ -276,8 +270,7 @@ def __str__(self):
EMPTY = EmptyValue()
-@six.add_metaclass(CassandraTypeType)
-class _CassandraType(object):
+class _CassandraType(object, metaclass=CassandraTypeType):
subtypes = ()
num_subtypes = 0
empty_binary_ok = False
@@ -296,7 +289,7 @@ class _CassandraType(object):
"""
def __repr__(self):
- return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
+ return '<%s>' % (self.cql_parameterized_type())
@classmethod
def from_binary(cls, byts, protocol_version):
@@ -380,8 +373,6 @@ def apply_parameters(cls, subtypes, names=None):
raise ValueError("%s types require %d subtypes (%d given)"
% (cls.typename, cls.num_subtypes, len(subtypes)))
newname = cls.cass_parameterized_type_with(subtypes)
- if six.PY2 and isinstance(newname, unicode):
- newname = newname.encode('utf-8')
return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names})
@classmethod
@@ -412,16 +403,10 @@ class _UnrecognizedType(_CassandraType):
num_subtypes = 'UNKNOWN'
-if six.PY3:
- def mkUnrecognizedType(casstypename):
- return CassandraTypeType(casstypename,
- (_UnrecognizedType,),
- {'typename': "'%s'" % casstypename})
-else:
- def mkUnrecognizedType(casstypename): # noqa
- return CassandraTypeType(casstypename.encode('utf8'),
- (_UnrecognizedType,),
- {'typename': "'%s'" % casstypename})
+def mkUnrecognizedType(casstypename):
+ return CassandraTypeType(casstypename,
+ (_UnrecognizedType,),
+ {'typename': "'%s'" % casstypename})
class BytesType(_CassandraType):
@@ -430,7 +415,7 @@ class BytesType(_CassandraType):
@staticmethod
def serialize(val, protocol_version):
- return six.binary_type(val)
+ return bytes(val)
class DecimalType(_CassandraType):
@@ -476,6 +461,7 @@ def serialize(uuid, protocol_version):
class BooleanType(_CassandraType):
typename = 'boolean'
+ serial_size = 1
@staticmethod
def deserialize(byts, protocol_version):
@@ -497,29 +483,25 @@ def serialize(byts, protocol_version):
return int8_pack(byts)
-if six.PY2:
- class AsciiType(_CassandraType):
- typename = 'ascii'
- empty_binary_ok = True
-else:
- class AsciiType(_CassandraType):
- typename = 'ascii'
- empty_binary_ok = True
+class AsciiType(_CassandraType):
+ typename = 'ascii'
+ empty_binary_ok = True
- @staticmethod
- def deserialize(byts, protocol_version):
- return byts.decode('ascii')
+ @staticmethod
+ def deserialize(byts, protocol_version):
+ return byts.decode('ascii')
- @staticmethod
- def serialize(var, protocol_version):
- try:
- return var.encode('ascii')
- except UnicodeDecodeError:
- return var
+ @staticmethod
+ def serialize(var, protocol_version):
+ try:
+ return var.encode('ascii')
+ except UnicodeDecodeError:
+ return var
class FloatType(_CassandraType):
typename = 'float'
+ serial_size = 4
@staticmethod
def deserialize(byts, protocol_version):
@@ -532,6 +514,7 @@ def serialize(byts, protocol_version):
class DoubleType(_CassandraType):
typename = 'double'
+ serial_size = 8
@staticmethod
def deserialize(byts, protocol_version):
@@ -544,6 +527,7 @@ def serialize(byts, protocol_version):
class LongType(_CassandraType):
typename = 'bigint'
+ serial_size = 8
@staticmethod
def deserialize(byts, protocol_version):
@@ -556,6 +540,7 @@ def serialize(byts, protocol_version):
class Int32Type(_CassandraType):
typename = 'int'
+ serial_size = 4
@staticmethod
def deserialize(byts, protocol_version):
@@ -600,7 +585,7 @@ def serialize(addr, protocol_version):
# since we've already determined the AF
return socket.inet_aton(addr)
except:
- if six.PY3 and isinstance(addr, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
+ if isinstance(addr, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
return addr.packed
raise ValueError("can't interpret %r as an inet address" % (addr,))
@@ -659,7 +644,7 @@ def serialize(v, protocol_version):
raise TypeError('DateType arguments must be a datetime, date, or timestamp')
timestamp = v
- return int64_pack(long(timestamp))
+ return int64_pack(int(timestamp))
class TimestampType(DateType):
@@ -668,6 +653,7 @@ class TimestampType(DateType):
class TimeUUIDType(DateType):
typename = 'timeuuid'
+ serial_size = 16
def my_timestamp(self):
return util.unix_time_from_uuid1(self.val)
@@ -703,7 +689,7 @@ def serialize(val, protocol_version):
try:
days = val.days_from_epoch
except AttributeError:
- if isinstance(val, six.integer_types):
+ if isinstance(val, int):
# the DB wants offset int values, but util.Date init takes days from epoch
# here we assume int values are offset, as they would appear in CQL
# short circuit to avoid subtracting just to add offset
@@ -714,6 +700,7 @@ def serialize(val, protocol_version):
class ShortType(_CassandraType):
typename = 'smallint'
+ serial_size = 2
@staticmethod
def deserialize(byts, protocol_version):
@@ -726,6 +713,7 @@ def serialize(byts, protocol_version):
class TimeType(_CassandraType):
typename = 'time'
+ serial_size = 8
@staticmethod
def deserialize(byts, protocol_version):
@@ -823,7 +811,7 @@ def deserialize_safe(cls, byts, protocol_version):
@classmethod
def serialize_safe(cls, items, protocol_version):
- if isinstance(items, six.string_types):
+ if isinstance(items, str):
raise TypeError("Received a string for a type that expects a sequence")
subtype, = cls.subtypes
@@ -897,7 +885,7 @@ def serialize_safe(cls, themap, protocol_version):
buf = io.BytesIO()
buf.write(pack(len(themap)))
try:
- items = six.iteritems(themap)
+ items = themap.items()
except AttributeError:
raise TypeError("Got a non-map object for a map value")
inner_proto = max(3, protocol_version)
@@ -972,9 +960,6 @@ class UserType(TupleType):
def make_udt_class(cls, keyspace, udt_name, field_names, field_types):
assert len(field_names) == len(field_types)
- if six.PY2 and isinstance(udt_name, unicode):
- udt_name = udt_name.encode('utf-8')
-
instance = cls._cache.get((keyspace, udt_name))
if not instance or instance.fieldnames != field_names or instance.subtypes != field_types:
instance = type(udt_name, (cls,), {'subtypes': field_types,
@@ -989,8 +974,6 @@ def make_udt_class(cls, keyspace, udt_name, field_names, field_types):
@classmethod
def evict_udt_class(cls, keyspace, udt_name):
- if six.PY2 and isinstance(udt_name, unicode):
- udt_name = udt_name.encode('utf-8')
try:
del cls._cache[(keyspace, udt_name)]
except KeyError:
@@ -1026,7 +1009,9 @@ def serialize_safe(cls, val, protocol_version):
try:
item = val[i]
except TypeError:
- item = getattr(val, fieldname)
+ item = getattr(val, fieldname, None)
+ if item is None and not hasattr(val, fieldname):
+ log.warning(f"field {fieldname} is part of the UDT {cls.typename} but is not present in the value {val}")
if item is not None:
packed_item = subtype.to_binary(item, proto_version)
@@ -1145,7 +1130,7 @@ def serialize_safe(cls, val, protocol_version):
def is_counter_type(t):
- if isinstance(t, six.string_types):
+ if isinstance(t, str):
t = lookup_casstype(t)
return issubclass(t, CounterColumnType)
@@ -1181,7 +1166,7 @@ def serialize(val, protocol_version):
@staticmethod
def deserialize(byts, protocol_version):
- is_little_endian = bool(_ord(byts[0]))
+ is_little_endian = bool(byts[0])
point = point_le if is_little_endian else point_be
return util.Point(*point.unpack_from(byts, 5)) # ofs = endian byte + int type
@@ -1198,7 +1183,7 @@ def serialize(val, protocol_version):
@staticmethod
def deserialize(byts, protocol_version):
- is_little_endian = bool(_ord(byts[0]))
+ is_little_endian = bool(byts[0])
point = point_le if is_little_endian else point_be
coords = ((point.unpack_from(byts, offset) for offset in range(1 + 4 + 4, len(byts), point.size))) # start = endian + int type + int count
return util.LineString(coords)
@@ -1227,7 +1212,7 @@ def serialize(val, protocol_version):
@staticmethod
def deserialize(byts, protocol_version):
- is_little_endian = bool(_ord(byts[0]))
+ is_little_endian = bool(byts[0])
if is_little_endian:
int_fmt = '" % (cls.typename, cls.subtype.typename, cls.vector_size)
diff --git a/cassandra/cython_marshal.pyx b/cassandra/cython_marshal.pyx
index e4f30e6a85..0a926b6eef 100644
--- a/cassandra/cython_marshal.pyx
+++ b/cassandra/cython_marshal.pyx
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import six
-
from libc.stdint cimport (int8_t, int16_t, int32_t, int64_t,
uint8_t, uint16_t, uint32_t, uint64_t)
from libc.string cimport memcpy
@@ -24,8 +22,6 @@ from cassandra.buffer cimport Buffer, buf_read, to_bytes
cdef bint is_little_endian
from cassandra.util import is_little_endian
-cdef bint PY3 = six.PY3
-
ctypedef fused num_t:
int64_t
int32_t
@@ -57,10 +53,7 @@ cdef inline num_t unpack_num(Buffer *buf, num_t *dummy=NULL): # dummy pointer be
cdef varint_unpack(Buffer *term):
"""Unpack a variable-sized integer"""
- if PY3:
- return varint_unpack_py3(to_bytes(term))
- else:
- return varint_unpack_py2(to_bytes(term))
+ return varint_unpack_py3(to_bytes(term))
# TODO: Optimize these two functions
cdef varint_unpack_py3(bytes term):
@@ -70,13 +63,6 @@ cdef varint_unpack_py3(bytes term):
val -= 1 << shift
return val
-cdef varint_unpack_py2(bytes term): # noqa
- val = int(term.encode('hex'), 16)
- if (ord(term[0]) & 128) != 0:
- shift = len(term) * 8 # * Note below
- val = val - (1 << shift)
- return val
-
# * Note *
# '1 << (len(term) * 8)' Cython tries to do native
# integer shifts, which overflows. We need this to
diff --git a/cassandra/datastax/cloud/__init__.py b/cassandra/datastax/cloud/__init__.py
index 46fd822b87..0f042ff1c8 100644
--- a/cassandra/datastax/cloud/__init__.py
+++ b/cassandra/datastax/cloud/__init__.py
@@ -18,12 +18,11 @@
import sys
import tempfile
import shutil
-import six
-from six.moves.urllib.request import urlopen
+from urllib.request import urlopen
_HAS_SSL = True
try:
- from ssl import SSLContext, PROTOCOL_TLSv1, CERT_REQUIRED
+ from ssl import SSLContext, PROTOCOL_TLS, CERT_REQUIRED
except:
_HAS_SSL = False
@@ -41,7 +40,7 @@
__all__ = ['get_cloud_config']
-PRODUCT_APOLLO = "DATASTAX_APOLLO"
+DATASTAX_CLOUD_PRODUCT_TYPE = "DATASTAX_APOLLO"
class CloudConfig(object):
@@ -97,8 +96,9 @@ def get_cloud_config(cloud_config, create_pyopenssl_context=False):
def read_cloud_config_from_zip(cloud_config, create_pyopenssl_context):
secure_bundle = cloud_config['secure_connect_bundle']
+ use_default_tempdir = cloud_config.get('use_default_tempdir', None)
with ZipFile(secure_bundle) as zipfile:
- base_dir = os.path.dirname(secure_bundle)
+ base_dir = tempfile.gettempdir() if use_default_tempdir else os.path.dirname(secure_bundle)
tmp_dir = tempfile.mkdtemp(dir=base_dir)
try:
zipfile.extractall(path=tmp_dir)
@@ -138,7 +138,7 @@ def read_metadata_info(config, cloud_config):
except Exception as e:
log.exception(e)
raise DriverException("Unable to connect to the metadata service at %s. "
- "Check the cluster status in the Constellation cloud console. " % url)
+ "Check the cluster status in the cloud console. " % url)
if response.code != 200:
raise DriverException(("Error while fetching the metadata at: %s. "
@@ -169,7 +169,7 @@ def parse_metadata_info(config, http_data):
def _ssl_context_from_cert(ca_cert_location, cert_location, key_location):
- ssl_context = SSLContext(PROTOCOL_TLSv1)
+ ssl_context = SSLContext(PROTOCOL_TLS)
ssl_context.load_verify_locations(ca_cert_location)
ssl_context.verify_mode = CERT_REQUIRED
ssl_context.load_cert_chain(certfile=cert_location, keyfile=key_location)
@@ -181,11 +181,9 @@ def _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location):
try:
from OpenSSL import SSL
except ImportError as e:
- six.reraise(
- ImportError,
- ImportError("PyOpenSSL must be installed to connect to Apollo with the Eventlet or Twisted event loops"),
- sys.exc_info()[2]
- )
+ raise ImportError(
+ "PyOpenSSL must be installed to connect to Astra with the Eventlet or Twisted event loops")\
+ .with_traceback(e.__traceback__)
ssl_context = SSL.Context(SSL.TLSv1_METHOD)
ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: ok)
ssl_context.use_certificate_file(cert_location)
diff --git a/cassandra/datastax/graph/__init__.py b/cassandra/datastax/graph/__init__.py
index 0c03c9249d..11785c84f6 100644
--- a/cassandra/datastax/graph/__init__.py
+++ b/cassandra/datastax/graph/__init__.py
@@ -13,10 +13,11 @@
# limitations under the License.
-from cassandra.datastax.graph.types import Element, Vertex, VertexProperty, Edge, Path
+from cassandra.datastax.graph.types import Element, Vertex, VertexProperty, Edge, Path, T
from cassandra.datastax.graph.query import (
GraphOptions, GraphProtocol, GraphStatement, SimpleGraphStatement, Result,
graph_object_row_factory, single_object_row_factory,
- graph_result_row_factory, graph_graphson2_row_factory
+ graph_result_row_factory, graph_graphson2_row_factory,
+ graph_graphson3_row_factory
)
from cassandra.datastax.graph.graphson import *
diff --git a/cassandra/datastax/graph/fluent/__init__.py b/cassandra/datastax/graph/fluent/__init__.py
index 5365a59a06..92f148721e 100644
--- a/cassandra/datastax/graph/fluent/__init__.py
+++ b/cassandra/datastax/graph/fluent/__init__.py
@@ -33,29 +33,29 @@
from cassandra.cluster import Session, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT
from cassandra.datastax.graph import GraphOptions, GraphProtocol
+ from cassandra.datastax.graph.query import _GraphSONContextRowFactory
from cassandra.datastax.graph.fluent.serializers import (
- GremlinGraphSONReader,
- deserializers,
- gremlin_deserializers
+ GremlinGraphSONReaderV2,
+ GremlinGraphSONReaderV3,
+ dse_graphson2_deserializers,
+ gremlin_graphson2_deserializers,
+ dse_graphson3_deserializers,
+ gremlin_graphson3_deserializers
)
from cassandra.datastax.graph.fluent.query import _DefaultTraversalBatch, _query_from_traversal
log = logging.getLogger(__name__)
- __all__ = ['BaseGraphRowFactory', 'dse_graphson_reader', 'graphson_reader', 'graph_traversal_row_factory',
+ __all__ = ['BaseGraphRowFactory', 'graph_traversal_row_factory',
'graph_traversal_dse_object_row_factory', 'DSESessionRemoteGraphConnection', 'DseGraph']
- # Create our custom GraphSONReader/Writer
- dse_graphson_reader = GremlinGraphSONReader(deserializer_map=deserializers)
- graphson_reader = GremlinGraphSONReader(deserializer_map=gremlin_deserializers)
-
# Traversal result keys
_bulk_key = 'bulk'
_result_key = 'result'
- class BaseGraphRowFactory(object):
+ class BaseGraphRowFactory(_GraphSONContextRowFactory):
"""
Base row factory for graph traversal. This class basically wraps a
graphson reader function to handle additional features of Gremlin/DSE
@@ -63,37 +63,51 @@ class BaseGraphRowFactory(object):
Currently supported:
- bulk results
+ """
- :param graphson_reader: The function used to read the graphson.
+ def __call__(self, column_names, rows):
+ for row in rows:
+ parsed_row = self.graphson_reader.readObject(row[0])
+ yield parsed_row[_result_key]
+ bulk = parsed_row.get(_bulk_key, 1)
+ for _ in range(bulk - 1):
+ yield copy.deepcopy(parsed_row[_result_key])
- Use example::
- my_custom_row_factory = BaseGraphRowFactory(custom_graphson_reader.readObject)
- """
+ class _GremlinGraphSON2RowFactory(BaseGraphRowFactory):
+ """Row Factory that returns the decoded graphson2."""
+ graphson_reader_class = GremlinGraphSONReaderV2
+ graphson_reader_kwargs = {'deserializer_map': gremlin_graphson2_deserializers}
- def __init__(self, graphson_reader):
- self._graphson_reader = graphson_reader
- def __call__(self, column_names, rows):
- results = []
+ class _DseGraphSON2RowFactory(BaseGraphRowFactory):
+ """Row Factory that returns the decoded graphson2 as DSE types."""
+ graphson_reader_class = GremlinGraphSONReaderV2
+ graphson_reader_kwargs = {'deserializer_map': dse_graphson2_deserializers}
- for row in rows:
- parsed_row = self._graphson_reader(row[0])
- bulk = parsed_row.get(_bulk_key, 1)
- if bulk > 1: # Avoid deepcopy call if bulk <= 1
- results.extend([copy.deepcopy(parsed_row[_result_key])
- for _ in range(bulk - 1)])
+ gremlin_graphson2_traversal_row_factory = _GremlinGraphSON2RowFactory
+ # TODO remove in next major
+ graph_traversal_row_factory = gremlin_graphson2_traversal_row_factory
- results.append(parsed_row[_result_key])
+ dse_graphson2_traversal_row_factory = _DseGraphSON2RowFactory
+ # TODO remove in next major
+ graph_traversal_dse_object_row_factory = dse_graphson2_traversal_row_factory
- return results
+ class _GremlinGraphSON3RowFactory(BaseGraphRowFactory):
+ """Row Factory that returns the decoded graphson2."""
+ graphson_reader_class = GremlinGraphSONReaderV3
+ graphson_reader_kwargs = {'deserializer_map': gremlin_graphson3_deserializers}
- graph_traversal_row_factory = BaseGraphRowFactory(graphson_reader.readObject)
- graph_traversal_row_factory.__doc__ = "Row Factory that returns the decoded graphson."
- graph_traversal_dse_object_row_factory = BaseGraphRowFactory(dse_graphson_reader.readObject)
- graph_traversal_dse_object_row_factory.__doc__ = "Row Factory that returns the decoded graphson as DSE types."
+ class _DseGraphSON3RowFactory(BaseGraphRowFactory):
+ """Row Factory that returns the decoded graphson3 as DSE types."""
+ graphson_reader_class = GremlinGraphSONReaderV3
+ graphson_reader_kwargs = {'deserializer_map': dse_graphson3_deserializers}
+
+
+ gremlin_graphson3_traversal_row_factory = _GremlinGraphSON3RowFactory
+ dse_graphson3_traversal_row_factory = _DseGraphSON3RowFactory
class DSESessionRemoteGraphConnection(RemoteConnection):
@@ -119,24 +133,41 @@ def __init__(self, session, graph_name=None, execution_profile=EXEC_PROFILE_GRAP
self.graph_name = graph_name
self.execution_profile = execution_profile
+ @staticmethod
+ def _traversers_generator(traversers):
+ for t in traversers:
+ yield Traverser(t)
+
def _prepare_query(self, bytecode):
- query = DseGraph.query_from_traversal(bytecode)
- ep = self.session.execution_profile_clone_update(self.execution_profile,
- row_factory=graph_traversal_row_factory)
- graph_options = ep.graph_options.copy()
+ ep = self.session.execution_profile_clone_update(self.execution_profile)
+ graph_options = ep.graph_options
+ graph_options.graph_name = self.graph_name or graph_options.graph_name
graph_options.graph_language = DseGraph.DSE_GRAPH_QUERY_LANGUAGE
- if self.graph_name:
- graph_options.graph_name = self.graph_name
- ep.graph_options = graph_options
+ # We resolve the execution profile options here , to know how what gremlin factory to set
+ self.session._resolve_execution_profile_options(ep)
+
+ context = None
+ if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0:
+ row_factory = gremlin_graphson2_traversal_row_factory
+ elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0:
+ row_factory = gremlin_graphson3_traversal_row_factory
+ context = {
+ 'cluster': self.session.cluster,
+ 'graph_name': graph_options.graph_name.decode('utf-8')
+ }
+ else:
+ raise ValueError('Unknown graph protocol: {}'.format(graph_options.graph_protocol))
+
+ ep.row_factory = row_factory
+ query = DseGraph.query_from_traversal(bytecode, graph_options.graph_protocol, context)
return query, ep
@staticmethod
def _handle_query_results(result_set, gremlin_future):
try:
- traversers = [Traverser(t) for t in result_set]
gremlin_future.set_result(
- RemoteTraversal(iter(traversers), TraversalSideEffects())
+ RemoteTraversal(DSESessionRemoteGraphConnection._traversers_generator(result_set), TraversalSideEffects())
)
except Exception as e:
gremlin_future.set_exception(e)
@@ -151,8 +182,7 @@ def submit(self, bytecode):
query, ep = self._prepare_query(bytecode)
traversers = self.session.execute_graph(query, execution_profile=ep)
- traversers = [Traverser(t) for t in traversers]
- return RemoteTraversal(iter(traversers), TraversalSideEffects())
+ return RemoteTraversal(self._traversers_generator(traversers), TraversalSideEffects())
def submitAsync(self, bytecode):
query, ep = self._prepare_query(bytecode)
@@ -181,12 +211,20 @@ class DseGraph(object):
Graph query language, Default is 'bytecode-json' (GraphSON).
"""
+ DSE_GRAPH_QUERY_PROTOCOL = GraphProtocol.GRAPHSON_2_0
+ """
+ Graph query language, Default is GraphProtocol.GRAPHSON_2_0.
+ """
+
@staticmethod
- def query_from_traversal(traversal):
+ def query_from_traversal(traversal, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, context=None):
"""
From a GraphTraversal, return a query string based on the language specified in `DseGraph.DSE_GRAPH_QUERY_LANGUAGE`.
:param traversal: The GraphTraversal object
+ :param graph_protocol: The graph protocol. Default is `DseGraph.DSE_GRAPH_QUERY_PROTOCOL`.
+ :param context: The dict of the serialization context, needed for GraphSON3 (tuple, udt).
+ e.g: {'cluster': cluster, 'graph_name': name}
"""
if isinstance(traversal, GraphTraversal):
@@ -197,7 +235,7 @@ def query_from_traversal(traversal):
log.warning("GraphTraversal session, graph_name and execution_profile are "
"only taken into account when executed with TinkerPop.")
- return _query_from_traversal(traversal)
+ return _query_from_traversal(traversal, graph_protocol, context)
@staticmethod
def traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT,
@@ -219,7 +257,7 @@ def traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFI
session = c.connect()
g = DseGraph.traversal_source(session, 'my_graph')
- print g.V().valueMap().toList()
+ print(g.V().valueMap().toList())
"""
@@ -233,18 +271,27 @@ def traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFI
return traversal_source
@staticmethod
- def create_execution_profile(graph_name):
+ def create_execution_profile(graph_name, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, **kwargs):
"""
Creates an ExecutionProfile for GraphTraversal execution. You need to register that execution profile to the
cluster by using `cluster.add_execution_profile`.
:param graph_name: The graph name
+ :param graph_protocol: (Optional) The graph protocol, default is `DSE_GRAPH_QUERY_PROTOCOL`.
"""
- ep = GraphExecutionProfile(row_factory=graph_traversal_dse_object_row_factory,
+ if graph_protocol == GraphProtocol.GRAPHSON_2_0:
+ row_factory = dse_graphson2_traversal_row_factory
+ elif graph_protocol == GraphProtocol.GRAPHSON_3_0:
+ row_factory = dse_graphson3_traversal_row_factory
+ else:
+ raise ValueError('Unknown graph protocol: {}'.format(graph_protocol))
+
+ ep = GraphExecutionProfile(row_factory=row_factory,
graph_options=GraphOptions(graph_name=graph_name,
graph_language=DseGraph.DSE_GRAPH_QUERY_LANGUAGE,
- graph_protocol=GraphProtocol.GRAPHSON_2_0))
+ graph_protocol=graph_protocol),
+ **kwargs)
return ep
@staticmethod
diff --git a/cassandra/datastax/graph/fluent/_predicates.py b/cassandra/datastax/graph/fluent/_predicates.py
index b63dd90043..95bd533d5e 100644
--- a/cassandra/datastax/graph/fluent/_predicates.py
+++ b/cassandra/datastax/graph/fluent/_predicates.py
@@ -18,7 +18,7 @@
from cassandra.util import Distance
-__all__ = ['GeoP', 'TextDistanceP', 'Search', 'GeoUnit', 'Geo']
+__all__ = ['GeoP', 'TextDistanceP', 'Search', 'GeoUnit', 'Geo', 'CqlCollection']
class GeoP(object):
@@ -138,6 +138,41 @@ def phrase(value, proximity):
return TextDistanceP.phrase(value, proximity)
+class CqlCollection(object):
+
+ @staticmethod
+ def contains(value):
+ """
+ Search for a value inside a cql list/set column.
+ :param value: the value to look for.
+ """
+ return P('contains', value)
+
+ @staticmethod
+ def contains_value(value):
+ """
+ Search for a map value.
+ :param value: the value to look for.
+ """
+ return P('containsValue', value)
+
+ @staticmethod
+ def contains_key(value):
+ """
+ Search for a map key.
+ :param value: the value to look for.
+ """
+ return P('containsKey', value)
+
+ @staticmethod
+ def entry_eq(value):
+ """
+ Search for a map entry.
+ :param value: the value to look for.
+ """
+ return P('entryEq', value)
+
+
class GeoUnit(object):
_EARTH_MEAN_RADIUS_KM = 6371.0087714
_DEGREES_TO_RADIANS = math.pi / 180
diff --git a/cassandra/datastax/graph/fluent/_query.py b/cassandra/datastax/graph/fluent/_query.py
index b5d24df05b..d5eb7f6373 100644
--- a/cassandra/datastax/graph/fluent/_query.py
+++ b/cassandra/datastax/graph/fluent/_query.py
@@ -14,27 +14,101 @@
import logging
-from cassandra.graph import SimpleGraphStatement
+from cassandra.graph import SimpleGraphStatement, GraphProtocol
from cassandra.cluster import EXEC_PROFILE_GRAPH_DEFAULT
from gremlin_python.process.graph_traversal import GraphTraversal
-from gremlin_python.structure.io.graphsonV2d0 import GraphSONWriter
+from gremlin_python.structure.io.graphsonV2d0 import GraphSONWriter as GraphSONWriterV2
+from gremlin_python.structure.io.graphsonV3d0 import GraphSONWriter as GraphSONWriterV3
-from cassandra.datastax.graph.fluent.serializers import serializers
+from cassandra.datastax.graph.fluent.serializers import GremlinUserTypeIO, \
+ dse_graphson2_serializers, dse_graphson3_serializers
log = logging.getLogger(__name__)
-graphson_writer = GraphSONWriter(serializer_map=serializers)
__all__ = ['TraversalBatch', '_query_from_traversal', '_DefaultTraversalBatch']
-def _query_from_traversal(traversal):
+class _GremlinGraphSONWriterAdapter(object):
+
+ def __init__(self, context, **kwargs):
+ super(_GremlinGraphSONWriterAdapter, self).__init__(**kwargs)
+ self.context = context
+ self.user_types = None
+
+ def serialize(self, value, _):
+ return self.toDict(value)
+
+ def get_serializer(self, value):
+ serializer = None
+ try:
+ serializer = self.serializers[type(value)]
+ except KeyError:
+ for key, ser in self.serializers.items():
+ if isinstance(value, key):
+ serializer = ser
+
+ if self.context:
+ # Check if UDT
+ if self.user_types is None:
+ try:
+ user_types = self.context['cluster']._user_types[self.context['graph_name']]
+ self.user_types = dict(map(reversed, user_types.items()))
+ except KeyError:
+ self.user_types = {}
+
+ # Custom detection to map a namedtuple to udt
+ if (tuple in self.serializers and serializer is self.serializers[tuple] and hasattr(value, '_fields') or
+ (not serializer and type(value) in self.user_types)):
+ serializer = GremlinUserTypeIO
+
+ if serializer:
+ try:
+ # A serializer can have specialized serializers (e.g for Int32 and Int64, so value dependant)
+ serializer = serializer.get_specialized_serializer(value)
+ except AttributeError:
+ pass
+
+ return serializer
+
+ def toDict(self, obj):
+ serializer = self.get_serializer(obj)
+ return serializer.dictify(obj, self) if serializer else obj
+
+ def definition(self, value):
+ serializer = self.get_serializer(value)
+ return serializer.definition(value, self)
+
+
+class GremlinGraphSON2Writer(_GremlinGraphSONWriterAdapter, GraphSONWriterV2):
+ pass
+
+
+class GremlinGraphSON3Writer(_GremlinGraphSONWriterAdapter, GraphSONWriterV3):
+ pass
+
+
+graphson2_writer = GremlinGraphSON2Writer
+graphson3_writer = GremlinGraphSON3Writer
+
+
+def _query_from_traversal(traversal, graph_protocol, context=None):
"""
From a GraphTraversal, return a query string.
:param traversal: The GraphTraversal object
+ :param graphson_protocol: The graph protocol to determine the output format.
"""
+ if graph_protocol == GraphProtocol.GRAPHSON_2_0:
+ graphson_writer = graphson2_writer(context, serializer_map=dse_graphson2_serializers)
+ elif graph_protocol == GraphProtocol.GRAPHSON_3_0:
+ if context is None:
+ raise ValueError('Missing context for GraphSON3 serialization requires.')
+ graphson_writer = graphson3_writer(context, serializer_map=dse_graphson3_serializers)
+ else:
+ raise ValueError('Unknown graph protocol: {}'.format(graph_protocol))
+
try:
query = graphson_writer.writeObject(traversal)
except Exception:
@@ -87,9 +161,11 @@ def execute(self):
"""
raise NotImplementedError()
- def as_graph_statement(self):
+ def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0):
"""
Return the traversal batch as GraphStatement.
+
+ :param graph_protocol: The graph protocol for the GraphSONWriter. Default is GraphProtocol.GRAPHSON_2_0.
"""
raise NotImplementedError()
@@ -115,32 +191,35 @@ def __init__(self, *args, **kwargs):
super(_DefaultTraversalBatch, self).__init__(*args, **kwargs)
self._traversals = []
- @property
- def _query(self):
- return u"[{0}]".format(','.join(self._traversals))
-
def add(self, traversal):
if not isinstance(traversal, GraphTraversal):
raise ValueError('traversal should be a gremlin GraphTraversal')
- query = _query_from_traversal(traversal)
- self._traversals.append(query)
-
+ self._traversals.append(traversal)
return self
def add_all(self, traversals):
for traversal in traversals:
self.add(traversal)
- def as_graph_statement(self):
- return SimpleGraphStatement(self._query)
+ def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0, context=None):
+ statements = [_query_from_traversal(t, graph_protocol, context) for t in self._traversals]
+ query = u"[{0}]".format(','.join(statements))
+ return SimpleGraphStatement(query)
def execute(self):
if self._session is None:
raise ValueError('A DSE Session must be provided to execute the traversal batch.')
execution_profile = self._execution_profile if self._execution_profile else EXEC_PROFILE_GRAPH_DEFAULT
- return self._session.execute_graph(self._query, execution_profile=execution_profile)
+ graph_options = self._session.get_execution_profile(execution_profile).graph_options
+ context = {
+ 'cluster': self._session.cluster,
+ 'graph_name': graph_options.graph_name
+ }
+ statement = self.as_graph_statement(graph_options.graph_protocol, context=context) \
+ if graph_options.graph_protocol else self.as_graph_statement(context=context)
+ return self._session.execute_graph(statement, execution_profile=execution_profile)
def clear(self):
del self._traversals[:]
diff --git a/cassandra/datastax/graph/fluent/_serializers.py b/cassandra/datastax/graph/fluent/_serializers.py
index 56591603af..83b3afb22d 100644
--- a/cassandra/datastax/graph/fluent/_serializers.py
+++ b/cassandra/datastax/graph/fluent/_serializers.py
@@ -14,36 +14,90 @@
from collections import OrderedDict
-import six
-
from gremlin_python.structure.io.graphsonV2d0 import (
- GraphSONReader,
- GraphSONUtil,
- VertexDeserializer,
- VertexPropertyDeserializer,
- PropertyDeserializer,
- EdgeDeserializer,
- PathDeserializer
+ GraphSONReader as GraphSONReaderV2,
+ GraphSONUtil as GraphSONUtil, # no difference between v2 and v3
+ VertexDeserializer as VertexDeserializerV2,
+ VertexPropertyDeserializer as VertexPropertyDeserializerV2,
+ PropertyDeserializer as PropertyDeserializerV2,
+ EdgeDeserializer as EdgeDeserializerV2,
+ PathDeserializer as PathDeserializerV2
)
-from cassandra.datastax.graph.graphson import (
- GraphSON2Serializer,
- GraphSON2Deserializer
+from gremlin_python.structure.io.graphsonV3d0 import (
+ GraphSONReader as GraphSONReaderV3,
+ VertexDeserializer as VertexDeserializerV3,
+ VertexPropertyDeserializer as VertexPropertyDeserializerV3,
+ PropertyDeserializer as PropertyDeserializerV3,
+ EdgeDeserializer as EdgeDeserializerV3,
+ PathDeserializer as PathDeserializerV3
)
+try:
+ from gremlin_python.structure.io.graphsonV2d0 import (
+ TraversalMetricsDeserializer as TraversalMetricsDeserializerV2,
+ MetricsDeserializer as MetricsDeserializerV2
+ )
+ from gremlin_python.structure.io.graphsonV3d0 import (
+ TraversalMetricsDeserializer as TraversalMetricsDeserializerV3,
+ MetricsDeserializer as MetricsDeserializerV3
+ )
+except ImportError:
+ TraversalMetricsDeserializerV2 = MetricsDeserializerV2 = None
+ TraversalMetricsDeserializerV3 = MetricsDeserializerV3 = None
+
+from cassandra.graph import (
+ GraphSON2Serializer,
+ GraphSON2Deserializer,
+ GraphSON3Serializer,
+ GraphSON3Deserializer
+)
+from cassandra.graph.graphson import UserTypeIO, TypeWrapperTypeIO
from cassandra.datastax.graph.fluent.predicates import GeoP, TextDistanceP
from cassandra.util import Distance
__all__ = ['GremlinGraphSONReader', 'GeoPSerializer', 'TextDistancePSerializer',
- 'DistanceIO', 'gremlin_deserializers', 'deserializers', 'serializers']
+ 'DistanceIO', 'gremlin_deserializers', 'deserializers', 'serializers',
+ 'GremlinGraphSONReaderV2', 'GremlinGraphSONReaderV3', 'dse_graphson2_serializers',
+ 'dse_graphson2_deserializers', 'dse_graphson3_serializers', 'dse_graphson3_deserializers',
+ 'gremlin_graphson2_deserializers', 'gremlin_graphson3_deserializers', 'GremlinUserTypeIO']
class _GremlinGraphSONTypeSerializer(object):
+ TYPE_KEY = "@type"
+ VALUE_KEY = "@value"
+ serializer = None
- @classmethod
- def dictify(cls, v, _):
- return GraphSON2Serializer.serialize(v)
+ def __init__(self, serializer):
+ self.serializer = serializer
+
+ def dictify(self, v, writer):
+ value = self.serializer.serialize(v, writer)
+ if self.serializer is TypeWrapperTypeIO:
+ graphson_base_type = v.type_io.graphson_base_type
+ graphson_type = v.type_io.graphson_type
+ else:
+ graphson_base_type = self.serializer.graphson_base_type
+ graphson_type = self.serializer.graphson_type
+
+ if graphson_base_type is None:
+ out = value
+ else:
+ out = {self.TYPE_KEY: graphson_type}
+ if value is not None:
+ out[self.VALUE_KEY] = value
+
+ return out
+
+ def definition(self, value, writer=None):
+ return self.serializer.definition(value, writer)
+
+ def get_specialized_serializer(self, value):
+ ser = self.serializer.get_specialized_serializer(value)
+ if ser is not self.serializer:
+ return _GremlinGraphSONTypeSerializer(ser)
+ return self
class _GremlinGraphSONTypeDeserializer(object):
@@ -54,22 +108,44 @@ def __init__(self, deserializer):
self.deserializer = deserializer
def objectify(self, v, reader):
- return self.deserializer.deserialize(v, reader=reader)
+ return self.deserializer.deserialize(v, reader)
-def _make_gremlin_deserializer(graphson_type):
+def _make_gremlin_graphson2_deserializer(graphson_type):
return _GremlinGraphSONTypeDeserializer(
GraphSON2Deserializer.get_deserializer(graphson_type.graphson_type)
)
-class GremlinGraphSONReader(GraphSONReader):
+def _make_gremlin_graphson3_deserializer(graphson_type):
+ return _GremlinGraphSONTypeDeserializer(
+ GraphSON3Deserializer.get_deserializer(graphson_type.graphson_type)
+ )
+
+
+class _GremlinGraphSONReader(object):
"""Gremlin GraphSONReader Adapter, required to use gremlin types"""
+ context = None
+
+ def __init__(self, context, deserializer_map=None):
+ self.context = context
+ super(_GremlinGraphSONReader, self).__init__(deserializer_map)
+
def deserialize(self, obj):
return self.toObject(obj)
+class GremlinGraphSONReaderV2(_GremlinGraphSONReader, GraphSONReaderV2):
+ pass
+
+# TODO remove next major
+GremlinGraphSONReader = GremlinGraphSONReaderV2
+
+class GremlinGraphSONReaderV3(_GremlinGraphSONReader, GraphSONReaderV3):
+ pass
+
+
class GeoPSerializer(object):
@classmethod
def dictify(cls, p, writer):
@@ -97,35 +173,88 @@ def dictify(cls, p, writer):
class DistanceIO(object):
@classmethod
def dictify(cls, v, _):
- return GraphSONUtil.typedValue('Distance', six.text_type(v), prefix='dse')
+ return GraphSONUtil.typedValue('Distance', str(v), prefix='dse')
+
+GremlinUserTypeIO = _GremlinGraphSONTypeSerializer(UserTypeIO)
-serializers = OrderedDict([
- (t, _GremlinGraphSONTypeSerializer)
- for t in six.iterkeys(GraphSON2Serializer.get_type_definitions())
+# GraphSON2
+dse_graphson2_serializers = OrderedDict([
+ (t, _GremlinGraphSONTypeSerializer(s))
+ for t, s in GraphSON2Serializer.get_type_definitions().items()
])
-# Predicates
-serializers.update(OrderedDict([
+dse_graphson2_serializers.update(OrderedDict([
(Distance, DistanceIO),
(GeoP, GeoPSerializer),
(TextDistanceP, TextDistancePSerializer)
]))
-deserializers = {
- k: _make_gremlin_deserializer(v)
- for k, v in six.iteritems(GraphSON2Deserializer.get_type_definitions())
+# TODO remove next major, this is just in case someone was using it
+serializers = dse_graphson2_serializers
+
+dse_graphson2_deserializers = {
+ k: _make_gremlin_graphson2_deserializer(v)
+ for k, v in GraphSON2Deserializer.get_type_definitions().items()
}
-deserializers.update({
+dse_graphson2_deserializers.update({
"dse:Distance": DistanceIO,
})
-gremlin_deserializers = deserializers.copy()
-gremlin_deserializers.update({
- 'g:Vertex': VertexDeserializer,
- 'g:VertexProperty': VertexPropertyDeserializer,
- 'g:Edge': EdgeDeserializer,
- 'g:Property': PropertyDeserializer,
- 'g:Path': PathDeserializer
+# TODO remove next major, this is just in case someone was using it
+deserializers = dse_graphson2_deserializers
+
+gremlin_graphson2_deserializers = dse_graphson2_deserializers.copy()
+gremlin_graphson2_deserializers.update({
+ 'g:Vertex': VertexDeserializerV2,
+ 'g:VertexProperty': VertexPropertyDeserializerV2,
+ 'g:Edge': EdgeDeserializerV2,
+ 'g:Property': PropertyDeserializerV2,
+ 'g:Path': PathDeserializerV2
})
+
+if TraversalMetricsDeserializerV2:
+ gremlin_graphson2_deserializers.update({
+ 'g:TraversalMetrics': TraversalMetricsDeserializerV2,
+ 'g:lMetrics': MetricsDeserializerV2
+ })
+
+# TODO remove next major, this is just in case someone was using it
+gremlin_deserializers = gremlin_graphson2_deserializers
+
+# GraphSON3
+dse_graphson3_serializers = OrderedDict([
+ (t, _GremlinGraphSONTypeSerializer(s))
+ for t, s in GraphSON3Serializer.get_type_definitions().items()
+])
+
+dse_graphson3_serializers.update(OrderedDict([
+ (Distance, DistanceIO),
+ (GeoP, GeoPSerializer),
+ (TextDistanceP, TextDistancePSerializer)
+]))
+
+dse_graphson3_deserializers = {
+ k: _make_gremlin_graphson3_deserializer(v)
+ for k, v in GraphSON3Deserializer.get_type_definitions().items()
+}
+
+dse_graphson3_deserializers.update({
+ "dse:Distance": DistanceIO
+})
+
+gremlin_graphson3_deserializers = dse_graphson3_deserializers.copy()
+gremlin_graphson3_deserializers.update({
+ 'g:Vertex': VertexDeserializerV3,
+ 'g:VertexProperty': VertexPropertyDeserializerV3,
+ 'g:Edge': EdgeDeserializerV3,
+ 'g:Property': PropertyDeserializerV3,
+ 'g:Path': PathDeserializerV3
+})
+
+if TraversalMetricsDeserializerV3:
+ gremlin_graphson3_deserializers.update({
+ 'g:TraversalMetrics': TraversalMetricsDeserializerV3,
+ 'g:Metrics': MetricsDeserializerV3
+ })
diff --git a/cassandra/datastax/graph/graphson.py b/cassandra/datastax/graph/graphson.py
index 620adf045e..335c7f7825 100644
--- a/cassandra/datastax/graph/graphson.py
+++ b/cassandra/datastax/graph/graphson.py
@@ -19,52 +19,68 @@
import json
from decimal import Decimal
from collections import OrderedDict
+import logging
+import itertools
+from functools import partial
-import six
+import ipaddress
-if six.PY3:
- import ipaddress
-from cassandra.util import Polygon, Point, LineString
-from cassandra.datastax.graph.types import Vertex, VertexProperty, Edge, Path
+from cassandra.cqltypes import cql_types_from_string
+from cassandra.metadata import UserType
+from cassandra.util import Polygon, Point, LineString, Duration
+from cassandra.datastax.graph.types import Vertex, VertexProperty, Edge, Path, T
__all__ = ['GraphSON1Serializer', 'GraphSON1Deserializer', 'GraphSON1TypeDeserializer',
- 'GraphSON2Serializer', 'GraphSON2Deserializer',
- 'GraphSON2Reader', 'BooleanTypeIO', 'Int16TypeIO', 'Int32TypeIO', 'DoubleTypeIO',
+ 'GraphSON2Serializer', 'GraphSON2Deserializer', 'GraphSON2Reader',
+ 'GraphSON3Serializer', 'GraphSON3Deserializer', 'GraphSON3Reader',
+ 'to_bigint', 'to_int', 'to_double', 'to_float', 'to_smallint',
+ 'BooleanTypeIO', 'Int16TypeIO', 'Int32TypeIO', 'DoubleTypeIO',
'FloatTypeIO', 'UUIDTypeIO', 'BigDecimalTypeIO', 'DurationTypeIO', 'InetTypeIO',
'InstantTypeIO', 'LocalDateTypeIO', 'LocalTimeTypeIO', 'Int64TypeIO', 'BigIntegerTypeIO',
- 'LocalDateTypeIO', 'PolygonTypeIO', 'PointTypeIO', 'LineStringTypeIO', 'BlobTypeIO']
+ 'LocalDateTypeIO', 'PolygonTypeIO', 'PointTypeIO', 'LineStringTypeIO', 'BlobTypeIO',
+ 'GraphSON3Serializer', 'GraphSON3Deserializer', 'UserTypeIO', 'TypeWrapperTypeIO']
"""
Supported types:
-DSE Graph GraphSON 2.0 Python Driver
------------- | -------------- | ------------
-text | ------ | str
-boolean | g:Boolean | bool
-bigint | g:Int64 | long
-int | g:Int32 | int
-double | g:Double | float
-float | g:Float | float
-uuid | g:UUID | UUID
-bigdecimal | gx:BigDecimal | Decimal
-duration | gx:Duration | timedelta
-inet | gx:InetAddress | str (unicode), IPV4Address/IPV6Address (PY3)
-timestamp | gx:Instant | datetime.datetime
-date | gx:LocalDate | datetime.date
-time | gx:LocalTime | datetime.time
-smallint | gx:Int16 | int
-varint | gx:BigInteger | long
-date | gx:LocalDate | Date
-polygon | dse:Polygon | Polygon
-point | dse:Point | Point
-linestring | dse:LineString | LineString
-blob | dse:Blob | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3)
+DSE Graph GraphSON 2.0 GraphSON 3.0 | Python Driver
+------------ | -------------- | -------------- | ------------
+text | string | string | str
+boolean | | | bool
+bigint | g:Int64 | g:Int64 | long
+int | g:Int32 | g:Int32 | int
+double | g:Double | g:Double | float
+float | g:Float | g:Float | float
+uuid | g:UUID | g:UUID | UUID
+bigdecimal | gx:BigDecimal | gx:BigDecimal | Decimal
+duration | gx:Duration | N/A | timedelta (Classic graph only)
+DSE Duration | N/A | dse:Duration | Duration (Core graph only)
+inet | gx:InetAddress | gx:InetAddress | str (unicode), IPV4Address/IPV6Address (PY3)
+timestamp | gx:Instant | gx:Instant | datetime.datetime
+date | gx:LocalDate | gx:LocalDate | datetime.date
+time | gx:LocalTime | gx:LocalTime | datetime.time
+smallint | gx:Int16 | gx:Int16 | int
+varint | gx:BigInteger | gx:BigInteger | long
+date | gx:LocalDate | gx:LocalDate | Date
+polygon | dse:Polygon | dse:Polygon | Polygon
+point | dse:Point | dse:Point | Point
+linestring | dse:Linestring | dse:LineString | LineString
+blob | dse:Blob | dse:Blob | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3)
+blob | gx:ByteBuffer | gx:ByteBuffer | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3)
+list | N/A | g:List | list (Core graph only)
+map | N/A | g:Map | dict (Core graph only)
+set | N/A | g:Set | set or list (Core graph only)
+ Can return a list due to numerical values returned by Java
+tuple | N/A | dse:Tuple | tuple (Core graph only)
+udt | N/A | dse:UDT | class or namedtuple (Core graph only)
"""
MAX_INT32 = 2 ** 32 - 1
MIN_INT32 = -2 ** 31
+log = logging.getLogger(__name__)
+
class _GraphSONTypeType(type):
"""GraphSONType metaclass, required to create a class property."""
@@ -74,16 +90,20 @@ def graphson_type(cls):
return "{0}:{1}".format(cls.prefix, cls.graphson_base_type)
-@six.add_metaclass(_GraphSONTypeType)
-class GraphSONTypeIO(object):
+class GraphSONTypeIO(object, metaclass=_GraphSONTypeType):
"""Represent a serializable GraphSON type"""
prefix = 'g'
graphson_base_type = None
+ cql_type = None
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ return {'cqlType': cls.cql_type}
@classmethod
- def serialize(cls, value):
- return six.text_type(value)
+ def serialize(cls, value, writer=None):
+ return str(value)
@classmethod
def deserialize(cls, value, reader=None):
@@ -94,23 +114,28 @@ def get_specialized_serializer(cls, value):
return cls
+class TextTypeIO(GraphSONTypeIO):
+ cql_type = 'text'
+
+
class BooleanTypeIO(GraphSONTypeIO):
- graphson_base_type = 'Boolean'
+ graphson_base_type = None
+ cql_type = 'boolean'
@classmethod
- def serialize(cls, value):
+ def serialize(cls, value, writer=None):
return bool(value)
class IntegerTypeIO(GraphSONTypeIO):
@classmethod
- def serialize(cls, value):
+ def serialize(cls, value, writer=None):
return value
@classmethod
def get_specialized_serializer(cls, value):
- if type(value) in six.integer_types and (value > MAX_INT32 or value < MIN_INT32):
+ if type(value) is int and (value > MAX_INT32 or value < MIN_INT32):
return Int64TypeIO
return Int32TypeIO
@@ -119,24 +144,30 @@ def get_specialized_serializer(cls, value):
class Int16TypeIO(IntegerTypeIO):
prefix = 'gx'
graphson_base_type = 'Int16'
+ cql_type = 'smallint'
class Int32TypeIO(IntegerTypeIO):
graphson_base_type = 'Int32'
+ cql_type = 'int'
class Int64TypeIO(IntegerTypeIO):
graphson_base_type = 'Int64'
+ cql_type = 'bigint'
@classmethod
def deserialize(cls, value, reader=None):
- if six.PY3:
- return value
- return long(value)
+ return value
class FloatTypeIO(GraphSONTypeIO):
graphson_base_type = 'Float'
+ cql_type = 'float'
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return value
@classmethod
def deserialize(cls, value, reader=None):
@@ -145,6 +176,7 @@ def deserialize(cls, value, reader=None):
class DoubleTypeIO(FloatTypeIO):
graphson_base_type = 'Double'
+ cql_type = 'double'
class BigIntegerTypeIO(IntegerTypeIO):
@@ -157,9 +189,10 @@ class LocalDateTypeIO(GraphSONTypeIO):
prefix = 'gx'
graphson_base_type = 'LocalDate'
+ cql_type = 'date'
@classmethod
- def serialize(cls, value):
+ def serialize(cls, value, writer=None):
return value.isoformat()
@classmethod
@@ -170,20 +203,14 @@ def deserialize(cls, value, reader=None):
# negative date
return value
- @classmethod
- def get_specialized_serializer(cls, value):
- if isinstance(value, datetime.datetime):
- return InstantTypeIO
-
- return cls
-
class InstantTypeIO(GraphSONTypeIO):
prefix = 'gx'
graphson_base_type = 'Instant'
+ cql_type = 'timestamp'
@classmethod
- def serialize(cls, value):
+ def serialize(cls, value, writer=None):
if isinstance(value, datetime.datetime):
value = datetime.datetime(*value.utctimetuple()[:6]).replace(microsecond=value.microsecond)
else:
@@ -209,9 +236,10 @@ class LocalTimeTypeIO(GraphSONTypeIO):
prefix = 'gx'
graphson_base_type = 'LocalTime'
+ cql_type = 'time'
@classmethod
- def serialize(cls, value):
+ def serialize(cls, value, writer=None):
return value.strftime(cls.FORMATS[2])
@classmethod
@@ -233,12 +261,12 @@ def deserialize(cls, value, reader=None):
class BlobTypeIO(GraphSONTypeIO):
prefix = 'dse'
graphson_base_type = 'Blob'
+ cql_type = 'blob'
@classmethod
- def serialize(cls, value):
+ def serialize(cls, value, writer=None):
value = base64.b64encode(value)
- if six.PY3:
- value = value.decode('utf-8')
+ value = value.decode('utf-8')
return value
@classmethod
@@ -246,8 +274,14 @@ def deserialize(cls, value, reader=None):
return bytearray(base64.b64decode(value))
+class ByteBufferTypeIO(BlobTypeIO):
+ prefix = 'gx'
+ graphson_base_type = 'ByteBuffer'
+
+
class UUIDTypeIO(GraphSONTypeIO):
graphson_base_type = 'UUID'
+ cql_type = 'uuid'
@classmethod
def deserialize(cls, value, reader=None):
@@ -257,6 +291,7 @@ def deserialize(cls, value, reader=None):
class BigDecimalTypeIO(GraphSONTypeIO):
prefix = 'gx'
graphson_base_type = 'BigDecimal'
+ cql_type = 'bigdecimal'
@classmethod
def deserialize(cls, value, reader=None):
@@ -266,6 +301,7 @@ def deserialize(cls, value, reader=None):
class DurationTypeIO(GraphSONTypeIO):
prefix = 'gx'
graphson_base_type = 'Duration'
+ cql_type = 'duration'
_duration_regex = re.compile(r"""
^P((?P\d+)D)?
@@ -280,7 +316,7 @@ class DurationTypeIO(GraphSONTypeIO):
_seconds_in_day = 24 * _seconds_in_hour
@classmethod
- def serialize(cls, value):
+ def serialize(cls, value, writer=None):
total_seconds = int(value.total_seconds())
days, total_seconds = divmod(total_seconds, cls._seconds_in_day)
hours, total_seconds = divmod(total_seconds, cls._seconds_in_hour)
@@ -298,14 +334,52 @@ def deserialize(cls, value, reader=None):
raise ValueError('Invalid duration: {0}'.format(value))
duration = {k: float(v) if v is not None else 0
- for k, v in six.iteritems(duration.groupdict())}
+ for k, v in duration.groupdict().items()}
return datetime.timedelta(days=duration['days'], hours=duration['hours'],
minutes=duration['minutes'], seconds=duration['seconds'])
+class DseDurationTypeIO(GraphSONTypeIO):
+ prefix = 'dse'
+ graphson_base_type = 'Duration'
+ cql_type = 'duration'
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return {
+ 'months': value.months,
+ 'days': value.days,
+ 'nanos': value.nanoseconds
+ }
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return Duration(
+ reader.deserialize(value['months']),
+ reader.deserialize(value['days']),
+ reader.deserialize(value['nanos'])
+ )
+
+
+class TypeWrapperTypeIO(GraphSONTypeIO):
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ return {'cqlType': value.type_io.cql_type}
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return value.type_io.serialize(value.value)
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return value.type_io.deserialize(value.value)
+
+
class PointTypeIO(GraphSONTypeIO):
prefix = 'dse'
graphson_base_type = 'Point'
+ cql_type = "org.apache.cassandra.db.marshal.PointType"
@classmethod
def deserialize(cls, value, reader=None):
@@ -315,6 +389,7 @@ def deserialize(cls, value, reader=None):
class LineStringTypeIO(GraphSONTypeIO):
prefix = 'dse'
graphson_base_type = 'LineString'
+ cql_type = "org.apache.cassandra.db.marshal.LineStringType"
@classmethod
def deserialize(cls, value, reader=None):
@@ -324,6 +399,7 @@ def deserialize(cls, value, reader=None):
class PolygonTypeIO(GraphSONTypeIO):
prefix = 'dse'
graphson_base_type = 'Polygon'
+ cql_type = "org.apache.cassandra.db.marshal.PolygonType"
@classmethod
def deserialize(cls, value, reader=None):
@@ -333,6 +409,7 @@ def deserialize(cls, value, reader=None):
class InetTypeIO(GraphSONTypeIO):
prefix = 'gx'
graphson_base_type = 'InetAddress'
+ cql_type = 'inet'
class VertexTypeIO(GraphSONTypeIO):
@@ -397,13 +474,277 @@ class PathTypeIO(GraphSONTypeIO):
@classmethod
def deserialize(cls, value, reader=None):
- labels = [set(label) for label in value['labels']]
- objects = [reader.deserialize(obj) for obj in value['objects']]
+ labels = [set(label) for label in reader.deserialize(value['labels'])]
+ objects = [obj for obj in reader.deserialize(value['objects'])]
p = Path(labels, [])
p.objects = objects # avoid the object processing in Path.__init__
return p
+class TraversalMetricsTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'TraversalMetrics'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return reader.deserialize(value)
+
+
+class MetricsTypeIO(GraphSONTypeIO):
+ graphson_base_type = 'Metrics'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return reader.deserialize(value)
+
+
+class JsonMapTypeIO(GraphSONTypeIO):
+ """In GraphSON2, dict are simply serialized as json map"""
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ out = {}
+ for k, v in value.items():
+ out[k] = writer.serialize(v, writer)
+
+ return out
+
+
+class MapTypeIO(GraphSONTypeIO):
+ """In GraphSON3, dict has its own type"""
+
+ graphson_base_type = 'Map'
+ cql_type = 'map'
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ out = OrderedDict([('cqlType', cls.cql_type)])
+ out['definition'] = []
+ for k, v in value.items():
+ # we just need the first pair to write the def
+ out['definition'].append(writer.definition(k))
+ out['definition'].append(writer.definition(v))
+ break
+ return out
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ out = []
+ for k, v in value.items():
+ out.append(writer.serialize(k, writer))
+ out.append(writer.serialize(v, writer))
+
+ return out
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ out = {}
+ a, b = itertools.tee(value)
+ for key, val in zip(
+ itertools.islice(a, 0, None, 2),
+ itertools.islice(b, 1, None, 2)
+ ):
+ out[reader.deserialize(key)] = reader.deserialize(val)
+ return out
+
+
+class ListTypeIO(GraphSONTypeIO):
+ """In GraphSON3, list has its own type"""
+
+ graphson_base_type = 'List'
+ cql_type = 'list'
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ out = OrderedDict([('cqlType', cls.cql_type)])
+ out['definition'] = []
+ if value:
+ out['definition'].append(writer.definition(value[0]))
+ return out
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return [writer.serialize(v, writer) for v in value]
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return [reader.deserialize(obj) for obj in value]
+
+
+class SetTypeIO(GraphSONTypeIO):
+ """In GraphSON3, set has its own type"""
+
+ graphson_base_type = 'Set'
+ cql_type = 'set'
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ out = OrderedDict([('cqlType', cls.cql_type)])
+ out['definition'] = []
+ for v in value:
+ # we only take into account the first value for the definition
+ out['definition'].append(writer.definition(v))
+ break
+ return out
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ return [writer.serialize(v, writer) for v in value]
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ lst = [reader.deserialize(obj) for obj in value]
+
+ s = set(lst)
+ if len(s) != len(lst):
+ log.warning("Coercing g:Set to list due to numerical values returned by Java. "
+ "See TINKERPOP-1844 for details.")
+ return lst
+
+ return s
+
+
+class BulkSetTypeIO(GraphSONTypeIO):
+ graphson_base_type = "BulkSet"
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ out = []
+
+ a, b = itertools.tee(value)
+ for val, bulk in zip(
+ itertools.islice(a, 0, None, 2),
+ itertools.islice(b, 1, None, 2)
+ ):
+ val = reader.deserialize(val)
+ bulk = reader.deserialize(bulk)
+ for n in range(bulk):
+ out.append(val)
+
+ return out
+
+
+class TupleTypeIO(GraphSONTypeIO):
+ prefix = 'dse'
+ graphson_base_type = 'Tuple'
+ cql_type = 'tuple'
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ out = OrderedDict()
+ out['cqlType'] = cls.cql_type
+ serializers = [writer.get_serializer(s) for s in value]
+ out['definition'] = [s.definition(v, writer) for v, s in zip(value, serializers)]
+ return out
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ out = cls.definition(value, writer)
+ out['value'] = [writer.serialize(v, writer) for v in value]
+ return out
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return tuple(reader.deserialize(obj) for obj in value['value'])
+
+
+class UserTypeIO(GraphSONTypeIO):
+ prefix = 'dse'
+ graphson_base_type = 'UDT'
+ cql_type = 'udt'
+
+ FROZEN_REMOVAL_REGEX = re.compile(r'frozen<"*([^"]+)"*>')
+
+ @classmethod
+ def cql_types_from_string(cls, typ):
+ # sanitizing: remove frozen references and double quotes...
+ return cql_types_from_string(
+ re.sub(cls.FROZEN_REMOVAL_REGEX, r'\1', typ)
+ )
+
+ @classmethod
+ def get_udt_definition(cls, value, writer):
+ user_type_name = writer.user_types[type(value)]
+ keyspace = writer.context['graph_name']
+ return writer.context['cluster'].metadata.keyspaces[keyspace].user_types[user_type_name]
+
+ @classmethod
+ def is_collection(cls, typ):
+ return typ in ['list', 'tuple', 'map', 'set']
+
+ @classmethod
+ def is_udt(cls, typ, writer):
+ keyspace = writer.context['graph_name']
+ if keyspace in writer.context['cluster'].metadata.keyspaces:
+ return typ in writer.context['cluster'].metadata.keyspaces[keyspace].user_types
+ return False
+
+ @classmethod
+ def field_definition(cls, types, writer, name=None):
+ """
+ Build the udt field definition. This is required when we have a complex udt type.
+ """
+ index = -1
+ out = [OrderedDict() if name is None else OrderedDict([('fieldName', name)])]
+
+ while types:
+ index += 1
+ typ = types.pop(0)
+ if index > 0:
+ out.append(OrderedDict())
+
+ if cls.is_udt(typ, writer):
+ keyspace = writer.context['graph_name']
+ udt = writer.context['cluster'].metadata.keyspaces[keyspace].user_types[typ]
+ out[index].update(cls.definition(udt, writer))
+ elif cls.is_collection(typ):
+ out[index]['cqlType'] = typ
+ definition = cls.field_definition(types, writer)
+ out[index]['definition'] = definition if isinstance(definition, list) else [definition]
+ else:
+ out[index]['cqlType'] = typ
+
+ return out if len(out) > 1 else out[0]
+
+ @classmethod
+ def definition(cls, value, writer=None):
+ udt = value if isinstance(value, UserType) else cls.get_udt_definition(value, writer)
+ return OrderedDict([
+ ('cqlType', cls.cql_type),
+ ('keyspace', udt.keyspace),
+ ('name', udt.name),
+ ('definition', [
+ cls.field_definition(cls.cql_types_from_string(typ), writer, name=name)
+ for name, typ in zip(udt.field_names, udt.field_types)])
+ ])
+
+ @classmethod
+ def serialize(cls, value, writer=None):
+ udt = cls.get_udt_definition(value, writer)
+ out = cls.definition(value, writer)
+ out['value'] = []
+ for name, typ in zip(udt.field_names, udt.field_types):
+ out['value'].append(writer.serialize(getattr(value, name), writer))
+ return out
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ udt_class = reader.context['cluster']._user_types[value['keyspace']][value['name']]
+ kwargs = zip(
+ list(map(lambda v: v['fieldName'], value['definition'])),
+ [reader.deserialize(v) for v in value['value']]
+ )
+ return udt_class(**dict(kwargs))
+
+
+class TTypeIO(GraphSONTypeIO):
+ prefix = 'g'
+ graphson_base_type = 'T'
+
+ @classmethod
+ def deserialize(cls, value, reader=None):
+ return T.name_to_value[value]
+
+
class _BaseGraphSONSerializer(object):
_serializers = OrderedDict()
@@ -448,15 +789,19 @@ def get_serializer(cls, value):
return serializer
@classmethod
- def serialize(cls, value):
+ def serialize(cls, value, writer=None):
"""
- Serialize a python object to graphson.
+ Serialize a python object to GraphSON.
+
+ e.g 'P42DT10H5M37S'
+ e.g. {'key': value}
:param value: The python object to serialize.
+ :param writer: A graphson serializer for recursive types (Optional)
"""
serializer = cls.get_serializer(value)
if serializer:
- return serializer.serialize(value)
+ return serializer.serialize(value, writer or cls)
return value
@@ -470,26 +815,27 @@ class GraphSON1Serializer(_BaseGraphSONSerializer):
# We want that iteration order to be consistent, so we use an OrderedDict,
# not a dict.
_serializers = OrderedDict([
+ (str, TextTypeIO),
(bool, BooleanTypeIO),
- (bytearray, BlobTypeIO),
+ (bytearray, ByteBufferTypeIO),
(Decimal, BigDecimalTypeIO),
(datetime.date, LocalDateTypeIO),
(datetime.time, LocalTimeTypeIO),
(datetime.timedelta, DurationTypeIO),
+ (datetime.datetime, InstantTypeIO),
(uuid.UUID, UUIDTypeIO),
(Polygon, PolygonTypeIO),
(Point, PointTypeIO),
- (LineString, LineStringTypeIO)
+ (LineString, LineStringTypeIO),
+ (dict, JsonMapTypeIO),
+ (float, FloatTypeIO)
])
-if six.PY2:
- GraphSON1Serializer.register(buffer, BlobTypeIO)
-else:
- GraphSON1Serializer.register(memoryview, BlobTypeIO)
- GraphSON1Serializer.register(bytes, BlobTypeIO)
- GraphSON1Serializer.register(ipaddress.IPv4Address, InetTypeIO)
- GraphSON1Serializer.register(ipaddress.IPv6Address, InetTypeIO)
+GraphSON1Serializer.register(ipaddress.IPv4Address, InetTypeIO)
+GraphSON1Serializer.register(ipaddress.IPv6Address, InetTypeIO)
+GraphSON1Serializer.register(memoryview, ByteBufferTypeIO)
+GraphSON1Serializer.register(bytes, ByteBufferTypeIO)
class _BaseGraphSONDeserializer(object):
@@ -526,7 +872,7 @@ class GraphSON1Deserializer(_BaseGraphSONDeserializer):
"""
Deserialize graphson1 types to python objects.
"""
- _TYPES = [UUIDTypeIO, BigDecimalTypeIO, InstantTypeIO, BlobTypeIO,
+ _TYPES = [UUIDTypeIO, BigDecimalTypeIO, InstantTypeIO, BlobTypeIO, ByteBufferTypeIO,
PointTypeIO, LineStringTypeIO, PolygonTypeIO, LocalDateTypeIO,
LocalTimeTypeIO, DurationTypeIO, InetTypeIO]
@@ -561,9 +907,7 @@ def deserialize_int(cls, value):
@classmethod
def deserialize_bigint(cls, value):
- if six.PY3:
- return cls.deserialize_int(value)
- return long(value)
+ return cls.deserialize_int(value)
@classmethod
def deserialize_double(cls, value):
@@ -581,7 +925,7 @@ def deserialize_decimal(cls, value):
@classmethod
def deserialize_blob(cls, value):
- return cls._deserializers[BlobTypeIO.graphson_type].deserialize(value)
+ return cls._deserializers[ByteBufferTypeIO.graphson_type].deserialize(value)
@classmethod
def deserialize_point(cls, value):
@@ -604,7 +948,7 @@ def deserialize_boolean(cls, value):
return value
-# Remove in the next major
+# TODO Remove in the next major
GraphSON1TypeDeserializer = GraphSON1Deserializer
GraphSON1TypeSerializer = GraphSON1Serializer
@@ -615,8 +959,7 @@ class GraphSON2Serializer(_BaseGraphSONSerializer):
_serializers = GraphSON1Serializer.get_type_definitions()
- @classmethod
- def serialize(cls, value):
+ def serialize(self, value, writer=None):
"""
Serialize a type to GraphSON2.
@@ -624,22 +967,29 @@ def serialize(cls, value):
:param value: The python object to serialize.
"""
- serializer = cls.get_serializer(value)
+ serializer = self.get_serializer(value)
if not serializer:
- # if no serializer found, we can't type it. `value` will be jsonized as string.
- return value
+ raise ValueError("Unable to find a serializer for value of type: ".format(type(value)))
- value = serializer.serialize(value)
- out = {cls.TYPE_KEY: serializer.graphson_type}
- if value is not None:
- out[cls.VALUE_KEY] = value
+ val = serializer.serialize(value, writer or self)
+ if serializer is TypeWrapperTypeIO:
+ graphson_base_type = value.type_io.graphson_base_type
+ graphson_type = value.type_io.graphson_type
+ else:
+ graphson_base_type = serializer.graphson_base_type
+ graphson_type = serializer.graphson_type
+
+ if graphson_base_type is None:
+ out = val
+ else:
+ out = {self.TYPE_KEY: graphson_type}
+ if val is not None:
+ out[self.VALUE_KEY] = val
return out
GraphSON2Serializer.register(int, IntegerTypeIO)
-if six.PY2:
- GraphSON2Serializer.register(long, IntegerTypeIO)
class GraphSON2Deserializer(_BaseGraphSONDeserializer):
@@ -647,7 +997,7 @@ class GraphSON2Deserializer(_BaseGraphSONDeserializer):
_TYPES = GraphSON1Deserializer._TYPES + [
Int16TypeIO, Int32TypeIO, Int64TypeIO, DoubleTypeIO, FloatTypeIO,
BigIntegerTypeIO, VertexTypeIO, VertexPropertyTypeIO, EdgeTypeIO,
- PathTypeIO, PropertyTypeIO]
+ PathTypeIO, PropertyTypeIO, TraversalMetricsTypeIO, MetricsTypeIO]
_deserializers = {
t.graphson_type: t
@@ -660,10 +1010,11 @@ class GraphSON2Reader(object):
GraphSON2 Reader that parse json and deserialize to python objects.
"""
- def __init__(self, extra_deserializer_map=None):
+ def __init__(self, context, extra_deserializer_map=None):
"""
:param extra_deserializer_map: map from GraphSON type tag to deserializer instance implementing `deserialize`
"""
+ self.context = context
self.deserializers = GraphSON2Deserializer.get_type_definitions()
if extra_deserializer_map:
self.deserializers.update(extra_deserializer_map)
@@ -685,8 +1036,97 @@ def deserialize(self, obj):
except KeyError:
pass
# list and map are treated as normal json objs (could be isolated deserializers)
- return {self.deserialize(k): self.deserialize(v) for k, v in six.iteritems(obj)}
+ return {self.deserialize(k): self.deserialize(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self.deserialize(o) for o in obj]
else:
return obj
+
+
+class TypeIOWrapper(object):
+ """Used to force a graphson type during serialization"""
+
+ type_io = None
+ value = None
+
+ def __init__(self, type_io, value):
+ self.type_io = type_io
+ self.value = value
+
+
+def _wrap_value(type_io, value):
+ return TypeIOWrapper(type_io, value)
+
+
+to_bigint = partial(_wrap_value, Int64TypeIO)
+to_int = partial(_wrap_value, Int32TypeIO)
+to_smallint = partial(_wrap_value, Int16TypeIO)
+to_double = partial(_wrap_value, DoubleTypeIO)
+to_float = partial(_wrap_value, FloatTypeIO)
+
+
+class GraphSON3Serializer(GraphSON2Serializer):
+
+ _serializers = GraphSON2Serializer.get_type_definitions()
+
+ context = None
+ """A dict of the serialization context"""
+
+ def __init__(self, context):
+ self.context = context
+ self.user_types = None
+
+ def definition(self, value):
+ serializer = self.get_serializer(value)
+ return serializer.definition(value, self)
+
+ def get_serializer(self, value):
+ """Custom get_serializer to support UDT/Tuple"""
+
+ serializer = super(GraphSON3Serializer, self).get_serializer(value)
+ is_namedtuple_udt = serializer is TupleTypeIO and hasattr(value, '_fields')
+ if not serializer or is_namedtuple_udt:
+ # Check if UDT
+ if self.user_types is None:
+ try:
+ user_types = self.context['cluster']._user_types[self.context['graph_name']]
+ self.user_types = dict(map(reversed, user_types.items()))
+ except KeyError:
+ self.user_types = {}
+
+ serializer = UserTypeIO if (is_namedtuple_udt or (type(value) in self.user_types)) else serializer
+
+ return serializer
+
+
+GraphSON3Serializer.register(dict, MapTypeIO)
+GraphSON3Serializer.register(list, ListTypeIO)
+GraphSON3Serializer.register(set, SetTypeIO)
+GraphSON3Serializer.register(tuple, TupleTypeIO)
+GraphSON3Serializer.register(Duration, DseDurationTypeIO)
+GraphSON3Serializer.register(TypeIOWrapper, TypeWrapperTypeIO)
+
+
+class GraphSON3Deserializer(GraphSON2Deserializer):
+ _TYPES = GraphSON2Deserializer._TYPES + [MapTypeIO, ListTypeIO,
+ SetTypeIO, TupleTypeIO,
+ UserTypeIO, DseDurationTypeIO,
+ TTypeIO, BulkSetTypeIO]
+
+ _deserializers = {t.graphson_type: t for t in _TYPES}
+
+
+class GraphSON3Reader(GraphSON2Reader):
+ """
+ GraphSON3 Reader that parse json and deserialize to python objects.
+ """
+
+ def __init__(self, context, extra_deserializer_map=None):
+ """
+ :param context: A dict of the context, mostly used as context for udt deserialization.
+ :param extra_deserializer_map: map from GraphSON type tag to deserializer instance implementing `deserialize`
+ """
+ self.context = context
+ self.deserializers = GraphSON3Deserializer.get_type_definitions()
+ if extra_deserializer_map:
+ self.deserializers.update(extra_deserializer_map)
diff --git a/cassandra/datastax/graph/query.py b/cassandra/datastax/graph/query.py
index 50a03b5561..866df7a94c 100644
--- a/cassandra/datastax/graph/query.py
+++ b/cassandra/datastax/graph/query.py
@@ -15,18 +15,16 @@
import json
from warnings import warn
-import six
-
from cassandra import ConsistencyLevel
from cassandra.query import Statement, SimpleStatement
-from cassandra.datastax.graph.types import Vertex, Edge, Path
-from cassandra.datastax.graph.graphson import GraphSON2Reader
+from cassandra.datastax.graph.types import Vertex, Edge, Path, VertexProperty
+from cassandra.datastax.graph.graphson import GraphSON2Reader, GraphSON3Reader
__all__ = [
'GraphProtocol', 'GraphOptions', 'GraphStatement', 'SimpleGraphStatement',
'single_object_row_factory', 'graph_result_row_factory', 'graph_object_row_factory',
- 'graph_graphson2_row_factory', 'Result'
+ 'graph_graphson2_row_factory', 'Result', 'graph_graphson3_row_factory'
]
# (attr, description, server option)
@@ -45,21 +43,24 @@
# this is defined by the execution profile attribute, not in graph options
_request_timeout_key = 'request-timeout'
-_graphson2_reader = GraphSON2Reader()
-
class GraphProtocol(object):
- GRAPHSON_1_0 = 'graphson-1.0'
+ GRAPHSON_1_0 = b'graphson-1.0'
"""
GraphSON1
"""
- GRAPHSON_2_0 = 'graphson-2.0'
+ GRAPHSON_2_0 = b'graphson-2.0'
"""
GraphSON2
"""
+ GRAPHSON_3_0 = b'graphson-3.0'
+ """
+ GraphSON3
+ """
+
class GraphOptions(object):
"""
@@ -67,12 +68,14 @@ class GraphOptions(object):
"""
# See _graph_options map above for notes on valid options
+ DEFAULT_GRAPH_PROTOCOL = GraphProtocol.GRAPHSON_1_0
+ DEFAULT_GRAPH_LANGUAGE = b'gremlin-groovy'
+
def __init__(self, **kwargs):
self._graph_options = {}
kwargs.setdefault('graph_source', 'g')
- kwargs.setdefault('graph_language', 'gremlin-groovy')
- kwargs.setdefault('graph_protocol', GraphProtocol.GRAPHSON_1_0)
- for attr, value in six.iteritems(kwargs):
+ kwargs.setdefault('graph_language', GraphOptions.DEFAULT_GRAPH_LANGUAGE)
+ for attr, value in kwargs.items():
if attr not in _graph_option_names:
warn("Unknown keyword argument received for GraphOptions: {0}".format(attr))
setattr(self, attr, value)
@@ -98,7 +101,7 @@ def get_options_map(self, other_options=None):
for cl in ('graph-write-consistency', 'graph-read-consistency'):
cl_enum = options.get(cl)
if cl_enum is not None:
- options[cl] = six.b(ConsistencyLevel.value_to_name[cl_enum])
+ options[cl] = ConsistencyLevel.value_to_name[cl_enum].encode()
return options
def set_source_default(self):
@@ -152,8 +155,8 @@ def get(self, key=opt[2]):
def set(self, value, key=opt[2]):
if value is not None:
# normalize text here so it doesn't have to be done every time we get options map
- if isinstance(value, six.text_type) and not isinstance(value, six.binary_type):
- value = six.b(value)
+ if isinstance(value, str):
+ value = value.encode()
self._graph_options[key] = value
else:
self._graph_options.pop(key, None)
@@ -222,11 +225,31 @@ def _graph_object_sequence(objects):
yield res
-def graph_graphson2_row_factory(column_names, rows):
- """
- Row Factory that returns the decoded graphson as DSE types.
- """
- return [_graphson2_reader.read(row[0])['result'] for row in rows]
+class _GraphSONContextRowFactory(object):
+ graphson_reader_class = None
+ graphson_reader_kwargs = None
+
+ def __init__(self, cluster):
+ context = {'cluster': cluster}
+ kwargs = self.graphson_reader_kwargs or {}
+ self.graphson_reader = self.graphson_reader_class(context, **kwargs)
+
+ def __call__(self, column_names, rows):
+ return [self.graphson_reader.read(row[0])['result'] for row in rows]
+
+
+class _GraphSON2RowFactory(_GraphSONContextRowFactory):
+ """Row factory to deserialize GraphSON2 results."""
+ graphson_reader_class = GraphSON2Reader
+
+
+class _GraphSON3RowFactory(_GraphSONContextRowFactory):
+ """Row factory to deserialize GraphSON3 results."""
+ graphson_reader_class = GraphSON3Reader
+
+
+graph_graphson2_row_factory = _GraphSON2RowFactory
+graph_graphson3_row_factory = _GraphSON3RowFactory
class Result(object):
@@ -253,7 +276,7 @@ def __getattr__(self, attr):
raise AttributeError("Result has no top-level attribute %r" % (attr,))
def __getitem__(self, item):
- if isinstance(self.value, dict) and isinstance(item, six.string_types):
+ if isinstance(self.value, dict) and isinstance(item, str):
return self.value[item]
elif isinstance(self.value, list) and isinstance(item, int):
return self.value[item]
@@ -302,3 +325,6 @@ def as_path(self):
return Path(self.labels, self.objects)
except (AttributeError, ValueError, TypeError):
raise TypeError("Could not create Path from %r" % (self,))
+
+ def as_vertex_property(self):
+ return VertexProperty(self.value.get('label'), self.value.get('value'), self.value.get('properties', {}))
diff --git a/cassandra/datastax/graph/types.py b/cassandra/datastax/graph/types.py
index ae22cd4bfe..9817c99d7d 100644
--- a/cassandra/datastax/graph/types.py
+++ b/cassandra/datastax/graph/types.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ['Element', 'Vertex', 'Edge', 'VertexProperty', 'Path']
+__all__ = ['Element', 'Vertex', 'Edge', 'VertexProperty', 'Path', 'T']
class Element(object):
@@ -159,3 +159,52 @@ def __str__(self):
def __repr__(self):
return "%s(%r, %r)" % (self.__class__.__name__, self.labels, [o.value for o in self.objects])
+
+
+class T(object):
+ """
+ Represents a collection of tokens for more concise Traversal definitions.
+ """
+
+ name = None
+ val = None
+
+ # class attributes
+ id = None
+ """
+ """
+
+ key = None
+ """
+ """
+ label = None
+ """
+ """
+ value = None
+ """
+ """
+
+ def __init__(self, name, val):
+ self.name = name
+ self.val = val
+
+ def __str__(self):
+ return self.name
+
+ def __repr__(self):
+ return "T.%s" % (self.name, )
+
+
+T.id = T("id", 1)
+T.id_ = T("id_", 2)
+T.key = T("key", 3)
+T.label = T("label", 4)
+T.value = T("value", 5)
+
+T.name_to_value = {
+ 'id': T.id,
+ 'id_': T.id_,
+ 'key': T.key,
+ 'label': T.label,
+ 'value': T.value
+}
diff --git a/cassandra/datastax/insights/registry.py b/cassandra/datastax/insights/registry.py
index 3dd1d255ae..03daebd86e 100644
--- a/cassandra/datastax/insights/registry.py
+++ b/cassandra/datastax/insights/registry.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import six
from collections import OrderedDict
from warnings import warn
@@ -59,7 +58,7 @@ def _get_serializer(self, cls):
try:
return self._mapping_dict[cls]
except KeyError:
- for registered_cls, serializer in six.iteritems(self._mapping_dict):
+ for registered_cls, serializer in self._mapping_dict.items():
if issubclass(cls, registered_cls):
return self._mapping_dict[registered_cls]
raise ValueError
diff --git a/cassandra/datastax/insights/reporter.py b/cassandra/datastax/insights/reporter.py
index b05a88deb0..83205fc458 100644
--- a/cassandra/datastax/insights/reporter.py
+++ b/cassandra/datastax/insights/reporter.py
@@ -24,7 +24,6 @@
import sys
from threading import Event, Thread
import time
-import six
from cassandra.policies import HostDistance
from cassandra.util import ms_timestamp_from_datetime
@@ -199,9 +198,9 @@ def _get_startup_data(self):
},
'platformInfo': {
'os': {
- 'name': uname_info.system if six.PY3 else uname_info[0],
- 'version': uname_info.release if six.PY3 else uname_info[2],
- 'arch': uname_info.machine if six.PY3 else uname_info[4]
+ 'name': uname_info.system,
+ 'version': uname_info.release,
+ 'arch': uname_info.machine
},
'cpus': {
'length': multiprocessing.cpu_count(),
diff --git a/cassandra/datastax/insights/serializers.py b/cassandra/datastax/insights/serializers.py
index aec4467a6a..289c165e8a 100644
--- a/cassandra/datastax/insights/serializers.py
+++ b/cassandra/datastax/insights/serializers.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import six
-
def initialize_registry(insights_registry):
# This will be called from the cluster module, so we put all this behavior
@@ -203,8 +201,8 @@ def graph_options_insights_serializer(options):
'language': options.graph_language,
'graphProtocol': options.graph_protocol
}
- updates = {k: v.decode('utf-8') for k, v in six.iteritems(rv)
- if isinstance(v, six.binary_type)}
+ updates = {k: v.decode('utf-8') for k, v in rv.items()
+ if isinstance(v, bytes)}
rv.update(updates)
return rv
diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx
index 7de6949099..7c256674b0 100644
--- a/cassandra/deserializers.pyx
+++ b/cassandra/deserializers.pyx
@@ -29,8 +29,6 @@ from uuid import UUID
from cassandra import cqltypes
from cassandra import util
-cdef bint PY2 = six.PY2
-
cdef class Deserializer:
"""Cython-based deserializer class for a cqltype"""
@@ -90,8 +88,6 @@ cdef class DesAsciiType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
if buf.size == 0:
return ""
- if PY2:
- return to_bytes(buf)
return to_bytes(buf).decode('ascii')
diff --git a/cassandra/encoder.py b/cassandra/encoder.py
index f2c3f8dfed..31d90549f4 100644
--- a/cassandra/encoder.py
+++ b/cassandra/encoder.py
@@ -27,28 +27,15 @@
import sys
import types
from uuid import UUID
-import six
+import ipaddress
from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey,
sortedset, Time, Date, Point, LineString, Polygon)
-if six.PY3:
- import ipaddress
-
-if six.PY3:
- long = int
-
def cql_quote(term):
- # The ordering of this method is important for the result of this method to
- # be a native str type (for both Python 2 and 3)
-
if isinstance(term, str):
return "'%s'" % str(term).replace("'", "''")
- # This branch of the if statement will only be used by Python 2 to catch
- # unicode strings, text_type is used to prevent type errors with Python 3.
- elif isinstance(term, six.text_type):
- return "'%s'" % term.encode('utf8').replace("'", "''")
else:
return str(term)
@@ -97,21 +84,13 @@ def __init__(self):
Polygon: self.cql_encode_str_quoted
}
- if six.PY2:
- self.mapping.update({
- unicode: self.cql_encode_unicode,
- buffer: self.cql_encode_bytes,
- long: self.cql_encode_object,
- types.NoneType: self.cql_encode_none,
- })
- else:
- self.mapping.update({
- memoryview: self.cql_encode_bytes,
- bytes: self.cql_encode_bytes,
- type(None): self.cql_encode_none,
- ipaddress.IPv4Address: self.cql_encode_ipaddress,
- ipaddress.IPv6Address: self.cql_encode_ipaddress
- })
+ self.mapping.update({
+ memoryview: self.cql_encode_bytes,
+ bytes: self.cql_encode_bytes,
+ type(None): self.cql_encode_none,
+ ipaddress.IPv4Address: self.cql_encode_ipaddress,
+ ipaddress.IPv6Address: self.cql_encode_ipaddress
+ })
def cql_encode_none(self, val):
"""
@@ -134,16 +113,8 @@ def cql_encode_str(self, val):
def cql_encode_str_quoted(self, val):
return "'%s'" % val
- if six.PY3:
- def cql_encode_bytes(self, val):
- return (b'0x' + hexlify(val)).decode('utf-8')
- elif sys.version_info >= (2, 7):
- def cql_encode_bytes(self, val): # noqa
- return b'0x' + hexlify(val)
- else:
- # python 2.6 requires string or read-only buffer for hexlify
- def cql_encode_bytes(self, val): # noqa
- return b'0x' + hexlify(buffer(val))
+ def cql_encode_bytes(self, val):
+ return (b'0x' + hexlify(val)).decode('utf-8')
def cql_encode_object(self, val):
"""
@@ -169,7 +140,7 @@ def cql_encode_datetime(self, val):
with millisecond precision.
"""
timestamp = calendar.timegm(val.utctimetuple())
- return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
+ return str(int(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3))
def cql_encode_date(self, val):
"""
@@ -214,7 +185,7 @@ def cql_encode_map_collection(self, val):
return '{%s}' % ', '.join('%s: %s' % (
self.mapping.get(type(k), self.cql_encode_object)(k),
self.mapping.get(type(v), self.cql_encode_object)(v)
- ) for k, v in six.iteritems(val))
+ ) for k, v in val.items())
def cql_encode_list_collection(self, val):
"""
@@ -236,14 +207,13 @@ def cql_encode_all_types(self, val, as_text_type=False):
if :attr:`~Encoder.mapping` does not contain an entry for the type.
"""
encoded = self.mapping.get(type(val), self.cql_encode_object)(val)
- if as_text_type and not isinstance(encoded, six.text_type):
+ if as_text_type and not isinstance(encoded, str):
return encoded.decode('utf-8')
return encoded
- if six.PY3:
- def cql_encode_ipaddress(self, val):
- """
- Converts an ipaddress (IPV4Address, IPV6Address) to a CQL string. This
- is suitable for ``inet`` type columns.
- """
- return "'%s'" % val.compressed
+ def cql_encode_ipaddress(self, val):
+ """
+ Converts an ipaddress (IPV4Address, IPV6Address) to a CQL string. This
+ is suitable for ``inet`` type columns.
+ """
+ return "'%s'" % val.compressed
diff --git a/cassandra/io/asyncioreactor.py b/cassandra/io/asyncioreactor.py
index 7cb0444a32..95f92e26e0 100644
--- a/cassandra/io/asyncioreactor.py
+++ b/cassandra/io/asyncioreactor.py
@@ -41,14 +41,12 @@ def end(self):
def __init__(self, timeout, callback, loop):
delayed = self._call_delayed_coro(timeout=timeout,
- callback=callback,
- loop=loop)
+ callback=callback)
self._handle = asyncio.run_coroutine_threadsafe(delayed, loop=loop)
@staticmethod
- @asyncio.coroutine
- def _call_delayed_coro(timeout, callback, loop):
- yield from asyncio.sleep(timeout, loop=loop)
+ async def _call_delayed_coro(timeout, callback):
+ await asyncio.sleep(timeout)
return callback()
def __lt__(self, other):
@@ -91,8 +89,8 @@ def __init__(self, *args, **kwargs):
self._connect_socket()
self._socket.setblocking(0)
- self._write_queue = asyncio.Queue(loop=self._loop)
- self._write_queue_lock = asyncio.Lock(loop=self._loop)
+ self._write_queue = asyncio.Queue()
+ self._write_queue_lock = asyncio.Lock()
# see initialize_reactor -- loop is running in a separate thread, so we
# have to use a threadsafe call
@@ -136,8 +134,7 @@ def close(self):
self._close(), loop=self._loop
)
- @asyncio.coroutine
- def _close(self):
+ async def _close(self):
log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint))
if self._write_watcher:
self._write_watcher.cancel()
@@ -174,21 +171,19 @@ def push(self, data):
# avoid races/hangs by just scheduling this, not using threadsafe
self._loop.create_task(self._push_msg(chunks))
- @asyncio.coroutine
- def _push_msg(self, chunks):
+ async def _push_msg(self, chunks):
# This lock ensures all chunks of a message are sequential in the Queue
- with (yield from self._write_queue_lock):
+ with await self._write_queue_lock:
for chunk in chunks:
self._write_queue.put_nowait(chunk)
- @asyncio.coroutine
- def handle_write(self):
+ async def handle_write(self):
while True:
try:
- next_msg = yield from self._write_queue.get()
+ next_msg = await self._write_queue.get()
if next_msg:
- yield from self._loop.sock_sendall(self._socket, next_msg)
+ await self._loop.sock_sendall(self._socket, next_msg)
except socket.error as err:
log.debug("Exception in send for %s: %s", self, err)
self.defunct(err)
@@ -196,18 +191,19 @@ def handle_write(self):
except asyncio.CancelledError:
return
- @asyncio.coroutine
- def handle_read(self):
+ async def handle_read(self):
while True:
try:
- buf = yield from self._loop.sock_recv(self._socket, self.in_buffer_size)
+ buf = await self._loop.sock_recv(self._socket, self.in_buffer_size)
self._iobuf.write(buf)
# sock_recv expects EWOULDBLOCK if socket provides no data, but
# nonblocking ssl sockets raise these instead, so we handle them
# ourselves by yielding to the event loop, where the socket will
# get the reading/writing it "wants" before retrying
except (ssl.SSLWantWriteError, ssl.SSLWantReadError):
- yield
+ # Apparently the preferred way to yield to the event loop from within
+ # a native coroutine based on https://github.com/python/asyncio/issues/284
+ await asyncio.sleep(0)
continue
except socket.error as err:
log.debug("Exception during socket recv for %s: %s",
diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py
index 1a6b9fd3e9..a50b719c5d 100644
--- a/cassandra/io/asyncorereactor.py
+++ b/cassandra/io/asyncorereactor.py
@@ -24,17 +24,25 @@
import sys
import ssl
-from six.moves import range
try:
from weakref import WeakSet
except ImportError:
from cassandra.util import WeakSet # noqa
-import asyncore
+from cassandra import DependencyException
+try:
+ import asyncore
+except ModuleNotFoundError:
+ raise DependencyException(
+ "Unable to import asyncore module. Note that this module has been removed in Python 3.12 "
+ "so when using the driver with this version (or anything newer) you will need to use one of the "
+ "other event loop implementations."
+ )
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager
+
log = logging.getLogger(__name__)
_dispatcher_map = {}
@@ -247,12 +255,21 @@ def _run_loop(self):
try:
self._loop_dispatcher.loop(self.timer_resolution)
self._timers.service_timeouts()
- except Exception:
- log.debug("Asyncore event loop stopped unexepectedly", exc_info=True)
+ except Exception as exc:
+ self._maybe_log_debug("Asyncore event loop stopped unexpectedly", exc_info=exc)
break
self._started = False
- log.debug("Asyncore event loop ended")
+ self._maybe_log_debug("Asyncore event loop ended")
+
+ def _maybe_log_debug(self, *args, **kwargs):
+ try:
+ log.debug(*args, **kwargs)
+ except Exception:
+ # TODO: Remove when Python 2 support is removed
+ # PYTHON-1266. If our logger has disappeared, there's nothing we
+ # can do, so just log nothing.
+ pass
def add_timer(self, timer):
self._timers.add_timer(timer)
diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py
index 162661f468..c51bfd7591 100644
--- a/cassandra/io/eventletreactor.py
+++ b/cassandra/io/eventletreactor.py
@@ -23,8 +23,6 @@
from threading import Event
import time
-from six.moves import xrange
-
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
try:
from eventlet.green.OpenSSL import SSL
@@ -105,11 +103,12 @@ def __init__(self, *args, **kwargs):
def _wrap_socket_from_context(self):
_check_pyopenssl()
- self._socket = SSL.Connection(self.ssl_context, self._socket)
- self._socket.set_connect_state()
+ rv = SSL.Connection(self.ssl_context, self._socket)
+ rv.set_connect_state()
if self.ssl_options and 'server_hostname' in self.ssl_options:
# This is necessary for SNI
- self._socket.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii'))
+ rv.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii'))
+ return rv
def _initiate_connection(self, sockaddr):
if self.uses_legacy_ssl_options:
@@ -119,14 +118,12 @@ def _initiate_connection(self, sockaddr):
if self.ssl_context or self.ssl_options:
self._socket.do_handshake()
- def _match_hostname(self):
- if self.uses_legacy_ssl_options:
- super(EventletConnection, self)._match_hostname()
- else:
+ def _validate_hostname(self):
+ if not self.uses_legacy_ssl_options:
cert_name = self._socket.get_peer_certificate().get_subject().commonName
if cert_name != self.endpoint.address:
raise Exception("Hostname verification failed! Certificate name '{}' "
- "doesn't endpoint '{}'".format(cert_name, self.endpoint.address))
+ "doesn't match endpoint '{}'".format(cert_name, self.endpoint.address))
def close(self):
with self.lock:
@@ -190,5 +187,5 @@ def handle_read(self):
def push(self, data):
chunk_size = self.out_buffer_size
- for i in xrange(0, len(data), chunk_size):
+ for i in range(0, len(data), chunk_size):
self._write_queue.put(data[i:i + chunk_size])
diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py
index ebc664d485..4f1f158aa7 100644
--- a/cassandra/io/geventreactor.py
+++ b/cassandra/io/geventreactor.py
@@ -20,7 +20,6 @@
import logging
import time
-from six.moves import range
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py
index 2487419784..4d4098ca7b 100644
--- a/cassandra/io/libevreactor.py
+++ b/cassandra/io/libevreactor.py
@@ -21,14 +21,13 @@
from threading import Lock, Thread
import time
-from six.moves import range
-
+from cassandra import DependencyException
from cassandra.connection import (Connection, ConnectionShutdown,
NONBLOCKING, Timer, TimerManager)
try:
import cassandra.io.libevwrapper as libev
except ImportError:
- raise ImportError(
+ raise DependencyException(
"The C extension needed to use libev was not found. This "
"probably means that you didn't have the required build dependencies "
"when installing the driver. See "
@@ -310,6 +309,8 @@ def handle_write(self, watcher, revents, errno=None):
with self._deque_lock:
next_msg = self.deque.popleft()
except IndexError:
+ if not self._socket_writable:
+ self._socket_writable = True
return
try:
@@ -317,6 +318,8 @@ def handle_write(self, watcher, revents, errno=None):
except socket.error as err:
if (err.args[0] in NONBLOCKING or
err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE)):
+ if err.args[0] in NONBLOCKING:
+ self._socket_writable = False
with self._deque_lock:
self.deque.appendleft(next_msg)
else:
@@ -326,6 +329,11 @@ def handle_write(self, watcher, revents, errno=None):
if sent < len(next_msg):
with self._deque_lock:
self.deque.appendleft(next_msg[sent:])
+ # we've seen some cases that 0 is returned instead of NONBLOCKING. But usually,
+ # we don't expect this to happen. https://bugs.python.org/issue20951
+ if sent == 0:
+ self._socket_writable = False
+ return
def handle_read(self, watcher, revents, errno=None):
if revents & libev.EV_ERROR:
diff --git a/cassandra/marshal.py b/cassandra/marshal.py
index 7533ebd307..726f0819eb 100644
--- a/cassandra/marshal.py
+++ b/cassandra/marshal.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import six
import struct
@@ -28,6 +27,7 @@ def _make_packer(format_string):
int8_pack, int8_unpack = _make_packer('>b')
uint64_pack, uint64_unpack = _make_packer('>Q')
uint32_pack, uint32_unpack = _make_packer('>I')
+uint32_le_pack, uint32_le_unpack = _make_packer('H')
uint8_pack, uint8_unpack = _make_packer('>B')
float_pack, float_unpack = _make_packer('>f')
@@ -44,35 +44,16 @@ def _make_packer(format_string):
v3_header_unpack = v3_header_struct.unpack
-if six.PY3:
- def byte2int(b):
- return b
-
-
- def varint_unpack(term):
- val = int(''.join("%02x" % i for i in term), 16)
- if (term[0] & 128) != 0:
- len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code
- val -= 1 << (len_term * 8)
- return val
-else:
- def byte2int(b):
- return ord(b)
-
-
- def varint_unpack(term): # noqa
- val = int(term.encode('hex'), 16)
- if (ord(term[0]) & 128) != 0:
- len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code
- val = val - (1 << (len_term * 8))
- return val
+def varint_unpack(term):
+ val = int(''.join("%02x" % i for i in term), 16)
+ if (term[0] & 128) != 0:
+ len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code
+ val -= 1 << (len_term * 8)
+ return val
def bit_length(n):
- if six.PY3 or isinstance(n, int):
- return int.bit_length(n)
- else:
- return long.bit_length(n)
+ return int.bit_length(n)
def varint_pack(big):
@@ -90,7 +71,7 @@ def varint_pack(big):
if pos and revbytes[-1] & 0x80:
revbytes.append(0)
revbytes.reverse()
- return six.binary_type(revbytes)
+ return bytes(revbytes)
point_be = struct.Struct('>dd')
@@ -112,7 +93,7 @@ def vints_unpack(term): # noqa
values = []
n = 0
while n < len(term):
- first_byte = byte2int(term[n])
+ first_byte = term[n]
if (first_byte & 128) == 0:
val = first_byte
@@ -123,7 +104,7 @@ def vints_unpack(term): # noqa
while n < end:
n += 1
val <<= 8
- val |= byte2int(term[n]) & 0xff
+ val |= term[n] & 0xff
n += 1
values.append(decode_zig_zag(val))
@@ -161,4 +142,4 @@ def vints_pack(values):
revbytes.append(abs(v))
revbytes.reverse()
- return six.binary_type(revbytes)
+ return bytes(revbytes)
diff --git a/cassandra/metadata.py b/cassandra/metadata.py
index f7019b7e9f..f52bfd9317 100644
--- a/cassandra/metadata.py
+++ b/cassandra/metadata.py
@@ -15,13 +15,12 @@
from binascii import unhexlify
from bisect import bisect_left
from collections import defaultdict
+from collections.abc import Mapping
from functools import total_ordering
from hashlib import md5
import json
import logging
import re
-import six
-from six.moves import zip
import sys
from threading import RLock
import struct
@@ -42,28 +41,27 @@
from cassandra.util import OrderedDict, Version
from cassandra.pool import HostDistance
from cassandra.connection import EndPoint
-from cassandra.compat import Mapping
log = logging.getLogger(__name__)
cql_keywords = set((
'add', 'aggregate', 'all', 'allow', 'alter', 'and', 'apply', 'as', 'asc', 'ascii', 'authorize', 'batch', 'begin',
'bigint', 'blob', 'boolean', 'by', 'called', 'clustering', 'columnfamily', 'compact', 'contains', 'count',
- 'counter', 'create', 'custom', 'date', 'decimal', 'delete', 'desc', 'describe', 'deterministic', 'distinct', 'double', 'drop',
+ 'counter', 'create', 'custom', 'date', 'decimal', 'default', 'delete', 'desc', 'describe', 'deterministic', 'distinct', 'double', 'drop',
'entries', 'execute', 'exists', 'filtering', 'finalfunc', 'float', 'from', 'frozen', 'full', 'function',
'functions', 'grant', 'if', 'in', 'index', 'inet', 'infinity', 'initcond', 'input', 'insert', 'int', 'into', 'is', 'json',
- 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'modify', 'monotonic', 'nan', 'nologin',
- 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission',
+ 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'mbean', 'mbeans', 'modify', 'monotonic',
+ 'nan', 'nologin', 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission',
'permissions', 'primary', 'rename', 'replace', 'returns', 'revoke', 'role', 'roles', 'schema', 'select', 'set',
'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'table', 'text', 'time', 'timestamp', 'timeuuid',
- 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', 'update', 'use', 'user',
+ 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', 'unset', 'update', 'use', 'user',
'users', 'using', 'uuid', 'values', 'varchar', 'varint', 'view', 'where', 'with', 'writetime',
# DSE specifics
"node", "nodes", "plan", "active", "application", "applications", "java", "executor", "executors", "std_out", "std_err",
"renew", "delegation", "no", "redact", "token", "lowercasestring", "cluster", "authentication", "schemes", "scheme",
"internal", "ldap", "kerberos", "remote", "object", "method", "call", "calls", "search", "schema", "config", "rows",
- "columns", "profiles", "commit", "reload", "unset", "rebuild", "field", "workpool", "any", "submission", "indices",
+ "columns", "profiles", "commit", "reload", "rebuild", "field", "workpool", "any", "submission", "indices",
"restrict", "unrestrict"
))
"""
@@ -292,7 +290,7 @@ def rebuild_token_map(self, partitioner, token_map):
token_to_host_owner = {}
ring = []
- for host, token_strings in six.iteritems(token_map):
+ for host, token_strings in token_map.items():
for token_string in token_strings:
token = token_class.from_string(token_string)
ring.append(token)
@@ -338,20 +336,23 @@ def remove_host(self, host):
with self._hosts_lock:
return bool(self._hosts.pop(host.endpoint, False))
- def get_host(self, endpoint_or_address):
+ def get_host(self, endpoint_or_address, port=None):
"""
- Find a host in the metadata for a specific endpoint. If a string inet address is passed,
- iterate all hosts to match the :attr:`~.pool.Host.broadcast_rpc_address` attribute.
+ Find a host in the metadata for a specific endpoint. If a string inet address and port are passed,
+ iterate all hosts to match the :attr:`~.pool.Host.broadcast_rpc_address` and
+ :attr:`~.pool.Host.broadcast_rpc_port`attributes.
"""
if not isinstance(endpoint_or_address, EndPoint):
- return self._get_host_by_address(endpoint_or_address)
+ return self._get_host_by_address(endpoint_or_address, port)
return self._hosts.get(endpoint_or_address)
- def _get_host_by_address(self, address):
- for host in six.itervalues(self._hosts):
- if host.broadcast_rpc_address == address:
+ def _get_host_by_address(self, address, port=None):
+ for host in self._hosts.values():
+ if (host.broadcast_rpc_address == address and
+ (port is None or host.broadcast_rpc_port is None or host.broadcast_rpc_port == port)):
return host
+
return None
def all_hosts(self):
@@ -383,8 +384,8 @@ def __new__(metacls, name, bases, dct):
return cls
-@six.add_metaclass(ReplicationStrategyTypeType)
-class _ReplicationStrategy(object):
+
+class _ReplicationStrategy(object, metaclass=ReplicationStrategyTypeType):
options_map = None
@classmethod
@@ -450,18 +451,82 @@ def make_token_replica_map(self, token_to_host_owner, ring):
return {}
+class ReplicationFactor(object):
+ """
+ Represent the replication factor of a keyspace.
+ """
+
+ all_replicas = None
+ """
+ The number of total replicas.
+ """
+
+ full_replicas = None
+ """
+ The number of replicas that own a full copy of the data. This is the same
+ than `all_replicas` when transient replication is not enabled.
+ """
+
+ transient_replicas = None
+ """
+ The number of transient replicas.
+
+ Only set if the keyspace has transient replication enabled.
+ """
+
+ def __init__(self, all_replicas, transient_replicas=None):
+ self.all_replicas = all_replicas
+ self.transient_replicas = transient_replicas
+ self.full_replicas = (all_replicas - transient_replicas) if transient_replicas else all_replicas
+
+ @staticmethod
+ def create(rf):
+ """
+ Given the inputted replication factor string, parse and return the ReplicationFactor instance.
+ """
+ transient_replicas = None
+ try:
+ all_replicas = int(rf)
+ except ValueError:
+ try:
+ rf = rf.split('/')
+ all_replicas, transient_replicas = int(rf[0]), int(rf[1])
+ except Exception:
+ raise ValueError("Unable to determine replication factor from: {}".format(rf))
+
+ return ReplicationFactor(all_replicas, transient_replicas)
+
+ def __str__(self):
+ return ("%d/%d" % (self.all_replicas, self.transient_replicas) if self.transient_replicas
+ else "%d" % self.all_replicas)
+
+ def __eq__(self, other):
+ if not isinstance(other, ReplicationFactor):
+ return False
+
+ return self.all_replicas == other.all_replicas and self.full_replicas == other.full_replicas
+
+
class SimpleStrategy(ReplicationStrategy):
- replication_factor = None
+ replication_factor_info = None
"""
- The replication factor for this keyspace.
+ A :class:`cassandra.metadata.ReplicationFactor` instance.
"""
+ @property
+ def replication_factor(self):
+ """
+ The replication factor for this keyspace.
+
+ For backward compatibility, this returns the
+ :attr:`cassandra.metadata.ReplicationFactor.full_replicas` value of
+ :attr:`cassandra.metadata.SimpleStrategy.replication_factor_info`.
+ """
+ return self.replication_factor_info.full_replicas
+
def __init__(self, options_map):
- try:
- self.replication_factor = int(options_map['replication_factor'])
- except Exception:
- raise ValueError("SimpleStrategy requires an integer 'replication_factor' option")
+ self.replication_factor_info = ReplicationFactor.create(options_map['replication_factor'])
def make_token_replica_map(self, token_to_host_owner, ring):
replica_map = {}
@@ -482,30 +547,41 @@ def export_for_schema(self):
Returns a string version of these replication options which are
suitable for use in a CREATE KEYSPACE statement.
"""
- return "{'class': 'SimpleStrategy', 'replication_factor': '%d'}" \
- % (self.replication_factor,)
+ return "{'class': 'SimpleStrategy', 'replication_factor': '%s'}" \
+ % (str(self.replication_factor_info),)
def __eq__(self, other):
if not isinstance(other, SimpleStrategy):
return False
- return self.replication_factor == other.replication_factor
+ return str(self.replication_factor_info) == str(other.replication_factor_info)
class NetworkTopologyStrategy(ReplicationStrategy):
+ dc_replication_factors_info = None
+ """
+ A map of datacenter names to the :class:`cassandra.metadata.ReplicationFactor` instance for that DC.
+ """
+
dc_replication_factors = None
"""
A map of datacenter names to the replication factor for that DC.
+
+ For backward compatibility, this maps to the :attr:`cassandra.metadata.ReplicationFactor.full_replicas`
+ value of the :attr:`cassandra.metadata.NetworkTopologyStrategy.dc_replication_factors_info` dict.
"""
def __init__(self, dc_replication_factors):
+ self.dc_replication_factors_info = dict(
+ (str(k), ReplicationFactor.create(v)) for k, v in dc_replication_factors.items())
self.dc_replication_factors = dict(
- (str(k), int(v)) for k, v in dc_replication_factors.items())
+ (dc, rf.full_replicas) for dc, rf in self.dc_replication_factors_info.items())
def make_token_replica_map(self, token_to_host_owner, ring):
- dc_rf_map = dict((dc, int(rf))
- for dc, rf in self.dc_replication_factors.items() if rf > 0)
+ dc_rf_map = dict(
+ (dc, full_replicas) for dc, full_replicas in self.dc_replication_factors.items()
+ if full_replicas > 0)
# build a map of DCs to lists of indexes into `ring` for tokens that
# belong to that DC
@@ -548,7 +624,7 @@ def make_token_replica_map(self, token_to_host_owner, ring):
racks_this_dc = dc_racks[dc]
hosts_this_dc = len(hosts_per_dc[dc])
- for token_offset_index in six.moves.range(index, index+num_tokens):
+ for token_offset_index in range(index, index+num_tokens):
if token_offset_index >= len(token_offsets):
token_offset_index = token_offset_index - len(token_offsets)
@@ -585,15 +661,15 @@ def export_for_schema(self):
suitable for use in a CREATE KEYSPACE statement.
"""
ret = "{'class': 'NetworkTopologyStrategy'"
- for dc, repl_factor in sorted(self.dc_replication_factors.items()):
- ret += ", '%s': '%d'" % (dc, repl_factor)
+ for dc, rf in sorted(self.dc_replication_factors_info.items()):
+ ret += ", '%s': '%s'" % (dc, str(rf))
return ret + "}"
def __eq__(self, other):
if not isinstance(other, NetworkTopologyStrategy):
return False
- return self.dc_replication_factors == other.dc_replication_factors
+ return self.dc_replication_factors_info == other.dc_replication_factors_info
class LocalStrategy(ReplicationStrategy):
@@ -677,10 +753,15 @@ class KeyspaceMetadata(object):
.. versionadded:: 3.15
"""
+ graph_engine = None
+ """
+ A string indicating whether a graph engine is enabled for this keyspace (Core/Classic).
+ """
+
_exc_info = None
""" set if metadata parsing failed """
- def __init__(self, name, durable_writes, strategy_class, strategy_options):
+ def __init__(self, name, durable_writes, strategy_class, strategy_options, graph_engine=None):
self.name = name
self.durable_writes = durable_writes
self.replication_strategy = ReplicationStrategy.create(strategy_class, strategy_options)
@@ -690,17 +771,28 @@ def __init__(self, name, durable_writes, strategy_class, strategy_options):
self.functions = {}
self.aggregates = {}
self.views = {}
+ self.graph_engine = graph_engine
+
+ @property
+ def is_graph_enabled(self):
+ return self.graph_engine is not None
def export_as_string(self):
"""
Returns a CQL query string that can be used to recreate the entire keyspace,
including user-defined types and tables.
"""
- cql = "\n\n".join([self.as_cql_query() + ';'] +
- self.user_type_strings() +
- [f.export_as_string() for f in self.functions.values()] +
- [a.export_as_string() for a in self.aggregates.values()] +
- [t.export_as_string() for t in self.tables.values()])
+ # Make sure tables with vertex are exported before tables with edges
+ tables_with_vertex = [t for t in self.tables.values() if hasattr(t, 'vertex') and t.vertex]
+ other_tables = [t for t in self.tables.values() if t not in tables_with_vertex]
+
+ cql = "\n\n".join(
+ [self.as_cql_query() + ';'] +
+ self.user_type_strings() +
+ [f.export_as_string() for f in self.functions.values()] +
+ [a.export_as_string() for a in self.aggregates.values()] +
+ [t.export_as_string() for t in tables_with_vertex + other_tables])
+
if self._exc_info:
import traceback
ret = "/*\nWarning: Keyspace %s is incomplete because of an error processing metadata.\n" % \
@@ -726,7 +818,10 @@ def as_cql_query(self):
ret = "CREATE KEYSPACE %s WITH replication = %s " % (
protect_name(self.name),
self.replication_strategy.export_for_schema())
- return ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false"))
+ ret = ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false"))
+ if self.graph_engine is not None:
+ ret = ret + (" AND graph_engine = '%s'" % self.graph_engine)
+ return ret
def user_type_strings(self):
user_type_strings = []
@@ -756,7 +851,7 @@ def _add_table_metadata(self, table_metadata):
# note the intentional order of add before remove
# this makes sure the maps are never absent something that existed before this update
- for index_name, index_metadata in six.iteritems(table_metadata.indexes):
+ for index_name, index_metadata in table_metadata.indexes.items():
self.indexes[index_name] = index_metadata
for index_name in (n for n in old_indexes if n not in table_metadata.indexes):
@@ -1243,7 +1338,7 @@ def _all_as_cql(self):
if self.extensions:
registry = _RegisteredExtensionType._extension_registry
- for k in six.viewkeys(registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey
+ for k in registry.keys() & self.extensions: # no viewkeys on OrderedMapSerializeKey
ext = registry[k]
cql = ext.after_table_cql(self, k, self.extensions[k])
if cql:
@@ -1351,6 +1446,89 @@ def _make_option_strings(cls, options_map):
return list(sorted(ret))
+class TableMetadataV3(TableMetadata):
+ """
+ For C* 3.0+. `option_maps` take a superset of map names, so if nothing
+ changes structurally, new option maps can just be appended to the list.
+ """
+ compaction_options = {}
+
+ option_maps = [
+ 'compaction', 'compression', 'caching',
+ 'nodesync' # added DSE 6.0
+ ]
+
+ @property
+ def is_cql_compatible(self):
+ return True
+
+ @classmethod
+ def _make_option_strings(cls, options_map):
+ ret = []
+ options_copy = dict(options_map.items())
+
+ for option in cls.option_maps:
+ value = options_copy.get(option)
+ if isinstance(value, Mapping):
+ del options_copy[option]
+ params = ("'%s': '%s'" % (k, v) for k, v in value.items())
+ ret.append("%s = {%s}" % (option, ', '.join(params)))
+
+ for name, value in options_copy.items():
+ if value is not None:
+ if name == "comment":
+ value = value or ""
+ ret.append("%s = %s" % (name, protect_value(value)))
+
+ return list(sorted(ret))
+
+
+class TableMetadataDSE68(TableMetadataV3):
+
+ vertex = None
+ """A :class:`.VertexMetadata` instance, if graph enabled"""
+
+ edge = None
+ """A :class:`.EdgeMetadata` instance, if graph enabled"""
+
+ def as_cql_query(self, formatted=False):
+ ret = super(TableMetadataDSE68, self).as_cql_query(formatted)
+
+ if self.vertex:
+ ret += " AND VERTEX LABEL %s" % protect_name(self.vertex.label_name)
+
+ if self.edge:
+ ret += " AND EDGE LABEL %s" % protect_name(self.edge.label_name)
+
+ ret += self._export_edge_as_cql(
+ self.edge.from_label,
+ self.edge.from_partition_key_columns,
+ self.edge.from_clustering_columns, "FROM")
+
+ ret += self._export_edge_as_cql(
+ self.edge.to_label,
+ self.edge.to_partition_key_columns,
+ self.edge.to_clustering_columns, "TO")
+
+ return ret
+
+ @staticmethod
+ def _export_edge_as_cql(label_name, partition_keys,
+ clustering_columns, keyword):
+ ret = " %s %s(" % (keyword, protect_name(label_name))
+
+ if len(partition_keys) == 1:
+ ret += protect_name(partition_keys[0])
+ else:
+ ret += "(%s)" % ", ".join([protect_name(k) for k in partition_keys])
+
+ if clustering_columns:
+ ret += ", %s" % ", ".join([protect_name(k) for k in clustering_columns])
+ ret += ")"
+
+ return ret
+
+
class TableExtensionInterface(object):
"""
Defines CQL/DDL for Cassandra table extensions.
@@ -1376,8 +1554,7 @@ def __new__(mcs, name, bases, dct):
return cls
-@six.add_metaclass(_RegisteredExtensionType)
-class RegisteredTableExtension(TableExtensionInterface):
+class RegisteredTableExtension(TableExtensionInterface, metaclass=_RegisteredExtensionType):
"""
Extending this class registers it by name (associated by key in the `system_schema.tables.extensions` map).
"""
@@ -1683,7 +1860,7 @@ class MD5Token(HashToken):
@classmethod
def hash_fn(cls, key):
- if isinstance(key, six.text_type):
+ if isinstance(key, str):
key = key.encode('UTF-8')
return abs(varint_unpack(md5(key).digest()))
@@ -1697,7 +1874,7 @@ class BytesToken(Token):
def from_string(cls, token_string):
""" `token_string` should be the string representation from the server. """
# unhexlify works fine with unicode input in everythin but pypy3, where it Raises "TypeError: 'str' does not support the buffer interface"
- if isinstance(token_string, six.text_type):
+ if isinstance(token_string, str):
token_string = token_string.encode('ascii')
# The BOP stores a hex string
return cls(unhexlify(token_string))
@@ -2312,6 +2489,8 @@ class SchemaParserV3(SchemaParserV22):
_function_agg_arument_type_col = 'argument_types'
+ _table_metadata_class = TableMetadataV3
+
recognized_table_options = (
'bloom_filter_fp_chance',
'caching',
@@ -2395,7 +2574,7 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_row
trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][table_name]
index_rows = index_rows or self.keyspace_table_index_rows[keyspace_name][table_name]
- table_meta = TableMetadataV3(keyspace_name, table_name, virtual=virtual)
+ table_meta = self._table_metadata_class(keyspace_name, table_name, virtual=virtual)
try:
table_meta.options = self._build_table_options(row)
flags = row.get('flags', set())
@@ -2659,15 +2838,15 @@ def _query_all(self):
# ignore them if we got an error
self.virtual_keyspaces_result = self._handle_results(
virtual_ks_success, virtual_ks_result,
- expected_failures=InvalidRequest
+ expected_failures=(InvalidRequest,)
)
self.virtual_tables_result = self._handle_results(
virtual_table_success, virtual_table_result,
- expected_failures=InvalidRequest
+ expected_failures=(InvalidRequest,)
)
self.virtual_columns_result = self._handle_results(
virtual_column_success, virtual_column_result,
- expected_failures=InvalidRequest
+ expected_failures=(InvalidRequest,)
)
self._aggregate_results()
@@ -2722,41 +2901,196 @@ class SchemaParserDSE67(SchemaParserV4):
("nodesync",))
-class TableMetadataV3(TableMetadata):
+class SchemaParserDSE68(SchemaParserDSE67):
"""
- For C* 3.0+. `option_maps` take a superset of map names, so if nothing
- changes structurally, new option maps can just be appended to the list.
+ For DSE 6.8+
"""
- compaction_options = {}
- option_maps = [
- 'compaction', 'compression', 'caching',
- 'nodesync' # added DSE 6.0
- ]
+ _SELECT_VERTICES = "SELECT * FROM system_schema.vertices"
+ _SELECT_EDGES = "SELECT * FROM system_schema.edges"
- @property
- def is_cql_compatible(self):
- return True
+ _table_metadata_class = TableMetadataDSE68
- @classmethod
- def _make_option_strings(cls, options_map):
- ret = []
- options_copy = dict(options_map.items())
+ def __init__(self, connection, timeout):
+ super(SchemaParserDSE68, self).__init__(connection, timeout)
+ self.keyspace_table_vertex_rows = defaultdict(lambda: defaultdict(list))
+ self.keyspace_table_edge_rows = defaultdict(lambda: defaultdict(list))
- for option in cls.option_maps:
- value = options_copy.get(option)
- if isinstance(value, Mapping):
- del options_copy[option]
- params = ("'%s': '%s'" % (k, v) for k, v in value.items())
- ret.append("%s = {%s}" % (option, ', '.join(params)))
+ def get_all_keyspaces(self):
+ for keyspace_meta in super(SchemaParserDSE68, self).get_all_keyspaces():
+ self._build_graph_metadata(keyspace_meta)
+ yield keyspace_meta
- for name, value in options_copy.items():
- if value is not None:
- if name == "comment":
- value = value or ""
- ret.append("%s = %s" % (name, protect_value(value)))
+ def get_table(self, keyspaces, keyspace, table):
+ table_meta = super(SchemaParserDSE68, self).get_table(keyspaces, keyspace, table)
+ cl = ConsistencyLevel.ONE
+ where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder)
+ vertices_query = QueryMessage(query=self._SELECT_VERTICES + where_clause, consistency_level=cl)
+ edges_query = QueryMessage(query=self._SELECT_EDGES + where_clause, consistency_level=cl)
- return list(sorted(ret))
+ (vertices_success, vertices_result), (edges_success, edges_result) \
+ = self.connection.wait_for_responses(vertices_query, edges_query, timeout=self.timeout, fail_on_error=False)
+ vertices_result = self._handle_results(vertices_success, vertices_result)
+ edges_result = self._handle_results(edges_success, edges_result)
+
+ try:
+ if vertices_result:
+ table_meta.vertex = self._build_table_vertex_metadata(vertices_result[0])
+ elif edges_result:
+ table_meta.edge = self._build_table_edge_metadata(keyspaces[keyspace], edges_result[0])
+ except Exception:
+ table_meta.vertex = None
+ table_meta.edge = None
+ table_meta._exc_info = sys.exc_info()
+ log.exception("Error while parsing graph metadata for table %s.%s.", keyspace, table)
+
+ return table_meta
+
+ @staticmethod
+ def _build_keyspace_metadata_internal(row):
+ name = row["keyspace_name"]
+ durable_writes = row.get("durable_writes", None)
+ replication = dict(row.get("replication")) if 'replication' in row else {}
+ replication_class = replication.pop("class") if 'class' in replication else None
+ graph_engine = row.get("graph_engine", None)
+ return KeyspaceMetadata(name, durable_writes, replication_class, replication, graph_engine)
+
+ def _build_graph_metadata(self, keyspace_meta):
+
+ def _build_table_graph_metadata(table_meta):
+ for row in self.keyspace_table_vertex_rows[keyspace_meta.name][table_meta.name]:
+ table_meta.vertex = self._build_table_vertex_metadata(row)
+
+ for row in self.keyspace_table_edge_rows[keyspace_meta.name][table_meta.name]:
+ table_meta.edge = self._build_table_edge_metadata(keyspace_meta, row)
+
+ try:
+ # Make sure we process vertices before edges
+ for table_meta in [t for t in keyspace_meta.tables.values()
+ if t.name in self.keyspace_table_vertex_rows[keyspace_meta.name]]:
+ _build_table_graph_metadata(table_meta)
+
+ # all other tables...
+ for table_meta in [t for t in keyspace_meta.tables.values()
+ if t.name not in self.keyspace_table_vertex_rows[keyspace_meta.name]]:
+ _build_table_graph_metadata(table_meta)
+ except Exception:
+ # schema error, remove all graph metadata for this keyspace
+ for t in keyspace_meta.tables.values():
+ t.edge = t.vertex = None
+ keyspace_meta._exc_info = sys.exc_info()
+ log.exception("Error while parsing graph metadata for keyspace %s", keyspace_meta.name)
+
+ @staticmethod
+ def _build_table_vertex_metadata(row):
+ return VertexMetadata(row.get("keyspace_name"), row.get("table_name"),
+ row.get("label_name"))
+
+ @staticmethod
+ def _build_table_edge_metadata(keyspace_meta, row):
+ from_table = row.get("from_table")
+ from_table_meta = keyspace_meta.tables.get(from_table)
+ from_label = from_table_meta.vertex.label_name
+ to_table = row.get("to_table")
+ to_table_meta = keyspace_meta.tables.get(to_table)
+ to_label = to_table_meta.vertex.label_name
+
+ return EdgeMetadata(
+ row.get("keyspace_name"), row.get("table_name"),
+ row.get("label_name"), from_table, from_label,
+ row.get("from_partition_key_columns"),
+ row.get("from_clustering_columns"), to_table, to_label,
+ row.get("to_partition_key_columns"),
+ row.get("to_clustering_columns"))
+
+ def _query_all(self):
+ cl = ConsistencyLevel.ONE
+ queries = [
+ # copied from v4
+ QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_TABLES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl),
+ QueryMessage(query=self._SELECT_TYPES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl),
+ QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl),
+ QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl),
+ QueryMessage(query=self._SELECT_VIRTUAL_KEYSPACES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_VIRTUAL_TABLES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_VIRTUAL_COLUMNS, consistency_level=cl),
+ # dse6.8 only
+ QueryMessage(query=self._SELECT_VERTICES, consistency_level=cl),
+ QueryMessage(query=self._SELECT_EDGES, consistency_level=cl)
+ ]
+
+ responses = self.connection.wait_for_responses(
+ *queries, timeout=self.timeout, fail_on_error=False)
+ (
+ # copied from V4
+ (ks_success, ks_result),
+ (table_success, table_result),
+ (col_success, col_result),
+ (types_success, types_result),
+ (functions_success, functions_result),
+ (aggregates_success, aggregates_result),
+ (triggers_success, triggers_result),
+ (indexes_success, indexes_result),
+ (views_success, views_result),
+ (virtual_ks_success, virtual_ks_result),
+ (virtual_table_success, virtual_table_result),
+ (virtual_column_success, virtual_column_result),
+ # dse6.8 responses
+ (vertices_success, vertices_result),
+ (edges_success, edges_result)
+ ) = responses
+
+ # copied from V4
+ self.keyspaces_result = self._handle_results(ks_success, ks_result)
+ self.tables_result = self._handle_results(table_success, table_result)
+ self.columns_result = self._handle_results(col_success, col_result)
+ self.triggers_result = self._handle_results(triggers_success, triggers_result)
+ self.types_result = self._handle_results(types_success, types_result)
+ self.functions_result = self._handle_results(functions_success, functions_result)
+ self.aggregates_result = self._handle_results(aggregates_success, aggregates_result)
+ self.indexes_result = self._handle_results(indexes_success, indexes_result)
+ self.views_result = self._handle_results(views_success, views_result)
+
+ # These tables don't exist in some DSE versions reporting 4.X so we can
+ # ignore them if we got an error
+ self.virtual_keyspaces_result = self._handle_results(
+ virtual_ks_success, virtual_ks_result,
+ expected_failures=(InvalidRequest,)
+ )
+ self.virtual_tables_result = self._handle_results(
+ virtual_table_success, virtual_table_result,
+ expected_failures=(InvalidRequest,)
+ )
+ self.virtual_columns_result = self._handle_results(
+ virtual_column_success, virtual_column_result,
+ expected_failures=(InvalidRequest,)
+ )
+
+ # dse6.8-only results
+ self.vertices_result = self._handle_results(vertices_success, vertices_result)
+ self.edges_result = self._handle_results(edges_success, edges_result)
+
+ self._aggregate_results()
+
+ def _aggregate_results(self):
+ super(SchemaParserDSE68, self)._aggregate_results()
+
+ m = self.keyspace_table_vertex_rows
+ for row in self.vertices_result:
+ ksname = row["keyspace_name"]
+ cfname = row['table_name']
+ m[ksname][cfname].append(row)
+
+ m = self.keyspace_table_edge_rows
+ for row in self.edges_result:
+ ksname = row["keyspace_name"]
+ cfname = row['table_name']
+ m[ksname][cfname].append(row)
class MaterializedViewMetadata(object):
@@ -2765,8 +3099,7 @@ class MaterializedViewMetadata(object):
"""
keyspace_name = None
-
- """ A string name of the view."""
+ """ A string name of the keyspace of this view."""
name = None
""" A string name of the view."""
@@ -2857,7 +3190,7 @@ def as_cql_query(self, formatted=False):
if self.extensions:
registry = _RegisteredExtensionType._extension_registry
- for k in six.viewkeys(registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey
+ for k in registry.keys() & self.extensions: # no viewkeys on OrderedMapSerializeKey
ext = registry[k]
cql = ext.after_table_cql(self, k, self.extensions[k])
if cql:
@@ -2868,11 +3201,89 @@ def export_as_string(self):
return self.as_cql_query(formatted=True) + ";"
+class VertexMetadata(object):
+ """
+ A representation of a vertex on a table
+ """
+
+ keyspace_name = None
+ """ A string name of the keyspace. """
+
+ table_name = None
+ """ A string name of the table this vertex is on. """
+
+ label_name = None
+ """ A string name of the label of this vertex."""
+
+ def __init__(self, keyspace_name, table_name, label_name):
+ self.keyspace_name = keyspace_name
+ self.table_name = table_name
+ self.label_name = label_name
+
+
+class EdgeMetadata(object):
+ """
+ A representation of an edge on a table
+ """
+
+ keyspace_name = None
+ """A string name of the keyspace """
+
+ table_name = None
+ """A string name of the table this edge is on"""
+
+ label_name = None
+ """A string name of the label of this edge"""
+
+ from_table = None
+ """A string name of the from table of this edge (incoming vertex)"""
+
+ from_label = None
+ """A string name of the from table label of this edge (incoming vertex)"""
+
+ from_partition_key_columns = None
+ """The columns that match the partition key of the incoming vertex table."""
+
+ from_clustering_columns = None
+ """The columns that match the clustering columns of the incoming vertex table."""
+
+ to_table = None
+ """A string name of the to table of this edge (outgoing vertex)"""
+
+ to_label = None
+ """A string name of the to table label of this edge (outgoing vertex)"""
+
+ to_partition_key_columns = None
+ """The columns that match the partition key of the outgoing vertex table."""
+
+ to_clustering_columns = None
+ """The columns that match the clustering columns of the outgoing vertex table."""
+
+ def __init__(
+ self, keyspace_name, table_name, label_name, from_table,
+ from_label, from_partition_key_columns, from_clustering_columns,
+ to_table, to_label, to_partition_key_columns,
+ to_clustering_columns):
+ self.keyspace_name = keyspace_name
+ self.table_name = table_name
+ self.label_name = label_name
+ self.from_table = from_table
+ self.from_label = from_label
+ self.from_partition_key_columns = from_partition_key_columns
+ self.from_clustering_columns = from_clustering_columns
+ self.to_table = to_table
+ self.to_label = to_label
+ self.to_partition_key_columns = to_partition_key_columns
+ self.to_clustering_columns = to_clustering_columns
+
+
def get_schema_parser(connection, server_version, dse_version, timeout):
version = Version(server_version)
if dse_version:
v = Version(dse_version)
- if v >= Version('6.7.0'):
+ if v >= Version('6.8.0'):
+ return SchemaParserDSE68(connection, timeout)
+ elif v >= Version('6.7.0'):
return SchemaParserDSE67(connection, timeout)
elif v >= Version('6.0.0'):
return SchemaParserDSE60(connection, timeout)
@@ -2954,3 +3365,48 @@ def group_keys_by_replica(session, keyspace, table, keys):
return dict(keys_per_host)
+
+# TODO next major reorg
+class _NodeInfo(object):
+ """
+ Internal utility functions to determine the different host addresses/ports
+ from a local or peers row.
+ """
+
+ @staticmethod
+ def get_broadcast_rpc_address(row):
+ # TODO next major, change the parsing logic to avoid any
+ # overriding of a non-null value
+ addr = row.get("rpc_address")
+ if "native_address" in row:
+ addr = row.get("native_address")
+ if "native_transport_address" in row:
+ addr = row.get("native_transport_address")
+ if not addr or addr in ["0.0.0.0", "::"]:
+ addr = row.get("peer")
+
+ return addr
+
+ @staticmethod
+ def get_broadcast_rpc_port(row):
+ port = row.get("rpc_port")
+ if port is None or port == 0:
+ port = row.get("native_port")
+
+ return port if port and port > 0 else None
+
+ @staticmethod
+ def get_broadcast_address(row):
+ addr = row.get("broadcast_address")
+ if addr is None:
+ addr = row.get("peer")
+
+ return addr
+
+ @staticmethod
+ def get_broadcast_port(row):
+ port = row.get("broadcast_port")
+ if port is None or port == 0:
+ port = row.get("peer_port")
+
+ return port if port and port > 0 else None
diff --git a/cassandra/murmur3.py b/cassandra/murmur3.py
index 7c8d641b32..282c43578d 100644
--- a/cassandra/murmur3.py
+++ b/cassandra/murmur3.py
@@ -1,4 +1,3 @@
-from six.moves import range
import struct
diff --git a/cassandra/obj_parser.pyx b/cassandra/obj_parser.pyx
index a0b5316a33..cf43771dd7 100644
--- a/cassandra/obj_parser.pyx
+++ b/cassandra/obj_parser.pyx
@@ -17,9 +17,12 @@ include "ioutils.pyx"
from cassandra import DriverException
from cassandra.bytesio cimport BytesIOReader
from cassandra.deserializers cimport Deserializer, from_binary
+from cassandra.deserializers import find_deserializer
from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser
from cassandra.tuple cimport tuple_new, tuple_set
+from cpython.bytes cimport PyBytes_AsStringAndSize
+
cdef class ListParser(ColumnParser):
"""Decode a ResultMessage into a list of tuples (or other objects)"""
@@ -58,18 +61,29 @@ cdef class TupleRowParser(RowParser):
assert desc.rowsize >= 0
cdef Buffer buf
+ cdef Buffer newbuf
cdef Py_ssize_t i, rowsize = desc.rowsize
cdef Deserializer deserializer
cdef tuple res = tuple_new(desc.rowsize)
+ ce_policy = desc.column_encryption_policy
for i in range(rowsize):
# Read the next few bytes
get_buf(reader, &buf)
# Deserialize bytes to python object
deserializer = desc.deserializers[i]
+ coldesc = desc.coldescs[i]
+ uses_ce = ce_policy and ce_policy.contains_column(coldesc)
try:
- val = from_binary(deserializer, &buf, desc.protocol_version)
+ if uses_ce:
+ col_type = ce_policy.column_type(coldesc)
+ decrypted_bytes = ce_policy.decrypt(coldesc, to_bytes(&buf))
+ PyBytes_AsStringAndSize(decrypted_bytes, &newbuf.ptr, &newbuf.size)
+ deserializer = find_deserializer(ce_policy.column_type(coldesc))
+ val = from_binary(deserializer, &newbuf, desc.protocol_version)
+ else:
+ val = from_binary(deserializer, &buf, desc.protocol_version)
except Exception as e:
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i],
desc.coltypes[i].cql_parameterized_type(),
diff --git a/cassandra/parsing.pxd b/cassandra/parsing.pxd
index aa9478cd14..27dc368b07 100644
--- a/cassandra/parsing.pxd
+++ b/cassandra/parsing.pxd
@@ -18,6 +18,8 @@ from cassandra.deserializers cimport Deserializer
cdef class ParseDesc:
cdef public object colnames
cdef public object coltypes
+ cdef public object column_encryption_policy
+ cdef public list coldescs
cdef Deserializer[::1] deserializers
cdef public int protocol_version
cdef Py_ssize_t rowsize
diff --git a/cassandra/parsing.pyx b/cassandra/parsing.pyx
index d2bc0a3abe..954767d227 100644
--- a/cassandra/parsing.pyx
+++ b/cassandra/parsing.pyx
@@ -19,9 +19,11 @@ Module containing the definitions and declarations (parsing.pxd) for parsers.
cdef class ParseDesc:
"""Description of what structure to parse"""
- def __init__(self, colnames, coltypes, deserializers, protocol_version):
+ def __init__(self, colnames, coltypes, column_encryption_policy, coldescs, deserializers, protocol_version):
self.colnames = colnames
self.coltypes = coltypes
+ self.column_encryption_policy = column_encryption_policy
+ self.coldescs = coldescs
self.deserializers = deserializers
self.protocol_version = protocol_version
self.rowsize = len(colnames)
diff --git a/cassandra/policies.py b/cassandra/policies.py
index fa1e8cf385..c60e558465 100644
--- a/cassandra/policies.py
+++ b/cassandra/policies.py
@@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections import namedtuple
+from functools import lru_cache
from itertools import islice, cycle, groupby, repeat
import logging
from random import randint, shuffle
from threading import Lock
import socket
import warnings
-from cassandra import WriteType as WT
+log = logging.getLogger(__name__)
+
+from cassandra import WriteType as WT
# This is done this way because WriteType was originally
# defined here and in order not to break the API.
# It may removed in the next mayor.
WriteType = WT
-
from cassandra import ConsistencyLevel, OperationTimedOut
-log = logging.getLogger(__name__)
-
-
class HostDistance(object):
"""
A measure of how "distant" a node is from the client, which
@@ -455,7 +455,7 @@ class HostFilterPolicy(LoadBalancingPolicy):
A :class:`.LoadBalancingPolicy` subclass configured with a child policy,
and a single-argument predicate. This policy defers to the child policy for
hosts where ``predicate(host)`` is truthy. Hosts for which
- ``predicate(host)`` is falsey will be considered :attr:`.IGNORED`, and will
+ ``predicate(host)`` is falsy will be considered :attr:`.IGNORED`, and will
not be used in a query plan.
This can be used in the cases where you need a whitelist or blacklist
@@ -491,7 +491,7 @@ def __init__(self, child_policy, predicate):
:param child_policy: an instantiated :class:`.LoadBalancingPolicy`
that this one will defer to.
:param predicate: a one-parameter function that takes a :class:`.Host`.
- If it returns a falsey value, the :class:`.Host` will
+ If it returns a falsy value, the :class:`.Host` will
be :attr:`.IGNORED` and not returned in query plans.
"""
super(HostFilterPolicy, self).__init__()
@@ -527,7 +527,7 @@ def predicate(self):
def distance(self, host):
"""
Checks if ``predicate(host)``, then returns
- :attr:`~HostDistance.IGNORED` if falsey, and defers to the child policy
+ :attr:`~HostDistance.IGNORED` if falsy, and defers to the child policy
otherwise.
"""
if self.predicate(host):
@@ -616,7 +616,7 @@ class ReconnectionPolicy(object):
def new_schedule(self):
"""
This should return a finite or infinite iterable of delays (each as a
- floating point number of seconds) inbetween each failed reconnection
+ floating point number of seconds) in-between each failed reconnection
attempt. Note that if the iterable is finite, reconnection attempts
will cease once the iterable is exhausted.
"""
@@ -626,12 +626,12 @@ def new_schedule(self):
class ConstantReconnectionPolicy(ReconnectionPolicy):
"""
A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay
- inbetween each reconnection attempt.
+ in-between each reconnection attempt.
"""
def __init__(self, delay, max_attempts=64):
"""
- `delay` should be a floating point number of seconds to wait inbetween
+ `delay` should be a floating point number of seconds to wait in-between
each attempt.
`max_attempts` should be a total number of attempts to be made before
@@ -655,7 +655,7 @@ def new_schedule(self):
class ExponentialReconnectionPolicy(ReconnectionPolicy):
"""
A :class:`.ReconnectionPolicy` subclass which exponentially increases
- the length of the delay inbetween each reconnection attempt up to
+ the length of the delay in-between each reconnection attempt up to
a set maximum delay.
A random amount of jitter (+/- 15%) will be added to the pure exponential
@@ -715,7 +715,7 @@ class RetryPolicy(object):
timeout and unavailable failures. These are failures reported from the
server side. Timeouts are configured by
`settings in cassandra.yaml `_.
- Unavailable failures occur when the coordinator cannot acheive the consistency
+ Unavailable failures occur when the coordinator cannot achieve the consistency
level for a request. For further information see the method descriptions
below.
@@ -865,7 +865,7 @@ def on_request_error(self, query, consistency, error, retry_num):
`retry_num` counts how many times the operation has been retried, so
the first time this method is called, `retry_num` will be 0.
- The default, it triggers a retry on the next host in the query plan
+ By default, it triggers a retry on the next host in the query plan
with the same consistency level.
"""
# TODO revisit this for the next major
@@ -1181,3 +1181,62 @@ def _rethrow(self, *args, **kwargs):
on_read_timeout = _rethrow
on_write_timeout = _rethrow
on_unavailable = _rethrow
+
+
+ColDesc = namedtuple('ColDesc', ['ks', 'table', 'col'])
+
+class ColumnEncryptionPolicy(object):
+ """
+ A policy enabling (mostly) transparent encryption and decryption of data before it is
+ sent to the cluster.
+
+ Key materials and other configurations are specified on a per-column basis. This policy can
+ then be used by driver structures which are aware of the underlying columns involved in their
+ work. In practice this includes the following cases:
+
+ * Prepared statements - data for columns specified by the cluster's policy will be transparently
+ encrypted before they are sent
+ * Rows returned from any query - data for columns specified by the cluster's policy will be
+ transparently decrypted before they are returned to the user
+
+ To enable this functionality, create an instance of this class (or more likely a subclass)
+ before creating a cluster. This policy should then be configured and supplied to the Cluster
+ at creation time via the :attr:`.Cluster.column_encryption_policy` attribute.
+ """
+
+ def encrypt(self, coldesc, obj_bytes):
+ """
+ Encrypt the specified bytes using the cryptography materials for the specified column.
+ Largely used internally, although this could also be used to encrypt values supplied
+ to non-prepared statements in a way that is consistent with this policy.
+ """
+ raise NotImplementedError()
+
+ def decrypt(self, coldesc, encrypted_bytes):
+ """
+ Decrypt the specified (encrypted) bytes using the cryptography materials for the
+ specified column. Used internally; could be used externally as well but there's
+ not currently an obvious use case.
+ """
+ raise NotImplementedError()
+
+ def add_column(self, coldesc, key):
+ """
+ Provide cryptography materials to be used when encrypted and/or decrypting data
+ for the specified column.
+ """
+ raise NotImplementedError()
+
+ def contains_column(self, coldesc):
+ """
+ Predicate to determine if a specific column is supported by this policy.
+ Currently only used internally.
+ """
+ raise NotImplementedError()
+
+ def encode_and_encrypt(self, coldesc, obj):
+ """
+ Helper function to enable use of this policy on simple (i.e. non-prepared)
+ statements.
+ """
+ raise NotImplementedError()
diff --git a/cassandra/pool.py b/cassandra/pool.py
index a4429aeed6..d61e81cd0d 100644
--- a/cassandra/pool.py
+++ b/cassandra/pool.py
@@ -55,21 +55,60 @@ class Host(object):
broadcast_address = None
"""
- broadcast address configured for the node, *if available* ('peer' in system.peers table).
- This is not present in the ``system.local`` table for older versions of Cassandra. It is also not queried if
- :attr:`~.Cluster.token_metadata_enabled` is ``False``.
+ broadcast address configured for the node, *if available*:
+
+ 'system.local.broadcast_address' or 'system.peers.peer' (Cassandra 2-3)
+ 'system.local.broadcast_address' or 'system.peers_v2.peer' (Cassandra 4)
+
+ This is not present in the ``system.local`` table for older versions of Cassandra. It
+ is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
+ """
+
+ broadcast_port = None
+ """
+ broadcast port configured for the node, *if available*:
+
+ 'system.local.broadcast_port' or 'system.peers_v2.peer_port' (Cassandra 4)
+
+ It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
"""
broadcast_rpc_address = None
"""
- The broadcast rpc address of the node (`native_address` or `rpc_address`).
+ The broadcast rpc address of the node:
+
+ 'system.local.rpc_address' or 'system.peers.rpc_address' (Cassandra 3)
+ 'system.local.rpc_address' or 'system.peers.native_transport_address (DSE 6+)'
+ 'system.local.rpc_address' or 'system.peers_v2.native_address (Cassandra 4)'
+ """
+
+ broadcast_rpc_port = None
+ """
+ The broadcast rpc port of the node, *if available*:
+
+ 'system.local.rpc_port' or 'system.peers.native_transport_port' (DSE 6+)
+ 'system.local.rpc_port' or 'system.peers_v2.native_port' (Cassandra 4)
"""
listen_address = None
"""
- listen address configured for the node, *if available*. This is only available in the ``system.local`` table for newer
- versions of Cassandra. It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
- Usually the same as ``broadcast_address`` unless configured differently in cassandra.yaml.
+ listen address configured for the node, *if available*:
+
+ 'system.local.listen_address'
+
+ This is only available in the ``system.local`` table for newer versions of Cassandra. It is also not
+ queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. Usually the same as ``broadcast_address``
+ unless configured differently in cassandra.yaml.
+ """
+
+ listen_port = None
+ """
+ listen port configured for the node, *if available*:
+
+ 'system.local.listen_port'
+
+ This is only available in the ``system.local`` table for newer versions of Cassandra. It is also not
+ queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``.
"""
conviction_policy = None
@@ -351,6 +390,10 @@ def __init__(self, host, host_distance, session):
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
self._stream_available_condition = Condition(self._lock)
self._is_replacing = False
+ # Contains connections which shouldn't be used anymore
+ # and are waiting until all requests time out or complete
+ # so that we can dispose of them.
+ self._trash = set()
if host_distance == HostDistance.IGNORED:
log.debug("Not opening connection to ignored host %s", self.host)
@@ -360,13 +403,13 @@ def __init__(self, host, host_distance, session):
return
log.debug("Initializing connection for host %s", self.host)
- self._connection = session.cluster.connection_factory(host.endpoint)
+ self._connection = session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
self._keyspace = session.keyspace
if self._keyspace:
self._connection.set_keyspace_blocking(self._keyspace)
log.debug("Finished initializing connection for host %s", self.host)
- def borrow_connection(self, timeout):
+ def _get_connection(self):
if self.is_shutdown:
raise ConnectionException(
"Pool for %s is shutdown" % (self.host,), self.host)
@@ -374,12 +417,25 @@ def borrow_connection(self, timeout):
conn = self._connection
if not conn:
raise NoConnectionsAvailable()
+ return conn
+
+ def borrow_connection(self, timeout):
+ conn = self._get_connection()
+ if conn.orphaned_threshold_reached:
+ with self._lock:
+ if not self._is_replacing:
+ self._is_replacing = True
+ self._session.submit(self._replace, conn)
+ log.debug(
+ "Connection to host %s reached orphaned stream limit, replacing...",
+ self.host
+ )
start = time.time()
remaining = timeout
while True:
with conn.lock:
- if conn.in_flight <= conn.max_request_id:
+ if not (conn.orphaned_threshold_reached and conn.is_closed) and conn.in_flight < conn.max_request_id:
conn.in_flight += 1
return conn, conn.get_request_id()
if timeout is not None:
@@ -387,15 +443,19 @@ def borrow_connection(self, timeout):
if remaining < 0:
break
with self._stream_available_condition:
- self._stream_available_condition.wait(remaining)
+ if conn.orphaned_threshold_reached and conn.is_closed:
+ conn = self._get_connection()
+ else:
+ self._stream_available_condition.wait(remaining)
raise NoConnectionsAvailable("All request IDs are currently in use")
- def return_connection(self, connection):
- with connection.lock:
- connection.in_flight -= 1
- with self._stream_available_condition:
- self._stream_available_condition.notify()
+ def return_connection(self, connection, stream_was_orphaned=False):
+ if not stream_was_orphaned:
+ with connection.lock:
+ connection.in_flight -= 1
+ with self._stream_available_condition:
+ self._stream_available_condition.notify()
if connection.is_defunct or connection.is_closed:
if connection.signaled_error and not self.shutdown_on_error:
@@ -422,6 +482,24 @@ def return_connection(self, connection):
return
self._is_replacing = True
self._session.submit(self._replace, connection)
+ else:
+ if connection in self._trash:
+ with connection.lock:
+ if connection.in_flight == len(connection.orphaned_request_ids):
+ with self._lock:
+ if connection in self._trash:
+ self._trash.remove(connection)
+ log.debug("Closing trashed connection (%s) to %s", id(connection), self.host)
+ connection.close()
+ return
+
+ def on_orphaned_stream_released(self):
+ """
+ Called when a response for an orphaned stream (timed out on the client
+ side) was received.
+ """
+ with self._stream_available_condition:
+ self._stream_available_condition.notify()
def _replace(self, connection):
with self._lock:
@@ -430,7 +508,7 @@ def _replace(self, connection):
log.debug("Replacing connection (%s) to %s", id(connection), self.host)
try:
- conn = self._session.cluster.connection_factory(self.host.endpoint)
+ conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
conn.set_keyspace_blocking(self._keyspace)
self._connection = conn
@@ -438,9 +516,15 @@ def _replace(self, connection):
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
self._session.submit(self._replace, connection)
else:
- with self._lock:
- self._is_replacing = False
- self._stream_available_condition.notify()
+ with connection.lock:
+ with self._lock:
+ if connection.orphaned_threshold_reached:
+ if connection.in_flight == len(connection.orphaned_request_ids):
+ connection.close()
+ else:
+ self._trash.add(connection)
+ self._is_replacing = False
+ self._stream_available_condition.notify()
def shutdown(self):
with self._lock:
@@ -454,6 +538,16 @@ def shutdown(self):
self._connection.close()
self._connection = None
+ trash_conns = None
+ with self._lock:
+ if self._trash:
+ trash_conns = self._trash
+ self._trash = set()
+
+ if trash_conns is not None:
+ for conn in self._trash:
+ conn.close()
+
def _set_keyspace_for_all_conns(self, keyspace, callback):
if self.is_shutdown or not self._connection:
return
@@ -474,7 +568,9 @@ def get_state(self):
connection = self._connection
open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0
in_flights = [connection.in_flight] if connection else []
- return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights}
+ orphan_requests = [connection.orphaned_request_ids] if connection else []
+ return {'shutdown': self.is_shutdown, 'open_count': open_count, \
+ 'in_flights': in_flights, 'orphan_requests': orphan_requests}
@property
def open_count(self):
@@ -509,7 +605,7 @@ def __init__(self, host, host_distance, session):
log.debug("Initializing new connection pool for host %s", self.host)
core_conns = session.cluster.get_core_connections_per_host(host_distance)
- self._connections = [session.cluster.connection_factory(host.endpoint)
+ self._connections = [session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
for i in range(core_conns)]
self._keyspace = session.keyspace
@@ -613,7 +709,7 @@ def _add_conn_if_under_max(self):
log.debug("Going to open new connection to host %s", self.host)
try:
- conn = self._session.cluster.connection_factory(self.host.endpoint)
+ conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
conn.set_keyspace_blocking(self._session.keyspace)
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
@@ -673,9 +769,10 @@ def _wait_for_conn(self, timeout):
raise NoConnectionsAvailable()
- def return_connection(self, connection):
+ def return_connection(self, connection, stream_was_orphaned=False):
with connection.lock:
- connection.in_flight -= 1
+ if not stream_was_orphaned:
+ connection.in_flight -= 1
in_flight = connection.in_flight
if connection.is_defunct or connection.is_closed:
@@ -711,6 +808,13 @@ def return_connection(self, connection):
else:
self._signal_available_conn()
+ def on_orphaned_stream_released(self):
+ """
+ Called when a response for an orphaned stream (timed out on the client
+ side) was received.
+ """
+ self._signal_available_conn()
+
def _maybe_trash_connection(self, connection):
core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
did_trash = False
@@ -824,4 +928,6 @@ def get_connections(self):
def get_state(self):
in_flights = [c.in_flight for c in self._connections]
- return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights}
+ orphan_requests = [c.orphaned_request_ids for c in self._connections]
+ return {'shutdown': self.is_shutdown, 'open_count': self.open_count, \
+ 'in_flights': in_flights, 'orphan_requests': orphan_requests}
diff --git a/cassandra/protocol.py b/cassandra/protocol.py
index eac9ebb8b5..3e4e984410 100644
--- a/cassandra/protocol.py
+++ b/cassandra/protocol.py
@@ -18,8 +18,6 @@
import socket
from uuid import UUID
-import six
-from six.moves import range
import io
from cassandra import ProtocolVersion
@@ -29,9 +27,6 @@
AlreadyExists, InvalidRequest, Unauthorized,
UnsupportedOperation, UserFunctionDescriptor,
UserAggregateDescriptor, SchemaTargetType)
-from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
- uint8_pack, int8_unpack, uint64_pack, header_pack,
- v3_header_pack, uint32_pack)
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
CounterColumnType, DateType, DecimalType,
DoubleType, FloatType, Int32Type,
@@ -40,6 +35,10 @@
UTF8Type, VarcharType, UUIDType, UserType,
TupleType, lookup_casstype, SimpleDateType,
TimeType, ByteType, ShortType, DurationType)
+from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
+ uint8_pack, int8_unpack, uint64_pack, header_pack,
+ v3_header_pack, uint32_pack, uint32_le_unpack, uint32_le_pack)
+from cassandra.policies import ColDesc
from cassandra import WriteType
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
from cassandra import util
@@ -85,8 +84,7 @@ def __init__(cls, name, bases, dct):
register_class(cls)
-@six.add_metaclass(_RegisterMessageType)
-class _MessageType(object):
+class _MessageType(object, metaclass=_RegisterMessageType):
tracing = False
custom_payload = None
@@ -136,8 +134,6 @@ def recv_body(cls, f, protocol_version, *args):
def summary_msg(self):
msg = 'Error from server: code=%04x [%s] message="%s"' \
% (self.code, self.summary, self.message)
- if six.PY2 and isinstance(msg, six.text_type):
- msg = msg.encode('utf-8')
return msg
def __str__(self):
@@ -158,8 +154,7 @@ def __init__(cls, name, bases, dct):
error_classes[cls.error_code] = cls
-@six.add_metaclass(ErrorMessageSubclass)
-class ErrorMessageSub(ErrorMessage):
+class ErrorMessageSub(ErrorMessage, metaclass=ErrorMessageSubclass):
error_code = None
@@ -180,6 +175,10 @@ class ProtocolException(ErrorMessageSub):
summary = 'Protocol error'
error_code = 0x000A
+ @property
+ def is_beta_protocol_error(self):
+ return 'USE_BETA flag is unset' in str(self)
+
class BadCredentials(ErrorMessageSub):
summary = 'Bad credentials'
@@ -719,11 +718,11 @@ class ResultMessage(_MessageType):
def __init__(self, kind):
self.kind = kind
- def recv(self, f, protocol_version, user_type_map, result_metadata):
+ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
if self.kind == RESULT_KIND_VOID:
return
elif self.kind == RESULT_KIND_ROWS:
- self.recv_results_rows(f, protocol_version, user_type_map, result_metadata)
+ self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
elif self.kind == RESULT_KIND_SET_KEYSPACE:
self.new_keyspace = read_string(f)
elif self.kind == RESULT_KIND_PREPARED:
@@ -734,32 +733,40 @@ def recv(self, f, protocol_version, user_type_map, result_metadata):
raise DriverException("Unknown RESULT kind: %d" % self.kind)
@classmethod
- def recv_body(cls, f, protocol_version, user_type_map, result_metadata):
+ def recv_body(cls, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
kind = read_int(f)
msg = cls(kind)
- msg.recv(f, protocol_version, user_type_map, result_metadata)
+ msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
return msg
- def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata):
+ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
self.recv_results_metadata(f, user_type_map)
column_metadata = self.column_metadata or result_metadata
rowcount = read_int(f)
rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
self.column_names = [c[2] for c in column_metadata]
self.column_types = [c[3] for c in column_metadata]
+ col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata]
+
+ def decode_val(val, col_md, col_desc):
+ uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc)
+ col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3]
+ raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val
+ return col_type.from_binary(raw_bytes, protocol_version)
+
+ def decode_row(row):
+ return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs))
+
try:
- self.parsed_rows = [
- tuple(ctype.from_binary(val, protocol_version)
- for ctype, val in zip(self.column_types, row))
- for row in rows]
+ self.parsed_rows = [decode_row(row) for row in rows]
except Exception:
for row in rows:
- for i in range(len(row)):
+ for val, col_md, col_desc in zip(row, column_metadata, col_descs):
try:
- self.column_types[i].from_binary(row[i], protocol_version)
+ decode_val(val, col_md, col_desc)
except Exception as e:
- raise DriverException('Failed decoding result column "%s" of type %s: %s' % (self.column_names[i],
- self.column_types[i].cql_parameterized_type(),
+ raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2],
+ col_md[3].cql_parameterized_type(),
str(e)))
def recv_results_prepared(self, f, protocol_version, user_type_map):
@@ -1095,6 +1102,9 @@ class _ProtocolHandler(object):
result decoding implementations.
"""
+ column_encryption_policy = None
+ """Instance of :class:`cassandra.policies.ColumnEncryptionPolicy` in use by this handler"""
+
@classmethod
def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version):
"""
@@ -1115,7 +1125,9 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta
msg.send_body(body, protocol_version)
body = body.getvalue()
- if compressor and len(body) > 0:
+ # With checksumming, the compression is done at the segment frame encoding
+ if (not ProtocolVersion.has_checksumming_support(protocol_version)
+ and compressor and len(body) > 0):
body = compressor(body)
flags |= COMPRESSED_FLAG
@@ -1155,7 +1167,8 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod
:param decompressor: optional decompression function to inflate the body
:return: a message decoded from the body and frame attributes
"""
- if flags & COMPRESSED_FLAG:
+ if (not ProtocolVersion.has_checksumming_support(protocol_version) and
+ flags & COMPRESSED_FLAG):
if decompressor is None:
raise RuntimeError("No de-compressor available for compressed frame!")
body = decompressor(body)
@@ -1186,7 +1199,7 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
msg_class = cls.message_types_by_opcode[opcode]
- msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata)
+ msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata, cls.column_encryption_policy)
msg.stream_id = stream_id
msg.trace_id = trace_id
msg.custom_payload = custom_payload
@@ -1271,6 +1284,33 @@ def read_int(f):
return int32_unpack(f.read(4))
+def read_uint_le(f, size=4):
+ """
+ Read a sequence of little endian bytes and return an unsigned integer.
+ """
+
+ if size == 4:
+ value = uint32_le_unpack(f.read(4))
+ else:
+ value = 0
+ for i in range(size):
+ value |= (read_byte(f) & 0xFF) << 8 * i
+
+ return value
+
+
+def write_uint_le(f, i, size=4):
+ """
+ Write an unsigned integer on a sequence of little endian bytes.
+ """
+ if size == 4:
+ f.write(uint32_le_pack(i))
+ else:
+ for j in range(size):
+ shift = j * 8
+ write_byte(f, i >> shift & 0xFF)
+
+
def write_int(f, i):
f.write(int32_pack(i))
@@ -1312,7 +1352,7 @@ def read_binary_string(f):
def write_string(f, s):
- if isinstance(s, six.text_type):
+ if isinstance(s, str):
s = s.encode('utf8')
write_short(f, len(s))
f.write(s)
@@ -1329,7 +1369,7 @@ def read_longstring(f):
def write_longstring(f, s):
- if isinstance(s, six.text_type):
+ if isinstance(s, str):
s = s.encode('utf8')
write_int(f, len(s))
f.write(s)
diff --git a/cassandra/query.py b/cassandra/query.py
index 0e7a41dc2d..e656124403 100644
--- a/cassandra/query.py
+++ b/cassandra/query.py
@@ -23,14 +23,13 @@
import re
import struct
import time
-import six
-from six.moves import range, zip
import warnings
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.util import unix_time_from_uuid1
from cassandra.encoder import Encoder
import cassandra.encoder
+from cassandra.policies import ColDesc
from cassandra.protocol import _UNSET_VALUE
from cassandra.util import OrderedDict, _sanitize_identifiers
@@ -76,7 +75,7 @@ def tuple_factory(colnames, rows):
>>> session = cluster.connect('mykeyspace')
>>> session.row_factory = tuple_factory
>>> rows = session.execute("SELECT name, age FROM users LIMIT 1")
- >>> print rows[0]
+ >>> print(rows[0])
('Bob', 42)
.. versionchanged:: 2.0.0
@@ -132,16 +131,16 @@ def named_tuple_factory(colnames, rows):
>>> user = rows[0]
>>> # you can access field by their name:
- >>> print "name: %s, age: %d" % (user.name, user.age)
+ >>> print("name: %s, age: %d" % (user.name, user.age))
name: Bob, age: 42
>>> # or you can access fields by their position (like a tuple)
>>> name, age = user
- >>> print "name: %s, age: %d" % (name, age)
+ >>> print("name: %s, age: %d" % (name, age))
name: Bob, age: 42
>>> name = user[0]
>>> age = user[1]
- >>> print "name: %s, age: %d" % (name, age)
+ >>> print("name: %s, age: %d" % (name, age))
name: Bob, age: 42
.. versionchanged:: 2.0.0
@@ -187,7 +186,7 @@ def dict_factory(colnames, rows):
>>> session = cluster.connect('mykeyspace')
>>> session.row_factory = dict_factory
>>> rows = session.execute("SELECT name, age FROM users LIMIT 1")
- >>> print rows[0]
+ >>> print(rows[0])
{u'age': 42, u'name': u'Bob'}
.. versionchanged:: 2.0.0
@@ -442,12 +441,14 @@ class PreparedStatement(object):
query_string = None
result_metadata = None
result_metadata_id = None
+ column_encryption_policy = None
routing_key_indexes = None
_routing_key_index_set = None
serial_consistency_level = None # TODO never used?
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
- keyspace, protocol_version, result_metadata, result_metadata_id):
+ keyspace, protocol_version, result_metadata, result_metadata_id,
+ column_encryption_policy=None):
self.column_metadata = column_metadata
self.query_id = query_id
self.routing_key_indexes = routing_key_indexes
@@ -456,14 +457,17 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query,
self.protocol_version = protocol_version
self.result_metadata = result_metadata
self.result_metadata_id = result_metadata_id
+ self.column_encryption_policy = column_encryption_policy
self.is_idempotent = False
@classmethod
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id):
+ result_metadata_id, column_encryption_policy=None):
if not column_metadata:
- return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version, result_metadata, result_metadata_id)
+ return PreparedStatement(column_metadata, query_id, None,
+ query, prepared_keyspace, protocol_version, result_metadata,
+ result_metadata_id, column_encryption_policy)
if pk_indexes:
routing_key_indexes = pk_indexes
@@ -489,7 +493,7 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
return PreparedStatement(column_metadata, query_id, routing_key_indexes,
query, prepared_keyspace, protocol_version, result_metadata,
- result_metadata_id)
+ result_metadata_id, column_encryption_policy)
def bind(self, values):
"""
@@ -577,6 +581,7 @@ def bind(self, values):
values = ()
proto_version = self.prepared_statement.protocol_version
col_meta = self.prepared_statement.column_metadata
+ ce_policy = self.prepared_statement.column_encryption_policy
# special case for binding dicts
if isinstance(values, dict):
@@ -623,7 +628,13 @@ def bind(self, values):
raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version)
else:
try:
- self.values.append(col_spec.type.serialize(value, proto_version))
+ col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name)
+ uses_ce = ce_policy and ce_policy.contains_column(col_desc)
+ col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type
+ col_bytes = col_type.serialize(value, proto_version)
+ if uses_ce:
+ col_bytes = ce_policy.encrypt(col_desc, col_bytes)
+ self.values.append(col_bytes)
except (TypeError, struct.error) as exc:
actual_type = type(value)
message = ('Received an argument of invalid type for column "%s". '
@@ -804,7 +815,7 @@ def add(self, statement, parameters=None):
Like with other statements, parameters must be a sequence, even
if there is only one item.
"""
- if isinstance(statement, six.string_types):
+ if isinstance(statement, str):
if parameters:
encoder = Encoder() if self._session is None else self._session.encoder
statement = bind_params(statement, parameters, encoder)
@@ -888,10 +899,8 @@ def __str__(self):
def bind_params(query, params, encoder):
- if six.PY2 and isinstance(query, six.text_type):
- query = query.encode('utf-8')
if isinstance(params, dict):
- return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params))
+ return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in params.items())
else:
return query % tuple(encoder.cql_encode_all_types(v) for v in params)
@@ -996,7 +1005,8 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None):
SimpleStatement(self._SELECT_SESSIONS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait)
# PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries
- is_complete = session_results and session_results[0].duration is not None and session_results[0].started_at is not None
+ session_row = session_results.one() if session_results else None
+ is_complete = session_row is not None and session_row.duration is not None and session_row.started_at is not None
if not session_results or (wait_for_complete and not is_complete):
time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt))
attempt += 1
@@ -1006,7 +1016,6 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None):
else:
log.debug("Fetching parital trace info for trace ID: %s", self.trace_id)
- session_row = session_results[0]
self.request_type = session_row.request
self.duration = timedelta(microseconds=session_row.duration) if is_complete else None
self.started_at = session_row.started_at
diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx
index 3a4b2f4604..88277a4593 100644
--- a/cassandra/row_parser.pyx
+++ b/cassandra/row_parser.pyx
@@ -13,13 +13,14 @@
# limitations under the License.
from cassandra.parsing cimport ParseDesc, ColumnParser
+from cassandra.policies import ColDesc
from cassandra.obj_parser import TupleRowParser
from cassandra.deserializers import make_deserializers
include "ioutils.pyx"
def make_recv_results_rows(ColumnParser colparser):
- def recv_results_rows(self, f, int protocol_version, user_type_map, result_metadata):
+ def recv_results_rows(self, f, int protocol_version, user_type_map, result_metadata, column_encryption_policy):
"""
Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples)
This is used as the recv_results_rows method of (Fast)ResultMessage
@@ -28,11 +29,12 @@ def make_recv_results_rows(ColumnParser colparser):
column_metadata = self.column_metadata or result_metadata
- self.column_names = [c[2] for c in column_metadata]
- self.column_types = [c[3] for c in column_metadata]
+ self.column_names = [md[2] for md in column_metadata]
+ self.column_types = [md[3] for md in column_metadata]
- desc = ParseDesc(self.column_names, self.column_types, make_deserializers(self.column_types),
- protocol_version)
+ desc = ParseDesc(self.column_names, self.column_types, column_encryption_policy,
+ [ColDesc(md[0], md[1], md[2]) for md in column_metadata],
+ make_deserializers(self.column_types), protocol_version)
reader = BytesIOReader(f.read())
try:
self.parsed_rows = colparser.parse_rows(reader, desc)
diff --git a/cassandra/segment.py b/cassandra/segment.py
new file mode 100644
index 0000000000..78161fe520
--- /dev/null
+++ b/cassandra/segment.py
@@ -0,0 +1,220 @@
+# Copyright DataStax, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import zlib
+
+from cassandra import DriverException
+from cassandra.marshal import int32_pack
+from cassandra.protocol import write_uint_le, read_uint_le
+
+CRC24_INIT = 0x875060
+CRC24_POLY = 0x1974F0B
+CRC24_LENGTH = 3
+CRC32_LENGTH = 4
+CRC32_INITIAL = zlib.crc32(b"\xfa\x2d\x55\xca")
+
+
+class CrcException(Exception):
+ """
+ CRC mismatch error.
+
+ TODO: here to avoid import cycles with cassandra.connection. In the next
+ major, the exceptions should be declared in a separated exceptions.py
+ file.
+ """
+ pass
+
+
+def compute_crc24(data, length):
+ crc = CRC24_INIT
+
+ for _ in range(length):
+ crc ^= (data & 0xff) << 16
+ data >>= 8
+
+ for i in range(8):
+ crc <<= 1
+ if crc & 0x1000000 != 0:
+ crc ^= CRC24_POLY
+
+ return crc
+
+
+def compute_crc32(data, value):
+ crc32 = zlib.crc32(data, value)
+ return crc32
+
+
+class SegmentHeader(object):
+
+ payload_length = None
+ uncompressed_payload_length = None
+ is_self_contained = None
+
+ def __init__(self, payload_length, uncompressed_payload_length, is_self_contained):
+ self.payload_length = payload_length
+ self.uncompressed_payload_length = uncompressed_payload_length
+ self.is_self_contained = is_self_contained
+
+ @property
+ def segment_length(self):
+ """
+ Return the total length of the segment, including the CRC.
+ """
+ hl = SegmentCodec.UNCOMPRESSED_HEADER_LENGTH if self.uncompressed_payload_length < 1 \
+ else SegmentCodec.COMPRESSED_HEADER_LENGTH
+ return hl + CRC24_LENGTH + self.payload_length + CRC32_LENGTH
+
+
+class Segment(object):
+
+ MAX_PAYLOAD_LENGTH = 128 * 1024 - 1
+
+ payload = None
+ is_self_contained = None
+
+ def __init__(self, payload, is_self_contained):
+ self.payload = payload
+ self.is_self_contained = is_self_contained
+
+
+class SegmentCodec(object):
+
+ COMPRESSED_HEADER_LENGTH = 5
+ UNCOMPRESSED_HEADER_LENGTH = 3
+ FLAG_OFFSET = 17
+
+ compressor = None
+ decompressor = None
+
+ def __init__(self, compressor=None, decompressor=None):
+ self.compressor = compressor
+ self.decompressor = decompressor
+
+ @property
+ def header_length(self):
+ return self.COMPRESSED_HEADER_LENGTH if self.compression \
+ else self.UNCOMPRESSED_HEADER_LENGTH
+
+ @property
+ def header_length_with_crc(self):
+ return (self.COMPRESSED_HEADER_LENGTH if self.compression
+ else self.UNCOMPRESSED_HEADER_LENGTH) + CRC24_LENGTH
+
+ @property
+ def compression(self):
+ return self.compressor and self.decompressor
+
+ def compress(self, data):
+ # the uncompressed length is already encoded in the header, so
+ # we remove it here
+ return self.compressor(data)[4:]
+
+ def decompress(self, encoded_data, uncompressed_length):
+ return self.decompressor(int32_pack(uncompressed_length) + encoded_data)
+
+ def encode_header(self, buffer, payload_length, uncompressed_length, is_self_contained):
+ if payload_length > Segment.MAX_PAYLOAD_LENGTH:
+ raise DriverException('Payload length exceed Segment.MAX_PAYLOAD_LENGTH')
+
+ header_data = payload_length
+
+ flag_offset = self.FLAG_OFFSET
+ if self.compression:
+ header_data |= uncompressed_length << flag_offset
+ flag_offset += 17
+
+ if is_self_contained:
+ header_data |= 1 << flag_offset
+
+ write_uint_le(buffer, header_data, size=self.header_length)
+ header_crc = compute_crc24(header_data, self.header_length)
+ write_uint_le(buffer, header_crc, size=CRC24_LENGTH)
+
+ def _encode_segment(self, buffer, payload, is_self_contained):
+ """
+ Encode a message to a single segment.
+ """
+ uncompressed_payload = payload
+ uncompressed_payload_length = len(payload)
+
+ if self.compression:
+ compressed_payload = self.compress(uncompressed_payload)
+ if len(compressed_payload) >= uncompressed_payload_length:
+ encoded_payload = uncompressed_payload
+ uncompressed_payload_length = 0
+ else:
+ encoded_payload = compressed_payload
+ else:
+ encoded_payload = uncompressed_payload
+
+ payload_length = len(encoded_payload)
+ self.encode_header(buffer, payload_length, uncompressed_payload_length, is_self_contained)
+ payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
+ buffer.write(encoded_payload)
+ write_uint_le(buffer, payload_crc)
+
+ def encode(self, buffer, msg):
+ """
+ Encode a message to one of more segments.
+ """
+ msg_length = len(msg)
+
+ if msg_length > Segment.MAX_PAYLOAD_LENGTH:
+ payloads = []
+ for i in range(0, msg_length, Segment.MAX_PAYLOAD_LENGTH):
+ payloads.append(msg[i:i + Segment.MAX_PAYLOAD_LENGTH])
+ else:
+ payloads = [msg]
+
+ is_self_contained = len(payloads) == 1
+ for payload in payloads:
+ self._encode_segment(buffer, payload, is_self_contained)
+
+ def decode_header(self, buffer):
+ header_data = read_uint_le(buffer, self.header_length)
+
+ expected_header_crc = read_uint_le(buffer, CRC24_LENGTH)
+ actual_header_crc = compute_crc24(header_data, self.header_length)
+ if actual_header_crc != expected_header_crc:
+ raise CrcException('CRC mismatch on header {:x}. Received {:x}", computed {:x}.'.format(
+ header_data, expected_header_crc, actual_header_crc))
+
+ payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH
+ header_data >>= 17
+
+ if self.compression:
+ uncompressed_payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH
+ header_data >>= 17
+ else:
+ uncompressed_payload_length = -1
+
+ is_self_contained = (header_data & 1) == 1
+
+ return SegmentHeader(payload_length, uncompressed_payload_length, is_self_contained)
+
+ def decode(self, buffer, header):
+ encoded_payload = buffer.read(header.payload_length)
+ expected_payload_crc = read_uint_le(buffer)
+
+ actual_payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
+ if actual_payload_crc != expected_payload_crc:
+ raise CrcException('CRC mismatch on payload. Received {:x}", computed {:x}.'.format(
+ expected_payload_crc, actual_payload_crc))
+
+ payload = encoded_payload
+ if self.compression and header.uncompressed_payload_length > 0:
+ payload = self.decompress(encoded_payload, header.uncompressed_payload_length)
+
+ return Segment(payload, header.is_self_contained)
diff --git a/cassandra/util.py b/cassandra/util.py
index 0651591203..06d338f2e1 100644
--- a/cassandra/util.py
+++ b/cassandra/util.py
@@ -13,17 +13,31 @@
# limitations under the License.
from __future__ import with_statement
+from _weakref import ref
import calendar
+from collections import OrderedDict
+from collections.abc import Mapping
import datetime
from functools import total_ordering
-import logging
-from geomet import wkt
from itertools import chain
+import keyword
+import logging
+import pickle
import random
import re
-import six
-import uuid
+import socket
import sys
+import time
+import uuid
+
+_HAS_GEOMET = True
+try:
+ from geomet import wkt
+except:
+ _HAS_GEOMET = False
+
+
+from cassandra import DriverException
DATETIME_EPOC = datetime.datetime(1970, 1, 1)
UTC_DATETIME_EPOC = datetime.datetime.utcfromtimestamp(0)
@@ -35,6 +49,7 @@
assert sys.byteorder in ('little', 'big')
is_little_endian = sys.byteorder == 'little'
+
def datetime_from_timestamp(timestamp):
"""
Creates a timezone-agnostic datetime from timestamp (in seconds) in a consistent manner.
@@ -189,161 +204,20 @@ def _addrinfo_to_ip_strings(addrinfo):
extracts the IP address from the sockaddr portion of the result.
Since this is meant to be used in conjunction with _addrinfo_or_none,
- this will pass None and EndPont instances through unaffected.
+ this will pass None and EndPoint instances through unaffected.
"""
if addrinfo is None:
return None
- return [entry[4][0] for entry in addrinfo]
+ return [(entry[4][0], entry[4][1]) for entry in addrinfo]
-def _resolve_contact_points_to_string_map(contact_points, port):
+def _resolve_contact_points_to_string_map(contact_points):
return OrderedDict(
- (cp, _addrinfo_to_ip_strings(_addrinfo_or_none(cp, port)))
- for cp in contact_points
+ ('{cp}:{port}'.format(cp=cp, port=port), _addrinfo_to_ip_strings(_addrinfo_or_none(cp, port)))
+ for cp, port in contact_points
)
-try:
- from collections import OrderedDict
-except ImportError:
- # OrderedDict from Python 2.7+
-
- # Copyright (c) 2009 Raymond Hettinger
- #
- # Permission is hereby granted, free of charge, to any person
- # obtaining a copy of this software and associated documentation files
- # (the "Software"), to deal in the Software without restriction,
- # including without limitation the rights to use, copy, modify, merge,
- # publish, distribute, sublicense, and/or sell copies of the Software,
- # and to permit persons to whom the Software is furnished to do so,
- # subject to the following conditions:
- #
- # The above copyright notice and this permission notice shall be
- # included in all copies or substantial portions of the Software.
- #
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
- # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
- # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
- # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
- # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
- # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
- # OTHER DEALINGS IN THE SOFTWARE.
- from UserDict import DictMixin
-
- class OrderedDict(dict, DictMixin): # noqa
- """ A dictionary which maintains the insertion order of keys. """
-
- def __init__(self, *args, **kwds):
- """ A dictionary which maintains the insertion order of keys. """
-
- if len(args) > 1:
- raise TypeError('expected at most 1 arguments, got %d' % len(args))
- try:
- self.__end
- except AttributeError:
- self.clear()
- self.update(*args, **kwds)
-
- def clear(self):
- self.__end = end = []
- end += [None, end, end] # sentinel node for doubly linked list
- self.__map = {} # key --> [key, prev, next]
- dict.clear(self)
-
- def __setitem__(self, key, value):
- if key not in self:
- end = self.__end
- curr = end[1]
- curr[2] = end[1] = self.__map[key] = [key, curr, end]
- dict.__setitem__(self, key, value)
-
- def __delitem__(self, key):
- dict.__delitem__(self, key)
- key, prev, next = self.__map.pop(key)
- prev[2] = next
- next[1] = prev
-
- def __iter__(self):
- end = self.__end
- curr = end[2]
- while curr is not end:
- yield curr[0]
- curr = curr[2]
-
- def __reversed__(self):
- end = self.__end
- curr = end[1]
- while curr is not end:
- yield curr[0]
- curr = curr[1]
-
- def popitem(self, last=True):
- if not self:
- raise KeyError('dictionary is empty')
- if last:
- key = next(reversed(self))
- else:
- key = next(iter(self))
- value = self.pop(key)
- return key, value
-
- def __reduce__(self):
- items = [[k, self[k]] for k in self]
- tmp = self.__map, self.__end
- del self.__map, self.__end
- inst_dict = vars(self).copy()
- self.__map, self.__end = tmp
- if inst_dict:
- return (self.__class__, (items,), inst_dict)
- return self.__class__, (items,)
-
- def keys(self):
- return list(self)
-
- setdefault = DictMixin.setdefault
- update = DictMixin.update
- pop = DictMixin.pop
- values = DictMixin.values
- items = DictMixin.items
- iterkeys = DictMixin.iterkeys
- itervalues = DictMixin.itervalues
- iteritems = DictMixin.iteritems
-
- def __repr__(self):
- if not self:
- return '%s()' % (self.__class__.__name__,)
- return '%s(%r)' % (self.__class__.__name__, self.items())
-
- def copy(self):
- return self.__class__(self)
-
- @classmethod
- def fromkeys(cls, iterable, value=None):
- d = cls()
- for key in iterable:
- d[key] = value
- return d
-
- def __eq__(self, other):
- if isinstance(other, OrderedDict):
- if len(self) != len(other):
- return False
- for p, q in zip(self.items(), other.items()):
- if p != q:
- return False
- return True
- return dict.__eq__(self, other)
-
- def __ne__(self, other):
- return not self == other
-
-
-# WeakSet from Python 2.7+ (https://code.google.com/p/weakrefset)
-
-from _weakref import ref
-
-
class _IterationGuard(object):
# This context manager registers itself in the current iterators of the
# weak container, such as to delay all removals until the context manager
@@ -780,15 +654,11 @@ def _find_insertion(self, x):
sortedset = SortedSet # backwards-compatibility
-from cassandra.compat import Mapping
-from six.moves import cPickle
-
-
class OrderedMap(Mapping):
'''
An ordered map that accepts non-hashable types for keys. It also maintains the
insertion order of items, behaving as OrderedDict in that regard. These maps
- are constructed and read just as normal mapping types, exept that they may
+ are constructed and read just as normal mapping types, except that they may
contain arbitrary collections and other non-hashable items as keys::
>>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'),
@@ -826,7 +696,7 @@ def __init__(self, *args, **kwargs):
for k, v in e:
self._insert(k, v)
- for k, v in six.iteritems(kwargs):
+ for k, v in kwargs.items():
self._insert(k, v)
def _insert(self, key, value):
@@ -892,7 +762,7 @@ def popitem(self):
raise KeyError()
def _serialize_key(self, key):
- return cPickle.dumps(key)
+ return pickle.dumps(key)
class OrderedMapSerializedKey(OrderedMap):
@@ -910,13 +780,6 @@ def _serialize_key(self, key):
return self.cass_key_type.serialize(key, self.protocol_version)
-import datetime
-import time
-
-if six.PY3:
- long = int
-
-
@total_ordering
class Time(object):
'''
@@ -942,11 +805,11 @@ def __init__(self, value):
- datetime.time: built-in time
- string_type: a string time of the form "HH:MM:SS[.mmmuuunnn]"
"""
- if isinstance(value, six.integer_types):
+ if isinstance(value, int):
self._from_timestamp(value)
elif isinstance(value, datetime.time):
self._from_time(value)
- elif isinstance(value, six.string_types):
+ elif isinstance(value, str):
self._from_timestring(value)
else:
raise TypeError('Time arguments must be a whole number, datetime.time, or string')
@@ -1022,7 +885,7 @@ def __eq__(self, other):
if isinstance(other, Time):
return self.nanosecond_time == other.nanosecond_time
- if isinstance(other, six.integer_types):
+ if isinstance(other, int):
return self.nanosecond_time == other
return self.nanosecond_time % Time.MICRO == 0 and \
@@ -1071,11 +934,11 @@ def __init__(self, value):
- datetime.date: built-in date
- string_type: a string time of the form "yyyy-mm-dd"
"""
- if isinstance(value, six.integer_types):
+ if isinstance(value, int):
self.days_from_epoch = value
elif isinstance(value, (datetime.date, datetime.datetime)):
self._from_timetuple(value.timetuple())
- elif isinstance(value, six.string_types):
+ elif isinstance(value, str):
self._from_datestring(value)
else:
raise TypeError('Date arguments must be a whole number, datetime.date, or string')
@@ -1115,7 +978,7 @@ def __eq__(self, other):
if isinstance(other, Date):
return self.days_from_epoch == other.days_from_epoch
- if isinstance(other, six.integer_types):
+ if isinstance(other, int):
return self.days_from_epoch == other
try:
@@ -1142,97 +1005,9 @@ def __str__(self):
# If we overflow datetime.[MIN|MAX]
return str(self.days_from_epoch)
-import socket
-if hasattr(socket, 'inet_pton'):
- inet_pton = socket.inet_pton
- inet_ntop = socket.inet_ntop
-else:
- """
- Windows doesn't have socket.inet_pton and socket.inet_ntop until Python 3.4
- This is an alternative impl using ctypes, based on this win_inet_pton project:
- https://github.com/hickeroar/win_inet_pton
- """
- import ctypes
-
- class sockaddr(ctypes.Structure):
- """
- Shared struct for ipv4 and ipv6.
-
- https://msdn.microsoft.com/en-us/library/windows/desktop/ms740496(v=vs.85).aspx
-
- ``__pad1`` always covers the port.
-
- When being used for ``sockaddr_in6``, ``ipv4_addr`` actually covers ``sin6_flowinfo``, resulting
- in proper alignment for ``ipv6_addr``.
- """
- _fields_ = [("sa_family", ctypes.c_short),
- ("__pad1", ctypes.c_ushort),
- ("ipv4_addr", ctypes.c_byte * 4),
- ("ipv6_addr", ctypes.c_byte * 16),
- ("__pad2", ctypes.c_ulong)]
-
- if hasattr(ctypes, 'windll'):
- WSAStringToAddressA = ctypes.windll.ws2_32.WSAStringToAddressA
- WSAAddressToStringA = ctypes.windll.ws2_32.WSAAddressToStringA
- else:
- def not_windows(*args):
- raise OSError("IPv6 addresses cannot be handled on Windows. "
- "Missing ctypes.windll")
- WSAStringToAddressA = not_windows
- WSAAddressToStringA = not_windows
-
- def inet_pton(address_family, ip_string):
- if address_family == socket.AF_INET:
- return socket.inet_aton(ip_string)
-
- addr = sockaddr()
- addr.sa_family = address_family
- addr_size = ctypes.c_int(ctypes.sizeof(addr))
-
- if WSAStringToAddressA(
- ip_string,
- address_family,
- None,
- ctypes.byref(addr),
- ctypes.byref(addr_size)
- ) != 0:
- raise socket.error(ctypes.FormatError())
-
- if address_family == socket.AF_INET6:
- return ctypes.string_at(addr.ipv6_addr, 16)
-
- raise socket.error('unknown address family')
-
- def inet_ntop(address_family, packed_ip):
- if address_family == socket.AF_INET:
- return socket.inet_ntoa(packed_ip)
-
- addr = sockaddr()
- addr.sa_family = address_family
- addr_size = ctypes.c_int(ctypes.sizeof(addr))
- ip_string = ctypes.create_string_buffer(128)
- ip_string_size = ctypes.c_int(ctypes.sizeof(ip_string))
-
- if address_family == socket.AF_INET6:
- if len(packed_ip) != ctypes.sizeof(addr.ipv6_addr):
- raise socket.error('packed IP wrong length for inet_ntoa')
- ctypes.memmove(addr.ipv6_addr, packed_ip, 16)
- else:
- raise socket.error('unknown address family')
-
- if WSAAddressToStringA(
- ctypes.byref(addr),
- addr_size,
- None,
- ip_string,
- ctypes.byref(ip_string_size)
- ) != 0:
- raise socket.error(ctypes.FormatError())
- return ip_string[:ip_string_size.value - 1]
-
-
-import keyword
+inet_pton = socket.inet_pton
+inet_ntop = socket.inet_ntop
# similar to collections.namedtuple, reproduced here because Python 2.6 did not have the rename logic
@@ -1308,6 +1083,9 @@ def from_wkt(s):
"""
Parse a Point geometry from a wkt string and return a new Point object.
"""
+ if not _HAS_GEOMET:
+ raise DriverException("Geomet is required to deserialize a wkt geometry.")
+
try:
geom = wkt.loads(s)
except ValueError:
@@ -1363,6 +1141,9 @@ def from_wkt(s):
"""
Parse a LineString geometry from a wkt string and return a new LineString object.
"""
+ if not _HAS_GEOMET:
+ raise DriverException("Geomet is required to deserialize a wkt geometry.")
+
try:
geom = wkt.loads(s)
except ValueError:
@@ -1444,6 +1225,9 @@ def from_wkt(s):
"""
Parse a Polygon geometry from a wkt string and return a new Polygon object.
"""
+ if not _HAS_GEOMET:
+ raise DriverException("Geomet is required to deserialize a wkt geometry.")
+
try:
geom = wkt.loads(s)
except ValueError:
@@ -1523,8 +1307,11 @@ class Duration(object):
"""
months = 0
+ ""
days = 0
+ ""
nanoseconds = 0
+ ""
def __init__(self, months=0, days=0, nanoseconds=0):
self.months = months
@@ -1667,7 +1454,7 @@ def __init__(self, value, precision):
if value is None:
milliseconds = None
- elif isinstance(value, six.integer_types):
+ elif isinstance(value, int):
milliseconds = value
elif isinstance(value, datetime.datetime):
value = value.replace(
@@ -1935,12 +1722,10 @@ def __init__(self, version):
try:
self.major = int(parts.pop())
- except ValueError:
- six.reraise(
- ValueError,
- ValueError("Couldn't parse version {}. Version should start with a number".format(version)),
- sys.exc_info()[2]
- )
+ except ValueError as e:
+ raise ValueError(
+ "Couldn't parse version {}. Version should start with a number".format(version))\
+ .with_traceback(e.__traceback__)
try:
self.minor = int(parts.pop()) if parts else 0
self.patch = int(parts.pop()) if parts else 0
@@ -1973,8 +1758,8 @@ def __str__(self):
@staticmethod
def _compare_version_part(version, other_version, cmp):
- if not (isinstance(version, six.integer_types) and
- isinstance(other_version, six.integer_types)):
+ if not (isinstance(version, int) and
+ isinstance(other_version, int)):
version = str(version)
other_version = str(other_version)
diff --git a/docs.yaml b/docs.yaml
index e2e1231834..07e2742637 100644
--- a/docs.yaml
+++ b/docs.yaml
@@ -22,6 +22,22 @@ sections:
# build extensions like libev
CASS_DRIVER_NO_CYTHON=1 python setup.py build_ext --inplace --force
versions:
+ - name: '3.29'
+ ref: 1a947f84
+ - name: '3.28'
+ ref: 4325afb6
+ - name: '3.27'
+ ref: 910f0282
+ - name: '3.26'
+ ref: f1e9126
+ - name: '3.25'
+ ref: a83c36a5
+ - name: '3.24'
+ ref: 21cac12b
+ - name: '3.23'
+ ref: a40a2af7
+ - name: '3.22'
+ ref: 1ccd5b99
- name: '3.21'
ref: 5589d96b
- name: '3.20'
@@ -59,9 +75,47 @@ versions:
redirects:
- \A\/(.*)/\Z: /\1.html
rewrites:
- - search: cassandra.apache.org/doc/cql3/CQL.html
- replace: cassandra.apache.org/doc/cql3/CQL-3.0.html
- - search: http://www.datastax.com/documentation/cql/3.1/
- replace: https://docs.datastax.com/en/archived/cql/3.1/
- search: http://www.datastax.com/docs/1.2/cql_cli/cql/BATCH
replace: https://docs.datastax.com/en/dse/6.7/cql/cql/cql_reference/cql_commands/cqlBatch.html
+ - search: http://www.datastax.com/documentation/cql/3.1/
+ replace: https://docs.datastax.com/en/archived/cql/3.1/
+ - search: 'https://community.datastax.com'
+ replace: 'https://www.datastax.com/dev/community'
+ - search: 'https://docs.datastax.com/en/astra/aws/doc/index.html'
+ replace: 'https://docs.datastax.com/en/astra-serverless/docs/connect/drivers/connect-python.html'
+ - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun'
+ replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#timeuuid-functions'
+ - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun'
+ replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#token'
+ - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#collections'
+ replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/types.html#collections'
+ - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#batchStmt'
+ replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/dml.html#batch_statement'
+ - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun'
+ replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#timeuuid-functions'
+ - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#tokenFun'
+ replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#token'
+ - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#collections'
+ replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/types.html#collections'
+ - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#batchStmt'
+ replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/dml.html#batch_statement'
+checks:
+ external_links:
+ exclude:
+ - 'https://twitter.com/dsJavaDriver'
+ - 'https://twitter.com/datastaxeng'
+ - 'https://twitter.com/datastax'
+ - 'https://projectreactor.io'
+ - 'https://docs.datastax.com/en/drivers/java/4.[0-9]+/com/datastax/oss/driver/internal/'
+ - 'http://www.planetcassandra.org/blog/user-defined-functions-in-cassandra-3-0/'
+ - 'http://www.planetcassandra.org/making-the-change-from-thrift-to-cql/'
+ - 'https://academy.datastax.com/slack'
+ - 'https://community.datastax.com/index.html'
+ - 'https://micrometer.io/docs'
+ - 'http://datastax.github.io/java-driver/features/shaded_jar/'
+ - 'http://aka.ms/vcpython27'
+ internal_links:
+ exclude:
+ - 'netty_pipeline/'
+ - '../core/'
+ - '%5Bguava%20eviction%5D'
diff --git a/docs/.nav b/docs/.nav
index 7b39d9001d..79f3029073 100644
--- a/docs/.nav
+++ b/docs/.nav
@@ -10,5 +10,12 @@ upgrading
user_defined_types
dates_and_times
cloud
+column_encryption
+geo_types
+graph
+classic_graph
+graph_fluent
+CHANGELOG
faq
api
+
diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst
index 71e110559e..a9a9d378a4 100644
--- a/docs/api/cassandra/cluster.rst
+++ b/docs/api/cassandra/cluster.rst
@@ -120,13 +120,19 @@
.. automethod:: set_meta_refresh_enabled
-.. autoclass:: ExecutionProfile (load_balancing_policy=