From 8e7b05f0f59d58bf6482e069b31c102eff5a1b83 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 11 Jul 2016 17:45:32 -0700 Subject: [PATCH 1/2] Fix codegen variable namespace collision for pmod --- .../sql/catalyst/expressions/arithmetic.scala | 25 ++++++++++--------- .../sql/test/DataFrameReaderWriterSuite.scala | 14 +++++++++++ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4db1352291e0..7b634f64f71e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -498,34 +498,35 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic wi override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val r = ctx.freshName("r") dataType match { case dt: DecimalType => val decimalAdd = "$plus" s""" - ${ctx.javaType(dataType)} r = $eval1.remainder($eval2); - if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { - ${ev.value} = (r.$decimalAdd($eval2)).remainder($eval2); + ${ctx.javaType(dataType)} $r = $eval1.remainder($eval2); + if ($r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { + ${ev.value} = ($r.$decimalAdd($eval2)).remainder($eval2); } else { - ${ev.value} = r; + ${ev.value} = $r; } """ // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => s""" - ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2); - if (r < 0) { - ${ev.value} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2); + ${ctx.javaType(dataType)} $r = (${ctx.javaType(dataType)})($eval1 % $eval2); + if ($r < 0) { + ${ev.value} = (${ctx.javaType(dataType)})(($r + $eval2) % $eval2); } else { - ${ev.value} = r; + ${ev.value} = $r; } """ case _ => s""" - ${ctx.javaType(dataType)} r = $eval1 % $eval2; - if (r < 0) { - ${ev.value} = (r + $eval2) % $eval2; + ${ctx.javaType(dataType)} $r = $eval1 % $eval2; + if ($r < 0) { + ${ev.value} = ($r + $eval2) % $eval2; } else { - ${ev.value} = r; + ${ev.value} = $r; } """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 05935cec4b67..f706b20364c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -449,6 +449,20 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } + test("pmod with partitionBy") { + val spark = this.spark + import spark.implicits._ + + case class Test(a: Int, b: String) + val data = Seq((0, "a"), (1, "b"), (1, "a")) + spark.createDataset(data).createOrReplaceTempView("test") + sql("select * from test distribute by pmod(_1, 2)") + .write + .partitionBy("_2") + .mode("overwrite") + .parquet(dir) + } + private def testRead( df: => DataFrame, expectedResult: Seq[String], From 8b2639f60975f07157790309052949ff0daabe38 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 11 Jul 2016 19:04:28 -0700 Subject: [PATCH 2/2] s/r/remainder --- .../sql/catalyst/expressions/arithmetic.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7b634f64f71e..91ffac0ba2a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -498,35 +498,35 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic wi override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val r = ctx.freshName("r") + val remainder = ctx.freshName("remainder") dataType match { case dt: DecimalType => val decimalAdd = "$plus" s""" - ${ctx.javaType(dataType)} $r = $eval1.remainder($eval2); - if ($r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { - ${ev.value} = ($r.$decimalAdd($eval2)).remainder($eval2); + ${ctx.javaType(dataType)} $remainder = $eval1.remainder($eval2); + if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { + ${ev.value} = ($remainder.$decimalAdd($eval2)).remainder($eval2); } else { - ${ev.value} = $r; + ${ev.value} = $remainder; } """ // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => s""" - ${ctx.javaType(dataType)} $r = (${ctx.javaType(dataType)})($eval1 % $eval2); - if ($r < 0) { - ${ev.value} = (${ctx.javaType(dataType)})(($r + $eval2) % $eval2); + ${ctx.javaType(dataType)} $remainder = (${ctx.javaType(dataType)})($eval1 % $eval2); + if ($remainder < 0) { + ${ev.value} = (${ctx.javaType(dataType)})(($remainder + $eval2) % $eval2); } else { - ${ev.value} = $r; + ${ev.value} = $remainder; } """ case _ => s""" - ${ctx.javaType(dataType)} $r = $eval1 % $eval2; - if ($r < 0) { - ${ev.value} = ($r + $eval2) % $eval2; + ${ctx.javaType(dataType)} $remainder = $eval1 % $eval2; + if ($remainder < 0) { + ${ev.value} = ($remainder + $eval2) % $eval2; } else { - ${ev.value} = $r; + ${ev.value} = $remainder; } """ }