Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
addressed comments and added tests
  • Loading branch information
brkyvz committed Apr 29, 2015
commit d3f7e0fa2c054aef5b048bad588e81bcfebe94ab
30 changes: 27 additions & 3 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,24 @@ def _function_obj(sc, is_math=False):
return sc._jvm.mathfunctions


def _create_function(name, doc="", is_math=False):
def _create_function(name, doc="", is_math=False, binary=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if you got my previous comment. it might be easier if is_math just takes a jvm object, rather than using an extra _function_obj

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did... The problem is, I can't pass in sc when it was outside this function. I guess the problem was globally calling sc = SparkContext._active_spark_context, meaning outside of a function. That's why I had to have this hack.

""" Create a function for aggregator by name"""
def _(col):
def _(col1, col2=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is somewhat strange to have col2 be default None. I think it's easier if we just create a _create_binary_function function.

sc = SparkContext._active_spark_context
jc = getattr(_function_obj(sc, is_math), name)(col._jc if isinstance(col, Column) else col)
if not binary:
jc = getattr(_function_obj(sc, is_math), name)(col1._jc if isinstance(col1, Column)
else col1)
else:
assert col2, "The second column for %s not provided!" % name
# users might write ints for simplicity. This would throw an error on the JVM side.
if type(col1) is int:
col1 = col1 * 1.0
if type(col2) is int:
col2 = col2 * 1.0
jc = getattr(_function_obj(sc, is_math), name)(col1._jc if isinstance(col1, Column)
else col1,
col2._jc if isinstance(col2, Column)
else col2)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
Expand Down Expand Up @@ -107,14 +120,25 @@ def _(col):
'measured in radians.'
}

# math functions that take two arguments as input
_binary_math_functions = {
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
'polar coordinates (r, theta).',
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
'pow': 'Returns the value of the first argument raised to the power of the second argument.'
}


for _name, _doc in _functions.items():
globals()[_name] = _create_function(_name, _doc)
for _name, _doc in _math_functions.items():
globals()[_name] = _create_function(_name, _doc, True)
for _name, _doc in _binary_math_functions.items():
globals()[_name] = _create_function(_name, _doc, True, True)
del _name, _doc
__all__ += _functions.keys()
__all__ += _math_functions.keys()
__all__ += _binary_math_functions.keys()
__all__.sort()


Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,35 @@ def test_aggregator(self):
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])

def test_math_functions(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
from pyspark.sql import functions
import math

def get_values(l):
return [j[0] for j in l]

def assert_close(a, b):
c = get_values(b)
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
return sum(diff) == len(a)
assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos(df.a)).collect())
assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos("a")).collect())
assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df.a)).collect())
assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df['a'])).collect())
assert_close([math.pow(i, 2 * i) for i in range(10)],
df.select(functions.pow(df.a, df.b)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2.0)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())

def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
Expand Down