From e1fca69cf548940b712c170afc62dd7d1b38ea6b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 26 Mar 2018 09:20:54 +0000 Subject: [PATCH] Use RandomUUIDGenerator in Uuid expression. --- .../sql/catalyst/analysis/Analyzer.scala | 16 ++++ .../spark/sql/catalyst/expressions/misc.scala | 27 +++++-- .../ResolvedUuidExpressionsSuite.scala | 73 +++++++++++++++++++ .../expressions/ExpressionEvalHelper.scala | 1 + .../expressions/MiscExpressionsSuite.scala | 23 +++++- .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ 6 files changed, 139 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 251099f750cf..9cc928c8ac10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer +import scala.util.Random import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ @@ -177,6 +178,7 @@ class Analyzer( TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: + ResolvedUuidExpressions :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -1994,6 +1996,20 @@ class Analyzer( } } + /** + * Set the seed for random number generation in Uuid expressions. + */ + object ResolvedUuidExpressions extends Rule[LogicalPlan] { + private lazy val random = new Random() + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case p if p.resolved => p + case p => p transformExpressionsUp { + case Uuid(None) => Uuid(Some(random.nextLong())) + } + } + } + /** * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the * null check. When user defines a UDF with primitive parameters, there is no way to tell if the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4b9006ab5b42..cdbe611fdd06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -117,18 +118,34 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { 46707d92-02f4-4817-8116-a4c3b23e6266 """) // scalastyle:on line.size.limit -case class Uuid() extends LeafExpression { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic { - override lazy val deterministic: Boolean = false + def this() = this(None) + + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false override def dataType: DataType = StringType - override def eval(input: InternalRow): Any = UTF8String.fromString(UUID.randomUUID().toString) + @transient private[this] var randomGenerator: RandomUUIDGenerator = _ + + + override protected def initializeInternal(partitionIndex: Int): Unit = + randomGenerator = RandomUUIDGenerator(randomSeed.get + partitionIndex) + + override protected def evalInternal(input: InternalRow): Any = + randomGenerator.getNextUUIDUTF8String() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.copy(code = s"final UTF8String ${ev.value} = " + - s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false") + val randomGen = ctx.freshName("randomGen") + ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen, + forceInline = true, + useFreshName = false) + ctx.addPartitionInitializationStatement(s"$randomGen = " + + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + + s"${randomSeed.get}L + partitionIndex);") + ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + isNull = "false") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala new file mode 100644 index 000000000000..fe57c199b874 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} + +/** + * Test suite for resolving Uuid expressions. + */ +class ResolvedUuidExpressionsSuite extends AnalysisTest { + + private lazy val a = 'a.int + private lazy val r = LocalRelation(a) + private lazy val uuid1 = Uuid().as('_uuid1) + private lazy val uuid2 = Uuid().as('_uuid2) + private lazy val uuid3 = Uuid().as('_uuid3) + private lazy val uuid1Ref = uuid1.toAttribute + + private val analyzer = getAnalyzer(caseSensitive = true) + + private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = { + plan.flatMap { + case p => + p.expressions.flatMap(_.collect { + case u: Uuid => u + }) + } + } + + test("analyzed plan sets random seed for Uuid expression") { + val plan = r.select(a, uuid1) + val resolvedPlan = analyzer.executeAndCheck(plan) + getUuidExpressions(resolvedPlan).foreach { u => + assert(u.resolved) + assert(u.randomSeed.isDefined) + } + } + + test("Uuid expressions should have different random seeds") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan = analyzer.executeAndCheck(plan) + assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3) + } + + test("Different analyzed plans should have different random seeds in Uuids") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan1 = analyzer.executeAndCheck(plan) + val resolvedPlan2 = analyzer.executeAndCheck(plan) + val uuids1 = getUuidExpressions(resolvedPlan1) + val uuids2 = getUuidExpressions(resolvedPlan2) + assert(uuids1.distinct.length == 3) + assert(uuids2.distinct.length == 3) + assert(uuids1.intersect(uuids2).length == 0) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4c8eab19c5c..fad26f0055d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -162,6 +162,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { Alias(expression, s"Optimized($expression)1")() :: Alias(expression, s"Optimized($expression)2")() :: Nil), expression) + plan.initialize(0) val unsafeRow = plan(inputRow) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index facc86308130..593bcd02078c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -40,7 +43,23 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("uuid") { - checkEvaluation(Length(Uuid()), 36) - assert(evaluate(Uuid()) !== evaluate(Uuid())) + def assertIncorrectEval(f: () => Unit): Unit = { + intercept[Exception] { + f() + }.getMessage().contains("Incorrect evaluation") + } + + checkEvaluation(Length(Uuid(Some(0))), 36) + val r = new Random() + val seed1 = Some(r.nextLong()) + val uuid1 = evaluate(Uuid(seed1)).asInstanceOf[UTF8String] + checkEvaluation(Uuid(seed1), uuid1.toString) + + val seed2 = Some(r.nextLong()) + val uuid2 = evaluate(Uuid(seed2)).asInstanceOf[UTF8String] + assertIncorrectEval(() => checkEvaluationWithoutCodegen(Uuid(seed1), uuid2)) + assertIncorrectEval(() => checkEvaluationWithGeneratedMutableProjection(Uuid(seed1), uuid2)) + assertIncorrectEval(() => checkEvalutionWithUnsafeProjection(Uuid(seed1), uuid2)) + assertIncorrectEval(() => checkEvaluationWithOptimization(Uuid(seed1), uuid2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8b66f77b2f92..f7b3393f65cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -2264,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(0, 10) :: Nil) assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + + test("Uuid expressions should produce same results at retries in the same DataFrame") { + val df = spark.range(1).select($"id", new Column(Uuid())) + checkAnswer(df, df.collect()) + } }