Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Feature/ecotricity customisations (#1)
* update .gitignore

* add deltalake support

* add Spark SQL comparison functions
  • Loading branch information
EdHarper-eco authored Dec 8, 2020
commit f0560c84bdd0dd3eefa976a1af901927a4f299be
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,6 @@ dmypy.json
.pyre/

.vscode

metastore_db/
spark-warehouse/
65 changes: 65 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ See samples below for more examples.
* PySpark with all Spark features including reading and writing to disk, UDFs and Pandas UDFs
* Databricks Utilities (`dbutils`, `display`) with user-configurable mocks
* Mocking connectors such as Azure Storage, S3 and SQL Data Warehouse
* Helper function to compare and evaluate the results of Spark SQL queries

## Unsupported features

Expand All @@ -71,6 +72,25 @@ See samples below for more examples.
* Writing directly to `/dbfs` mount on local filesystem
* Databricks extensions to Spark such as `spark.read.format("binaryFile")`

## Helper functions

Helper functions are available to assist with Spark SQL queries:

### `Session.assert_queries_are_equal(actual_query, expected_query)`

* Asserts that the result sets returned by two supplied Spark SQL queries are equal
* A detailed table comparison output is only shown in the event the assertion fails.
In the table comparison output, the first column (`m`) can have three different values:
* the symbol `<` indicates that the row was found in the *expected* results but did not match anything in the *actual* results
* the symbol `>` indicates that the row was found in the *actual* table but not in the *expected* table
* the symbol `=` indicates that the row was matched between the *expected* and *actual* tables
* *this behaviour is inspired by the `tSQLt` unit test framework for SQL Server*

### `Session.assert_query_returns_no_rows(actual_query)`

* Assets that the result set returned by a supplied Spark SQL query is empty
* A detailed table comparison output is only shown in the event that the result set is not empty

## Sample test

Sample test case for an ETL notebook reading CSV and writing Parquet.
Expand Down Expand Up @@ -238,6 +258,51 @@ def test_sqldw(monkeypatch):
assert_frame_equal(expectedDF, resultDF, check_dtype=False)
```

## Spark SQL comparison functions

A test comparing the output of two Spark SQL queries using the
`assert_queries_are_equal` function:

```python
def test_results_do_not_match():
with databricks_test.session() as dbrickstest:
actual_query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo'),
(101,'bar'),
(102,'baz')
) AS v (col1, col2)
"""

expected_query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo'),
(110,'bar'),
(999,'qux')
) AS v (col1, col2)
"""

dbrickstest.assert_queries_are_equal(actual_query, expected_query)
```

A test validating that the output of a Spark SQL query returns no rows
using the `assert_query_returns_no_rows` function:

```python
def test_no_rows_returned():
with databricks_test.session() as dbrickstest:
query = """
SELECT 100 AS col1, 'abc' AS col2
WHERE 1=2
"""

