Skip to content
Closed
Prev Previous commit
Next Next commit
Fix decimal support in PySpark
  • Loading branch information
mateiz committed Nov 1, 2014
commit 4dc6bae2304d82888d556607206a11f60dcb9ec7
33 changes: 32 additions & 1 deletion python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import keyword
import warnings
import json
import re
from array import array
from operator import itemgetter
from itertools import imap
Expand Down Expand Up @@ -148,13 +149,35 @@ class TimestampType(PrimitiveType):
"""


class DecimalType(PrimitiveType):
class DecimalType(DataType):

"""Spark SQL DecimalType

The data type representing decimal.Decimal values.
"""

def __init__(self, precision=None, scale=None):
if precision is None:
self.hasPrecisionInfo = False
self.precision = None
self.scale = None
else:
self.hasPrecisionInfo = True
self.precision = precision
self.scale = scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these lines could be simplified as:

self.precision = precision
self.scale = scale
self.hasPrecisionInfo = precision is not None

Maybe we could remove hasPrecisionInfo


def jsonValue(self):
if self.hasPrecisionInfo:
return "decimal(%d,%d)" % (self.precision, self.scale)
else:
return "decimal"

def __repr__(self):
if self.hasPrecisionInfo:
return "DecimalType(%d,%d)" % (self.precision, self.scale)
else:
return "DecimalType()"


class DoubleType(PrimitiveType):

Expand Down Expand Up @@ -446,9 +469,17 @@ def _parse_datatype_json_string(json_string):
return _parse_datatype_json_value(json.loads(json_string))


_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")


def _parse_datatype_json_value(json_value):
if type(json_value) is unicode and json_value in _all_primitive_types.keys():
return _all_primitive_types[json_value]()
elif type(json_value) is unicode and json_value == u'decimal':
return DecimalType()
elif type(json_value) is unicode and _FIXED_DECIMAL.match(json_value):
m = _FIXED_DECIMAL.match(json_value)
return DecimalType(int(m.group(1)), int(m.group(2)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change these into two levels and raise exception if json_value is unicode but can not recognize the value?

else:
return _all_complex_types[json_value["type"]].fromJson(json_value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution

import java.util.{List => JList, Map => JMap}

import org.apache.spark.sql.catalyst.types.decimal.Decimal

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._

Expand Down Expand Up @@ -116,7 +118,7 @@ object EvaluatePython {
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null

case (row: Row, struct: StructType) =>
case (row: Seq[Any], struct: StructType) =>
val fields = struct.fields.map(field => field.dataType)
row.zip(fields).map {
case (obj, dataType) => toJava(obj, dataType)
Expand All @@ -133,6 +135,8 @@ object EvaluatePython {
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
}.asJava

case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal

// Pyrolite can handle Timestamp
case (other, _) => other
}
Expand Down