diff --git a/.gitignore b/.gitignore index a2190b5..4609c36 100644 --- a/.gitignore +++ b/.gitignore @@ -123,3 +123,6 @@ dmypy.json .pyre/ .vscode + +metastore_db/ +spark-warehouse/ diff --git a/README.md b/README.md index 8a110ca..c96f30f 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ See samples below for more examples. * PySpark with all Spark features including reading and writing to disk, UDFs and Pandas UDFs * Databricks Utilities (`dbutils`, `display`) with user-configurable mocks * Mocking connectors such as Azure Storage, S3 and SQL Data Warehouse +* Helper function to compare and evaluate the results of Spark SQL queries ## Unsupported features @@ -71,6 +72,25 @@ See samples below for more examples. * Writing directly to `/dbfs` mount on local filesystem * Databricks extensions to Spark such as `spark.read.format("binaryFile")` +## Helper functions + +Helper functions are available to assist with Spark SQL queries: + +### `Session.assert_queries_are_equal(actual_query, expected_query)` + +* Asserts that the result sets returned by two supplied Spark SQL queries are equal +* A detailed table comparison output is only shown in the event the assertion fails. +In the table comparison output, the first column (`m`) can have three different values: + * the symbol `<` indicates that the row was found in the *expected* results but did not match anything in the *actual* results + * the symbol `>` indicates that the row was found in the *actual* table but not in the *expected* table + * the symbol `=` indicates that the row was matched between the *expected* and *actual* tables + * *this behaviour is inspired by the `tSQLt` unit test framework for SQL Server* + +### `Session.assert_query_returns_no_rows(actual_query)` + +* Assets that the result set returned by a supplied Spark SQL query is empty +* A detailed table comparison output is only shown in the event that the result set is not empty + ## Sample test Sample test case for an ETL notebook reading CSV and writing Parquet. @@ -238,6 +258,51 @@ def test_sqldw(monkeypatch): assert_frame_equal(expectedDF, resultDF, check_dtype=False) ``` +## Spark SQL comparison functions + +A test comparing the output of two Spark SQL queries using the +`assert_queries_are_equal` function: + +```python +def test_results_do_not_match(): + with databricks_test.session() as dbrickstest: + actual_query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo'), + (101,'bar'), + (102,'baz') + ) AS v (col1, col2) + """ + + expected_query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo'), + (110,'bar'), + (999,'qux') + ) AS v (col1, col2) + """ + + dbrickstest.assert_queries_are_equal(actual_query, expected_query) +``` + +A test validating that the output of a Spark SQL query returns no rows +using the `assert_query_returns_no_rows` function: + +```python +def test_no_rows_returned(): + with databricks_test.session() as dbrickstest: + query = """ + SELECT 100 AS col1, 'abc' AS col2 + WHERE 1=2 + """ + + dbrickstest.assert_query_returns_no_rows(query) +``` + ## Issues Please report issues at [http://github.com/algattik/databricks_test](http://github.com/algattik/databricks_test). diff --git a/databricks_test/__init__.py b/databricks_test/__init__.py index 66deabd..fefa6b5 100644 --- a/databricks_test/__init__.py +++ b/databricks_test/__init__.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock import inspect from pyspark.sql import SparkSession -from pyspark.sql.functions import udf +from pyspark.sql.functions import udf, col import importlib import sys import os @@ -87,6 +87,12 @@ def __init__(self): self.spark = (SparkSession.builder .master("local") .appName("test-pyspark") + #add delta lake support + .config("spark.jars.packages", "io.delta:delta-core_2.12:0.7.0") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") + #set default SerDe to parquet + .config("hive.default.fileformat","parquet") .enableHiveSupport() .getOrCreate()) @@ -105,6 +111,63 @@ def run_notebook(self, dir, script): except WorkflowInterrupted: pass + def get_show_string(self, df, n=20, truncate=True, vertical=False)->str: + """ + Returns the output of the `.show()` function as a string + """ + if isinstance(truncate, bool) and truncate: + return(df._jdf.showString(n, 20, vertical)) + else: + return(df._jdf.showString(n, int(truncate), vertical)) + + def assert_queries_are_equal(self, actual_query, expected_query): + """ + Compares the results of two queries for equality + + The resultsets generated by the queries must have the same schema + + The detailed table comparison output is only shown in the event of a failure + """ + result_df = self.spark.sql(f""" + WITH actual + AS + ({actual_query}) + ,expected + AS + ({expected_query}) + ,matched + AS + (SELECT * FROM actual INTERSECT SELECT * FROM expected) + ,missing + AS + (SELECT * FROM actual EXCEPT SELECT * FROM expected) + ,extra + AS + (SELECT * FROM expected EXCEPT SELECT * FROM actual) + SELECT '=' AS m, * FROM matched + UNION ALL + SELECT '>' AS m, * FROM missing + UNION ALL + SELECT '<' AS m, * FROM extra + """ + ) + + failure_count = result_df.where(col("m").isin({">", "<"})).count() + if failure_count > 0: + msg = self.get_show_string(result_df, n=result_df.count(), truncate=False) + assert failure_count == 0, f"the result sets did not match:\n{msg}" + + def assert_query_returns_no_rows(self, actual_query): + """ + Asserts that a query returns an empty result set + """ + result_df = self.spark.sql(actual_query) + + record_count = result_df.count() + if record_count > 0: + msg = self.get_show_string(result_df, n=result_df.count(), truncate=False) + assert record_count == 0, f"the result set was not empty:\n{msg}" + def inject_variables(): """ diff --git a/tests/assert_queries_are_equal_test.py b/tests/assert_queries_are_equal_test.py new file mode 100644 index 0000000..de38614 --- /dev/null +++ b/tests/assert_queries_are_equal_test.py @@ -0,0 +1,110 @@ +import databricks_test +import pytest + + +def test_results_match(): + with databricks_test.session() as dbrickstest: + query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo'), + (101,'bar'), + (102,'baz') + ) AS v (col1, col2) + """ + + dbrickstest.assert_queries_are_equal(query, query) + +def test_results_do_not_match(): + with databricks_test.session() as dbrickstest: + actual_query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo'), + (101,'bar'), + (102,'baz') + ) AS v (col1, col2) + """ + + expected_query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo'), + (110,'bar'), + (999,'qux') + ) AS v (col1, col2) + """ + + with pytest.raises(Exception) as exception_message: + dbrickstest.assert_queries_are_equal(actual_query, expected_query) + + assert str(exception_message.value).startswith("the result sets did not match:") + +def test_unexpected_result(): + with databricks_test.session() as dbrickstest: + actual_query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo'), + (101,'bar') + ) AS v (col1, col2) + """ + + expected_query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo') + ) AS v (col1, col2) + """ + + expected_message="""the result sets did not match: ++---+----+----+ +|m |col1|col2| ++---+----+----+ +|= |100 |foo | +|> |101 |bar | ++---+----+----+ +""" + + with pytest.raises(Exception) as exception_message: + dbrickstest.assert_queries_are_equal(actual_query, expected_query) + + assert str(exception_message.value)==expected_message + +def test_missing_result(): + with databricks_test.session() as dbrickstest: + actual_query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo') + ) AS v (col1, col2) + """ + + expected_query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo'), + (101,'bar') + ) AS v (col1, col2) + """ + + expected_message="""the result sets did not match: ++---+----+----+ +|m |col1|col2| ++---+----+----+ +|= |100 |foo | +|< |101 |bar | ++---+----+----+ +""" + + with pytest.raises(Exception) as exception_message: + dbrickstest.assert_queries_are_equal(actual_query, expected_query) + + assert str(exception_message.value)==expected_message diff --git a/tests/assert_query_returns_no_rows_test.py b/tests/assert_query_returns_no_rows_test.py new file mode 100644 index 0000000..b889ea2 --- /dev/null +++ b/tests/assert_query_returns_no_rows_test.py @@ -0,0 +1,45 @@ +import databricks_test +import pytest + + +def test_no_rows_returned(): + with databricks_test.session() as dbrickstest: + query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo'), + (101,'bar'), + (102,'baz') + ) AS v (col1, col2) + WHERE 1=2 + """ + + dbrickstest.assert_query_returns_no_rows(query) + +def test_rows_returned(): + with databricks_test.session() as dbrickstest: + query = """ + SELECT col1,col2 + FROM + (VALUES + (100,'foo'), + (101,'bar'), + (102,'baz') + ) AS v (col1, col2) + ORDER BY col1 + """ + + expected_message = """the result set was not empty: ++----+----+ +|col1|col2| ++----+----+ +|100 |foo | +|101 |bar | +|102 |baz | ++----+----+ +""" + with pytest.raises(Exception) as exception_message: + dbrickstest.assert_query_returns_no_rows(query) + + assert str(exception_message.value) == expected_message diff --git a/tests/deltalake_test.py b/tests/deltalake_test.py new file mode 100644 index 0000000..386d57d --- /dev/null +++ b/tests/deltalake_test.py @@ -0,0 +1,29 @@ +import databricks_test +from tempfile import TemporaryDirectory + + +def test_deltalake_write(): + with databricks_test.session() as dbrickstest: + with TemporaryDirectory() as tmp_dir: + out_dir = f"{tmp_dir}/delta_out" + + # Provide input and output location as widgets to notebook + switch = { + "output": out_dir, + } + dbrickstest.dbutils.widgets.get.side_effect = lambda x: switch.get( + x, "") + + # Run notebook + dbrickstest.run_notebook(".", "deltalake_write_notebook") + + # Read delta + df = dbrickstest.spark.read.format("delta").load(out_dir) + + # Validate dataframe contains the expected values + rg = range(0, 5) + for n in rg: + assert df.filter(df["id"]==n).count() == 1 + + # Validate dataframe contains no unexpected values + assert df.count() == 5 diff --git a/tests/deltalake_write_notebook.py b/tests/deltalake_write_notebook.py new file mode 100644 index 0000000..2108874 --- /dev/null +++ b/tests/deltalake_write_notebook.py @@ -0,0 +1,11 @@ +# Databricks notebook source + +# Instrument for unit tests. This is only executed in local unit tests, not in Databricks. +if 'dbutils' not in locals(): + import databricks_test + databricks_test.inject_variables() + +# COMMAND ---------- +data = spark.range(0, 5) +data.write.format("delta").save(dbutils.widgets.get('output')) +