Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add the base of fallback logic fo UnsafeProjection.
  • Loading branch information
viirya committed Apr 19, 2018
commit fe2a1cdd9002f14422c812d51041ed4f6d361cfc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.InternalCompilerException

object CodegenObjectFactory {
Copy link
Contributor

@hvanhovell hvanhovell Apr 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this an abstract class in which we just need to implement hooks for the code generated and interpreted versions, i.e.:

object CodegenError {
  def unapply(throwable: Throwable): Option[Exception] = throwable match {
    case e: InternalCompilerException => Some(e)
    case e: CompileException => Some(e)
    case _ => None
  }
}

abstract class CodegenObjectFactory[IN, OUT] {
  def create(in: IN): OUT = try createCodeGeneratedObject(in) catch {
    case CodegenError(_) =>createInterpretedObject(in)
  }

  protected def createCodeGeneratedObject(in: IN): OUT
  protected def createInterpretedObject(in: IN): OUT
}

object UnsafeProjectionCreator extends CodegenObjectFactory[Seq[Expression], UnsafeProjection] {
  ...
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, except that I think UnsafeProjectionCreator should still support UnsafeProjectionCreator's APIs. So we can replace all usage of UnsafeProjection with UnsafeProjectionCreator.

def codegenOrInterpreted[T](codegenCreator: () => T, interpretedCreator: () => T): T = {
try {
codegenCreator()
} catch {
// Catch compile error related exceptions
case e: InternalCompilerException => interpretedCreator()
case e: CompileException => interpretedCreator()
}
}
}

object UnsafeProjectionFactory extends UnsafeProjectionCreator {
import CodegenObjectFactory._

private val codegenCreator = UnsafeProjection
private lazy val interpretedCreator = InterpretedUnsafeProjection

/**
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
*/
override protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
codegenOrInterpreted[UnsafeProjection](() => codegenCreator.createProjection(exprs),
() => interpretedCreator.createProjection(exprs))
}

/**
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
* The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example,
* when fallbacking to interpreted execution, it is not supported.
*/
def create(
exprs: Seq[Expression],
inputSchema: Seq[Attribute],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
codegenOrInterpreted[UnsafeProjection](
() => codegenCreator.create(exprs, inputSchema, subexpressionEliminationEnabled),
() => interpretedCreator.create(exprs, inputSchema))
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
/**
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
*/
override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
override protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
// We need to make sure that we do not reuse stateful expressions.
val cleanedExpressions = exprs.map(_.transform {
case s: Stateful => s.freshCopy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,18 @@ abstract class UnsafeProjection extends Projection {
}

trait UnsafeProjectionCreator {
protected def toBoundExprs(
exprs: Seq[Expression],
inputSchema: Seq[Attribute]): Seq[Expression] = {
exprs.map(BindReferences.bindReference(_, inputSchema))
}

protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
}

/**
* Returns an UnsafeProjection for given StructType.
*
Expand All @@ -129,10 +141,7 @@ trait UnsafeProjectionCreator {
* Returns an UnsafeProjection for given sequence of bound Expressions.
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
val unsafeExprs = exprs.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
createProjection(unsafeExprs)
createProjection(toUnsafeExprs(exprs))
}

def create(expr: Expression): UnsafeProjection = create(Seq(expr))
Expand All @@ -142,18 +151,18 @@ trait UnsafeProjectionCreator {
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
create(toBoundExprs(exprs, inputSchema))
}

/**
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
*/
protected def createProjection(exprs: Seq[Expression]): UnsafeProjection
protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection
}

object UnsafeProjection extends UnsafeProjectionCreator {

override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
override protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(exprs)
}

Expand All @@ -165,11 +174,8 @@ object UnsafeProjection extends UnsafeProjectionCreator {
exprs: Seq[Expression],
inputSchema: Seq[Attribute],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
val e = exprs.map(BindReferences.bindReference(_, inputSchema))
.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema))
GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled)
}
}

Expand Down