Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ed932d2
Temporarily renames Dataset to DS
liancheng Mar 1, 2016
e59e940
Renames DataFrame to Dataset[T]
liancheng Mar 1, 2016
b357371
Fixes Java API compilation failures
liancheng Mar 1, 2016
3783e31
Fixes styling issues
liancheng Mar 1, 2016
a02a922
Fixes compilation failure introduced while rebasing
liancheng Mar 1, 2016
3db81f8
Temporarily disables MiMA check for convenience
liancheng Mar 1, 2016
f67f497
Fixes infinite recursion in Dataset constructor
liancheng Mar 1, 2016
f921583
Fixes test failures
liancheng Mar 3, 2016
fa22261
Migrates encoder stuff to the new Dataset
liancheng Mar 3, 2016
8cf5672
Makes some shape-keeping operations typed
liancheng Mar 5, 2016
712ee19
Adds collectRows() for Java API
liancheng Mar 6, 2016
c73b91f
Migrates joinWith operations
liancheng Mar 6, 2016
54cb36a
Migrates typed select
liancheng Mar 7, 2016
cbd7519
Renames typed groupBy to groupByKey
liancheng Mar 7, 2016
f1a2903
Migrates typed groupBy
liancheng Mar 7, 2016
15b4193
Migrates functional transformers
liancheng Mar 7, 2016
9aff0e2
Removes the old DS class and gets everything compiled
liancheng Mar 7, 2016
f053852
Fixes compilation error
liancheng Mar 7, 2016
3a7aff4
Row encoder should produce GenericRowWithSchema
liancheng Mar 8, 2016
9f8e0ad
Fixes compilation error after rebasing
liancheng Mar 8, 2016
bc081a9
Migrated Dataset.showString should handle typed objects
liancheng Mar 8, 2016
6b69f43
MapObjects should also handle DecimalType and DateType
liancheng Mar 8, 2016
29cb70f
Always use eager analysis
liancheng Mar 9, 2016
ba86e09
Fixes compilation error after rebasing
liancheng Mar 10, 2016
5727f48
Fixes compilation error after rebasing
liancheng Mar 10, 2016
cd63810
Temporarily ignores Python UDT test cases
liancheng Mar 10, 2016
4c8d139
fix python
cloud-fan Mar 10, 2016
cf0c112
Merge pull request #9 from cloud-fan/ds-to-df
liancheng Mar 10, 2016
91942cf
fix pymllib
cloud-fan Mar 10, 2016
736fbcb
Merge pull request #10 from cloud-fan/ds-to-df
liancheng Mar 10, 2016
488a82e
MIMA
yhuai Mar 10, 2016
343c611
Fix typo...
yhuai Mar 10, 2016
63d4d69
MIMA: Exclude DataFrame class.
yhuai Mar 10, 2016
f6bcd50
Revert "MIMA: Exclude DataFrame class."
yhuai Mar 10, 2016
49c6fc2
Revert "Fix typo..."
yhuai Mar 10, 2016
d52ce17
Revert "MIMA"
yhuai Mar 10, 2016
7d29c06
Merge remote-tracking branch 'upstream/master' into ds-to-df
yhuai Mar 11, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix python
  • Loading branch information
cloud-fan committed Mar 10, 2016
commit 4c8d13928ef6ecafdb88a19d40736039d205d824
140 changes: 70 additions & 70 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,76 +528,76 @@ def check_datatype(datatype):
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))

# def test_infer_schema_with_udt(self):
# from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
# row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
# df = self.sqlCtx.createDataFrame([row])
# schema = df.schema
# field = [f for f in schema.fields if f.name == "point"][0]
# self.assertEqual(type(field.dataType), ExamplePointUDT)
# df.registerTempTable("labeled_point")
# point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
# self.assertEqual(point, ExamplePoint(1.0, 2.0))