dbrickstest.assert_query_returns_no_rows(query)
```

## Issues

Please report issues at [http://github.com/algattik/databricks_test](http://github.com/algattik/databricks_test).
65 changes: 64 additions & 1 deletion databricks_test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import MagicMock
import inspect
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.functions import udf, col
import importlib
import sys
import os
Expand Down Expand Up @@ -87,6 +87,12 @@ def __init__(self):
self.spark = (SparkSession.builder
.master("local")
.appName("test-pyspark")
#add delta lake support
.config("spark.jars.packages", "io.delta:delta-core_2.12:0.7.0")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
#set default SerDe to parquet
.config("hive.default.fileformat","parquet")
.enableHiveSupport()
.getOrCreate())

Expand All @@ -105,6 +111,63 @@ def run_notebook(self, dir, script):
except WorkflowInterrupted:
pass

def get_show_string(self, df, n=20, truncate=True, vertical=False)->str:
"""
Returns the output of the `.show()` function as a string
"""
if isinstance(truncate, bool) and truncate:
return(df._jdf.showString(n, 20, vertical))
else:
return(df._jdf.showString(n, int(truncate), vertical))

def assert_queries_are_equal(self, actual_query, expected_query):
"""
Compares the results of two queries for equality

The resultsets generated by the queries must have the same schema

The detailed table comparison output is only shown in the event of a failure
"""
result_df = self.spark.sql(f"""
WITH actual
AS
({actual_query})
,expected
AS
({expected_query})
,matched
AS
(SELECT * FROM actual INTERSECT SELECT * FROM expected)
,missing
AS
(SELECT * FROM actual EXCEPT SELECT * FROM expected)
,extra
AS
(SELECT * FROM expected EXCEPT SELECT * FROM actual)
SELECT '=' AS m, * FROM matched
UNION ALL
SELECT '>' AS m, * FROM missing
UNION ALL
SELECT '<' AS m, * FROM extra
"""
)

failure_count = result_df.where(col("m").isin({">", "<"})).count()
if failure_count > 0:
msg = self.get_show_string(result_df, n=result_df.count(), truncate=False)
assert failure_count == 0, f"the result sets did not match:\n{msg}"

def assert_query_returns_no_rows(self, actual_query):
"""
Asserts that a query returns an empty result set
"""
result_df = self.spark.sql(actual_query)

record_count = result_df.count()
if record_count > 0:
msg = self.get_show_string(result_df, n=result_df.count(), truncate=False)
assert record_count == 0, f"the result set was not empty:\n{msg}"


def inject_variables():
"""
Expand Down
110 changes: 110 additions & 0 deletions tests/assert_queries_are_equal_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import databricks_test
import pytest


def test_results_match():
with databricks_test.session() as dbrickstest:
query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo'),
(101,'bar'),
(102,'baz')
) AS v (col1, col2)
"""

dbrickstest.assert_queries_are_equal(query, query)

def test_results_do_not_match():
with databricks_test.session() as dbrickstest:
actual_query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo'),
(101,'bar'),
(102,'baz')
) AS v (col1, col2)
"""

expected_query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo'),
(110,'bar'),
(999,'qux')
) AS v (col1, col2)
"""

with pytest.raises(Exception) as exception_message:
dbrickstest.assert_queries_are_equal(actual_query, expected_query)

assert str(exception_message.value).startswith("the result sets did not match:")

def test_unexpected_result():
with databricks_test.session() as dbrickstest:
actual_query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo'),
(101,'bar')
) AS v (col1, col2)
"""

expected_query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo')
) AS v (col1, col2)
"""

expected_message="""the result sets did not match:
+---+----+----+
|m |col1|col2|
+---+----+----+
|= |100 |foo |
|> |101 |bar |
+---+----+----+
"""

with pytest.raises(Exception) as exception_message:
dbrickstest.assert_queries_are_equal(actual_query, expected_query)

assert str(exception_message.value)==expected_message

def test_missing_result():
with databricks_test.session() as dbrickstest:
actual_query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo')
) AS v (col1, col2)
"""

expected_query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo'),
(101,'bar')
) AS v (col1, col2)
"""

expected_message="""the result sets did not match:
+---+----+----+
|m |col1|col2|
+---+----+----+
|= |100 |foo |
|< |101 |bar |
+---+----+----+
"""

with pytest.raises(Exception) as exception_message:
dbrickstest.assert_queries_are_equal(actual_query, expected_query)

assert str(exception_message.value)==expected_message
45 changes: 45 additions & 0 deletions tests/assert_query_returns_no_rows_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import databricks_test
import pytest


def test_no_rows_returned():
with databricks_test.session() as dbrickstest:
query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo'),
(101,'bar'),
(102,'baz')
) AS v (col1, col2)
WHERE 1=2
"""

dbrickstest.assert_query_returns_no_rows(query)

def test_rows_returned():
with databricks_test.session() as dbrickstest:
query = """
SELECT col1,col2
FROM
(VALUES
(100,'foo'),
(101,'bar'),
(102,'baz')
) AS v (col1, col2)
ORDER BY col1
"""

expected_message = """the result set was not empty:
+----+----+
|col1|col2|
+----+----+
|100 |foo |
|101 |bar |
|102 |baz |
+----+----+
"""
with pytest.raises(Exception) as exception_message:
dbrickstest.assert_query_returns_no_rows(query)

assert str(exception_message.value) == expected_message
29 changes: 29 additions & 0 deletions tests/deltalake_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import databricks_test
from tempfile import TemporaryDirectory


def test_deltalake_write():
with databricks_test.session() as dbrickstest:
with TemporaryDirectory() as tmp_dir:
out_dir = f"{tmp_dir}/delta_out"

# Provide input and output location as widgets to notebook
switch = {
"output": out_dir,
}
dbrickstest.dbutils.widgets.get.side_effect = lambda x: switch.get(
x, "")

# Run notebook
dbrickstest.run_notebook(".", "deltalake_write_notebook")

# Read delta
df = dbrickstest.spark.read.format("delta").load(out_dir)

# Validate dataframe contains the expected values
rg = range(0, 5)
for n in rg:
assert df.filter(df["id"]==n).count() == 1

# Validate dataframe contains no unexpected values
assert df.count() == 5
11 changes: 11 additions & 0 deletions tests/deltalake_write_notebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Databricks notebook source

# Instrument for unit tests. This is only executed in local unit tests, not in Databricks.
if 'dbutils' not in locals():
import databricks_test
databricks_test.inject_variables()

# COMMAND ----------
data = spark.range(0, 5)
data.write.format("delta").save(dbutils.widgets.get('output'))