diff --git a/.gitignore b/.gitignore index f9a3ac8bd..e025587de 100644 --- a/.gitignore +++ b/.gitignore @@ -35,4 +35,6 @@ nosetests.xml .pydevproject # Pycharm -/.idea \ No newline at end of file +/.idea + +.venv diff --git a/HACKING.txt b/HACKING.txt index 4b3763681..e6bb1403c 100644 --- a/HACKING.txt +++ b/HACKING.txt @@ -4,7 +4,11 @@ Development setup Running nose tests with IPython is tricky, so there's a run_tests.sh script for it. -To temporarily insert breakpoints for debugging: `from nose.tools import set_trace; set_trace()` + pip install -e . + ./run_tests.sh + +To temporarily insert breakpoints for debugging: `from nose.tools import set_trace; set_trace()`. +Or, if running tests, use `pytest.set_trace()`. Tests have requirements not installed by setup.py: diff --git a/MANIFEST.in b/MANIFEST.in index 910adaed4..c27afb39e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ include README.rst -include NEWS.txt +include NEWS.rst include LICENSE diff --git a/NEWS.txt b/NEWS.rst similarity index 76% rename from NEWS.txt rename to NEWS.rst index 989c61648..f540decbf 100644 --- a/NEWS.txt +++ b/NEWS.rst @@ -1,15 +1,15 @@ News -==== +---- 0.1 ---- +~~~ *Release date: 21-Mar-2013* * Initial release 0.1.1 ------ +~~~~~ *Release date: 29-Mar-2013* @@ -23,7 +23,7 @@ News 0.1.2 ------ +~~~~~ *Release date: 29-Mar-2013* @@ -34,14 +34,14 @@ News * allow multiple SQL per cell 0.2.0 ------ +~~~~~ *Release date: 30-May-2013* * Accept bind variables (Thanks Mike Wilson!) 0.2.1 ------ +~~~~~ *Release date: 15-June-2013* @@ -50,21 +50,21 @@ News * Bugfix - issue 4 (remember existing connections by case) 0.2.2 ------ +~~~~~ *Release date: 30-July-2013* Converted from an IPython Plugin to an Extension for 1.0 compatibility 0.2.2.1 -------- +~~~~~~~ *Release date: 01-Aug-2013* Deleted Plugin import left behind in 0.2.2 0.2.3 ------ +~~~~~ *Release date: 20-Sep-2013* @@ -78,7 +78,7 @@ Deleted Plugin import left behind in 0.2.2 result sets 0.3.0 ------ +~~~~~ *Release date: 13-Oct-2013* @@ -91,58 +91,58 @@ Deleted Plugin import left behind in 0.2.2 * dict-style access for result sets by primary key 0.3.1 ------ +~~~~~ * Reporting of number of rows affected configurable with ``feedback`` * Local variables usable as SQL bind variables 0.3.2 ------ +~~~~~ * ``.csv(filename=None)`` method added to result sets 0.3.3 ------ +~~~~~ * Python 3 compatibility restored * DSN access supported (thanks Berton Earnshaw) 0.3.4 ------ +~~~~~ * PERSIST pseudo-SQL command added 0.3.5 ------ +~~~~~ * Indentations visible in HTML cells * COMMIT each SQL statement immediately - prevent locks 0.3.6 ------ +~~~~~ * Fixed issue #30, commit failures for sqlite (thanks stonebig, jandot) 0.3.7 ------ +~~~~~ * New `column_local_vars` config option submitted by darikg * Avoid contaminating user namespace from locals (thanks alope107) 0.3.7.1 -------- +~~~~~~~ * Avoid "connection busy" error for SQL Server (thanks Andrés Celis) 0.3.8 ------ +~~~~~ * Stop warnings for deprecated use of IPython 3 traitlets in IPython 4 (thanks graphaelli; also stonebig, aebrahim, mccahill) * README update for keeping connection info private, from eshilts 0.3.9 ------ +~~~~~ * Fix truth value of DataFrame error (thanks michael-erasmus) * `<<` operator (thanks xiaochuanyu) @@ -152,6 +152,28 @@ Deleted Plugin import left behind in 0.2.2 * conceal passwords in connection strings (thanks jstoebel) 0.3.9 ------ +~~~~~ -* Restored Python 2 compatibility (thanks tokenmathguy) \ No newline at end of file +* Restored Python 2 compatibility (thanks tokenmathguy) + +0.4.0 +~~~~~ + +* Changed most non-SQL commands to argparse arguments (thanks pik) +* User can specify a creator for connections (thanks pik) +* Bogus pseudo-SQL command `PERSIST` removed, replaced with `--persist` arg +* Turn off echo of connection information with `displaycon` in config +* Consistent support for {} variables (thanks Lucas) + +0.4.1 +~~~~~ + +* Fixed .rst file location in MANIFEST.in +* Parse SQL comments in first line +* Bugfixes for DSN, `--close`, others + +0.5.0 +~~~~~ + +* Use SQLAlchemy 2.0 +* Drop undocumented support for dict-style access to raw row instances \ No newline at end of file diff --git a/README.rst b/README.rst index 2b1d5166d..1c48d0c0e 100644 --- a/README.rst +++ b/README.rst @@ -6,7 +6,15 @@ ipython-sql Introduces a %sql (or %%sql) magic. -Connect to a database, using SQLAlchemy connect strings, then issue SQL +Legacy project +-------------- + +IPython-SQL's functionality and maintenance have been eclipsed by JupySQL_, a fork maintained and developed by the Ploomber team. Future work will be directed into JupySQL - please file issues there, as well! + +Description +----------- + +Connect to a database, using `SQLAlchemy URL`_ connect strings, then issue SQL commands within IPython or IPython Notebook. .. image:: https://raw.github.com/catherinedevlin/ipython-sql/master/examples/writers.png @@ -92,30 +100,47 @@ makes sense for statements with no output Out[11]: [] -Bind variables (bind parameters) can be used in the "named" (:x) style. -The variable names used should be defined in the local namespace - -.. code-block:: python - - In [12]: name = 'Countess' - - In [13]: %sql select description from character where charname = :name - Out[13]: [(u'mother to Bertram',)] - As a convenience, dict-style access for result sets is supported, with the leftmost column serving as key, for unique values. .. code-block:: python - In [14]: result = %sql select * from work + In [12]: result = %sql select * from work 43 rows affected. - In [15]: result['richard2'] - Out[15]: (u'richard2', u'Richard II', u'History of Richard II', 1595, u'h', None, u'Moby', 22411, 628) + In [13]: result['richard2'] + Out[14]: (u'richard2', u'Richard II', u'History of Richard II', 1595, u'h', None, u'Moby', 22411, 628) Results can also be retrieved as an iterator of dictionaries (``result.dicts()``) or a single dictionary with a tuple of scalar values per key (``result.dict()``) +Variable substitution +--------------------- + +Bind variables (bind parameters) can be used in the "named" (:x) style. +The variable names used should be defined in the local namespace. + +.. code-block:: python + + In [15]: name = 'Countess' + + In [16]: %sql select description from character where charname = :name + Out[16]: [(u'mother to Bertram',)] + + In [17]: %sql select description from character where charname = '{name}' + Out[17]: [(u'mother to Bertram',)] + +Alternately, ``$variable_name`` or ``{variable_name}`` can be +used to inject variables from the local namespace into the SQL +statement before it is formed and passed to the SQL engine. +(Using ``$`` and ``{}`` together, as in ``${variable_name}``, +is not supported.) + +Bind variables are passed through to the SQL engine and can only +be used to replace strings passed to SQL. ``$`` and ``{}`` are +substituted before passing to SQL and can be used to form SQL +statements dynamically. + Assignment ---------- @@ -123,7 +148,7 @@ Ordinary IPython assignment works for single-line `%sql` queries: .. code-block:: python - In [16]: works = %sql SELECT title, year FROM work + In [18]: works = %sql SELECT title, year FROM work 43 rows affected. The `<<` operator captures query results in a local variable, and @@ -131,7 +156,7 @@ can be used in multi-line ``%%sql``: .. code-block:: python - In [17]: %%sql works << SELECT title, year + In [19]: %%sql works << SELECT title, year ...: FROM work ...: 43 rows affected. @@ -140,7 +165,7 @@ can be used in multi-line ``%%sql``: Connecting ---------- -Connection strings are `SQLAlchemy`_ standard. +Connection strings are `SQLAlchemy URL`_ standard. Some example connection strings:: @@ -150,7 +175,7 @@ Some example connection strings:: sqlite:///foo.db mssql+pyodbc://username:password@host/database?driver=SQL+Server+Native+Client+11.0 -.. _SQLAlchemy: http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls +.. _`SQLAlchemy URL`: http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls Note that ``mysql`` and ``mysql+pymysql`` connections (and perhaps others) don't read your client character set information from .my.cnf. You need @@ -158,13 +183,44 @@ to specify it in the connection string:: mysql+pymysql://scott:tiger@localhost/foo?charset=utf8 -Note that ``impala`` connecion with `impyla`_ for HiveServer2 requires to disable autocommit:: +Note that an ``impala`` connection with `impyla`_ for HiveServer2 requires disabling autocommit:: %config SqlMagic.autocommit=False %sql impala://hserverhost:port/default?kerberos_service_name=hive&auth_mechanism=GSSAPI .. _impyla: https://github.com/cloudera/impyla +Connection arguments not whitelisted by SQLALchemy can be provided as +a flag with (-a|--connection_arguments)the connection string as a JSON string. +See `SQLAlchemy Args`_. + + | %sql --connection_arguments {"timeout":10,"mode":"ro"} sqlite:// SELECT * FROM work; + | %sql -a '{"timeout":10, "mode":"ro"}' sqlite:// SELECT * from work; + +.. _`SQLAlchemy Args`: https://docs.sqlalchemy.org/en/13/core/engines.html#custom-dbapi-args + +DSN connections +~~~~~~~~~~~~~~~ + +Alternately, you can store connection info in a +configuration file, under a section name chosen to +refer to your database. + +For example, if dsn.ini contains + + | [DB_CONFIG_1] + | drivername=postgres + | host=my.remote.host + | port=5433 + | database=mydatabase + | username=myuser + | password=1234 + +then you can + + | %config SqlMagic.dsn_filename='./dsn.ini' + | %sql --section DB_CONFIG_1 + Configuration ------------- @@ -177,34 +233,45 @@ only the screen display is truncated. .. code-block:: python - In [2]: %config SqlMagic - SqlMagic options - -------------- - SqlMagic.autocommit= - Current: True - Set autocommit mode - SqlMagic.autolimit= - Current: 0 - Automatically limit the size of the returned result sets - SqlMagic.autopandas= - Current: False - Return Pandas DataFrames instead of regular result sets - SqlMagic.displaylimit= - Current: 0 - Automatically limit the number of rows displayed (full result set is still - stored) - SqlMagic.feedback= - Current: True - Print number of rows affected by DML - SqlMagic.short_errors= - Current: True - Don't display the full traceback on SQL Programming Error - SqlMagic.style= - Current: 'DEFAULT' - Set the table printing style to any of prettytable's defined styles - (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM) - - In[3]: %config SqlMagic.feedback = False + In [2]: %config SqlMagic + SqlMagic options + -------------- + SqlMagic.autocommit= + Current: True + Set autocommit mode + SqlMagic.autolimit= + Current: 0 + Automatically limit the size of the returned result sets + SqlMagic.autopandas= + Current: False + Return Pandas DataFrames instead of regular result sets + SqlMagic.column_local_vars= + Current: False + Return data into local variables from column names + SqlMagic.displaycon= + Current: False + Show connection string after execute + SqlMagic.displaylimit= + Current: None + Automatically limit the number of rows displayed (full result set is still + stored) + SqlMagic.dsn_filename= + Current: 'odbc.ini' + Path to DSN file. When the first argument is of the form [section], a + sqlalchemy connection string is formed from the matching section in the DSN + file. + SqlMagic.feedback= + Current: False + Print number of rows affected by DML + SqlMagic.short_errors= + Current: True + Don't display the full traceback on SQL Programming Error + SqlMagic.style= + Current: 'DEFAULT' + Set the table printing style to any of prettytable's defined styles + (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM) + + In[3]: %config SqlMagic.feedback = False Please note: if you have autopandas set to true, the displaylimit option will not apply. You can set the pandas display limit by using the pandas ``max_rows`` option as described in the `pandas documentation `_. @@ -220,12 +287,17 @@ If you have installed ``pandas``, you can use a result set's In [4]: dataframe = result.DataFrame() -The bogus non-standard pseudo-SQL command ``PERSIST`` will create a table name -in the database from the named DataFrame. + +The ``--persist`` argument, with the name of a +DataFrame object in memory, +will create a table name +in the database from the named DataFrame. +Or use ``--append`` to add rows to an existing +table by that name. .. code-block:: python - In [5]: %sql PERSIST dataframe + In [5]: %sql --persist dataframe In [6]: %sql SELECT * FROM dataframe; @@ -275,10 +347,55 @@ are provided by `PGSpecial`_. Example: .. _meta-commands: https://www.postgresql.org/docs/9.6/static/app-psql.html#APP-PSQL-META-COMMANDS + +Options +------- + +``-l`` / ``--connections`` + List all active connections + +``-x`` / ``--close `` + Close named connection + +``-c`` / ``--creator `` + Specify creator function for new connection + +``-s`` / ``--section `` + Section of dsn_file to be used for generating a connection string + +``-p`` / ``--persist`` + Create a table name in the database from the named DataFrame + +``--append`` + Like ``--persist``, but appends to the table if it already exists + +``-a`` / ``--connection_arguments <"{connection arguments}">`` + Specify dictionary of connection arguments to pass to SQL driver + +``-f`` / ``--file `` + Run SQL from file at this path + +Caution +------- + +Comments +~~~~~~~~ + +Because ipyton-sql accepts ``--``-delimited options like ``--persist``, but ``--`` +is also the syntax to denote a SQL comment, the parser needs to make some assumptions. + +- If you try to pass an unsupported argument, like ``--lutefisk``, it will + be interpreted as a SQL comment and will not throw an unsupported argument + exception. +- If the SQL statement begins with a first-line comment that looks like one + of the accepted arguments - like ``%sql --persist is great!`` - it will be + parsed like an argument, not a comment. Moving the comment to the second + line or later will avoid this. + Installing ---------- -Install the lastest release with:: +Install the latest release with:: pip install ipython-sql @@ -303,12 +420,23 @@ Credits - Mike Wilson for bind variable code - Thomas Kluyver and Steve Holden for debugging help - Berton Earnshaw for DSN connection syntax +- Bruno Harbulot for DSN example - Andrés Celis for SQL Server bugfix - Michael Erasmus for DataFrame truth bugfix - Noam Finkelstein for README clarification - Xiaochuan Yu for `<<` operator, syntax colorization - Amjith Ramanujam for PGSpecial and incorporating it here +- Alexander Maznev for better arg parsing, connections accepting specified creator +- Jonathan Larkin for configurable displaycon +- Jared Moore for ``connection-arguments`` support +- Gilbert Brault for ``--append`` +- Lucas Zeer for multi-line bugfixes for var substitution, ``<<`` +- vkk800 for ``--file`` +- Jens Albrecht for MySQL DatabaseError bugfix +- meihkv for connection-closing bugfix +- Abhinav C for SQLAlchemy 2.0 compatibility .. _Distribute: http://pypi.python.org/pypi/distribute .. _Buildout: http://www.buildout.org/ .. _modern-package-template: http://pypi.python.org/pypi/modern-package-template +.. _JupySQL: https://github.com/ploomber/jupysql diff --git a/examples/writers.ipynb b/examples/writers.ipynb index 48ae7357b..9b41b00ac 100644 --- a/examples/writers.ipynb +++ b/examples/writers.ipynb @@ -1,112 +1,305 @@ { - "metadata": { - "name": "writers" - }, - "nbformat": 3, - "nbformat_minor": 0, - "worksheets": [ + "cells": [ { - "cells": [ + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "%load_ext sql" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "%sql sqlite://" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ { - "cell_type": "code", - "collapsed": false, - "input": [ - "%load_ext sql" - ], - "language": "python", - "metadata": {}, - "outputs": [], - "prompt_number": 3 + "name": "stdout", + "output_type": "stream", + "text": [ + " * sqlite://\n", + "Done.\n", + "1 rows affected.\n", + "1 rows affected.\n" + ] }, { - "cell_type": "code", - "collapsed": false, - "input": [ - "%sql sqlite://" - ], - "language": "python", + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 3, "metadata": {}, - "outputs": [ - { - "output_type": "pyout", - "prompt_number": 4, - "text": [ - "'Connected: None@None'" - ] - } - ], - "prompt_number": 4 + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "CREATE TABLE writer (first_name, last_name, year_of_death);\n", + "INSERT INTO writer VALUES ('William', 'Shakespeare', 1616);\n", + "INSERT INTO writer VALUES ('Bertold', 'Brecht', 1956);" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " * sqlite://\n", + "Done.\n" + ] }, { - "cell_type": "code", - "collapsed": false, - "input": [ - "%%sql\n", - "CREATE TABLE writer (first_name, last_name, year_of_death);\n", - "INSERT INTO writer VALUES ('William', 'Shakespeare', 1616);\n", - "INSERT INTO writer VALUES ('Bertold', 'Brecht', 1956);" - ], - "language": "python", + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
first_namelast_nameyear_of_death
WilliamShakespeare1616
BertoldBrecht1956
" + ], + "text/plain": [ + "[('William', 'Shakespeare', 1616), ('Bertold', 'Brecht', 1956)]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sql select * from writer" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " * sqlite://\n", + "Done.\n", + "Returning data to local variable writers\n" + ] + } + ], + "source": [ + "%%sql writers << select first_name, year_of_death\n", + "from writer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
first_nameyear_of_death
William1616
Bertold1956
" + ], + "text/plain": [ + "[('William', 1616), ('Bertold', 1956)]" + ] + }, + "execution_count": 6, "metadata": {}, - "outputs": [ - { - "output_type": "pyout", - "prompt_number": 5, - "text": [ - "[]" - ] - } - ], - "prompt_number": 5 + "output_type": "execute_result" + } + ], + "source": [ + "writers" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "var = 'last_name'" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " * sqlite://\n", + "Done.\n" + ] }, { - "cell_type": "code", - "collapsed": false, - "input": [ - "%sql select * from writer" - ], - "language": "python", + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
first_namelast_nameyear_of_death
BertoldBrecht1956
" + ], + "text/plain": [ + "[('Bertold', 'Brecht', 1956)]" + ] + }, + "execution_count": 8, "metadata": {}, - "outputs": [ - { - "html": [ - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
first_namelast_nameyear_of_death
WilliamShakespeare1616
BertoldBrecht1956
" - ], - "output_type": "pyout", - "prompt_number": 6, - "text": [ - "[('William', 'Shakespeare', 1616), ('Bertold', 'Brecht', 1956)]" - ] - } - ], - "prompt_number": 6 + "output_type": "execute_result" + } + ], + "source": [ + "%sql select * from writer where {var} = 'Brecht'" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " * sqlite://\n", + "Done.\n" + ] }, { - "cell_type": "code", - "collapsed": false, - "input": [], - "language": "python", + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
first_namelast_nameyear_of_death
BertoldBrecht1956
" + ], + "text/plain": [ + "[('Bertold', 'Brecht', 1956)]" + ] + }, + "execution_count": 9, "metadata": {}, - "outputs": [] + "output_type": "execute_result" } ], - "metadata": {} + "source": [ + "%%sql select * from writer \n", + "where {var} = 'Brecht'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/requirements-dev.txt b/requirements-dev.txt index b096622f2..f48c0c875 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,8 @@ -psycopg2 +psycopg2-binary pandas pytest - +wheel +twine +readme-renderer +black +isort diff --git a/requirements.txt b/requirements.txt index f528f3b49..4b9734adf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ -prettytable=0.7.2 -ipython>=1.0 -sqlalchemy>=0.6.7 +prettytable +ipython +sqlalchemy>=2.0 sqlparse six -ipython-genutils>=0.1.0 +ipython-genutils +traitlets +matplotlib diff --git a/setup.py b/setup.py index 7047857d6..d628b4b6b 100644 --- a/setup.py +++ b/setup.py @@ -1,44 +1,49 @@ -from io import open -from setuptools import setup, find_packages import os +from io import open + +from setuptools import find_packages, setup here = os.path.abspath(os.path.dirname(__file__)) -README = open(os.path.join(here, 'README.rst'), encoding='utf-8').read() -NEWS = open(os.path.join(here, 'NEWS.txt'), encoding='utf-8').read() +README = open(os.path.join(here, "README.rst"), encoding="utf-8").read() +NEWS = open(os.path.join(here, "NEWS.rst"), encoding="utf-8").read() -version = '0.3.9' +version = "0.5.0" install_requires = [ - 'prettytable', - 'ipython>=1.0', - 'sqlalchemy>=0.6.7', - 'sqlparse', - 'six', - 'ipython-genutils>=0.1.0', + "prettytable", + "ipython", + "sqlalchemy>=2.0", + "sqlparse", + "six", + "ipython-genutils", ] -setup(name='ipython-sql', +setup( + name="ipython-sql", version=version, description="RDBMS access via IPython", - long_description=README + '\n\n' + NEWS, + long_description=README + "\n\n" + NEWS, + long_description_content_type="text/x-rst", classifiers=[ - 'Development Status :: 3 - Alpha', - 'Environment :: Console', - 'License :: OSI Approved :: MIT License', - 'Topic :: Database', - 'Topic :: Database :: Front-Ends', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 2', + "Development Status :: 3 - Alpha", + "Environment :: Console", + "License :: OSI Approved :: MIT License", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Programming Language :: Python :: 3", ], - keywords='database ipython postgresql mysql', - author='Catherine Devlin', - author_email='catherine.devlin@gmail.com', - url='https://pypi.python.org/pypi/ipython-sql', - license='MIT', - packages=find_packages('src'), - package_dir = {'': 'src'}, + keywords="database ipython postgresql mysql", + author="Catherine Devlin", + author_email="catherine.devlin@gmail.com", + url="https://github.com/catherinedevlin/ipython-sql", + project_urls={ + "Source": "https://github.com/catherinedevlin/ipython-sql", + }, + license="MIT", + packages=find_packages("src"), + package_dir={"": "src"}, include_package_data=True, zip_safe=False, install_requires=install_requires, diff --git a/src/sql/column_guesser.py b/src/sql/column_guesser.py index 8e46ac289..33e40c47a 100644 --- a/src/sql/column_guesser.py +++ b/src/sql/column_guesser.py @@ -4,26 +4,33 @@ (X values, Y values, and text labels). """ + class Column(list): - 'Store a column of tabular data; record its name and whether it is numeric' + """Store a column of tabular data; record its name and whether it is numeric""" is_quantity = True - name = '' + name = "" + def __init__(self, *arg, **kwarg): pass - + def is_quantity(val): """Is ``val`` a quantity (int, float, datetime, etc) (not str, bool)? Relies on presence of __sub__. """ - return hasattr(val, '__sub__') + return hasattr(val, "__sub__") + class ColumnGuesserMixin(object): """ plot: [x, y, y...], y pie: ... y """ + + def __init__(self): + self.keys = None + def _build_columns(self): self.columns = [Column() for col in self.keys] for row in self: @@ -32,39 +39,40 @@ def _build_columns(self): col.append(col_val) if (col_val is not None) and (not is_quantity(col_val)): col.is_quantity = False - + for (idx, key_name) in enumerate(self.keys): self.columns[idx].name = key_name - + self.x = Column() self.ys = [] - + def _get_y(self): - for idx in range(len(self.columns)-1,-1,-1): + for idx in range(len(self.columns) - 1, -1, -1): if self.columns[idx].is_quantity: self.ys.insert(0, self.columns.pop(idx)) return True - def _get_x(self): + def _get_x(self): for idx in range(len(self.columns)): if self.columns[idx].is_quantity: self.x = self.columns.pop(idx) return True - + def _get_xlabel(self, xlabel_sep=" "): self.xlabels = [] if self.columns: for row_idx in range(len(self.columns[0])): - self.xlabels.append(xlabel_sep.join( - str(c[row_idx]) for c in self.columns)) + self.xlabels.append( + xlabel_sep.join(str(c[row_idx]) for c in self.columns) + ) self.xlabel = ", ".join(c.name for c in self.columns) - + def _guess_columns(self): self._build_columns() self._get_y() if not self.ys: raise AttributeError("No quantitative columns found for chart") - + def guess_pie_columns(self, xlabel_sep=" "): """ Assigns x, y, and x labels from the data set for a pie chart. @@ -75,7 +83,7 @@ def guess_pie_columns(self, xlabel_sep=" "): """ self._guess_columns() self._get_xlabel(xlabel_sep) - + def guess_plot_columns(self): """ Assigns ``x`` and ``y`` series from the data set for a plot. @@ -88,4 +96,4 @@ def guess_plot_columns(self): self._guess_columns() self._get_x() while self._get_y(): - pass \ No newline at end of file + pass diff --git a/src/sql/connection.py b/src/sql/connection.py index 986c043a7..fa6f4cdb7 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -1,20 +1,22 @@ -import sqlalchemy import os -import re +import traceback + +import sqlalchemy + class ConnectionError(Exception): pass def rough_dict_get(dct, sought, default=None): - ''' + """ Like dct.get(sought), but any key containing sought will do. If there is a `@` in sought, seek each piece separately. This lets `me@server` match `me:***@myserver/db` - ''' - - sought = sought.split('@') + """ + + sought = sought.split("@") for (key, val) in dct.items(): if not any(s.lower() not in key.lower() for s in sought): return val @@ -29,54 +31,90 @@ class Connection(object): def tell_format(cls): return """Connection info needed in SQLAlchemy format, example: postgresql://username:password@hostname/dbname - or an existing connection: %s""" % str(cls.connections.keys()) + or an existing connection: %s""" % str( + cls.connections.keys() + ) - def __init__(self, connect_str=None): + def __init__(self, connect_str=None, connect_args={}, creator=None): try: - engine = sqlalchemy.create_engine(connect_str) - except: # TODO: bare except; but what's an ArgumentError? + if creator: + engine = sqlalchemy.create_engine( + connect_str, connect_args=connect_args, creator=creator + ) + else: + engine = sqlalchemy.create_engine( + connect_str, connect_args=connect_args + ) + except Exception as ex: # TODO: bare except; but what's an ArgumentError? + print(traceback.format_exc()) print(self.tell_format()) raise + self.url = engine.url self.dialect = engine.url.get_dialect() - self.metadata = sqlalchemy.MetaData(bind=engine) self.name = self.assign_name(engine) - self.session = engine.connect() - self.connections[repr(self.metadata.bind.url)] = self + self.internal_connection = engine.connect() + self.connections[repr(self.url)] = self + self.connect_args = connect_args Connection.current = self @classmethod - def set(cls, descriptor): - "Sets the current database connection" + def set(cls, descriptor, displaycon, connect_args={}, creator=None): + """Sets the current database connection""" if descriptor: if isinstance(descriptor, Connection): cls.current = descriptor else: existing = rough_dict_get(cls.connections, descriptor) - cls.current = existing or Connection(descriptor) + # http://docs.sqlalchemy.org/en/rel_0_9/core/engines.html#custom-dbapi-connect-arguments + cls.current = existing or Connection(descriptor, connect_args, creator) else: + if cls.connections: - print(cls.connection_list()) + if displaycon: + print(cls.connection_list()) else: - if os.getenv('DATABASE_URL'): - cls.current = Connection(os.getenv('DATABASE_URL')) + if os.getenv("DATABASE_URL"): + cls.current = Connection( + os.getenv("DATABASE_URL"), connect_args, creator + ) else: - raise ConnectionError('Environment variable $DATABASE_URL not set, and no connect string given.') + raise ConnectionError( + "Environment variable $DATABASE_URL not set, and no connect string given." + ) return cls.current @classmethod def assign_name(cls, engine): - name = '%s@%s' % (engine.url.username or '', engine.url.database) + name = "%s@%s" % (engine.url.username or "", engine.url.database) return name @classmethod def connection_list(cls): result = [] for key in sorted(cls.connections): - engine_url = cls.connections[key].metadata.bind.url # type: sqlalchemy.engine.url.URL + engine_url = cls.connections[ + key + ].url # type: sqlalchemy.engine.url.URL if cls.connections[key] == cls.current: - template = ' * {}' + template = " * {}" else: - template = ' {}' + template = " {}" result.append(template.format(engine_url.__repr__())) - return '\n'.join(result) + return "\n".join(result) + + @classmethod + def close(cls, descriptor): + if isinstance(descriptor, Connection): + conn = descriptor + else: + conn = cls.connections.get(descriptor) or cls.connections.get( + descriptor.lower() + ) + if not conn: + raise Exception( + "Could not close connection because it was not found amongst these: %s" + % str(cls.connections.keys()) + ) + cls.connections.pop(str(conn.url)) + conn.internal_connection.close() diff --git a/src/sql/magic.py b/src/sql/magic.py index 58812a40f..d3c876e6f 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -1,6 +1,21 @@ +import json import re -from IPython.core.magic import Magics, magics_class, cell_magic, line_magic, needs_local_scope -from IPython.display import display_javascript +import traceback + +from IPython.core.magic import ( + Magics, + cell_magic, + line_magic, + magics_class, + needs_local_scope, +) +from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring +from sqlalchemy.exc import OperationalError, ProgrammingError, DatabaseError + +import sql.connection +import sql.parse +import sql.run + try: from traitlets.config.configurable import Configurable from traitlets import Bool, Int, Unicode @@ -13,12 +28,6 @@ DataFrame = None Series = None -from sqlalchemy.exc import ProgrammingError, OperationalError - -import sql.connection -import sql.parse -import sql.run - @magics_class class SqlMagic(Magics, Configurable): @@ -26,31 +35,93 @@ class SqlMagic(Magics, Configurable): Provides the %%sql magic.""" - autolimit = Int(0, config=True, allow_none=True, help="Automatically limit the size of the returned result sets") - style = Unicode('DEFAULT', config=True, help="Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)") - short_errors = Bool(True, config=True, help="Don't display the full traceback on SQL Programming Error") - displaylimit = Int(None, config=True, allow_none=True, help="Automatically limit the number of rows displayed (full result set is still stored)") - autopandas = Bool(False, config=True, help="Return Pandas DataFrames instead of regular result sets") - column_local_vars = Bool(False, config=True, help="Return data into local variables from column names") + displaycon = Bool(True, config=True, help="Show connection string after execute") + autolimit = Int( + 0, + config=True, + allow_none=True, + help="Automatically limit the size of the returned result sets", + ) + style = Unicode( + "DEFAULT", + config=True, + help="Set the table printing style to any of prettytable's defined styles " + "(currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)", + ) + short_errors = Bool( + True, + config=True, + help="Don't display the full traceback on SQL Programming Error", + ) + displaylimit = Int( + None, + config=True, + allow_none=True, + help="Automatically limit the number of rows displayed (full result set is still stored)", + ) + autopandas = Bool( + False, + config=True, + help="Return Pandas DataFrames instead of regular result sets", + ) + column_local_vars = Bool( + False, config=True, help="Return data into local variables from column names" + ) feedback = Bool(True, config=True, help="Print number of rows affected by DML") - dsn_filename = Unicode('odbc.ini', config=True, help="Path to DSN file. " - "When the first argument is of the form [section], " - "a sqlalchemy connection string is formed from the " - "matching section in the DSN file.") + dsn_filename = Unicode( + "odbc.ini", + config=True, + help="Path to DSN file. " + "When the first argument is of the form [section], " + "a sqlalchemy connection string is formed from the " + "matching section in the DSN file.", + ) autocommit = Bool(True, config=True, help="Set autocommit mode") - def __init__(self, shell): Configurable.__init__(self, config=shell.config) Magics.__init__(self, shell=shell) - # Add ourself to the list of module configurable via %config + # Add ourselves to the list of module configurable via %config self.shell.configurables.append(self) @needs_local_scope - @line_magic('sql') - @cell_magic('sql') - def execute(self, line, cell='', local_ns={}): + @line_magic("sql") + @cell_magic("sql") + @magic_arguments() + @argument("line", default="", nargs="*", type=str, help="sql") + @argument( + "-l", "--connections", action="store_true", help="list active connections" + ) + @argument("-x", "--close", type=str, help="close a session by name") + @argument( + "-c", "--creator", type=str, help="specify creator function for new connection" + ) + @argument( + "-s", + "--section", + type=str, + help="section of dsn_file to be used for generating a connection string", + ) + @argument( + "-p", + "--persist", + action="store_true", + help="create a table name in the database from the named DataFrame", + ) + @argument( + "--append", + action="store_true", + help="create, or append to, a table name in the database from the named DataFrame", + ) + @argument( + "-a", + "--connection_arguments", + type=str, + help="specify dictionary of connection arguments to pass to SQL driver", + ) + @argument("-f", "--file", type=str, help="Run SQL from file at this path") + def execute(self, line="", cell="", local_ns=None): """Runs SQL statement against a database, specified by SQLAlchemy connect string. If no database connection has been established, first word @@ -75,28 +146,85 @@ def execute(self, line, cell='', local_ns={}): mysql+pymysql://me:mypw@localhost/mydb """ - # save globals and locals so they can be referenced in bind vars + # Parse variables (words wrapped in {}) for %%sql magic (for %sql this is done automatically) + if local_ns is None: + local_ns = {} + cell = self.shell.var_expand(cell) + line = sql.parse.without_sql_comment(parser=self.execute.parser, line=line) + args = parse_argstring(self.execute, line) + if args.connections: + return sql.connection.Connection.connections + elif args.close: + return sql.connection.Connection.close(args.close) + + # save globals and locals, so they can be referenced in bind vars user_ns = self.shell.user_ns.copy() user_ns.update(local_ns) - parsed = sql.parse.parse('%s\n%s' % (line, cell), self) - flags = parsed['flags'] + command_text = " ".join(args.line) + "\n" + cell + + if args.file: + with open(args.file, "r") as infile: + command_text = infile.read() + "\n" + command_text + + parsed = sql.parse.parse(command_text, self) + + connect_str = parsed["connection"] + if args.section: + connect_str = sql.parse.connection_from_dsn_section(args.section, self) + + if args.connection_arguments: + try: + # check for string delineators, we need to strip them for json parse + raw_args = args.connection_arguments + if len(raw_args) > 1: + targets = ['"', "'"] + head = raw_args[0] + tail = raw_args[-1] + if head in targets and head == tail: + raw_args = raw_args[1:-1] + args.connection_arguments = json.loads(raw_args) + except Exception as e: + print(traceback.format_exc()) + raise e + else: + args.connection_arguments = {} + if args.creator: + args.creator = user_ns[args.creator] + try: - conn = sql.connection.Connection.set(parsed['connection']) - except Exception as e: - print(e) + conn = sql.connection.Connection.set( + connect_str, + displaycon=self.displaycon, + connect_args=args.connection_arguments, + creator=args.creator, + ) + # Rollback just in case there was an error in previous statement + conn.internal_connection.rollback() + except Exception: + print(traceback.format_exc()) print(sql.connection.Connection.tell_format()) return None - if flags.get('persist'): - return self._persist_dataframe(parsed['sql'], conn, user_ns) + if args.persist: + return self._persist_dataframe(parsed["sql"], conn, user_ns, append=False) + + if args.append: + return self._persist_dataframe(parsed["sql"], conn, user_ns, append=True) + + if not parsed["sql"]: + return try: - result = sql.run.run(conn, parsed['sql'], self, user_ns) + result = sql.run.run(conn, parsed["sql"], self, user_ns) - if result is not None and not isinstance(result, str) and self.column_local_vars: - #Instead of returning values, set variables directly in the - #users namespace. Variable names given by column names + if ( + result is not None + and not isinstance(result, str) + and self.column_local_vars + ): + # Instead of returning values, set variables directly in the + # user's namespace. Variable names given by column names if self.autopandas: keys = result.keys() @@ -105,51 +233,59 @@ def execute(self, line, cell='', local_ns={}): result = result.dict() if self.feedback: - print('Returning data to local variables [{}]'.format( - ', '.join(keys))) + print( + "Returning data to local variables [{}]".format(", ".join(keys)) + ) self.shell.user_ns.update(result) return None else: - if flags.get('result_var'): - result_var = flags['result_var'] + if parsed["result_var"]: + result_var = parsed["result_var"] print("Returning data to local variable {}".format(result_var)) self.shell.user_ns.update({result_var: result}) return None - #Return results into the default ipython _ variable + # Return results into the default ipython _ variable return result - except (ProgrammingError, OperationalError) as e: + # JA: added DatabaseError for MySQL + except (ProgrammingError, OperationalError, DatabaseError) as e: # Sqlite apparently return all errors as OperationalError :/ if self.short_errors: print(e) else: - raise + print(traceback.format_exc()) + raise e - legal_sql_identifier = re.compile(r'^[A-Za-z0-9#_$]+') - def _persist_dataframe(self, raw, conn, user_ns): + legal_sql_identifier = re.compile(r"^[A-Za-z0-9#_$]+") + + def _persist_dataframe(self, raw, conn, user_ns, append=False): """Implements PERSIST, which writes a DataFrame to the RDBMS""" if not DataFrame: raise ImportError("Must `pip install pandas` to use DataFrames") - frame_name = raw.strip(';') + frame_name = raw.strip(";") # Get the DataFrame from the user namespace if not frame_name: - raise SyntaxError('Syntax: %sql PERSIST ') - frame = eval(frame_name, user_ns) + raise SyntaxError("Syntax: %sql --persist ") + try: + frame = eval(frame_name, user_ns) + except SyntaxError: + raise SyntaxError("Syntax: %sql --persist ") if not isinstance(frame, DataFrame) and not isinstance(frame, Series): - raise TypeError('%s is not a Pandas DataFrame or Series' % frame_name) + raise TypeError("%s is not a Pandas DataFrame or Series" % frame_name) - # Make a suitable name for the resulting database table + # Make a suitable name for the resulting database table table_name = frame_name.lower() table_name = self.legal_sql_identifier.search(table_name).group(0) - frame.to_sql(table_name, conn.session.engine) - return 'Persisted %s' % table_name + if_exists = "append" if append else "fail" + frame.to_sql(table_name, conn.internal_connection.engine, if_exists=if_exists) + return "Persisted %s" % table_name def load_ipython_extension(ip): diff --git a/src/sql/parse.py b/src/sql/parse.py index 800729570..71f3edd06 100644 --- a/src/sql/parse.py +++ b/src/sql/parse.py @@ -1,54 +1,91 @@ +import itertools +import shlex from os.path import expandvars -import six + from six.moves import configparser as CP from sqlalchemy.engine.url import URL -def parse(cell, config): - """Separate input into (connection info, SQL statement)""" +def connection_from_dsn_section(section, config): + parser = CP.ConfigParser() + parser.read(config.dsn_filename) + cfg_dict = dict(parser.items(section)) + return URL.create(**cfg_dict) - parts = [part.strip() for part in cell.split(None, 1)] - if not parts: - return {'connection': '', 'sql': '', 'flags': {}} - parts[0] = expandvars(parts[0]) # for environment variables - if parts[0].startswith('[') and parts[0].endswith(']'): - section = parts[0].lstrip('[').rstrip(']') +def _connection_string(s, config): + s = expandvars(s) # for environment variables + if "@" in s or "://" in s: + return s + if s.startswith("[") and s.endswith("]"): + section = s.lstrip("[").rstrip("]") parser = CP.ConfigParser() parser.read(config.dsn_filename) cfg_dict = dict(parser.items(section)) + return str(URL(**cfg_dict)) + return "" + + +def parse(cell, config): + """Extract connection info and result variable from SQL + + Please don't add any more syntax requiring + special parsing. + Instead, add @arguments to SqlMagic.execute. + + We're grandfathering the + connection string and `<<` operator in. + """ + + result = {"connection": "", "sql": "", "result_var": None} + + pieces = cell.split(None, 1) + if not pieces: + return result + result["connection"] = _connection_string(pieces[0], config) + if result["connection"]: + if len(pieces) == 1: + return result + cell = pieces[1] + + pieces = cell.split(None, 2) + if len(pieces) > 1 and pieces[1] == "<<": + result["result_var"] = pieces[0] + if len(pieces) == 2: + return result + cell = pieces[2] + + result["sql"] = cell + return result + + +def _option_strings_from_parser(parser): + """Extracts the expected option strings (-a, --append, etc) from argparse parser + + Thanks Martijn Pieters + https://stackoverflow.com/questions/28881456/how-can-i-list-all-registered-arguments-from-an-argumentparser-instance + + :param parser: [description] + :type parser: IPython.core.magic_arguments.MagicArgumentParser + """ + opts = [a.option_strings for a in parser._actions] + return list(itertools.chain.from_iterable(opts)) + + +def without_sql_comment(parser, line): + """Strips -- comment from a line + + The argparser unfortunately expects -- to precede an option, + but in SQL that delineates a comment. So this removes comments + so a line can safely be fed to the argparser. + + :param line: A line of SQL, possibly mixed with option strings + :type line: str + """ - connection = str(URL(**cfg_dict)) - sql = parts[1] if len(parts) > 1 else '' - elif '@' in parts[0] or '://' in parts[0]: - connection = parts[0] - if len(parts) > 1: - sql = parts[1] - else: - sql = '' - else: - connection = '' - sql = cell - flags, sql = parse_sql_flags(sql.strip()) - return {'connection': connection.strip(), - 'sql': sql, - 'flags': flags} - - -def parse_sql_flags(sql): - words = sql.split() - flags = { - 'persist': False, - 'result_var': None - } - if not words: - return (flags, "") - num_words = len(words) - trimmed_sql = sql - if words[0].lower() == 'persist': - flags['persist'] = True - trimmed_sql = " ".join(words[1:]) - elif num_words >= 2 and words[1] == '<<': - flags['result_var'] = words[0] - trimmed_sql = " ".join(words[2:]) - return (flags, trimmed_sql.strip()) + args = _option_strings_from_parser(parser) + result = itertools.takewhile( + lambda word: (not word.startswith("--")) or (word in args), + shlex.split(line, posix=False), + ) + return " ".join(result) diff --git a/src/sql/run.py b/src/sql/run.py index 816f81800..bce0ec775 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -3,6 +3,7 @@ import operator import os.path import re +import traceback from functools import reduce import prettytable @@ -24,9 +25,9 @@ def unduplicate_field_names(field_names): for k in field_names: if k in res: i = 1 - while k + '_' + str(i) in res: + while k + "_" + str(i) in res: i += 1 - k += '_' + str(i) + k += "_" + str(i) res.append(k) return res @@ -46,8 +47,7 @@ def __init__(self, f, dialect=csv.excel, encoding="utf-8", **kwds): def writerow(self, row): if six.PY2: - _row = [s.encode("utf-8") if hasattr(s, "encode") else s - for s in row] + _row = [s.encode("utf-8") if hasattr(s, "encode") else s for s in row] else: _row = row self.writer.writerow(_row) @@ -55,7 +55,7 @@ def writerow(self, row): data = self.queue.getvalue() if six.PY2: data = data.decode("utf-8") - # ... and reencode it into the target encoding + # ... and re-encode it into the target encoding data = self.encoder.encode(data) # write to the target stream self.stream.write(data) @@ -75,12 +75,12 @@ def __init__(self, file_path): self.file_path = file_path def __repr__(self): - return 'CSV results at %s' % os.path.join( - os.path.abspath('.'), self.file_path) + return "CSV results at %s" % os.path.join(os.path.abspath("."), self.file_path) def _repr_html_(self): - return 'CSV results' % os.path.join('.', 'files', - self.file_path) + return 'CSV results' % os.path.join( + ".", "files", self.file_path + ) def _nonbreaking_spaces(match_obj): @@ -90,11 +90,11 @@ def _nonbreaking_spaces(match_obj): Call with a ``re`` match object. Retain group 1, replace group 2 with nonbreaking speaces. """ - spaces = ' ' * len(match_obj.group(2)) - return '%s%s' % (match_obj.group(1), spaces) + spaces = " " * len(match_obj.group(2)) + return "%s%s" % (match_obj.group(1), spaces) -_cell_with_spaces_pattern = re.compile(r'()( {2,})') +_cell_with_spaces_pattern = re.compile(r"()( {2,})") class ResultSet(list, ColumnGuesserMixin): @@ -104,42 +104,38 @@ class ResultSet(list, ColumnGuesserMixin): Can access rows listwise, or by string value of leftmost column. """ - def __init__(self, sqlaproxy, sql, config): - self.keys = sqlaproxy.keys() - self.sql = sql + def __init__(self, sqlaproxy, config): self.config = config - self.limit = config.autolimit - style_name = config.style - self.style = prettytable.__dict__[style_name.upper()] if sqlaproxy.returns_rows: - if self.limit: - list.__init__(self, sqlaproxy.fetchmany(size=self.limit)) + self.keys = sqlaproxy.keys() + if config.autolimit: + list.__init__(self, sqlaproxy.fetchmany(size=config.autolimit)) else: list.__init__(self, sqlaproxy.fetchall()) self.field_names = unduplicate_field_names(self.keys) - self.pretty = PrettyTable(self.field_names, style=self.style) - # self.pretty.set_style(self.style) + self.pretty = PrettyTable(self.field_names, style=prettytable.__dict__[config.style.upper()]) else: list.__init__(self, []) self.pretty = None def _repr_html_(self): - _cell_with_spaces_pattern = re.compile(r'()( {2,})') + _cell_with_spaces_pattern = re.compile(r"()( {2,})") if self.pretty: self.pretty.add_rows(self) result = self.pretty.get_html_string() result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result) - if self.config.displaylimit and len( - self) > self.config.displaylimit: - result = '%s\n%d rows, truncated to displaylimit of %d' % ( - result, len(self), self.config.displaylimit) + if self.config.displaylimit and len(self) > self.config.displaylimit: + result = ( + '%s\n%d rows, truncated to displaylimit of %d' + % (result, len(self), self.config.displaylimit) + ) return result else: return None def __str__(self, *arg, **kwarg): self.pretty.add_rows(self) - return str(self.pretty or '') + return str(self.pretty or "") def __getitem__(self, key): """ @@ -163,13 +159,14 @@ def dict(self): return dict(zip(self.keys, zip(*self))) def dicts(self): - "Iterator yielding a dict for each row" + """Iterator yielding a dict for each row""" for row in self: yield dict(zip(self.keys, row)) def DataFrame(self): - "Returns a Pandas DataFrame instance built from the result set." + """Returns a Pandas DataFrame instance built from the result set.""" import pandas as pd + frame = pd.DataFrame(self, columns=(self and self.keys) or []) return frame @@ -196,6 +193,7 @@ def pie(self, key_word_sep=" ", title=None, **kwargs): """ self.guess_pie_columns(xlabel_sep=key_word_sep) import matplotlib.pylab as plt + pie = plt.pie(self.ys[0], labels=self.xlabels, **kwargs) plt.title(title or self.ys[0].name) return pie @@ -219,11 +217,12 @@ def plot(self, title=None, **kwargs): through to ``matplotlib.pylab.plot``. """ import matplotlib.pylab as plt + self.guess_plot_columns() self.x = self.x or range(len(self.ys[0])) coords = reduce(operator.add, [(self.x, y) for y in self.ys]) plot = plt.plot(*coords, **kwargs) - if hasattr(self.x, 'name'): + if hasattr(self.x, "name"): plt.xlabel(self.x.name) ylabel = ", ".join(y.name for y in self.ys) plt.title(title or ylabel) @@ -251,6 +250,7 @@ def bar(self, key_word_sep=" ", title=None, **kwargs): through to ``matplotlib.pylab.bar``. """ import matplotlib.pylab as plt + self.guess_pie_columns(xlabel_sep=key_word_sep) plot = plt.bar(range(len(self.ys[0])), self.ys[0], **kwargs) if self.xlabels: @@ -266,11 +266,11 @@ def csv(self, filename=None, **format_params): return None # no results self.pretty.add_rows(self) if filename: - encoding = format_params.get('encoding', 'utf-8') + encoding = format_params.get("encoding", "utf-8") if six.PY2: - outfile = open(filename, 'wb') + outfile = open(filename, "wb") else: - outfile = open(filename, 'w', newline='', encoding=encoding) + outfile = open(filename, "w", newline="", encoding=encoding) else: outfile = six.StringIO() writer = UnicodeWriter(outfile, **format_params) @@ -286,9 +286,9 @@ def csv(self, filename=None, **format_params): def interpret_rowcount(rowcount): if rowcount < 0: - result = 'Done.' + result = "Done." else: - result = '%d rows affected.' % rowcount + result = "%d rows affected." % rowcount return result @@ -311,67 +311,74 @@ def __init__(self, cursor, headers): self.returns_rows = True def from_list(self, source_list): - "Simulates SQLA ResultProxy from a list." + """Simulates SQLA ResultProxy from a list.""" - self.fetchall = lambda: source_list + self.fetchall = lambda: source_list self.rowcount = len(source_list) def fetchmany(size): - pos = 0 + pos = 0 while pos < len(source_list): - yield source_list[pos:pos+size] - pos += size + yield source_list[pos: pos + size] + pos += size self.fetchmany = fetchmany - # some dialects have autocommit # specific dialects break when commit is used: -_COMMIT_BLACKLIST_DIALECTS = ('mssql', 'clickhouse', 'teradata') + +_COMMIT_BLACKLIST_DIALECTS = ("athena", "bigquery", "clickhouse", "ingres", "mssql", "teradata", "vertica") def _commit(conn, config): """Issues a commit, if appropriate for current config and dialect""" _should_commit = config.autocommit and all( - dialect not in str(conn.dialect) - for dialect in _COMMIT_BLACKLIST_DIALECTS) + dialect not in str(conn.dialect) for dialect in _COMMIT_BLACKLIST_DIALECTS + ) if _should_commit: try: - conn.session.execute('commit') + conn.internal_connection.commit() except sqlalchemy.exc.OperationalError: pass # not all engines can commit + except Exception as ex: + conn.internal_connection.rollback() + print(traceback.format_exc()) + raise ex def run(conn, sql, config, user_namespace): if sql.strip(): for statement in sqlparse.split(sql): first_word = sql.strip().split()[0].lower() - if first_word == 'begin': + if first_word == "begin": raise Exception("ipython_sql does not support transactions") - if first_word.startswith('\\') and 'postgres' in str(conn.dialect): + if first_word.startswith("\\") and \ + ("postgres" in str(conn.dialect) or + "redshift" in str(conn.dialect)): if not PGSpecial: - raise ImportError('pgspecial not installed') + raise ImportError("pgspecial not installed") pgspecial = PGSpecial() _, cur, headers, _ = pgspecial.execute( - conn.session.connection.cursor(), statement)[0] + conn.internal_connection.connection.cursor(), statement + )[0] result = FakeResultProxy(cur, headers) else: txt = sqlalchemy.sql.text(statement) - result = conn.session.execute(txt, user_namespace) + result = conn.internal_connection.execute(txt, user_namespace) _commit(conn=conn, config=config) if result and config.feedback: print(interpret_rowcount(result.rowcount)) - resultset = ResultSet(result, statement, config) + resultset = ResultSet(result, config) if config.autopandas: return resultset.DataFrame() else: return resultset - #returning only last result, intentionally + # returning only last result, intentionally else: - return 'Connected: %s' % conn.name + return "Connected: %s" % conn.name class PrettyTable(prettytable.PrettyTable): @@ -391,5 +398,5 @@ def add_rows(self, data): self.row_count = len(data) else: self.row_count = min(len(data), self.displaylimit) - for row in data[:self.displaylimit]: + for row in data[: self.displaylimit]: self.add_row(row) diff --git a/src/tests/test_column_guesser.py b/src/tests/test_column_guesser.py index ebbc781ff..d10b8b9f4 100644 --- a/src/tests/test_column_guesser.py +++ b/src/tests/test_column_guesser.py @@ -1,6 +1,3 @@ -import re -import sys - import pytest from sql.magic import SqlMagic @@ -13,10 +10,10 @@ def __init__(self, connectstr): self.connectstr = connectstr def query(self, txt): - return ip.run_line_magic('sql', "%s %s" % (self.connectstr, txt)) + return ip.run_line_magic("sql", "%s %s" % (self.connectstr, txt)) -sql_env = SqlEnv('sqlite://') +sql_env = SqlEnv("sqlite://") @pytest.fixture @@ -54,14 +51,14 @@ def test_pie(self, tbl): assert results.ys == [[1.01, 2.01, 3.01]] assert results.x == [] assert results.xlabels == [] - assert results.xlabel == '' + assert results.xlabel == "" def test_plot(self, tbl): results = self.run_query() results.guess_plot_columns() assert results.ys == [[1.01, 2.01, 3.01]] assert results.x == [] - assert results.x.name == '' + assert results.x.name == "" class TestOneStrOneNum(Harness): @@ -72,8 +69,8 @@ def test_pie(self, tbl): results.guess_pie_columns(xlabel_sep="//") assert results.ys[0].is_quantity assert results.ys == [[1.01, 2.01, 3.01]] - assert results.xlabels == ['r1-txt1', 'r2-txt1', 'r3-txt1'] - assert results.xlabel == 'name' + assert results.xlabels == ["r1-txt1", "r2-txt1", "r3-txt1"] + assert results.xlabel == "name" def test_plot(self, tbl): results = self.run_query() @@ -91,10 +88,11 @@ def test_pie(self, tbl): assert results.ys[0].is_quantity assert results.ys == [[1.01, 2.01, 3.01]] assert results.xlabels == [ - 'r1-txt2//1.04//r1-txt1', 'r2-txt2//2.04//r2-txt1', - 'r3-txt2//3.04//r3-txt1' + "r1-txt2//1.04//r1-txt1", + "r2-txt2//2.04//r2-txt1", + "r3-txt2//3.04//r3-txt1", ] - assert results.xlabel == 'name2, y3, name' + assert results.xlabel == "name2, y3, name" def test_plot(self, tbl): results = self.run_query() @@ -112,8 +110,9 @@ def test_pie(self, tbl): assert results.ys[0].is_quantity assert results.ys == [[1.04, 2.04, 3.04]] assert results.xlabels == [ - 'r1-txt1//1.01//r1-txt2//1.02', 'r2-txt1//2.01//r2-txt2//2.02', - 'r3-txt1//3.01//r3-txt2//3.02' + "r1-txt1//1.01//r1-txt2//1.02", + "r2-txt1//2.01//r2-txt2//2.02", + "r3-txt1//3.01//r3-txt2//3.02", ] def test_plot(self, tbl): diff --git a/src/tests/test_dsn_config.ini b/src/tests/test_dsn_config.ini new file mode 100644 index 000000000..e9e5db9c8 --- /dev/null +++ b/src/tests/test_dsn_config.ini @@ -0,0 +1,14 @@ +[DB_CONFIG_1] +drivername = postgres +host = my.remote.host +port = 5432 +database = pgmain +username = goesto11 +password = seentheelephant + +[DB_CONFIG_2] +drivername = mysql +host = 127.0.0.1 +database = dolfin +username = thefin +password = fishputsfishonthetable diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 92454c8d3..116cd6638 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -5,16 +5,12 @@ import pytest -from sql.magic import SqlMagic - def runsql(ip_session, statements): if isinstance(statements, str): - statements = [ - statements, - ] + statements = [statements] for statement in statements: - result = ip_session.run_line_magic('sql', 'sqlite:// %s' % statement) + result = ip_session.run_line_magic("sql", "sqlite:// %s" % statement) return result # returns only last result @@ -23,77 +19,108 @@ def ip(): """Provides an IPython session in which tables have been created""" ip_session = get_ipython() - runsql(ip_session, [ - "CREATE TABLE test (n INT, name TEXT)", - "INSERT INTO test VALUES (1, 'foo')", - "INSERT INTO test VALUES (2, 'bar')", - "CREATE TABLE author (first_name, last_name, year_of_death)", - "INSERT INTO author VALUES ('William', 'Shakespeare', 1616)", - "INSERT INTO author VALUES ('Bertold', 'Brecht', 1956)" - ]) + runsql( + ip_session, + [ + "CREATE TABLE test (n INT, name TEXT)", + "INSERT INTO test VALUES (1, 'foo')", + "INSERT INTO test VALUES (2, 'bar')", + "CREATE TABLE author (first_name, last_name, year_of_death)", + "INSERT INTO author VALUES ('William', 'Shakespeare', 1616)", + "INSERT INTO author VALUES ('Bertold', 'Brecht', 1956)", + ], + ) yield ip_session - runsql(ip_session, 'DROP TABLE test') - runsql(ip_session, 'DROP TABLE author') + runsql(ip_session, "DROP TABLE test") + runsql(ip_session, "DROP TABLE author") def test_memory_db(ip): assert runsql(ip, "SELECT * FROM test;")[0][0] == 1 - assert runsql(ip, "SELECT * FROM test;")[1]['name'] == 'bar' + assert runsql(ip, "SELECT * FROM test;")[1].name == "bar" def test_html(ip): result = runsql(ip, "SELECT * FROM test;") - assert 'foo' in result._repr_html_().lower() + assert "foo" in result._repr_html_().lower() def test_print(ip): result = runsql(ip, "SELECT * FROM test;") - assert re.search(r'1\s+\|\s+foo', str(result)) + assert re.search(r"1\s+\|\s+foo", str(result)) def test_plain_style(ip): - ip.run_line_magic('config', "SqlMagic.style = 'PLAIN_COLUMNS'") + ip.run_line_magic("config", "SqlMagic.style = 'PLAIN_COLUMNS'") result = runsql(ip, "SELECT * FROM test;") - assert re.search(r'1\s+\|\s+foo', str(result)) + assert re.search(r"1\s+\|\s+foo", str(result)) +@pytest.mark.skip def test_multi_sql(ip): - result = ip.run_cell_magic('sql', '', """ + result = ip.run_cell_magic( + "sql", + "", + """ sqlite:// SELECT last_name FROM author; - """) - assert 'Shakespeare' in str(result) and 'Brecht' in str(result) + """, + ) + assert "Shakespeare" in str(result) and "Brecht" in str(result) def test_result_var(ip): - ip.run_cell_magic('sql', '', """ + ip.run_cell_magic( + "sql", + "", + """ sqlite:// x << SELECT last_name FROM author; - """) - result = ip.user_global_ns['x'] - assert 'Shakespeare' in str(result) and 'Brecht' in str(result) + """, + ) + result = ip.user_global_ns["x"] + assert "Shakespeare" in str(result) and "Brecht" in str(result) + + +def test_result_var_multiline_shovel(ip): + ip.run_cell_magic( + "sql", + "", + """ + sqlite:// x << SELECT last_name + FROM author; + """, + ) + result = ip.user_global_ns["x"] + assert "Shakespeare" in str(result) and "Brecht" in str(result) def test_access_results_by_keys(ip): - assert runsql(ip, - "SELECT * FROM author;")['William'] == (u'William', - u'Shakespeare', 1616) + assert runsql(ip, "SELECT * FROM author;")["William"] == ( + u"William", + u"Shakespeare", + 1616, + ) def test_duplicate_column_names_accepted(ip): - result = ip.run_cell_magic('sql', '', """ + result = ip.run_cell_magic( + "sql", + "", + """ sqlite:// SELECT last_name, last_name FROM author; - """) - assert (u'Brecht', u'Brecht') in result + """, + ) + assert (u"Brecht", u"Brecht") in result def test_autolimit(ip): - ip.run_line_magic('config', "SqlMagic.autolimit = 0") + ip.run_line_magic("config", "SqlMagic.autolimit = 0") result = runsql(ip, "SELECT * FROM test;") assert len(result) == 2 - ip.run_line_magic('config', "SqlMagic.autolimit = 1") + ip.run_line_magic("config", "SqlMagic.autolimit = 1") result = runsql(ip, "SELECT * FROM test;") assert len(result) == 1 @@ -101,36 +128,69 @@ def test_autolimit(ip): def test_persist(ip): runsql(ip, "") ip.run_cell("results = %sql SELECT * FROM test;") - ip.runcode("results_dframe = results.DataFrame()") - runsql(ip, 'PERSIST results_dframe') - persisted = runsql(ip, 'SELECT * FROM results_dframe') - assert 'foo' in str(persisted) + ip.run_cell("results_dframe = results.DataFrame()") + ip.run_cell("%sql --persist sqlite:// results_dframe") + persisted = runsql(ip, "SELECT * FROM results_dframe") + assert "foo" in str(persisted) + + +def test_append(ip): + runsql(ip, "") + ip.run_cell("results = %sql SELECT * FROM test;") + ip.run_cell("results_dframe = results.DataFrame()") + ip.run_cell("%sql --persist sqlite:// results_dframe") + persisted = runsql(ip, "SELECT COUNT(*) FROM results_dframe") + ip.run_cell("%sql --append sqlite:// results_dframe") + appended = runsql(ip, "SELECT COUNT(*) FROM results_dframe") + assert appended[0][0] == persisted[0][0] * 2 def test_persist_nonexistent_raises(ip): runsql(ip, "") - with pytest.raises(NameError): - runsql(ip, 'PERSIST no_such_dataframe') + result = ip.run_cell("%sql --persist sqlite:// no_such_dataframe") + assert result.error_in_exec def test_persist_non_frame_raises(ip): ip.run_cell("not_a_dataframe = 22") runsql(ip, "") - with pytest.raises(TypeError): - runsql(ip, 'PERSIST not_a_dataframe') + result = ip.run_cell("%sql --persist sqlite:// not_a_dataframe") + assert result.error_in_exec def test_persist_bare(ip): - ip.run_line_magic('sql', "sqlite://") - with pytest.raises(SyntaxError): - runsql(ip, 'PERSIST') + result = ip.run_cell("%sql --persist sqlite://") + assert result.error_in_exec def test_persist_frame_at_its_creation(ip): ip.run_cell("results = %sql SELECT * FROM author;") - runsql(ip, 'PERSIST results.DataFrame()') - persisted = runsql(ip, 'SELECT * FROM results') - assert 'Shakespeare' in str(persisted) + ip.run_cell("%sql --persist sqlite:// results.DataFrame()") + persisted = runsql(ip, "SELECT * FROM results") + assert "Shakespeare" in str(persisted) + + +def test_connection_args_enforce_json(ip): + result = ip.run_cell('%sql --connection_arguments {"badlyformed":true') + assert result.error_in_exec + + +def test_connection_args_in_connection(ip): + ip.run_cell('%sql --connection_arguments {"timeout":10} sqlite:///:memory:') + result = ip.run_cell("%sql --connections") + assert "timeout" in result.result["sqlite:///:memory:"].connect_args + + +def test_connection_args_single_quotes(ip): + ip.run_cell("%sql --connection_arguments '{\"timeout\": 10}' sqlite:///:memory:") + result = ip.run_cell("%sql --connections") + assert "timeout" in result.result["sqlite:///:memory:"].connect_args + + +def test_connection_args_double_quotes(ip): + ip.run_cell('%sql --connection_arguments "{\\"timeout\\": 10}" sqlite:///:memory:') + result = ip.run_cell("%sql --connections") + assert "timeout" in result.result["sqlite:///:memory:"].connect_args # TODO: support @@ -143,95 +203,191 @@ def test_persist_frame_at_its_creation(ip): def test_displaylimit(ip): - ip.run_line_magic('config', "SqlMagic.autolimit = None") - ip.run_line_magic('config', "SqlMagic.displaylimit = None") + ip.run_line_magic("config", "SqlMagic.autolimit = None") + ip.run_line_magic("config", "SqlMagic.displaylimit = None") result = runsql( ip, - "SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;" + "SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;", ) - assert 'apple' in result._repr_html_() - assert 'banana' in result._repr_html_() - assert 'cherry' in result._repr_html_() - ip.run_line_magic('config', "SqlMagic.displaylimit = 1") + assert "apple" in result._repr_html_() + assert "banana" in result._repr_html_() + assert "cherry" in result._repr_html_() + ip.run_line_magic("config", "SqlMagic.displaylimit = 1") result = runsql( ip, - "SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;" + "SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;", ) - assert 'apple' in result._repr_html_() - assert 'cherry' not in result._repr_html_() + assert "apple" in result._repr_html_() + assert "cherry" not in result._repr_html_() def test_column_local_vars(ip): - ip.run_line_magic('config', "SqlMagic.column_local_vars = True") + ip.run_line_magic("config", "SqlMagic.column_local_vars = True") result = runsql(ip, "SELECT * FROM author;") assert result is None - assert 'William' in ip.user_global_ns['first_name'] - assert 'Shakespeare' in ip.user_global_ns['last_name'] - assert len(ip.user_global_ns['first_name']) == 2 - ip.run_line_magic('config', "SqlMagic.column_local_vars = False") + assert "William" in ip.user_global_ns["first_name"] + assert "Shakespeare" in ip.user_global_ns["last_name"] + assert len(ip.user_global_ns["first_name"]) == 2 + ip.run_line_magic("config", "SqlMagic.column_local_vars = False") def test_userns_not_changed(ip): ip.run_cell( - dedent(""" + dedent( + """ def function(): local_var = 'local_val' %sql sqlite:// INSERT INTO test VALUES (2, 'bar'); - function()""")) - assert 'local_var' not in ip.user_ns + function()""" + ) + ) + assert "local_var" not in ip.user_ns def test_bind_vars(ip): - ip.user_global_ns['x'] = 22 + ip.user_global_ns["x"] = 22 result = runsql(ip, "SELECT :x") assert result[0][0] == 22 def test_autopandas(ip): - ip.run_line_magic('config', "SqlMagic.autopandas = True") + ip.run_line_magic("config", "SqlMagic.autopandas = True") dframe = runsql(ip, "SELECT * FROM test;") assert not dframe.empty assert dframe.ndim == 2 - assert dframe.name[0] == 'foo' + assert dframe.name[0] == "foo" def test_csv(ip): - ip.run_line_magic('config', "SqlMagic.autopandas = False") # uh-oh + ip.run_line_magic("config", "SqlMagic.autopandas = False") # uh-oh result = runsql(ip, "SELECT * FROM test;") result = result.csv() for row in result.splitlines(): - assert row.count(',') == 1 + assert row.count(",") == 1 assert len(result.splitlines()) == 3 def test_csv_to_file(ip): - ip.run_line_magic('config', "SqlMagic.autopandas = False") # uh-oh + ip.run_line_magic("config", "SqlMagic.autopandas = False") # uh-oh result = runsql(ip, "SELECT * FROM test;") with tempfile.TemporaryDirectory() as tempdir: - fname = os.path.join(tempdir, 'test.csv') + fname = os.path.join(tempdir, "test.csv") output = result.csv(fname) assert os.path.exists(output.file_path) with open(output.file_path) as csvfile: content = csvfile.read() for row in content.splitlines(): - assert row.count(',') == 1 + assert row.count(",") == 1 assert len(content.splitlines()) == 3 +def test_sql_from_file(ip): + ip.run_line_magic("config", "SqlMagic.autopandas = False") + with tempfile.TemporaryDirectory() as tempdir: + fname = os.path.join(tempdir, "test.sql") + with open(fname, "w") as tempf: + tempf.write("SELECT * FROM test;") + result = ip.run_cell("%sql --file " + fname) + assert result.result == [(1, "foo"), (2, "bar")] + + +def test_sql_from_nonexistent_file(ip): + ip.run_line_magic("config", "SqlMagic.autopandas = False") + with tempfile.TemporaryDirectory() as tempdir: + fname = os.path.join(tempdir, "nonexistent.sql") + result = ip.run_cell("%sql --file " + fname) + assert isinstance(result.error_in_exec, FileNotFoundError) + + def test_dict(ip): result = runsql(ip, "SELECT * FROM author;") result = result.dict() assert isinstance(result, dict) - assert 'first_name' in result - assert 'last_name' in result - assert 'year_of_death' in result - assert len(result['last_name']) == 2 + assert "first_name" in result + assert "last_name" in result + assert "year_of_death" in result + assert len(result["last_name"]) == 2 def test_dicts(ip): result = runsql(ip, "SELECT * FROM author;") for row in result.dicts(): assert isinstance(row, dict) - assert 'first_name' in row - assert 'last_name' in row - assert 'year_of_death' in row + assert "first_name" in row + assert "last_name" in row + assert "year_of_death" in row + + +def test_bracket_var_substitution(ip): + ip.user_global_ns["col"] = "first_name" + assert runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ")[0] == ( + u"William", + u"Shakespeare", + 1616, + ) + + ip.user_global_ns["col"] = "last_name" + result = runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ") + assert not result + + +def test_multiline_bracket_var_substitution(ip): + ip.user_global_ns["col"] = "first_name" + assert runsql(ip, "SELECT * FROM author\n" " WHERE {col} = 'William' ")[0] == ( + u"William", + u"Shakespeare", + 1616, + ) + + ip.user_global_ns["col"] = "last_name" + result = runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ") + assert not result + + +def test_multiline_bracket_var_substitution(ip): + ip.user_global_ns["col"] = "first_name" + result = ip.run_cell_magic( + "sql", + "", + """ + sqlite:// SELECT * FROM author + WHERE {col} = 'William' + """, + ) + assert (u"William", u"Shakespeare", 1616) in result + + ip.user_global_ns["col"] = "last_name" + result = ip.run_cell_magic( + "sql", + "", + """ + sqlite:// SELECT * FROM author + WHERE {col} = 'William' + """, + ) + assert not result + + +def test_json_in_select(ip): + # Variable expansion does not work within json, but + # at least the two usages of curly braces do not collide + ip.user_global_ns["person"] = "prince" + result = ip.run_cell_magic( + "sql", + "", + """ + sqlite:// + SELECT + '{"greeting": "Farewell sweet {person}"}' + AS json + """, + ) + assert ('{"greeting": "Farewell sweet {person}"}',) + + +def test_close_connection(ip): + connections = runsql(ip, "%sql -l") + connection_name = list(connections)[0] + runsql(ip, f"%sql -x {connection_name}") + connections_afterward = runsql(ip, "%sql -l") + assert connection_name not in connections_afterward diff --git a/src/tests/test_parse.py b/src/tests/test_parse.py index 495425450..46d3678bd 100644 --- a/src/tests/test_parse.py +++ b/src/tests/test_parse.py @@ -1,41 +1,184 @@ import os -from sql.parse import parse -from six.moves import configparser +from pathlib import Path + +from sql.parse import connection_from_dsn_section, parse, without_sql_comment + try: from traitlets.config.configurable import Configurable except ImportError: from IPython.config.configurable import Configurable empty_config = Configurable() -default_flags = {'persist': False, 'result_var': None} +default_connect_args = {"options": "-csearch_path=test"} + + def test_parse_no_sql(): - assert parse("will:longliveliz@localhost/shakes", empty_config) == \ - {'connection': "will:longliveliz@localhost/shakes", - 'sql': '', - 'flags': default_flags} + assert parse("will:longliveliz@localhost/shakes", empty_config) == { + "connection": "will:longliveliz@localhost/shakes", + "sql": "", + "result_var": None, + } + def test_parse_with_sql(): - assert parse("postgresql://will:longliveliz@localhost/shakes SELECT * FROM work", - empty_config) == \ - {'connection': "postgresql://will:longliveliz@localhost/shakes", - 'sql': 'SELECT * FROM work', - 'flags': default_flags} + assert parse( + "postgresql://will:longliveliz@localhost/shakes SELECT * FROM work", + empty_config, + ) == { + "connection": "postgresql://will:longliveliz@localhost/shakes", + "sql": "SELECT * FROM work", + "result_var": None, + } + def test_parse_sql_only(): - assert parse("SELECT * FROM work", empty_config) == \ - {'connection': "", - 'sql': 'SELECT * FROM work', - 'flags': default_flags} + assert parse("SELECT * FROM work", empty_config) == { + "connection": "", + "sql": "SELECT * FROM work", + "result_var": None, + } + def test_parse_postgresql_socket_connection(): - assert parse("postgresql:///shakes SELECT * FROM work", empty_config) == \ - {'connection': "postgresql:///shakes", - 'sql': 'SELECT * FROM work', - 'flags': default_flags} + assert parse("postgresql:///shakes SELECT * FROM work", empty_config) == { + "connection": "postgresql:///shakes", + "sql": "SELECT * FROM work", + "result_var": None, + } + def test_expand_environment_variables_in_connection(): - os.environ['DATABASE_URL'] = 'postgresql:///shakes' - assert parse("$DATABASE_URL SELECT * FROM work", empty_config) == \ - {'connection': "postgresql:///shakes", - 'sql': 'SELECT * FROM work', - 'flags': default_flags} + os.environ["DATABASE_URL"] = "postgresql:///shakes" + assert parse("$DATABASE_URL SELECT * FROM work", empty_config) == { + "connection": "postgresql:///shakes", + "sql": "SELECT * FROM work", + "result_var": None, + } + + +def test_parse_shovel_operator(): + assert parse("dest << SELECT * FROM work", empty_config) == { + "connection": "", + "sql": "SELECT * FROM work", + "result_var": "dest", + } + + +def test_parse_connect_plus_shovel(): + assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == { + "connection": "sqlite://", + "sql": "SELECT * FROM work", + "result_var": None, + } + + +def test_parse_shovel_operator(): + assert parse("dest << SELECT * FROM work", empty_config) == { + "connection": "", + "sql": "SELECT * FROM work", + "result_var": "dest", + } + + +def test_parse_connect_plus_shovel(): + assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == { + "connection": "sqlite://", + "sql": "SELECT * FROM work", + "result_var": "dest", + } + + +def test_parse_early_newlines(): + assert parse("--comment\nSELECT *\n--comment\nFROM work", empty_config) == { + "connection": "", + "sql": "--comment\nSELECT *\n--comment\nFROM work", + "result_var": None + } + + +def test_parse_connect_shovel_over_newlines(): + assert parse("\nsqlite://\ndest\n<<\nSELECT *\nFROM work", empty_config) == { + "connection": "sqlite://", + "sql": "SELECT *\nFROM work", + "result_var": "dest" + } + + +class DummyConfig: + dsn_filename = Path("src/tests/test_dsn_config.ini") + + +def test_connection_from_dsn_section(): + result = connection_from_dsn_section(section="DB_CONFIG_1", config=DummyConfig()) + assert str(result) == "postgres://goesto11:***@my.remote.host:5432/pgmain" + result = connection_from_dsn_section(section="DB_CONFIG_2", config=DummyConfig()) + assert str(result) == "mysql://thefin:***@127.0.0.1/dolfin" + + +class Bunch: + def __init__(self, **kwds): + self.__dict__.update(kwds) + + +class ParserStub: + opstrs = [ + [], + ["-l", "--connections"], + ["-x", "--close"], + ["-c", "--creator"], + ["-s", "--section"], + ["-p", "--persist"], + ["--append"], + ["-a", "--connection_arguments"], + ["-f", "--file"], + ] + _actions = [Bunch(option_strings=o) for o in opstrs] + + +parser_stub = ParserStub() + + +def test_without_sql_comment_plain(): + line = "SELECT * FROM author" + assert without_sql_comment(parser=parser_stub, line=line) == line + + +def test_without_sql_comment_with_arg(): + line = "--file moo.txt --persist SELECT * FROM author" + assert without_sql_comment(parser=parser_stub, line=line) == line + + +def test_without_sql_comment_with_comment(): + line = "SELECT * FROM author -- uff da" + expected = "SELECT * FROM author" + assert without_sql_comment(parser=parser_stub, line=line) == expected + + +def test_without_sql_comment_with_arg_and_comment(): + line = "--file moo.txt --persist SELECT * FROM author -- uff da" + expected = "--file moo.txt --persist SELECT * FROM author" + assert without_sql_comment(parser=parser_stub, line=line) == expected + + +def test_without_sql_comment_unspaced_comment(): + line = "SELECT * FROM author --uff da" + expected = "SELECT * FROM author" + assert without_sql_comment(parser=parser_stub, line=line) == expected + + +def test_without_sql_comment_dashes_in_string(): + line = "SELECT '--very --confusing' FROM author -- uff da" + expected = "SELECT '--very --confusing' FROM author" + assert without_sql_comment(parser=parser_stub, line=line) == expected + + +def test_without_sql_comment_with_arg_and_leading_comment(): + line = "--file moo.txt --persist --comment, not arg" + expected = "--file moo.txt --persist" + assert without_sql_comment(parser=parser_stub, line=line) == expected + + +def test_without_sql_persist(): + line = "--persist my_table --uff da" + expected = "--persist my_table" + assert without_sql_comment(parser=parser_stub, line=line) == expected