-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-35261][SQL] Support static magic method for stateless Java ScalarFunction #32407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
20e8f1c
108123b
d2ff91d
994d3b1
ae18622
1f322a4
2fcb3f3
2b7afc7
5096fac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.analysis | ||
|
|
||
| import java.lang.reflect.Method | ||
| import java.lang.reflect.{Method, Modifier} | ||
| import java.util | ||
| import java.util.Locale | ||
| import java.util.concurrent.atomic.AtomicBoolean | ||
|
|
@@ -48,7 +48,6 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ | |
| import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} | ||
| import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction} | ||
| import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME | ||
| import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.STATIC_MAGIC_METHOD_NAME | ||
| import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} | ||
| import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} | ||
| import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation | ||
|
|
@@ -2181,30 +2180,27 @@ class Analyzer(override val catalogManager: CatalogManager) | |
| // may also want to check if the parameter types from the magic method | ||
| // match the input type through `BoundFunction.inputTypes`. | ||
| val argClasses = inputType.fields.map(_.dataType) | ||
| findMethod(scalarFunc, STATIC_MAGIC_METHOD_NAME, argClasses) match { | ||
| case Some(_) => | ||
| findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { | ||
| case Some(m) if Modifier.isStatic(m.getModifiers) => | ||
| StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), | ||
|
||
| STATIC_MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable) | ||
| MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable) | ||
| case Some(_) => | ||
| val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) | ||
| Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), | ||
| arguments, returnNullable = scalarFunc.isResultNullable) | ||
| case _ => | ||
| findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { | ||
| // TODO: handle functions defined in Scala too - in Scala, even if a | ||
| // subclass do not override the default method in parent interface | ||
| // defined in Java, the method can still be found from | ||
| // `getDeclaredMethod`. | ||
| // since `inputType` is a `StructType`, it is mapped to a `InternalRow` | ||
| // which we can use to lookup the `produceResult` method. | ||
| findMethod(scalarFunc, "produceResult", Seq(inputType)) match { | ||
| case Some(_) => | ||
| val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) | ||
| Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), | ||
| arguments, returnNullable = scalarFunc.isResultNullable) | ||
| case _ => | ||
| // TODO: handle functions defined in Scala too - in Scala, even if a | ||
| // subclass do not override the default method in parent interface | ||
| // defined in Java, the method can still be found from | ||
| // `getDeclaredMethod`. | ||
| // since `inputType` is a `StructType`, it is mapped to a `InternalRow` | ||
| // which we can use to lookup the `produceResult` method. | ||
| findMethod(scalarFunc, "produceResult", Seq(inputType)) match { | ||
| case Some(_) => | ||
| ApplyFunctionExpression(scalarFunc, arguments) | ||
| case None => | ||
| failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" + | ||
| s" magic method nor override 'produceResult'") | ||
| } | ||
| ApplyFunctionExpression(scalarFunc, arguments) | ||
| case None => | ||
| failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" + | ||
| s" magic method nor override 'produceResult'") | ||
| } | ||
| } | ||
| } | ||
|
|
||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package test.org.apache.spark.sql.connector.catalog.functions; | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow; | ||
| import org.apache.spark.sql.connector.catalog.functions.BoundFunction; | ||
| import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; | ||
| import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; | ||
| import org.apache.spark.sql.types.DataType; | ||
| import org.apache.spark.sql.types.DataTypes; | ||
| import org.apache.spark.sql.types.LongType; | ||
| import org.apache.spark.sql.types.StructField; | ||
| import org.apache.spark.sql.types.StructType; | ||
|
|
||
| public class JavaLongAdd implements UnboundFunction { | ||
| private final ScalarFunction<Long> impl; | ||
|
|
||
| public JavaLongAdd(ScalarFunction<Long> impl) { | ||
| this.impl = impl; | ||
| } | ||
|
|
||
| @Override | ||
| public String name() { | ||
| return "long_add"; | ||
| } | ||
|
|
||
| @Override | ||
| public BoundFunction bind(StructType inputType) { | ||
| if (inputType.fields().length != 2) { | ||
| throw new UnsupportedOperationException("Expect two arguments"); | ||
| } | ||
| StructField[] fields = inputType.fields(); | ||
| if (!(fields[0].dataType() instanceof LongType)) { | ||
| throw new UnsupportedOperationException("Expect first argument to be LongType"); | ||
| } | ||
| if (!(fields[1].dataType() instanceof LongType)) { | ||
| throw new UnsupportedOperationException("Expect second argument to be LongType"); | ||
| } | ||
| return impl; | ||
| } | ||
|
|
||
| @Override | ||
| public String description() { | ||
| return "long_add"; | ||
| } | ||
|
|
||
| public static abstract class JavaLongAddBase implements ScalarFunction<Long> { | ||
| private final boolean isResultNullable; | ||
|
|
||
| public JavaLongAddBase(boolean isResultNullable) { | ||
| this.isResultNullable = isResultNullable; | ||
| } | ||
|
|
||
| @Override | ||
| public DataType[] inputTypes() { | ||
| return new DataType[] { DataTypes.LongType, DataTypes.LongType }; | ||
| } | ||
|
|
||
| @Override | ||
| public DataType resultType() { | ||
| return DataTypes.LongType; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean isResultNullable() { | ||
| return isResultNullable; | ||
| } | ||
| } | ||
|
|
||
| public static class JavaLongAddDefault extends JavaLongAddBase { | ||
| public JavaLongAddDefault(boolean isResultNullable) { | ||
| super(isResultNullable); | ||
| } | ||
dongjoon-hyun marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| @Override | ||
| public String name() { | ||
| return "long_add_default"; | ||
| } | ||
|
|
||
| @Override | ||
| public Long produceResult(InternalRow input) { | ||
| return input.getLong(0) + input.getLong(1); | ||
| } | ||
| } | ||
|
|
||
| public static class JavaLongAddMagic extends JavaLongAddBase { | ||
| public JavaLongAddMagic(boolean isResultNullable) { | ||
| super(isResultNullable); | ||
| } | ||
|
|
||
| @Override | ||
| public String name() { | ||
| return "add_long_magic"; | ||
sunchao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| public long invoke(long left, long right) { | ||
| return left + right; | ||
| } | ||
| } | ||
|
|
||
| public static class JavaLongAddStaticMagic extends JavaLongAddBase { | ||
| public JavaLongAddStaticMagic(boolean isResultNullable) { | ||
| super(isResultNullable); | ||
| } | ||
|
|
||
| @Override | ||
| public String name() { | ||
| return "long_add_static_magic"; | ||
| } | ||
|
|
||
| public static long invoke(long left, long right) { | ||
| return left + right; | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.