Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
addressed comments v2.0
  • Loading branch information
brkyvz committed Apr 29, 2015
commit 25e653436073ae396a8136a11179b42fb5153d12
62 changes: 30 additions & 32 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,36 +33,38 @@
__all__ = ['countDistinct', 'approxCountDistinct', 'udf']


def _function_obj(sc, is_math=False):
if not is_math:
return sc._jvm.functions
else:
return sc._jvm.mathfunctions
def _create_function(name, doc="", is_math=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u can now remove is_math

""" Create a function for aggregator by name"""
def _(col1):
sc = SparkContext._active_spark_context
if is_math:
jvm_class = sc._jvm.mathfunctions
else:
jvm_class = sc._jvm.functions
jc = getattr(jvm_class, name)(col1._jc if isinstance(col1, Column) else col1)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
return _


def _create_function(name, doc="", is_math=False, binary=False):
def _create_binary_function(name, doc=""):
""" Create a function for aggregator by name"""
def _(col1, col2=None):
def _(col1, col2):
sc = SparkContext._active_spark_context
if not binary:
jc = getattr(_function_obj(sc, is_math), name)(col1._jc if isinstance(col1, Column)
else col1)
else:
assert col2, "The second column for %s not provided!" % name
# users might write ints for simplicity. This would throw an error on the JVM side.
if type(col1) is int:
col1 = col1 * 1.0
if type(col2) is int:
col2 = col2 * 1.0
jc = getattr(_function_obj(sc, is_math), name)(col1._jc if isinstance(col1, Column)
else col1,
col2._jc if isinstance(col2, Column)
else col2)
# users might write ints for simplicity. This would throw an error on the JVM side.
if type(col1) is int:
col1 = col1 * 1.0
if type(col2) is int:
col2 = col2 * 1.0
jc = getattr(sc._jvm.mathfunctions, name)(col1._jc if isinstance(col1, Column) else col1,
col2._jc if isinstance(col2, Column) else col2)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
return _


_functions = {
'lit': 'Creates a :class:`Column` of literal value.',
'col': 'Returns a :class:`Column` based on the given column name.',
Expand All @@ -87,26 +89,22 @@ def _(col1, col2=None):
}

# math functions are found under another object therefore, they need to be handled separately
_math_functions = {
_mathfunctions = {
'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' +
'0.0 through pi.',
'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' +
'-pi/2 through pi/2.',
'atan': 'Computes the tangent inverse of the given value.',
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
'polar coordinates (r, theta).',
'cbrt': 'Computes the cube-root of the given value.',
'ceil': 'Computes the ceiling of the given value.',
'cos': 'Computes the cosine of the given value.',
'cosh': 'Computes the hyperbolic cosine of the given value.',
'exp': 'Computes the exponential of the given value.',
'expm1': 'Computes the exponential of the given value minus one.',
'floor': 'Computes the floor of the given value.',
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
'log': 'Computes the natural logarithm of the given value.',
'log10': 'Computes the logarithm of the given value in Base 10.',
'log1p': 'Computes the natural logarithm of the given value plus one.',
'pow': 'Returns the value of the first argument raised to the power of the second argument.',
'rint': 'Returns the double value that is closest in value to the argument and' +
' is equal to a mathematical integer.',
'signum': 'Computes the signum of the given value.',
Expand All @@ -121,7 +119,7 @@ def _(col1, col2=None):
}

# math functions that take two arguments as input
_binary_math_functions = {
_binary_mathfunctions = {
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
'polar coordinates (r, theta).',
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
Expand All @@ -131,14 +129,14 @@ def _(col1, col2=None):

for _name, _doc in _functions.items():
globals()[_name] = _create_function(_name, _doc)
for _name, _doc in _math_functions.items():
for _name, _doc in _mathfunctions.items():
globals()[_name] = _create_function(_name, _doc, True)
for _name, _doc in _binary_math_functions.items():
globals()[_name] = _create_function(_name, _doc, True, True)
for _name, _doc in _binary_mathfunctions.items():
globals()[_name] = _create_binary_function(_name, _doc)
del _name, _doc
__all__ += _functions.keys()
__all__ += _math_functions.keys()
__all__ += _binary_math_functions.keys()
__all__ += _mathfunctions.keys()
__all__ += _binary_mathfunctions.keys()
__all__.sort()


Expand Down