# row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
# df = self.sqlCtx.createDataFrame([row])
# schema = df.schema
# field = [f for f in schema.fields if f.name == "point"][0]
# self.assertEqual(type(field.dataType), PythonOnlyUDT)
# df.registerTempTable("labeled_point")
# point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
# self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))

# def test_apply_schema_with_udt(self):
# from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
# row = (1.0, ExamplePoint(1.0, 2.0))
# schema = StructType([StructField("label", DoubleType(), False),
# StructField("point", ExamplePointUDT(), False)])
# df = self.sqlCtx.createDataFrame([row], schema)
# point = df.head().point
# self.assertEqual(point, ExamplePoint(1.0, 2.0))

# row = (1.0, PythonOnlyPoint(1.0, 2.0))
# schema = StructType([StructField("label", DoubleType(), False),
# StructField("point", PythonOnlyUDT(), False)])
# df = self.sqlCtx.createDataFrame([row], schema)
# point = df.head().point
# self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))

# def test_udf_with_udt(self):
# from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
# row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
# df = self.sqlCtx.createDataFrame([row])
# self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
# udf = UserDefinedFunction(lambda p: p.y, DoubleType())
# self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
# udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
# self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

# row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
# df = self.sqlCtx.createDataFrame([row])
# self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
# udf = UserDefinedFunction(lambda p: p.y, DoubleType())
# self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
# udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
# self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

# def test_parquet_with_udt(self):
# from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
# row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
# df0 = self.sqlCtx.createDataFrame([row])
# output_dir = os.path.join(self.tempdir.name, "labeled_point")
# df0.write.parquet(output_dir)
# df1 = self.sqlCtx.read.parquet(output_dir)
# point = df1.head().point
# self.assertEqual(point, ExamplePoint(1.0, 2.0))

# row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
# df0 = self.sqlCtx.createDataFrame([row])
# df0.write.parquet(output_dir, mode='overwrite')
# df1 = self.sqlCtx.read.parquet(output_dir)
# point = df1.head().point
# self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df = self.sqlCtx.createDataFrame([row])
schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
df.registerTempTable("labeled_point")
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))

row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
df = self.sqlCtx.createDataFrame([row])
schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), PythonOnlyUDT)
df.registerTempTable("labeled_point")
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))

def test_apply_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = (1.0, ExamplePoint(1.0, 2.0))
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
df = self.sqlCtx.createDataFrame([row], schema)
point = df.head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))

row = (1.0, PythonOnlyPoint(1.0, 2.0))
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", PythonOnlyUDT(), False)])
df = self.sqlCtx.createDataFrame([row], schema)
point = df.head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))

def test_udf_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df = self.sqlCtx.createDataFrame([row])
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
df = self.sqlCtx.createDataFrame([row])
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df0 = self.sqlCtx.createDataFrame([row])
output_dir = os.path.join(self.tempdir.name, "labeled_point")
df0.write.parquet(output_dir)
df1 = self.sqlCtx.read.parquet(output_dir)
point = df1.head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))

row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
df0 = self.sqlCtx.createDataFrame([row])
df0.write.parquet(output_dir, mode='overwrite')
df1 = self.sqlCtx.read.parquet(output_dir)
point = df1.head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))

def test_unionAll_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ object RowEncoder {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject

case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType)

case udt: UserDefinedType[_] =>
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Expand Down Expand Up @@ -151,10 +153,14 @@ object RowEncoder {

private def constructorFor(schema: StructType): Expression = {
val fields = schema.zipWithIndex.map { case (f, i) =>
val field = BoundReference(i, f.dataType, f.nullable)
val dt = f.dataType match {
case p: PythonUserDefinedType => p.sqlType
case other => other
}
val field = BoundReference(i, dt, f.nullable)
If(
IsNull(field),
Literal.create(null, externalDataTypeFor(f.dataType)),
Literal.create(null, externalDataTypeFor(dt)),
constructorFor(field)
)
}
Expand Down