Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add udf decorator
This PR adds `udf` decorator syntax as proposed in [SPARK-19160](https://issues.apache.org/jira/browse/SPARK-19160).

This allows users to define UDF using simplified syntax:

```
from pyspark.sql.decorators import udf

@udf(IntegerType())
def add_one(x):
    """Adds one"""
    if x is not None:
        return x + 1
```

without need to define a separate function and udf.

Tested wiht existing unit tests to ensure backward compatibility and additional unit tests covering new functionality.
  • Loading branch information
zero323 committed Feb 15, 2017
commit 8280f424138293e4fd411a6ef11136caec987ef7
27 changes: 4 additions & 23 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,6 @@ def test_explode(self):
self.assertEqual(result[0][0], "a")
self.assertEqual(result[0][1], "b")

with self.assertRaises(ValueError):
data.select(explode(data.mapfield).alias("a", "b", metadata={'max': 99})).count()

def test_and_in_expression(self):
self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
Expand Down Expand Up @@ -543,32 +540,23 @@ def substr(x, start, end):
if x is not None:
return x[start:end]

@udf("long")
def trunc(x):
return int(x)

@udf(returnType="double")
def as_double(x):
return float(x)

df = (
self.spark
.createDataFrame(
[(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
[(1, "Foo", "foobar")], ("one", "Foo", "foobar"))
.select(
add_one("one"), add_two("one"),
to_upper("Foo"), to_lower("Foo"),
substr("foobar", lit(0), lit(3)),
trunc("float"), as_double("one")))
substr("foobar", lit(0), lit(3))))

self.assertListEqual(
[tpe for _, tpe in df.dtypes],
["int", "double", "string", "string", "string", "bigint", "double"]
["int", "double", "string", "string", "string"]
)

self.assertListEqual(
list(df.first()),
[2, 3.0, "FOO", "foo", "foo", 3, 1.0]
[2, 3.0, "FOO", "foo", "foo"]
)

def test_basic_functions(self):
Expand Down Expand Up @@ -955,13 +943,6 @@ def test_column_select(self):
self.assertEqual(self.testData, df.select(df.key, df.value).collect())
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())

def test_column_alias_metadata(self):
df = self.df
df_with_meta = df.select(df.key.alias('pk', metadata={'label': 'Primary Key'}))
self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary Key')
with self.assertRaises(AssertionError):
df.select(df.key.alias('pk', metdata={'label': 'Primary Key'}))

def test_freqItems(self):
vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
df = self.sc.parallelize(vals).toDF()
Expand Down