diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 216a8c4cc4326..14b4cb8cbdaab 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2894,6 +2894,8 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructTypeOr # treated as struct or element type of array in order to make it more # R-friendly. if (class(schema) == "Column") { + df <- createDataFrame(list(list(0))) + jschema <- collect(select(df, schema))[[1]][[1]] jschema <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createArrayType", jschema) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index b94a33007b1f1..6f79de1c91555 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -7102,7 +7102,8 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def sequence(start: Column, stop: Column): Column = sequence(start, stop, lit(1L)) + def sequence(start: Column, stop: Column): Column = + Column.fn("sequence", start, stop) /** * Creates an array containing the left argument repeated the number of times given by the right diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala index 1240fbc9ada88..ab470e9aaaa2c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala @@ -218,7 +218,6 @@ class FunctionTestSuite extends ConnectFunSuite { to_json(a, Collections.emptyMap[String, String]), to_json(a, Map.empty[String, String])) testEquals("sort_array", sort_array(a), sort_array(a, asc = true)) - testEquals("sequence", sequence(lit(1), lit(10)), sequence(lit(1), lit(10), lit(1L))) testEquals( "from_csv", from_csv(a, lit(schema.toDDL), Collections.emptyMap[String, String]), diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_arrays_zip.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_arrays_zip.explain index 0dc3f43b074dc..36f6e81a424a0 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_arrays_zip.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_arrays_zip.explain @@ -1,2 +1,2 @@ -Project [arrays_zip(e#0, sequence(cast(1 as bigint), cast(20 as bigint), Some(cast(1 as bigint)), Some(America/Los_Angeles)), e, 1) AS arrays_zip(e, sequence(1, 20, 1))#0] +Project [arrays_zip(e#0, sequence(1, 20, None, Some(America/Los_Angeles)), e, 1) AS arrays_zip(e, sequence(1, 20))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_concat.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_concat.explain index 4d765e5a9c3e6..55faccce1c77c 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_concat.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_concat.explain @@ -1,2 +1,2 @@ -Project [concat(cast(e#0 as array), cast(array(1, 2) as array), sequence(cast(33 as bigint), cast(40 as bigint), Some(cast(1 as bigint)), Some(America/Los_Angeles))) AS concat(e, array(1, 2), sequence(33, 40, 1))#0] +Project [concat(e#0, array(1, 2), sequence(33, 40, None, Some(America/Los_Angeles))) AS concat(e, array(1, 2), sequence(33, 40))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_flatten.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_flatten.explain index ebdb5617a86a4..8bef8d240dc8f 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_flatten.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_flatten.explain @@ -1,2 +1,2 @@ -Project [flatten(array(cast(e#0 as array), sequence(cast(1 as bigint), cast(10 as bigint), Some(cast(1 as bigint)), Some(America/Los_Angeles)))) AS flatten(array(e, sequence(1, 10, 1)))#0] +Project [flatten(array(e#0, sequence(1, 10, None, Some(America/Los_Angeles)))) AS flatten(array(e, sequence(1, 10)))#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_sequence.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_sequence.explain index 2a71190c269c7..ad6db4c38dda0 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_sequence.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_sequence.explain @@ -1,2 +1,2 @@ -Project [sequence(cast(1 as bigint), cast(10 as bigint), Some(cast(1 as bigint)), Some(America/Los_Angeles)) AS sequence(1, 10, 1)#0] +Project [sequence(1, 10, None, Some(America/Los_Angeles)) AS sequence(1, 10)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_arrays_zip.json b/connector/connect/common/src/test/resources/query-tests/queries/function_arrays_zip.json index 14769082725f1..f24ee44835eb4 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_arrays_zip.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_arrays_zip.json @@ -29,10 +29,6 @@ "literal": { "integer": 20 } - }, { - "literal": { - "long": "1" - } }] } }] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_arrays_zip.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_arrays_zip.proto.bin index 609f52db32478..67c867e6d450c 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_arrays_zip.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_arrays_zip.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_concat.json b/connector/connect/common/src/test/resources/query-tests/queries/function_concat.json index 4a053d9c3c354..bad1ad6f3b90e 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_concat.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_concat.json @@ -42,10 +42,6 @@ "literal": { "integer": 40 } - }, { - "literal": { - "long": "1" - } }] } }] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_concat.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_concat.proto.bin index e53eb7a75b8a2..7411f55f14747 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_concat.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_concat.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_flatten.json b/connector/connect/common/src/test/resources/query-tests/queries/function_flatten.json index 32da97271d2dd..1f04630fd5f31 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_flatten.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_flatten.json @@ -32,10 +32,6 @@ "literal": { "integer": 10 } - }, { - "literal": { - "long": "1" - } }] } }] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_flatten.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_flatten.proto.bin index e6bb018a37005..9a684850f9cfa 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_flatten.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_flatten.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_sequence.json b/connector/connect/common/src/test/resources/query-tests/queries/function_sequence.json index 84bced640ff37..b8bd1b68c9a8f 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_sequence.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_sequence.json @@ -22,10 +22,6 @@ "literal": { "integer": 10 } - }, { - "literal": { - "long": "1" - } }] } }] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_sequence.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_sequence.proto.bin index 09e1ab3be7dab..36f1980f4ec2b 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_sequence.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_sequence.proto.bin differ diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index e0de99e7a6dd1..d91cfdf529513 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -712,11 +712,11 @@ def __getitem__(self, k: Any) -> "Column": -------- >>> df = spark.createDataFrame([('abcedfg', {"key": "value"})], ["l", "d"]) >>> df.select(df.l[slice(1, 3)], df.d['key']).show() - +------------------+------+ - |substring(l, 1, 3)|d[key]| - +------------------+------+ - | abc| value| - +------------------+------+ + +---------------+------+ + |substr(l, 1, 3)|d[key]| + +---------------+------+ + | abc| value| + +---------------+------+ """ if isinstance(k, slice): if k.step is not None: diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 0529293816338..19ec93151f00e 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -258,7 +258,7 @@ def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) - error_class="NOT_COLUMN_OR_INT", message_parameters={"arg_name": "length", "arg_type": type(length).__name__}, ) - return Column(UnresolvedFunction("substring", [self._expr, start_expr, length_expr])) + return Column(UnresolvedFunction("substr", [self._expr, start_expr, length_expr])) substr.__doc__ = PySparkColumn.substr.__doc__ diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 24b552a45e642..4749c642975bd 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -1358,7 +1358,7 @@ def var_samp(col: "ColumnOrName") -> Column: def variance(col: "ColumnOrName") -> Column: - return var_samp(col) + return _invoke_function_over_columns("variance", col) variance.__doc__ = pysparkfuncs.variance.__doc__ @@ -1944,7 +1944,7 @@ def map_concat( def map_contains_key(col: "ColumnOrName", value: Any) -> Column: - return array_contains(map_keys(col), lit(value)) + return _invoke_function("map_contains_key", _to_col(col), lit(value)) map_contains_key.__doc__ = pysparkfuncs.map_contains_key.__doc__ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5474873df7b21..6297ce6869965 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2960,7 +2960,7 @@ def variance(col: "ColumnOrName") -> Column: >>> df = spark.range(6) >>> df.select(variance(df.id)).show() +------------+ - |var_samp(id)| + |variance(id)| +------------+ | 3.5| +------------+ @@ -13779,17 +13779,17 @@ def map_contains_key(col: "ColumnOrName", value: Any) -> Column: >>> from pyspark.sql.functions import map_contains_key >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") >>> df.select(map_contains_key("data", 1)).show() - +---------------------------------+ - |array_contains(map_keys(data), 1)| - +---------------------------------+ - | true| - +---------------------------------+ + +-------------------------+ + |map_contains_key(data, 1)| + +-------------------------+ + | true| + +-------------------------+ >>> df.select(map_contains_key("data", -1)).show() - +----------------------------------+ - |array_contains(map_keys(data), -1)| - +----------------------------------+ - | false| - +----------------------------------+ + +--------------------------+ + |map_contains_key(data, -1)| + +--------------------------+ + | false| + +--------------------------+ """ return _invoke_function("map_contains_key", _to_java_column(col), value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 2051219131219..64a31ed44b2b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -374,6 +374,7 @@ object CreateStruct { case (u @ UnresolvedExtractValue(_, e: Literal), _) if e.dataType == StringType => Seq(e, u) case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) case (e: NamedExpression, _) => Seq(NamePlaceholder, e) + case (g @ GetStructField(_, _, Some(name)), _) => Seq(Literal(name), g) case (e, index) => Seq(Literal(s"col${index + 1}"), e) }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 3298daa81f29b..643cbc3cbdb9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -319,12 +319,12 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "type" -> v.getClass.toString)) } - def pivotColumnUnsupportedError(v: Any, dataType: DataType): RuntimeException = { + def pivotColumnUnsupportedError(v: Any, expr: Expression): RuntimeException = { new SparkRuntimeException( errorClass = "UNSUPPORTED_FEATURE.PIVOT_TYPE", messageParameters = Map( "value" -> v.toString, - "type" -> toSQLType(dataType))) + "type" -> (if (expr.resolved) toSQLType(expr.dataType) else "unknown"))) } def noDefaultForDataTypeError(dataType: DataType): SparkException = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 39e4815fc57c5..bb326119ab49c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -57,6 +57,22 @@ private[sql] object Column { .build() a.withMetadata(metadataWithoutId) } + + private[sql] def fn(name: String, inputs: Column*): Column = { + fn(name, isDistinct = false, ignoreNulls = false, inputs: _*) + } + + private[sql] def fn(name: String, isDistinct: Boolean, inputs: Column*): Column = { + fn(name, isDistinct = isDistinct, ignoreNulls = false, inputs: _*) + } + + private[sql] def fn( + name: String, + isDistinct: Boolean, + ignoreNulls: Boolean, + inputs: Column*): Column = Column { + UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) + } } /** @@ -140,6 +156,16 @@ class Column(val expr: Expression) extends Logging { case _ => UnresolvedAttribute.quotedString(name) }) + private def fn(name: String): Column = { + Column.fn(name, this) + } + private def fn(name: String, other: Column): Column = { + Column.fn(name, this, other) + } + private def fn(name: String, other: Any): Column = { + Column.fn(name, this, lit(other)) + } + override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { @@ -218,7 +244,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def unary_- : Column = withExpr { UnaryMinus(expr) } + def unary_- : Column = fn("negative") /** * Inversion of boolean expression, i.e. NOT. @@ -234,7 +260,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def unary_! : Column = withExpr { Not(expr) } + def unary_! : Column = fn("!") /** * Equality test. @@ -250,14 +276,14 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def === (other: Any): Column = withExpr { + def ===(other: Any): Column = { val right = lit(other).expr if (this.expr == right) { logWarning( s"Constructing trivially true equals predicate, '${this.expr} = $right'. " + - "Perhaps you need to use aliases.") + "Perhaps you need to use aliases.") } - EqualTo(expr, right) + fn("=", other) } /** @@ -291,7 +317,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.0.0 */ - def =!= (other: Any): Column = withExpr{ Not(EqualTo(expr, lit(other).expr)) } + def =!= (other: Any): Column = !(this === other) /** * Inequality test. @@ -326,7 +352,7 @@ class Column(val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def notEqual(other: Any): Column = withExpr { Not(EqualTo(expr, lit(other).expr)) } + def notEqual(other: Any): Column = this =!= other /** * Greater than. @@ -342,7 +368,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def > (other: Any): Column = withExpr { GreaterThan(expr, lit(other).expr) } + def >(other: Any): Column = fn(">", other) /** * Greater than. @@ -373,7 +399,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def < (other: Any): Column = withExpr { LessThan(expr, lit(other).expr) } + def <(other: Any): Column = fn("<", other) /** * Less than. @@ -403,7 +429,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <= (other: Any): Column = withExpr { LessThanOrEqual(expr, lit(other).expr) } + def <=(other: Any): Column = fn("<=", other) /** * Less than or equal to. @@ -433,7 +459,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def >= (other: Any): Column = withExpr { GreaterThanOrEqual(expr, lit(other).expr) } + def >=(other: Any): Column = fn(">=", other) /** * Greater than or equal to an expression. @@ -456,14 +482,14 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <=> (other: Any): Column = withExpr { + def <=>(other: Any): Column = { val right = lit(other).expr if (this.expr == right) { logWarning( s"Constructing trivially true equals predicate, '${this.expr} <=> $right'. " + "Perhaps you need to use aliases.") } - EqualNullSafe(expr, right) + fn("<=>", other) } /** @@ -495,15 +521,17 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def when(condition: Column, value: Any): Column = this.expr match { - case CaseWhen(branches, None) => - withExpr { CaseWhen(branches :+ ((condition.expr, lit(value).expr))) } - case CaseWhen(branches, Some(_)) => - throw new IllegalArgumentException( - "when() cannot be applied once otherwise() is applied") - case _ => - throw new IllegalArgumentException( - "when() can only be applied on a Column previously generated by when() function") + def when(condition: Column, value: Any): Column = withExpr { + this.expr match { + case CaseWhen(branches, None) => + CaseWhen(branches :+ ((condition.expr, lit(value).expr))) + case CaseWhen(_, Some(_)) => + throw new IllegalArgumentException( + "when() cannot be applied once otherwise() is applied") + case _ => + throw new IllegalArgumentException( + "when() can only be applied on a Column previously generated by when() function") + } } /** @@ -527,15 +555,17 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def otherwise(value: Any): Column = this.expr match { - case CaseWhen(branches, None) => - withExpr { CaseWhen(branches, Option(lit(value).expr)) } - case CaseWhen(branches, Some(_)) => - throw new IllegalArgumentException( - "otherwise() can only be applied once on a Column previously generated by when()") - case _ => - throw new IllegalArgumentException( - "otherwise() can only be applied on a Column previously generated by when()") + def otherwise(value: Any): Column = withExpr { + this.expr match { + case CaseWhen(branches, None) => + CaseWhen(branches, Option(lit(value).expr)) + case CaseWhen(_, Some(_)) => + throw new IllegalArgumentException( + "otherwise() can only be applied once on a Column previously generated by when()") + case _ => + throw new IllegalArgumentException( + "otherwise() can only be applied on a Column previously generated by when()") + } } /** @@ -554,7 +584,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.5.0 */ - def isNaN: Column = withExpr { IsNaN(expr) } + def isNaN: Column = fn("isNaN") /** * True if the current expression is null. @@ -562,7 +592,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def isNull: Column = withExpr { IsNull(expr) } + def isNull: Column = fn("isNull") /** * True if the current expression is NOT null. @@ -570,7 +600,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def isNotNull: Column = withExpr { IsNotNull(expr) } + def isNotNull: Column = fn("isNotNull") /** * Boolean OR. @@ -585,7 +615,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def || (other: Any): Column = withExpr { Or(expr, lit(other).expr) } + def ||(other: Any): Column = fn("or", other) /** * Boolean OR. @@ -615,7 +645,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def && (other: Any): Column = withExpr { And(expr, lit(other).expr) } + def &&(other: Any): Column = fn("and", other) /** * Boolean AND. @@ -645,7 +675,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def + (other: Any): Column = withExpr { Add(expr, lit(other).expr) } + def +(other: Any): Column = fn("+", other) /** * Sum of this expression and another expression. @@ -675,7 +705,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def - (other: Any): Column = withExpr { Subtract(expr, lit(other).expr) } + def -(other: Any): Column = fn("-", other) /** * Subtraction. Subtract the other expression from this expression. @@ -705,7 +735,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def * (other: Any): Column = withExpr { Multiply(expr, lit(other).expr) } + def *(other: Any): Column = fn("*", other) /** * Multiplication of this expression and another expression. @@ -735,7 +765,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def / (other: Any): Column = withExpr { Divide(expr, lit(other).expr) } + def /(other: Any): Column = fn("/", other) /** * Division this expression by another expression. @@ -758,7 +788,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def % (other: Any): Column = withExpr { Remainder(expr, lit(other).expr) } + def %(other: Any): Column = fn("%", other) /** * Modulo (a.k.a. remainder) expression. @@ -826,7 +856,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def like(literal: String): Column = withExpr { new Like(expr, lit(literal).expr) } + def like(literal: String): Column = fn("like", literal) /** * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex @@ -835,7 +865,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def rlike(literal: String): Column = withExpr { RLike(expr, lit(literal).expr) } + def rlike(literal: String): Column = fn("rlike", literal) /** * SQL ILIKE expression (case insensitive LIKE). @@ -843,7 +873,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 3.3.0 */ - def ilike(literal: String): Column = withExpr { new ILike(expr, lit(literal).expr) } + def ilike(literal: String): Column = fn("ilike", literal) /** * An expression that gets an item at position `ordinal` out of an array, @@ -852,7 +882,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) } + def getItem(key: Any): Column = apply(key) // scalastyle:off line.size.limit /** @@ -983,9 +1013,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def getField(fieldName: String): Column = withExpr { - UnresolvedExtractValue(expr, Literal(fieldName)) - } + def getField(fieldName: String): Column = apply(fieldName) /** * An expression that returns a substring. @@ -995,9 +1023,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def substr(startPos: Column, len: Column): Column = withExpr { - Substring(expr, startPos.expr, len.expr) - } + def substr(startPos: Column, len: Column): Column = Column.fn("substr", this, startPos, len) /** * An expression that returns a substring. @@ -1007,9 +1033,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def substr(startPos: Int, len: Int): Column = withExpr { - Substring(expr, lit(startPos).expr, lit(len).expr) - } + def substr(startPos: Int, len: Int): Column = substr(lit(startPos), lit(len)) /** * Contains the other element. Returns a boolean column based on a string match. @@ -1017,7 +1041,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def contains(other: Any): Column = withExpr { Contains(expr, lit(other).expr) } + def contains(other: Any): Column = fn("contains", other) /** * String starts with. Returns a boolean column based on a string match. @@ -1025,7 +1049,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def startsWith(other: Column): Column = withExpr { StartsWith(expr, lit(other).expr) } + def startsWith(other: Column): Column = fn("startswith", other) /** * String starts with another string literal. Returns a boolean column based on a string match. @@ -1033,7 +1057,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def startsWith(literal: String): Column = this.startsWith(lit(literal)) + def startsWith(literal: String): Column = startsWith(lit(literal)) /** * String ends with. Returns a boolean column based on a string match. @@ -1041,7 +1065,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def endsWith(other: Column): Column = withExpr { EndsWith(expr, lit(other).expr) } + def endsWith(other: Column): Column = fn("endswith", other) /** * String ends with another string literal. Returns a boolean column based on a string match. @@ -1049,7 +1073,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def endsWith(literal: String): Column = this.endsWith(lit(literal)) + def endsWith(literal: String): Column = endsWith(lit(literal)) /** * Gives the column an alias. Same as `as`. @@ -1308,7 +1332,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseOR(other: Any): Column = withExpr { BitwiseOr(expr, lit(other).expr) } + def bitwiseOR(other: Any): Column = fn("|", other) /** * Compute bitwise AND of this expression with another expression. @@ -1319,7 +1343,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseAND(other: Any): Column = withExpr { BitwiseAnd(expr, lit(other).expr) } + def bitwiseAND(other: Any): Column = fn("&", other) /** * Compute bitwise XOR of this expression with another expression. @@ -1330,7 +1354,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseXOR(other: Any): Column = withExpr { BitwiseXor(expr, lit(other).expr) } + def bitwiseXOR(other: Any): Column = fn("^", other) /** * Defines a windowing column. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 11327cdf7d1d3..f67f2b2cfde20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{NumericType, StructType} @@ -464,7 +463,7 @@ class RelationalGroupedDataset protected[sql]( Literal.apply(v) } catch { case _: SparkRuntimeException => - throw QueryExecutionErrors.pivotColumnUnsupportedError(v, pivotColumn.expr.dataType) + throw QueryExecutionErrors.pivotColumnUnsupportedError(v, pivotColumn.expr) } }) new RelationalGroupedDataset( @@ -706,9 +705,8 @@ private[sql] object RelationalGroupedDataset { private def alias(expr: Expression): NamedExpression = expr match { case expr: NamedExpression => expr - case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => - UnresolvedAlias(a, Some(Column.generateAlias)) - case u: UnresolvedFunction => UnresolvedAlias(expr, None) + case a: AggregateExpression => UnresolvedAlias(a, Some(Column.generateAlias)) + case _ if !expr.resolved => UnresolvedAlias(expr, None) case expr: Expression => Alias(expr, toPrettySQL(expr))() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 350622eb3958f..a76465ff42add 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericRowWithSchema} +import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericRowWithSchema, Literal} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION @@ -220,8 +220,8 @@ private[sql] object SQLUtils extends Logging { sparkSession.catalog.listTables(db).collect().map(_.name) } - def createArrayType(column: Column): ArrayType = { - new ArrayType(ExprUtils.evalTypeExpr(column.expr), true) + def createArrayType(elementType: String): ArrayType = { + ArrayType(ExprUtils.evalTypeExpr(Literal(elementType)), true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a4b8f1b1b6849..1b832fc437cda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Collections + import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -30,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} -import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} @@ -272,9 +274,7 @@ object functions { * @group agg_funcs * @since 2.1.0 */ - def approx_count_distinct(e: Column): Column = withAggregateFunction { - HyperLogLogPlusPlus(e.expr) - } + def approx_count_distinct(e: Column): Column = Column.fn("approx_count_distinct", e) /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -292,8 +292,8 @@ object functions { * @group agg_funcs * @since 2.1.0 */ - def approx_count_distinct(e: Column, rsd: Double): Column = withAggregateFunction { - HyperLogLogPlusPlus(e.expr, rsd, 0, 0) + def approx_count_distinct(e: Column, rsd: Double): Column = { + Column.fn("approx_count_distinct", e, lit(rsd)) } /** @@ -314,7 +314,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def avg(e: Column): Column = withAggregateFunction { Average(e.expr) } + def avg(e: Column): Column = Column.fn("avg", e) /** * Aggregate function: returns the average of the values in a group. @@ -333,7 +333,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def collect_list(e: Column): Column = withAggregateFunction { CollectList(e.expr) } + def collect_list(e: Column): Column = Column.fn("collect_list", e) /** * Aggregate function: returns a list of objects with duplicates. @@ -355,7 +355,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def collect_set(e: Column): Column = withAggregateFunction { CollectSet(e.expr) } + def collect_set(e: Column): Column = Column.fn("collect_set", e) /** * Aggregate function: returns a set of objects with duplicate elements eliminated. @@ -377,13 +377,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def count_min_sketch( - e: Column, - eps: Column, - confidence: Column, - seed: Column): Column = withAggregateFunction { - new CountMinSketchAgg(e.expr, eps.expr, confidence.expr, seed.expr) - } + def count_min_sketch(e: Column, eps: Column, confidence: Column, seed: Column): Column = + Column.fn("count_min_sketch", e, eps, confidence, seed) private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = withAggregateFunction { CollectTopK(e.expr, num, reverse) } @@ -394,9 +389,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def corr(column1: Column, column2: Column): Column = withAggregateFunction { - Corr(column1.expr, column2.expr) - } + def corr(column1: Column, column2: Column): Column = Column.fn("corr", column1, column2) /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. @@ -414,12 +407,13 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(e: Column): Column = withAggregateFunction { - e.expr match { + def count(e: Column): Column = { + val withoutStar = e.expr match { // Turn count(*) into count(1) - case s: Star => Count(Literal(1)) - case _ => Count(e.expr) + case _: Star => Column(Literal(1)) + case _ => e } + Column.fn("count", withoutStar) } /** @@ -462,9 +456,7 @@ object functions { */ @scala.annotation.varargs def count_distinct(expr: Column, exprs: Column*): Column = - // For usage like countDistinct("*"), we should let analyzer expand star and - // resolve function. - Column(UnresolvedFunction("count", (expr +: exprs).map(_.expr), isDistinct = true)) + Column.fn("count", isDistinct = true, expr +: exprs: _*) /** * Aggregate function: returns the population covariance for two columns. @@ -472,9 +464,8 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def covar_pop(column1: Column, column2: Column): Column = withAggregateFunction { - CovPopulation(column1.expr, column2.expr) - } + def covar_pop(column1: Column, column2: Column): Column = + Column.fn("covar_pop", column1, column2) /** * Aggregate function: returns the population covariance for two columns. @@ -492,9 +483,8 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def covar_samp(column1: Column, column2: Column): Column = withAggregateFunction { - CovSample(column1.expr, column2.expr) - } + def covar_samp(column1: Column, column2: Column): Column = + Column.fn("covar_samp", column1, column2) /** * Aggregate function: returns the sample covariance for two columns. @@ -518,9 +508,8 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - First(e.expr, ignoreNulls) - } + def first(e: Column, ignoreNulls: Boolean): Column = + Column.fn("first", false, ignoreNulls, e) /** * Aggregate function: returns the first value of a column in a group. @@ -575,7 +564,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def first_value(e: Column): Column = call_function("first_value", e) + def first_value(e: Column): Column = Column.fn("first_value", e) /** * Aggregate function: returns the first value in a group. @@ -590,7 +579,7 @@ object functions { * @since 3.5.0 */ def first_value(e: Column, ignoreNulls: Column): Column = - call_function("first_value", e, ignoreNulls) + Column.fn("first_value", e, ignoreNulls) /** * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated @@ -599,7 +588,7 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def grouping(e: Column): Column = Column(Grouping(e.expr)) + def grouping(e: Column): Column = Column.fn("grouping", e) /** * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated @@ -623,7 +612,7 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr))) + def grouping_id(cols: Column*): Column = Column.fn("grouping_id", cols: _*) /** * Aggregate function: returns the level of grouping, equals to @@ -648,9 +637,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def hll_sketch_agg(e: Column, lgConfigK: Column): Column = withAggregateFunction { - HllSketchAgg(e.expr, lgConfigK.expr) - } + def hll_sketch_agg(e: Column, lgConfigK: Column): Column = + Column.fn("hll_sketch_agg", e, lgConfigK) /** * Aggregate function: returns the updatable binary representation of the Datasketches @@ -660,9 +648,7 @@ object functions { * @since 3.5.0 */ def hll_sketch_agg(e: Column, lgConfigK: Int): Column = - withAggregateFunction { - new HllSketchAgg(e.expr, Literal(lgConfigK)) - } + Column.fn("hll_sketch_agg", e, lit(lgConfigK)) /** * Aggregate function: returns the updatable binary representation of the Datasketches @@ -682,9 +668,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def hll_sketch_agg(e: Column): Column = withAggregateFunction { - new HllSketchAgg(e.expr) - } + def hll_sketch_agg(e: Column): Column = + Column.fn("hll_sketch_agg", e) /** * Aggregate function: returns the updatable binary representation of the Datasketches @@ -706,9 +691,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def hll_union_agg(e: Column, allowDifferentLgConfigK: Column): Column = withAggregateFunction { - new HllUnionAgg(e.expr, allowDifferentLgConfigK.expr) - } + def hll_union_agg(e: Column, allowDifferentLgConfigK: Column): Column = + Column.fn("hll_union_agg", e, allowDifferentLgConfigK) /** * Aggregate function: returns the updatable binary representation of the Datasketches @@ -719,9 +703,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def hll_union_agg(e: Column, allowDifferentLgConfigK: Boolean): Column = withAggregateFunction { - new HllUnionAgg(e.expr, allowDifferentLgConfigK) - } + def hll_union_agg(e: Column, allowDifferentLgConfigK: Boolean): Column = + Column.fn("hll_union_agg", e, lit(allowDifferentLgConfigK)) /** * Aggregate function: returns the updatable binary representation of the Datasketches @@ -745,9 +728,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def hll_union_agg(e: Column): Column = withAggregateFunction { - new HllUnionAgg(e.expr) - } + def hll_union_agg(e: Column): Column = + Column.fn("hll_union_agg", e) /** * Aggregate function: returns the updatable binary representation of the Datasketches @@ -768,7 +750,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) } + def kurtosis(e: Column): Column = Column.fn("kurtosis", e) /** * Aggregate function: returns the kurtosis of the values in a group. @@ -790,9 +772,8 @@ object functions { * @group agg_funcs * @since 2.0.0 */ - def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - Last(e.expr, ignoreNulls) - } + def last(e: Column, ignoreNulls: Boolean): Column = + Column.fn("last", false, ignoreNulls, e) /** * Aggregate function: returns the last value of the column in a group. @@ -847,7 +828,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def last_value(e: Column): Column = call_function("last_value", e) + def last_value(e: Column): Column = Column.fn("last_value", e) /** * Aggregate function: returns the last value in a group. @@ -862,7 +843,7 @@ object functions { * @since 3.5.0 */ def last_value(e: Column, ignoreNulls: Column): Column = - call_function("last_value", e, ignoreNulls) + Column.fn("last_value", e, ignoreNulls) /** * Aggregate function: returns the most frequent value in a group. @@ -870,7 +851,7 @@ object functions { * @group agg_funcs * @since 3.4.0 */ - def mode(e: Column): Column = mode(e, deterministic = false) + def mode(e: Column): Column = Column.fn("mode", e) /** * Aggregate function: returns the most frequent value in a group. @@ -882,9 +863,7 @@ object functions { * @group agg_funcs * @since 4.0.0 */ - def mode(e: Column, deterministic: Boolean): Column = withAggregateFunction { - Mode(e.expr, deterministicExpr = lit(deterministic).expr) - } + def mode(e: Column, deterministic: Boolean): Column = Column.fn("mode", e, lit(deterministic)) /** * Aggregate function: returns the maximum value of the expression in a group. @@ -892,7 +871,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def max(e: Column): Column = withAggregateFunction { Max(e.expr) } + def max(e: Column): Column = Column.fn("max", e) /** * Aggregate function: returns the maximum value of the column in a group. @@ -908,7 +887,7 @@ object functions { * @group agg_funcs * @since 3.3.0 */ - def max_by(e: Column, ord: Column): Column = withAggregateFunction { MaxBy(e.expr, ord.expr) } + def max_by(e: Column, ord: Column): Column = Column.fn("max_by", e, ord) /** * Aggregate function: returns the average of the values in a group. @@ -934,7 +913,7 @@ object functions { * @group agg_funcs * @since 3.4.0 */ - def median(e: Column): Column = withAggregateFunction { Median(e.expr) } + def median(e: Column): Column = Column.fn("median", e) /** * Aggregate function: returns the minimum value of the expression in a group. @@ -942,7 +921,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def min(e: Column): Column = withAggregateFunction { Min(e.expr) } + def min(e: Column): Column = Column.fn("min", e) /** * Aggregate function: returns the minimum value of the column in a group. @@ -958,7 +937,7 @@ object functions { * @group agg_funcs * @since 3.3.0 */ - def min_by(e: Column, ord: Column): Column = withAggregateFunction { MinBy(e.expr, ord.expr) } + def min_by(e: Column, ord: Column): Column = Column.fn("min_by", e, ord) /** * Aggregate function: returns the exact percentile(s) of numeric column `expr` at the @@ -967,11 +946,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def percentile(e: Column, percentage: Column): Column = { - withAggregateFunction { - new Percentile(e.expr, percentage.expr) - } - } + def percentile(e: Column, percentage: Column): Column = Column.fn("percentile", e, percentage) /** * Aggregate function: returns the exact percentile(s) of numeric column `expr` at the @@ -980,14 +955,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def percentile( - e: Column, - percentage: Column, - frequency: Column): Column = { - withAggregateFunction { - new Percentile(e.expr, percentage.expr, frequency.expr) - } - } + def percentile(e: Column, percentage: Column, frequency: Column): Column = + Column.fn("percentile", e, percentage, frequency) /** * Aggregate function: returns the approximate `percentile` of the numeric column `col` which @@ -1005,13 +974,8 @@ object functions { * @group agg_funcs * @since 3.1.0 */ - def percentile_approx(e: Column, percentage: Column, accuracy: Column): Column = { - withAggregateFunction { - new ApproximatePercentile( - e.expr, percentage.expr, accuracy.expr - ) - } - } + def percentile_approx(e: Column, percentage: Column, accuracy: Column): Column = + Column.fn("percentile_approx", e, percentage, accuracy) /** * Aggregate function: returns the approximate `percentile` of the numeric column `col` which @@ -1029,8 +993,9 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def approx_percentile(e: Column, percentage: Column, accuracy: Column): Column = - call_function("approx_percentile", e, percentage, accuracy) + def approx_percentile(e: Column, percentage: Column, accuracy: Column): Column = { + Column.fn("approx_percentile", e, percentage, accuracy) + } /** * Aggregate function: returns the product of all numerical elements in a group. @@ -1047,7 +1012,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) } + def skewness(e: Column): Column = Column.fn("skewness", e) /** * Aggregate function: returns the skewness of the values in a group. @@ -1063,7 +1028,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def std(e: Column): Column = call_function("std", e) + def std(e: Column): Column = Column.fn("std", e) /** * Aggregate function: alias for `stddev_samp`. @@ -1071,7 +1036,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = call_function("stddev", e) + def stddev(e: Column): Column = Column.fn("stddev", e) /** * Aggregate function: alias for `stddev_samp`. @@ -1088,7 +1053,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + def stddev_samp(e: Column): Column = Column.fn("stddev_samp", e) /** * Aggregate function: returns the sample standard deviation of @@ -1106,7 +1071,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) } + def stddev_pop(e: Column): Column = Column.fn("stddev_pop", e) /** * Aggregate function: returns the population standard deviation of @@ -1123,7 +1088,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sum(e: Column): Column = withAggregateFunction { Sum(e.expr) } + def sum(e: Column): Column = Column.fn("sum", e) /** * Aggregate function: returns the sum of all values in the given column. @@ -1140,7 +1105,7 @@ object functions { * @since 1.3.0 */ @deprecated("Use sum_distinct", "3.2.0") - def sumDistinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true) + def sumDistinct(e: Column): Column = sum_distinct(e) /** * Aggregate function: returns the sum of distinct values in the expression. @@ -1157,7 +1122,7 @@ object functions { * @group agg_funcs * @since 3.2.0 */ - def sum_distinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true) + def sum_distinct(e: Column): Column = Column.fn("sum", isDistinct = true, e) /** * Aggregate function: alias for `var_samp`. @@ -1165,7 +1130,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + def variance(e: Column): Column = Column.fn("variance", e) /** * Aggregate function: alias for `var_samp`. @@ -1181,7 +1146,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + def var_samp(e: Column): Column = Column.fn("var_samp", e) /** * Aggregate function: returns the unbiased variance of the values in a group. @@ -1197,7 +1162,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } + def var_pop(e: Column): Column = Column.fn("var_pop", e) /** * Aggregate function: returns the population variance of the values in a group. @@ -1214,7 +1179,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def regr_avgx(y: Column, x: Column): Column = withAggregateFunction { RegrAvgX(y.expr, x.expr) } + def regr_avgx(y: Column, x: Column): Column = Column.fn("regr_avgx", y, x) /** * Aggregate function: returns the average of the independent variable for non-null pairs @@ -1223,7 +1188,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def regr_avgy(y: Column, x: Column): Column = withAggregateFunction { RegrAvgY(y.expr, x.expr) } + def regr_avgy(y: Column, x: Column): Column = Column.fn("regr_avgy", y, x) /** * Aggregate function: returns the number of non-null number pairs @@ -1232,7 +1197,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def regr_count(y: Column, x: Column): Column = withAggregateFunction { RegrCount(y.expr, x.expr) } + def regr_count(y: Column, x: Column): Column = Column.fn("regr_count", y, x) /** * Aggregate function: returns the intercept of the univariate linear regression line @@ -1242,8 +1207,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def regr_intercept(y: Column, x: Column): Column = - withAggregateFunction { RegrIntercept(y.expr, x.expr) } + def regr_intercept(y: Column, x: Column): Column = Column.fn("regr_intercept", y, x) /** * Aggregate function: returns the coefficient of determination for non-null pairs @@ -1252,7 +1216,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def regr_r2(y: Column, x: Column): Column = withAggregateFunction { RegrR2(y.expr, x.expr) } + def regr_r2(y: Column, x: Column): Column = Column.fn("regr_r2", y, x) /** * Aggregate function: returns the slope of the linear regression line for non-null pairs @@ -1261,8 +1225,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def regr_slope(y: Column, x: Column): Column = - withAggregateFunction { RegrSlope(y.expr, x.expr) } + def regr_slope(y: Column, x: Column): Column = Column.fn("regr_slope", y, x) /** * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs @@ -1271,7 +1234,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def regr_sxx(y: Column, x: Column): Column = withAggregateFunction { RegrSXX(y.expr, x.expr) } + def regr_sxx(y: Column, x: Column): Column = Column.fn("regr_sxx", y, x) /** * Aggregate function: returns REGR_COUNT(y, x) * COVAR_POP(y, x) for non-null pairs @@ -1280,7 +1243,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def regr_sxy(y: Column, x: Column): Column = withAggregateFunction { RegrSXY(y.expr, x.expr) } + def regr_sxy(y: Column, x: Column): Column = Column.fn("regr_sxy", y, x) /** * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs @@ -1289,7 +1252,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def regr_syy(y: Column, x: Column): Column = withAggregateFunction { RegrSYY(y.expr, x.expr) } + def regr_syy(y: Column, x: Column): Column = Column.fn("regr_syy", y, x) /** * Aggregate function: returns some value of `e` for a group of rows. @@ -1297,7 +1260,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def any_value(e: Column): Column = withAggregateFunction { new AnyValue(e.expr) } + def any_value(e: Column): Column = Column.fn("any_value", e) /** * Aggregate function: returns some value of `e` for a group of rows. @@ -1307,7 +1270,7 @@ object functions { * @since 3.5.0 */ def any_value(e: Column, ignoreNulls: Column): Column = - withAggregateFunction { new AnyValue(e.expr, ignoreNulls.expr) } + Column.fn("any_value", e, ignoreNulls) /** * Aggregate function: returns the number of `TRUE` values for the expression. @@ -1315,7 +1278,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def count_if(e: Column): Column = withAggregateFunction { CountIf(e.expr) } + def count_if(e: Column): Column = Column.fn("count_if", e) /** * Aggregate function: computes a histogram on numeric 'expr' using nb bins. @@ -1333,7 +1296,7 @@ object functions { * @since 3.5.0 */ def histogram_numeric(e: Column, nBins: Column): Column = - withAggregateFunction { new HistogramNumeric(e.expr, nBins.expr) } + Column.fn("histogram_numeric", e, nBins) /** * Aggregate function: returns true if all values of `e` are true. @@ -1341,7 +1304,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def every(e: Column): Column = call_function("every", e) + def every(e: Column): Column = Column.fn("every", e) /** * Aggregate function: returns true if all values of `e` are true. @@ -1349,7 +1312,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def bool_and(e: Column): Column = withAggregateFunction { BoolAnd(e.expr) } + def bool_and(e: Column): Column = Column.fn("bool_and", e) /** * Aggregate function: returns true if at least one value of `e` is true. @@ -1357,7 +1320,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def some(e: Column): Column = call_function("some", e) + def some(e: Column): Column = Column.fn("some", e) /** * Aggregate function: returns true if at least one value of `e` is true. @@ -1365,7 +1328,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def any(e: Column): Column = call_function("any", e) + def any(e: Column): Column = Column.fn("any", e) /** * Aggregate function: returns true if at least one value of `e` is true. @@ -1373,7 +1336,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def bool_or(e: Column): Column = withAggregateFunction { BoolOr(e.expr) } + def bool_or(e: Column): Column = Column.fn("bool_or", e) /** * Aggregate function: returns the bitwise AND of all non-null input values, or null if none. @@ -1381,7 +1344,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def bit_and(e: Column): Column = withAggregateFunction { BitAndAgg(e.expr) } + def bit_and(e: Column): Column = Column.fn("bit_and", e) /** * Aggregate function: returns the bitwise OR of all non-null input values, or null if none. @@ -1389,7 +1352,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def bit_or(e: Column): Column = withAggregateFunction { BitOrAgg(e.expr) } + def bit_or(e: Column): Column = Column.fn("bit_or", e) /** * Aggregate function: returns the bitwise XOR of all non-null input values, or null if none. @@ -1397,7 +1360,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def bit_xor(e: Column): Column = withAggregateFunction { BitXorAgg(e.expr) } + def bit_xor(e: Column): Column = Column.fn("bit_xor", e) ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions @@ -1415,7 +1378,7 @@ object functions { * @group window_funcs * @since 1.6.0 */ - def cume_dist(): Column = withExpr { new CumeDist } + def cume_dist(): Column = Column.fn("cume_dist") /** * Window function: returns the rank of rows within a window partition, without any gaps. @@ -1431,7 +1394,7 @@ object functions { * @group window_funcs * @since 1.6.0 */ - def dense_rank(): Column = withExpr { new DenseRank } + def dense_rank(): Column = Column.fn("dense_rank") /** * Window function: returns the value that is `offset` rows before the current row, and @@ -1497,9 +1460,8 @@ object functions { * @group window_funcs * @since 3.2.0 */ - def lag(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = withExpr { - Lag(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls) - } + def lag(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = + Column.fn("lag", false, ignoreNulls, e, lit(offset), lit(defaultValue)) /** * Window function: returns the value that is `offset` rows after the current row, and @@ -1565,9 +1527,8 @@ object functions { * @group window_funcs * @since 3.2.0 */ - def lead(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = withExpr { - Lead(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls) - } + def lead(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = + Column.fn("lead", false, ignoreNulls, e, lit(offset), lit(defaultValue)) /** * Window function: returns the value that is the `offset`th row of the window frame @@ -1581,9 +1542,8 @@ object functions { * @group window_funcs * @since 3.1.0 */ - def nth_value(e: Column, offset: Int, ignoreNulls: Boolean): Column = withExpr { - NthValue(e.expr, Literal(offset), ignoreNulls) - } + def nth_value(e: Column, offset: Int, ignoreNulls: Boolean): Column = + Column.fn("nth_value", false, ignoreNulls, e, lit(offset)) /** * Window function: returns the value that is the `offset`th row of the window frame @@ -1594,9 +1554,7 @@ object functions { * @group window_funcs * @since 3.1.0 */ - def nth_value(e: Column, offset: Int): Column = withExpr { - NthValue(e.expr, Literal(offset), false) - } + def nth_value(e: Column, offset: Int): Column = nth_value(e, offset, false) /** * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window @@ -1608,7 +1566,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def ntile(n: Int): Column = withExpr { new NTile(Literal(n)) } + def ntile(n: Int): Column = Column.fn("ntile", lit(n)) /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. @@ -1623,7 +1581,7 @@ object functions { * @group window_funcs * @since 1.6.0 */ - def percent_rank(): Column = withExpr { new PercentRank } + def percent_rank(): Column = Column.fn("percent_rank") /** * Window function: returns the rank of rows within a window partition. @@ -1639,7 +1597,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def rank(): Column = withExpr { new Rank } + def rank(): Column = Column.fn("rank") /** * Window function: returns a sequential number starting at 1 within a window partition. @@ -1647,7 +1605,7 @@ object functions { * @group window_funcs * @since 1.6.0 */ - def row_number(): Column = withExpr { RowNumber() } + def row_number(): Column = Column.fn("row_number") ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions @@ -1660,7 +1618,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def array(cols: Column*): Column = withExpr { CreateArray(cols.map(_.expr)) } + def array(cols: Column*): Column = Column.fn("array", cols: _*) /** * Creates a new array column. The input columns must all have the same data type. @@ -1682,7 +1640,7 @@ object functions { * @since 2.0 */ @scala.annotation.varargs - def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } + def map(cols: Column*): Column = Column.fn("map", cols: _*) /** * Creates a struct with the given field names and values. @@ -1690,7 +1648,7 @@ object functions { * @group normal_funcs * @since 3.5.0 */ - def named_struct(cols: Column*): Column = withExpr { CreateNamedStruct(cols.map(_.expr)) } + def named_struct(cols: Column*): Column = Column.fn("named_struct", cols: _*) /** * Creates a new map column. The array in the first column is used for keys. The array in the @@ -1699,9 +1657,8 @@ object functions { * @group normal_funcs * @since 2.4 */ - def map_from_arrays(keys: Column, values: Column): Column = withExpr { - MapFromArrays(keys.expr, values.expr) - } + def map_from_arrays(keys: Column, values: Column): Column = + Column.fn("map_from_arrays", keys, values) /** * Creates a map after splitting the text into key/value pairs using delimiters. @@ -1710,9 +1667,8 @@ object functions { * @group map_funcs * @since 3.5.0 */ - def str_to_map(text: Column, pairDelim: Column, keyValueDelim: Column): Column = withExpr { - StringToMap(text.expr, pairDelim.expr, keyValueDelim.expr) - } + def str_to_map(text: Column, pairDelim: Column, keyValueDelim: Column): Column = + Column.fn("str_to_map", text, pairDelim, keyValueDelim) /** * Creates a map after splitting the text into key/value pairs using delimiters. @@ -1721,9 +1677,8 @@ object functions { * @group map_funcs * @since 3.5.0 */ - def str_to_map(text: Column, pairDelim: Column): Column = withExpr { - new StringToMap(text.expr, pairDelim.expr) - } + def str_to_map(text: Column, pairDelim: Column): Column = + Column.fn("str_to_map", text, pairDelim) /** * Creates a map after splitting the text into key/value pairs using delimiters. @@ -1731,9 +1686,7 @@ object functions { * @group map_funcs * @since 3.5.0 */ - def str_to_map(text: Column): Column = withExpr { - new StringToMap(text.expr) - } + def str_to_map(text: Column): Column = Column.fn("str_to_map", text) /** * Marks a DataFrame as small enough for use in broadcast joins. @@ -1762,7 +1715,7 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } + def coalesce(e: Column*): Column = Column.fn("coalesce", e: _*) /** * Creates a string column for the file name of the current Spark task. @@ -1770,7 +1723,7 @@ object functions { * @group normal_funcs * @since 1.6.0 */ - def input_file_name(): Column = withExpr { InputFileName() } + def input_file_name(): Column = Column.fn("input_file_name") /** * Return true iff the column is NaN. @@ -1778,7 +1731,7 @@ object functions { * @group normal_funcs * @since 1.6.0 */ - def isnan(e: Column): Column = withExpr { IsNaN(e.expr) } + def isnan(e: Column): Column = e.isNaN /** * Return true iff the column is null. @@ -1786,7 +1739,7 @@ object functions { * @group normal_funcs * @since 1.6.0 */ - def isnull(e: Column): Column = withExpr { IsNull(e.expr) } + def isnull(e: Column): Column = e.isNull /** * A column expression that generates monotonically increasing 64-bit integers. @@ -1827,7 +1780,7 @@ object functions { * @group normal_funcs * @since 1.6.0 */ - def monotonically_increasing_id(): Column = withExpr { MonotonicallyIncreasingID() } + def monotonically_increasing_id(): Column = Column.fn("monotonically_increasing_id") /** * Returns col1 if it is not NaN, or col2 if col1 is NaN. @@ -1837,7 +1790,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def nanvl(col1: Column, col2: Column): Column = withExpr { NaNvl(col1.expr, col2.expr) } + def nanvl(col1: Column, col2: Column): Column = Column.fn("nanvl", col1, col2) /** * Unary minus, i.e. negate the expression. @@ -1922,7 +1875,7 @@ object functions { * @group normal_funcs * @since 1.6.0 */ - def spark_partition_id(): Column = withExpr { SparkPartitionID() } + def spark_partition_id(): Column = Column.fn("spark_partition_id") /** * Computes the square root of the specified float value. @@ -1930,7 +1883,7 @@ object functions { * @group math_funcs * @since 1.3.0 */ - def sqrt(e: Column): Column = withExpr { Sqrt(e.expr) } + def sqrt(e: Column): Column = Column.fn("sqrt", e) /** * Computes the square root of the specified float value. @@ -1947,7 +1900,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_add(left: Column, right: Column): Column = call_function("try_add", left, right) + def try_add(left: Column, right: Column): Column = Column.fn("try_add", left, right) /** * Returns the mean calculated from values of a group and the result is null on overflow. @@ -1955,8 +1908,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_avg(e: Column): Column = - call_function("try_avg", e) + def try_avg(e: Column): Column = Column.fn("try_avg", e) /** * Returns `dividend``/``divisor`. It always performs floating point division. Its result is @@ -1965,8 +1917,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_divide(dividend: Column, divisor: Column): Column = - call_function("try_divide", dividend, divisor) + def try_divide(left: Column, right: Column): Column = Column.fn("try_divide", left, right) /** * Returns `left``*``right` and the result is null on overflow. The acceptable input types are @@ -1975,8 +1926,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_multiply(left: Column, right: Column): Column = - call_function("try_multiply", left, right) + def try_multiply(left: Column, right: Column): Column = Column.fn("try_multiply", left, right) /** * Returns `left``-``right` and the result is null on overflow. The acceptable input types are @@ -1985,8 +1935,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_subtract(left: Column, right: Column): Column = - call_function("try_subtract", left, right) + def try_subtract(left: Column, right: Column): Column = Column.fn("try_subtract", left, right) /** * Returns the sum calculated from values of a group and the result is null on overflow. @@ -1994,7 +1943,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_sum(e: Column): Column = call_function("try_sum", e) + def try_sum(e: Column): Column = Column.fn("try_sum", e) /** * Creates a new struct column. @@ -2007,7 +1956,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def struct(cols: Column*): Column = withExpr { CreateStruct.create(cols.map(_.expr)) } + def struct(cols: Column*): Column = Column.fn("struct", cols: _*) /** * Creates a new struct column that composes multiple input columns. @@ -2060,7 +2009,7 @@ object functions { * @group normal_funcs * @since 3.2.0 */ - def bitwise_not(e: Column): Column = withExpr { BitwiseNot(e.expr) } + def bitwise_not(e: Column): Column = Column.fn("~", e) /** * Returns the number of bits that are set in the argument expr as an unsigned 64-bit integer, @@ -2069,7 +2018,7 @@ object functions { * @group bitwise_funcs * @since 3.5.0 */ - def bit_count(e: Column): Column = withExpr { BitwiseCount(e.expr) } + def bit_count(e: Column): Column = Column.fn("bit_count", e) /** * Returns the value of the bit (0 or 1) at the specified position. @@ -2079,7 +2028,7 @@ object functions { * @group bitwise_funcs * @since 3.5.0 */ - def bit_get(e: Column, pos: Column): Column = withExpr { BitwiseGet(e.expr, pos.expr) } + def bit_get(e: Column, pos: Column): Column = Column.fn("bit_get", e, pos) /** * Returns the value of the bit (0 or 1) at the specified position. @@ -2089,7 +2038,7 @@ object functions { * @group bitwise_funcs * @since 3.5.0 */ - def getbit(e: Column, pos: Column): Column = call_function("getbit", e, pos) + def getbit(e: Column, pos: Column): Column = Column.fn("getbit", e, pos) /** * Parses the expression string into the column that it represents, similar to @@ -2101,11 +2050,11 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = { + def expr(expr: String): Column = withExpr { val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { new SparkSqlParser() } - Column(parser.parseExpression(expr)) + parser.parseExpression(expr) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2118,7 +2067,7 @@ object functions { * @group math_funcs * @since 1.3.0 */ - def abs(e: Column): Column = withExpr { Abs(e.expr) } + def abs(e: Column): Column = Column.fn("abs", e) /** * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` @@ -2126,7 +2075,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def acos(e: Column): Column = withExpr { Acos(e.expr) } + def acos(e: Column): Column = Column.fn("acos", e) /** * @return inverse cosine of `columnName`, as if computed by `java.lang.Math.acos` @@ -2142,7 +2091,7 @@ object functions { * @group math_funcs * @since 3.1.0 */ - def acosh(e: Column): Column = withExpr { Acosh(e.expr) } + def acosh(e: Column): Column = Column.fn("acosh", e) /** * @return inverse hyperbolic cosine of `columnName` @@ -2158,7 +2107,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def asin(e: Column): Column = withExpr { Asin(e.expr) } + def asin(e: Column): Column = Column.fn("asin", e) /** * @return inverse sine of `columnName`, as if computed by `java.lang.Math.asin` @@ -2174,7 +2123,7 @@ object functions { * @group math_funcs * @since 3.1.0 */ - def asinh(e: Column): Column = withExpr { Asinh(e.expr) } + def asinh(e: Column): Column = Column.fn("asinh", e) /** * @return inverse hyperbolic sine of `columnName` @@ -2190,7 +2139,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan(e: Column): Column = withExpr { Atan(e.expr) } + def atan(e: Column): Column = Column.fn("atan", e) /** * @return inverse tangent of `columnName`, as if computed by `java.lang.Math.atan` @@ -2212,7 +2161,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(y: Column, x: Column): Column = withExpr { Atan2(y.expr, x.expr) } + def atan2(y: Column, x: Column): Column = Column.fn("atan2", y, x) /** * @param y coordinate on y-axis @@ -2319,7 +2268,7 @@ object functions { * @group math_funcs * @since 3.1.0 */ - def atanh(e: Column): Column = withExpr { Atanh(e.expr) } + def atanh(e: Column): Column = Column.fn("atanh", e) /** * @return inverse hyperbolic tangent of `columnName` @@ -2336,7 +2285,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def bin(e: Column): Column = withExpr { Bin(e.expr) } + def bin(e: Column): Column = Column.fn("bin", e) /** * An expression that returns the string representation of the binary value of the given long @@ -2353,7 +2302,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cbrt(e: Column): Column = withExpr { Cbrt(e.expr) } + def cbrt(e: Column): Column = Column.fn("cbrt", e) /** * Computes the cube-root of the given column. @@ -2369,7 +2318,7 @@ object functions { * @group math_funcs * @since 3.3.0 */ - def ceil(e: Column, scale: Column): Column = call_function("ceil", e, scale) + def ceil(e: Column, scale: Column): Column = Column.fn("ceil", e, scale) /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2377,7 +2326,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def ceil(e: Column): Column = call_function("ceil", e) + def ceil(e: Column): Column = Column.fn("ceil", e) /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2393,8 +2342,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column, scale: Column): Column = - call_function("ceiling", e, scale) + def ceiling(e: Column, scale: Column): Column = Column.fn("ceiling", e, scale) /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2402,7 +2350,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column): Column = call_function("ceiling", e) + def ceiling(e: Column): Column = Column.fn("ceiling", e) /** * Convert a number in a string column from one base to another. @@ -2410,9 +2358,8 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def conv(num: Column, fromBase: Int, toBase: Int): Column = withExpr { - Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) - } + def conv(num: Column, fromBase: Int, toBase: Int): Column = + Column.fn("conv", num, lit(fromBase), lit(toBase)) /** * @param e angle in radians @@ -2421,7 +2368,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cos(e: Column): Column = withExpr { Cos(e.expr) } + def cos(e: Column): Column = Column.fn("cos", e) /** * @param columnName angle in radians @@ -2439,7 +2386,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cosh(e: Column): Column = withExpr { Cosh(e.expr) } + def cosh(e: Column): Column = Column.fn("cosh", e) /** * @param columnName hyperbolic angle @@ -2457,7 +2404,7 @@ object functions { * @group math_funcs * @since 3.3.0 */ - def cot(e: Column): Column = withExpr { Cot(e.expr) } + def cot(e: Column): Column = Column.fn("cot", e) /** * @param e angle in radians @@ -2466,7 +2413,7 @@ object functions { * @group math_funcs * @since 3.3.0 */ - def csc(e: Column): Column = withExpr { Csc(e.expr) } + def csc(e: Column): Column = Column.fn("csc", e) /** * Returns Euler's number. @@ -2474,7 +2421,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def e(): Column = withExpr { EulerNumber() } + def e(): Column = Column.fn("e") /** * Computes the exponential of the given value. @@ -2482,7 +2429,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def exp(e: Column): Column = withExpr { Exp(e.expr) } + def exp(e: Column): Column = Column.fn("exp", e) /** * Computes the exponential of the given column. @@ -2498,7 +2445,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def expm1(e: Column): Column = withExpr { Expm1(e.expr) } + def expm1(e: Column): Column = Column.fn("expm1", e) /** * Computes the exponential of the given column minus one. @@ -2514,7 +2461,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def factorial(e: Column): Column = withExpr { Factorial(e.expr) } + def factorial(e: Column): Column = Column.fn("factorial", e) /** * Computes the floor of the given value of `e` to `scale` decimal places. @@ -2522,7 +2469,7 @@ object functions { * @group math_funcs * @since 3.3.0 */ - def floor(e: Column, scale: Column): Column = call_function("floor", e, scale) + def floor(e: Column, scale: Column): Column = Column.fn("floor", e, scale) /** * Computes the floor of the given value of `e` to 0 decimal places. @@ -2530,7 +2477,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def floor(e: Column): Column = call_function("floor", e) + def floor(e: Column): Column = Column.fn("floor", e) /** * Computes the floor of the given column value to 0 decimal places. @@ -2548,7 +2495,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) } + def greatest(exprs: Column*): Column = Column.fn("greatest", exprs: _*) /** * Returns the greatest value of the list of column names, skipping null values. @@ -2568,7 +2515,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def hex(column: Column): Column = withExpr { Hex(column.expr) } + def hex(column: Column): Column = Column.fn("hex", column) /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number @@ -2577,7 +2524,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = withExpr { Unhex(column.expr) } + def unhex(column: Column): Column = Column.fn("unhex", column) /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -2585,7 +2532,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, r: Column): Column = withExpr { Hypot(l.expr, r.expr) } + def hypot(l: Column, r: Column): Column = Column.fn("hypot", l, r) /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -2652,7 +2599,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) } + def least(exprs: Column*): Column = Column.fn("least", exprs: _*) /** * Returns the least value of the list of column names, skipping null values. @@ -2672,7 +2619,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ln(e: Column): Column = log(e) + def ln(e: Column): Column = Column.fn("ln", e) /** * Computes the natural logarithm of the given value. @@ -2680,7 +2627,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(e: Column): Column = withExpr { Log(e.expr) } + def log(e: Column): Column = ln(e) /** * Computes the natural logarithm of the given column. @@ -2696,7 +2643,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(base: Double, a: Column): Column = withExpr { Logarithm(lit(base).expr, a.expr) } + def log(base: Double, a: Column): Column = Column.fn("log", lit(base), a) /** * Returns the first argument-base logarithm of the second argument. @@ -2712,7 +2659,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log10(e: Column): Column = withExpr { Log10(e.expr) } + def log10(e: Column): Column = Column.fn("log10", e) /** * Computes the logarithm of the given value in base 10. @@ -2728,7 +2675,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log1p(e: Column): Column = withExpr { Log1p(e.expr) } + def log1p(e: Column): Column = Column.fn("log1p", e) /** * Computes the natural logarithm of the given column plus one. @@ -2744,7 +2691,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def log2(expr: Column): Column = withExpr { Log2(expr.expr) } + def log2(expr: Column): Column = Column.fn("log2", expr) /** * Computes the logarithm of the given value in base 2. @@ -2760,7 +2707,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def negative(e: Column): Column = call_function("negative", e) + def negative(e: Column): Column = Column.fn("negative", e) /** * Returns Pi. @@ -2768,7 +2715,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def pi(): Column = withExpr { Pi() } + def pi(): Column = Column.fn("pi") /** * Returns the value. @@ -2776,7 +2723,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def positive(e: Column): Column = withExpr { UnaryPositive(e.expr) } + def positive(e: Column): Column = Column.fn("positive", e) /** * Returns the value of the first argument raised to the power of the second argument. @@ -2784,7 +2731,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, r: Column): Column = withExpr { Pow(l.expr, r.expr) } + def pow(l: Column, r: Column): Column = Column.fn("power", l, r) /** * Returns the value of the first argument raised to the power of the second argument. @@ -2848,7 +2795,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def power(l: Column, r: Column): Column = pow(l, r) + def power(l: Column, r: Column): Column = Column.fn("power", l, r) /** * Returns the positive value of dividend mod divisor. @@ -2856,9 +2803,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def pmod(dividend: Column, divisor: Column): Column = withExpr { - Pmod(dividend.expr, divisor.expr) - } + def pmod(dividend: Column, divisor: Column): Column = Column.fn("pmod", dividend, divisor) /** * Returns the double value that is closest in value to the argument and @@ -2867,7 +2812,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def rint(e: Column): Column = withExpr { Rint(e.expr) } + def rint(e: Column): Column = Column.fn("rint", e) /** * Returns the double value that is closest in value to the argument and @@ -2893,7 +2838,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) } + def round(e: Column, scale: Int): Column = Column.fn("round", e, lit(scale)) /** * Round the value of `e` to `scale` decimal places with HALF_UP round mode @@ -2902,9 +2847,7 @@ object functions { * @group math_funcs * @since 4.0.0 */ - def round(e: Column, scale: Column): Column = withExpr { - Round(e.expr, scale.expr) - } + def round(e: Column, scale: Column): Column = Column.fn("round", e, scale) /** * Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode. @@ -2921,7 +2864,7 @@ object functions { * @group math_funcs * @since 2.0.0 */ - def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) } + def bround(e: Column, scale: Int): Column = Column.fn("bround", e, lit(scale)) /** * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode @@ -2930,9 +2873,7 @@ object functions { * @group math_funcs * @since 4.0.0 */ - def bround(e: Column, scale: Column): Column = withExpr { - BRound(e.expr, scale.expr) - } + def bround(e: Column, scale: Column): Column = Column.fn("bround", e, scale) /** * @param e angle in radians @@ -2941,7 +2882,7 @@ object functions { * @group math_funcs * @since 3.3.0 */ - def sec(e: Column): Column = withExpr { Sec(e.expr) } + def sec(e: Column): Column = Column.fn("sec", e) /** * Shift the given value numBits left. If the given value is a long value, this function @@ -2960,7 +2901,7 @@ object functions { * @group math_funcs * @since 3.2.0 */ - def shiftleft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) } + def shiftleft(e: Column, numBits: Int): Column = Column.fn("shiftleft", e, lit(numBits)) /** * (Signed) shift the given value numBits right. If the given value is a long value, it will @@ -2979,9 +2920,7 @@ object functions { * @group math_funcs * @since 3.2.0 */ - def shiftright(e: Column, numBits: Int): Column = withExpr { - ShiftRight(e.expr, lit(numBits).expr) - } + def shiftright(e: Column, numBits: Int): Column = Column.fn("shiftright", e, lit(numBits)) /** * Unsigned shift the given value numBits right. If the given value is a long value, @@ -3000,9 +2939,8 @@ object functions { * @group math_funcs * @since 3.2.0 */ - def shiftrightunsigned(e: Column, numBits: Int): Column = withExpr { - ShiftRightUnsigned(e.expr, lit(numBits).expr) - } + def shiftrightunsigned(e: Column, numBits: Int): Column = + Column.fn("shiftrightunsigned", e, lit(numBits)) /** * Computes the signum of the given value. @@ -3010,7 +2948,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def sign(e: Column): Column = call_function("sign", e) + def sign(e: Column): Column = Column.fn("sign", e) /** * Computes the signum of the given value. @@ -3018,7 +2956,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def signum(e: Column): Column = withExpr { Signum(e.expr) } + def signum(e: Column): Column = Column.fn("signum", e) /** * Computes the signum of the given column. @@ -3035,7 +2973,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sin(e: Column): Column = withExpr { Sin(e.expr) } + def sin(e: Column): Column = Column.fn("sin", e) /** * @param columnName angle in radians @@ -3053,7 +2991,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sinh(e: Column): Column = withExpr { Sinh(e.expr) } + def sinh(e: Column): Column = Column.fn("sinh", e) /** * @param columnName hyperbolic angle @@ -3071,7 +3009,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tan(e: Column): Column = withExpr { Tan(e.expr) } + def tan(e: Column): Column = Column.fn("tan", e) /** * @param columnName angle in radians @@ -3089,7 +3027,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tanh(e: Column): Column = withExpr { Tanh(e.expr) } + def tanh(e: Column): Column = Column.fn("tanh", e) /** * @param columnName hyperbolic angle @@ -3123,7 +3061,7 @@ object functions { * @group math_funcs * @since 2.1.0 */ - def degrees(e: Column): Column = withExpr { ToDegrees(e.expr) } + def degrees(e: Column): Column = Column.fn("degrees", e) /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. @@ -3159,7 +3097,7 @@ object functions { * @group math_funcs * @since 2.1.0 */ - def radians(e: Column): Column = withExpr { ToRadians(e.expr) } + def radians(e: Column): Column = Column.fn("radians", e) /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. @@ -3185,9 +3123,8 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def width_bucket(v: Column, min: Column, max: Column, numBucket: Column): Column = withExpr { - WidthBucket(v.expr, min.expr, max.expr, numBucket.expr) - } + def width_bucket(v: Column, min: Column, max: Column, numBucket: Column): Column = + Column.fn("width_bucket", v, min, max, numBucket) ////////////////////////////////////////////////////////////////////////////////////////////// // Misc functions @@ -3199,7 +3136,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def current_catalog(): Column = withExpr { CurrentCatalog() } + def current_catalog(): Column = Column.fn("current_catalog") /** * Returns the current database. @@ -3207,7 +3144,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def current_database(): Column = withExpr { CurrentDatabase() } + def current_database(): Column = Column.fn("current_database") /** * Returns the current schema. @@ -3215,7 +3152,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def current_schema(): Column = call_function("current_schema") + def current_schema(): Column = Column.fn("current_schema") /** * Returns the user name of current execution context. @@ -3223,7 +3160,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def current_user(): Column = withExpr { CurrentUser() } + def current_user(): Column = Column.fn("current_user") /** * Calculates the MD5 digest of a binary column and returns the value @@ -3232,7 +3169,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def md5(e: Column): Column = withExpr { Md5(e.expr) } + def md5(e: Column): Column = Column.fn("md5", e) /** * Calculates the SHA-1 digest of a binary column and returns the value @@ -3241,7 +3178,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def sha1(e: Column): Column = withExpr { Sha1(e.expr) } + def sha1(e: Column): Column = Column.fn("sha1", e) /** * Calculates the SHA-2 family of hash functions of a binary column and @@ -3254,9 +3191,10 @@ object functions { * @since 1.5.0 */ def sha2(e: Column, numBits: Int): Column = { - require(Seq(0, 224, 256, 384, 512).contains(numBits), + require( + Seq(0, 224, 256, 384, 512).contains(numBits), s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") - withExpr { Sha2(e.expr, lit(numBits).expr) } + Column.fn("sha2", e, lit(numBits)) } /** @@ -3266,7 +3204,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def crc32(e: Column): Column = withExpr { Crc32(e.expr) } + def crc32(e: Column): Column = Column.fn("crc32", e) /** * Calculates the hash code of given columns, and returns the result as an int column. @@ -3275,9 +3213,7 @@ object functions { * @since 2.0.0 */ @scala.annotation.varargs - def hash(cols: Column*): Column = withExpr { - new Murmur3Hash(cols.map(_.expr)) - } + def hash(cols: Column*): Column = Column.fn("hash", cols: _*) /** * Calculates the hash code of given columns using the 64-bit @@ -3288,9 +3224,7 @@ object functions { * @since 3.0.0 */ @scala.annotation.varargs - def xxhash64(cols: Column*): Column = withExpr { - new XxHash64(cols.map(_.expr)) - } + def xxhash64(cols: Column*): Column = Column.fn("xxhash64", cols: _*) /** * Returns null if the condition is true, and throws an exception otherwise. @@ -3298,9 +3232,7 @@ object functions { * @group misc_funcs * @since 3.1.0 */ - def assert_true(c: Column): Column = withExpr { - new AssertTrue(c.expr) - } + def assert_true(c: Column): Column = Column.fn("assert_true", c) /** * Returns null if the condition is true; throws an exception with the error message otherwise. @@ -3308,9 +3240,7 @@ object functions { * @group misc_funcs * @since 3.1.0 */ - def assert_true(c: Column, e: Column): Column = withExpr { - new AssertTrue(c.expr, e.expr) - } + def assert_true(c: Column, e: Column): Column = Column.fn("assert_true", c, e) /** * Throws an exception with the provided error message. @@ -3318,9 +3248,7 @@ object functions { * @group misc_funcs * @since 3.1.0 */ - def raise_error(c: Column): Column = withExpr { - RaiseError(c.expr) - } + def raise_error(c: Column): Column = Column.fn("raise_error", c) /** * Returns the estimated number of unique values given the binary representation @@ -3329,9 +3257,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def hll_sketch_estimate(c: Column): Column = withExpr { - HllSketchEstimate(c.expr) - } + def hll_sketch_estimate(c: Column): Column = Column.fn("hll_sketch_estimate", c) /** * Returns the estimated number of unique values given the binary representation @@ -3352,9 +3278,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def hll_union(c1: Column, c2: Column): Column = withExpr { - new HllUnion(c1.expr, c2.expr) - } + def hll_union(c1: Column, c2: Column): Column = + Column.fn("hll_union", c1, c2) /** * Merges two binary representations of Datasketches HllSketch objects, using a @@ -3376,9 +3301,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def hll_union(c1: Column, c2: Column, allowDifferentLgConfigK: Boolean): Column = withExpr { - new HllUnion(c1.expr, c2.expr, Literal(allowDifferentLgConfigK)) - } + def hll_union(c1: Column, c2: Column, allowDifferentLgConfigK: Boolean): Column = + Column.fn("hll_union", c1, c2, lit(allowDifferentLgConfigK)) /** * Merges two binary representations of Datasketches HllSketch objects, using a @@ -3399,7 +3323,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def user(): Column = call_function("user") + def user(): Column = Column.fn("user") /** * Returns the user name of current execution context. @@ -3407,7 +3331,7 @@ object functions { * @group misc_funcs * @since 4.0.0 */ - def session_user(): Column = withExpr { CurrentUser() } + def session_user(): Column = Column.fn("session_user") /** * Returns an universally unique identifier (UUID) string. The value is returned as a canonical @@ -3455,9 +3379,7 @@ object functions { mode: Column, padding: Column, iv: Column, - aad: Column): Column = withExpr { - AesEncrypt(input.expr, key.expr, mode.expr, padding.expr, iv.expr, aad.expr) - } + aad: Column): Column = Column.fn("aes_encrypt", input, key, mode, padding, iv, aad) /** * Returns an encrypted value of `input`. @@ -3469,14 +3391,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def aes_encrypt( - input: Column, - key: Column, - mode: Column, - padding: Column, - iv: Column): Column = withExpr { - new AesEncrypt(input.expr, key.expr, mode.expr, padding.expr, iv.expr) - } + def aes_encrypt(input: Column, key: Column, mode: Column, padding: Column, iv: Column): Column = + Column.fn("aes_encrypt", input, key, mode, padding, iv) /** * Returns an encrypted value of `input`. @@ -3488,9 +3404,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def aes_encrypt(input: Column, key: Column, mode: Column, padding: Column): Column = withExpr { - new AesEncrypt(input.expr, key.expr, mode.expr, padding.expr) - } + def aes_encrypt(input: Column, key: Column, mode: Column, padding: Column): Column = + Column.fn("aes_encrypt", input, key, mode, padding) /** * Returns an encrypted value of `input`. @@ -3502,9 +3417,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def aes_encrypt(input: Column, key: Column, mode: Column): Column = withExpr { - new AesEncrypt(input.expr, key.expr, mode.expr) - } + def aes_encrypt(input: Column, key: Column, mode: Column): Column = + Column.fn("aes_encrypt", input, key, mode) /** * Returns an encrypted value of `input`. @@ -3516,9 +3430,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def aes_encrypt(input: Column, key: Column): Column = withExpr { - new AesEncrypt(input.expr, key.expr) - } + def aes_encrypt(input: Column, key: Column): Column = + Column.fn("aes_encrypt", input, key) /** * Returns a decrypted value of `input` using AES in `mode` with `padding`. Key lengths of 16, @@ -3550,9 +3463,8 @@ object functions { key: Column, mode: Column, padding: Column, - aad: Column): Column = withExpr { - AesDecrypt(input.expr, key.expr, mode.expr, padding.expr, aad.expr) - } + aad: Column): Column = + Column.fn("aes_decrypt", input, key, mode, padding, aad) /** * Returns a decrypted value of `input`. @@ -3563,13 +3475,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def aes_decrypt( - input: Column, - key: Column, - mode: Column, - padding: Column): Column = withExpr { - new AesDecrypt(input.expr, key.expr, mode.expr, padding.expr) - } + def aes_decrypt(input: Column, key: Column, mode: Column, padding: Column): Column = + Column.fn("aes_decrypt", input, key, mode, padding) /** * Returns a decrypted value of `input`. @@ -3580,9 +3487,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def aes_decrypt(input: Column, key: Column, mode: Column): Column = withExpr { - new AesDecrypt(input.expr, key.expr, mode.expr) - } + def aes_decrypt(input: Column, key: Column, mode: Column): Column = + Column.fn("aes_decrypt", input, key, mode) /** * Returns a decrypted value of `input`. @@ -3593,9 +3499,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def aes_decrypt(input: Column, key: Column): Column = withExpr { - new AesDecrypt(input.expr, key.expr) - } + def aes_decrypt(input: Column, key: Column): Column = + Column.fn("aes_decrypt", input, key) /** * This is a special version of `aes_decrypt` that performs the same operation, but returns a @@ -3624,9 +3529,8 @@ object functions { key: Column, mode: Column, padding: Column, - aad: Column): Column = withExpr { - new TryAesDecrypt(input.expr, key.expr, mode.expr, padding.expr, aad.expr) - } + aad: Column): Column = + Column.fn("try_aes_decrypt", input, key, mode, padding, aad) /** * Returns a decrypted value of `input`. @@ -3637,13 +3541,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def try_aes_decrypt( - input: Column, - key: Column, - mode: Column, - padding: Column): Column = withExpr { - new TryAesDecrypt(input.expr, key.expr, mode.expr, padding.expr) - } + def try_aes_decrypt(input: Column, key: Column, mode: Column, padding: Column): Column = + Column.fn("try_aes_decrypt", input, key, mode, padding) /** * Returns a decrypted value of `input`. @@ -3654,9 +3553,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def try_aes_decrypt(input: Column, key: Column, mode: Column): Column = withExpr { - new TryAesDecrypt(input.expr, key.expr, mode.expr) - } + def try_aes_decrypt(input: Column, key: Column, mode: Column): Column = + Column.fn("try_aes_decrypt", input, key, mode) /** * Returns a decrypted value of `input`. @@ -3667,9 +3565,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def try_aes_decrypt(input: Column, key: Column): Column = withExpr { - new TryAesDecrypt(input.expr, key.expr) - } + def try_aes_decrypt(input: Column, key: Column): Column = + Column.fn("try_aes_decrypt", input, key) /** * Returns a sha1 hash value as a hex string of the `col`. @@ -3677,7 +3574,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def sha(col: Column): Column = call_function("sha", col) + def sha(col: Column): Column = Column.fn("sha", col) /** * Returns the length of the block being read, or -1 if not available. @@ -3685,9 +3582,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def input_file_block_length(): Column = withExpr { - InputFileBlockLength() - } + def input_file_block_length(): Column = Column.fn("input_file_block_length") /** * Returns the start offset of the block being read, or -1 if not available. @@ -3695,9 +3590,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def input_file_block_start(): Column = withExpr { - InputFileBlockStart() - } + def input_file_block_start(): Column = Column.fn("input_file_block_start") /** * Calls a method with reflection. @@ -3705,9 +3598,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def reflect(cols: Column*): Column = withExpr { - CallMethodViaReflection(cols.map(_.expr)) - } + def reflect(cols: Column*): Column = Column.fn("reflect", cols: _*) /** * Calls a method with reflection. @@ -3715,8 +3606,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def java_method(cols: Column*): Column = - call_function("java_method", cols: _*) + def java_method(cols: Column*): Column = Column.fn("java_method", cols: _*) /** * This is a special version of `reflect` that performs the same operation, but returns a NULL @@ -3725,9 +3615,7 @@ object functions { * @group misc_funcs * @since 4.0.0 */ - def try_reflect(cols: Column*): Column = withExpr { - new TryReflect(cols.map(_.expr)) - } + def try_reflect(cols: Column*): Column = Column.fn("try_reflect", cols: _*) /** * Returns the Spark version. The string contains 2 fields, the first being a release version @@ -3736,9 +3624,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def version(): Column = withExpr { - SparkVersion() - } + def version(): Column = Column.fn("version") /** * Return DDL-formatted type string for the data type of the input. @@ -3746,9 +3632,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def typeof(col: Column): Column = withExpr { - TypeOf(col.expr) - } + def typeof(col: Column): Column = Column.fn("typeof", col) /** * Separates `col1`, ..., `colk` into `n` rows. Uses column names col0, col1, etc. by default @@ -3757,9 +3641,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def stack(cols: Column*): Column = withExpr { - Stack(cols.map(_.expr)) - } + def stack(cols: Column*): Column = Column.fn("stack", cols: _*) /** * Returns a random value with independent and identically distributed (i.i.d.) uniformly @@ -3785,9 +3667,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def bitmap_bucket_number(col: Column): Column = withExpr { - BitmapBucketNumber(col.expr) - } + def bitmap_bit_position(col: Column): Column = + Column.fn("bitmap_bit_position", col) /** * Returns the bit position for the given input column. @@ -3795,9 +3676,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def bitmap_bit_position(col: Column): Column = withExpr { - BitmapBitPosition(col.expr) - } + def bitmap_bucket_number(col: Column): Column = + Column.fn("bitmap_bucket_number", col) /** * Returns a bitmap with the positions of the bits set from all the values from the input column. @@ -3806,9 +3686,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def bitmap_construct_agg(col: Column): Column = withAggregateFunction { - BitmapConstructAgg(col.expr) - } + def bitmap_construct_agg(col: Column): Column = + Column.fn("bitmap_construct_agg", col) /** * Returns the number of set bits in the input bitmap. @@ -3816,9 +3695,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def bitmap_count(col: Column): Column = withExpr { - BitmapCount(col.expr) - } + def bitmap_count(col: Column): Column = Column.fn("bitmap_count", col) /** * Returns a bitmap that is the bitwise OR of all of the bitmaps from the input column. @@ -3827,9 +3704,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def bitmap_or_agg(col: Column): Column = withAggregateFunction { - BitmapOrAgg(col.expr) - } + def bitmap_or_agg(col: Column): Column = Column.fn("bitmap_or_agg", col) ////////////////////////////////////////////////////////////////////////////////////////////// // String functions @@ -3842,7 +3717,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ascii(e: Column): Column = withExpr { Ascii(e.expr) } + def ascii(e: Column): Column = Column.fn("ascii", e) /** * Computes the BASE64 encoding of a binary column and returns it as a string column. @@ -3851,7 +3726,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def base64(e: Column): Column = withExpr { Base64(e.expr) } + def base64(e: Column): Column = Column.fn("base64", e) /** * Calculates the bit length for the specified string column. @@ -3859,7 +3734,7 @@ object functions { * @group string_funcs * @since 3.3.0 */ - def bit_length(e: Column): Column = withExpr { BitLength(e.expr) } + def bit_length(e: Column): Column = Column.fn("bit_length", e) /** * Concatenates multiple input string columns together into a single string column, @@ -3871,9 +3746,8 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def concat_ws(sep: String, exprs: Column*): Column = withExpr { - ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) - } + def concat_ws(sep: String, exprs: Column*): Column = + Column.fn("concat_ws", lit(sep) +: exprs: _*) /** * Computes the first argument into a string from a binary using the provided character set @@ -3883,9 +3757,8 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def decode(value: Column, charset: String): Column = withExpr { - StringDecode(value.expr, lit(charset).expr) - } + def decode(value: Column, charset: String): Column = + Column.fn("decode", value, lit(charset)) /** * Computes the first argument into a binary from a string using the provided character set @@ -3895,9 +3768,8 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def encode(value: Column, charset: String): Column = withExpr { - Encode(value.expr, lit(charset).expr) - } + def encode(value: Column, charset: String): Column = + Column.fn("encode", value, lit(charset)) /** * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places @@ -3909,9 +3781,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def format_number(x: Column, d: Int): Column = withExpr { - FormatNumber(x.expr, lit(d).expr) - } + def format_number(x: Column, d: Int): Column = Column.fn("format_number", x, lit(d)) /** * Formats the arguments in printf-style and returns the result as a string column. @@ -3920,9 +3790,8 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def format_string(format: String, arguments: Column*): Column = withExpr { - FormatString((lit(format) +: arguments).map(_.expr): _*) - } + def format_string(format: String, arguments: Column*): Column = + Column.fn("format_string", lit(format) +: arguments: _*) /** * Returns a new string column by converting the first letter of each word to uppercase. @@ -3933,7 +3802,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def initcap(e: Column): Column = withExpr { InitCap(e.expr) } + def initcap(e: Column): Column = Column.fn("initcap", e) /** * Locate the position of the first occurrence of substr column in the given string. @@ -3945,9 +3814,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def instr(str: Column, substring: String): Column = withExpr { - StringInstr(str.expr, lit(substring).expr) - } + def instr(str: Column, substring: String): Column = Column.fn("instr", str, lit(substring)) /** * Computes the character length of a given string or number of bytes of a binary string. @@ -3957,7 +3824,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def length(e: Column): Column = withExpr { Length(e.expr) } + def length(e: Column): Column = Column.fn("length", e) /** * Computes the character length of a given string or number of bytes of a binary string. @@ -3967,7 +3834,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def len(e: Column): Column = withExpr { Length(e.expr) } + def len(e: Column): Column = Column.fn("len", e) /** * Converts a string column to lower case. @@ -3975,7 +3842,7 @@ object functions { * @group string_funcs * @since 1.3.0 */ - def lower(e: Column): Column = withExpr { Lower(e.expr) } + def lower(e: Column): Column = Column.fn("lower", e) /** * Computes the Levenshtein distance of the two given string columns if it's less than or @@ -3984,16 +3851,15 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def levenshtein(l: Column, r: Column, threshold: Int): Column = withExpr { - Levenshtein(l.expr, r.expr, Some(Literal(threshold))) - } + def levenshtein(l: Column, r: Column, threshold: Int): Column = + Column.fn("levenshtein", l, r, lit(threshold)) /** * Computes the Levenshtein distance of the two given string columns. * @group string_funcs * @since 1.5.0 */ - def levenshtein(l: Column, r: Column): Column = withExpr { Levenshtein(l.expr, r.expr, None) } + def levenshtein(l: Column, r: Column): Column = Column.fn("levenshtein", l, r) /** * Locate the position of the first occurrence of substr. @@ -4004,9 +3870,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: Column): Column = withExpr { - new StringLocate(lit(substr).expr, str.expr) - } + def locate(substr: String, str: Column): Column = Column.fn("locate", lit(substr), str) /** * Locate the position of the first occurrence of substr in a string column, after position pos. @@ -4017,9 +3881,8 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: Column, pos: Int): Column = withExpr { - StringLocate(lit(substr).expr, str.expr, lit(pos).expr) - } + def locate(substr: String, str: Column, pos: Int): Column = + Column.fn("locate", lit(substr), str, lit(pos)) /** * Left-pad the string column with pad to a length of len. If the string column is longer @@ -4028,9 +3891,8 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def lpad(str: Column, len: Int, pad: String): Column = withExpr { - StringLPad(str.expr, lit(len).expr, lit(pad).expr) - } + def lpad(str: Column, len: Int, pad: String): Column = + Column.fn("lpad", str, lit(len), lit(pad)) /** * Left-pad the binary column with pad to a byte length of len. If the binary column is longer @@ -4040,7 +3902,7 @@ object functions { * @since 3.3.0 */ def lpad(str: Column, len: Int, pad: Array[Byte]): Column = - call_function("lpad", str, lit(len), lit(pad)) + Column.fn("lpad", str, lit(len), lit(pad)) /** * Trim the spaces from left end for the specified string value. @@ -4048,16 +3910,14 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } + def ltrim(e: Column): Column = Column.fn("ltrim", e) /** * Trim the specified character string from left end for the specified string column. * @group string_funcs * @since 2.3.0 */ - def ltrim(e: Column, trimString: String): Column = withExpr { - StringTrimLeft(e.expr, Literal(trimString)) - } + def ltrim(e: Column, trimString: String): Column = Column.fn("ltrim", lit(trimString), e) /** * Calculates the byte length for the specified string column. @@ -4065,7 +3925,7 @@ object functions { * @group string_funcs * @since 3.3.0 */ - def octet_length(e: Column): Column = withExpr { OctetLength(e.expr) } + def octet_length(e: Column): Column = Column.fn("octet_length", e) /** * Returns true if `str` matches `regexp`, or false otherwise. @@ -4073,9 +3933,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def rlike(str: Column, regexp: Column): Column = withExpr { - RLike(str.expr, regexp.expr) - } + def rlike(str: Column, regexp: Column): Column = Column.fn("rlike", str, regexp) /** * Returns true if `str` matches `regexp`, or false otherwise. @@ -4083,8 +3941,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp(str: Column, regexp: Column): Column = - call_function("regexp", str, regexp) + def regexp(str: Column, regexp: Column): Column = Column.fn("regexp", str, regexp) /** * Returns true if `str` matches `regexp`, or false otherwise. @@ -4092,8 +3949,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp_like(str: Column, regexp: Column): Column = - call_function("regexp_like", str, regexp) + def regexp_like(str: Column, regexp: Column): Column = Column.fn("regexp_like", str, regexp) /** * Returns a count of the number of times that the regular expression pattern `regexp` @@ -4102,9 +3958,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp_count(str: Column, regexp: Column): Column = withExpr { - RegExpCount(str.expr, regexp.expr) - } + def regexp_count(str: Column, regexp: Column): Column = Column.fn("regexp_count", str, regexp) /** * Extract a specific group matched by a Java regex, from the specified string column. @@ -4115,9 +3969,8 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = withExpr { - RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) - } + def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = + Column.fn("regexp_extract", e, lit(exp), lit(groupIdx)) /** * Extract all strings in the `str` that match the `regexp` expression and @@ -4126,9 +3979,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp_extract_all(str: Column, regexp: Column): Column = withExpr { - new RegExpExtractAll(str.expr, regexp.expr) - } + def regexp_extract_all(str: Column, regexp: Column): Column = + Column.fn("regexp_extract_all", str, regexp) /** * Extract all strings in the `str` that match the `regexp` expression and @@ -4137,9 +3989,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp_extract_all(str: Column, regexp: Column, idx: Column): Column = withExpr { - RegExpExtractAll(str.expr, regexp.expr, idx.expr) - } + def regexp_extract_all(str: Column, regexp: Column, idx: Column): Column = + Column.fn("regexp_extract_all", str, regexp, idx) /** * Replace all substrings of the specified string value that match regexp with rep. @@ -4147,9 +3998,8 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def regexp_replace(e: Column, pattern: String, replacement: String): Column = withExpr { - RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) - } + def regexp_replace(e: Column, pattern: String, replacement: String): Column = + regexp_replace(e, lit(pattern), lit(replacement)) /** * Replace all substrings of the specified string value that match regexp with rep. @@ -4157,9 +4007,8 @@ object functions { * @group string_funcs * @since 2.1.0 */ - def regexp_replace(e: Column, pattern: Column, replacement: Column): Column = withExpr { - RegExpReplace(e.expr, pattern.expr, replacement.expr) - } + def regexp_replace(e: Column, pattern: Column, replacement: Column): Column = + Column.fn("regexp_replace", e, pattern, replacement) /** * Returns the substring that matches the regular expression `regexp` within the string `str`. @@ -4168,9 +4017,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp_substr(str: Column, regexp: Column): Column = withExpr { - RegExpSubStr(str.expr, regexp.expr) - } + def regexp_substr(str: Column, regexp: Column): Column = Column.fn("regexp_substr", str, regexp) /** * Searches a string for a regular expression and returns an integer that indicates @@ -4180,9 +4027,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp_instr(str: Column, regexp: Column): Column = withExpr { - new RegExpInStr(str.expr, regexp.expr) - } + def regexp_instr(str: Column, regexp: Column): Column = Column.fn("regexp_instr", str, regexp) /** * Searches a string for a regular expression and returns an integer that indicates @@ -4192,9 +4037,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp_instr(str: Column, regexp: Column, idx: Column): Column = withExpr { - RegExpInStr(str.expr, regexp.expr, idx.expr) - } + def regexp_instr(str: Column, regexp: Column, idx: Column): Column = + Column.fn("regexp_instr", str, regexp, idx) /** * Decodes a BASE64 encoded string column and returns it as a binary column. @@ -4203,7 +4047,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) } + def unbase64(e: Column): Column = Column.fn("unbase64", e) /** * Right-pad the string column with pad to a length of len. If the string column is longer @@ -4212,9 +4056,8 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rpad(str: Column, len: Int, pad: String): Column = withExpr { - StringRPad(str.expr, lit(len).expr, lit(pad).expr) - } + def rpad(str: Column, len: Int, pad: String): Column = + Column.fn("rpad", str, lit(len), lit(pad)) /** * Right-pad the binary column with pad to a byte length of len. If the binary column is longer @@ -4224,7 +4067,7 @@ object functions { * @since 3.3.0 */ def rpad(str: Column, len: Int, pad: Array[Byte]): Column = - call_function("rpad", str, lit(len), lit(pad)) + Column.fn("rpad", str, lit(len), lit(pad)) /** * Repeats a string column n times, and returns it as a new string column. @@ -4232,9 +4075,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def repeat(str: Column, n: Int): Column = withExpr { - StringRepeat(str.expr, lit(n).expr) - } + def repeat(str: Column, n: Int): Column = Column.fn("repeat", str, lit(n)) /** * Repeats a string column n times, and returns it as a new string column. @@ -4242,9 +4083,7 @@ object functions { * @group string_funcs * @since 4.0.0 */ - def repeat(str: Column, n: Column): Column = withExpr { - StringRepeat(str.expr, n.expr) - } + def repeat(str: Column, n: Column): Column = Column.fn("repeat", str, n) /** * Trim the spaces from right end for the specified string value. @@ -4252,16 +4091,14 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } + def rtrim(e: Column): Column = Column.fn("rtrim", e) /** * Trim the specified character string from right end for the specified string column. * @group string_funcs * @since 2.3.0 */ - def rtrim(e: Column, trimString: String): Column = withExpr { - StringTrimRight(e.expr, Literal(trimString)) - } + def rtrim(e: Column, trimString: String): Column = Column.fn("rtrim", lit(trimString), e) /** * Returns the soundex code for the specified expression. @@ -4269,7 +4106,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def soundex(e: Column): Column = withExpr { SoundEx(e.expr) } + def soundex(e: Column): Column = Column.fn("soundex", e) /** * Splits str around matches of the given pattern. @@ -4281,9 +4118,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def split(str: Column, pattern: String): Column = withExpr { - StringSplit(str.expr, Literal(pattern), Literal(-1)) - } + def split(str: Column, pattern: String): Column = Column.fn("split", str, lit(pattern)) /** * Splits str around matches of the given pattern. @@ -4303,9 +4138,8 @@ object functions { * @group string_funcs * @since 3.0.0 */ - def split(str: Column, pattern: String, limit: Int): Column = withExpr { - StringSplit(str.expr, Literal(pattern), Literal(limit)) - } + def split(str: Column, pattern: String, limit: Int): Column = + Column.fn("split", str, lit(pattern), lit(limit)) /** * Substring starts at `pos` and is of length `len` when str is String type or @@ -4317,9 +4151,8 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def substring(str: Column, pos: Int, len: Int): Column = withExpr { - Substring(str.expr, lit(pos).expr, lit(len).expr) - } + def substring(str: Column, pos: Int, len: Int): Column = + Column.fn("substring", str, lit(pos), lit(len)) /** * Returns the substring from string str before count occurrences of the delimiter delim. @@ -4329,9 +4162,8 @@ object functions { * * @group string_funcs */ - def substring_index(str: Column, delim: String, count: Int): Column = withExpr { - SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) - } + def substring_index(str: Column, delim: String, count: Int): Column = + Column.fn("substring_index", str, lit(delim), lit(count)) /** * Overlay the specified portion of `src` with `replace`, @@ -4340,9 +4172,8 @@ object functions { * @group string_funcs * @since 3.0.0 */ - def overlay(src: Column, replace: Column, pos: Column, len: Column): Column = withExpr { - Overlay(src.expr, replace.expr, pos.expr, len.expr) - } + def overlay(src: Column, replace: Column, pos: Column, len: Column): Column = + Column.fn("overlay", src, replace, pos, len) /** * Overlay the specified portion of `src` with `replace`, @@ -4351,18 +4182,16 @@ object functions { * @group string_funcs * @since 3.0.0 */ - def overlay(src: Column, replace: Column, pos: Column): Column = withExpr { - new Overlay(src.expr, replace.expr, pos.expr) - } + def overlay(src: Column, replace: Column, pos: Column): Column = + Column.fn("overlay", src, replace, pos) /** * Splits a string into arrays of sentences, where each sentence is an array of words. * @group string_funcs * @since 3.2.0 */ - def sentences(string: Column, language: Column, country: Column): Column = withExpr { - Sentences(string.expr, language.expr, country.expr) - } + def sentences(string: Column, language: Column, country: Column): Column = + Column.fn("sentences", string, language, country) /** * Splits a string into arrays of sentences, where each sentence is an array of words. @@ -4370,9 +4199,7 @@ object functions { * @group string_funcs * @since 3.2.0 */ - def sentences(string: Column): Column = withExpr { - Sentences(string.expr) - } + def sentences(string: Column): Column = Column.fn("sentences", string) /** * Translate any character in the src by a character in replaceString. @@ -4383,9 +4210,8 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def translate(src: Column, matchingString: String, replaceString: String): Column = withExpr { - StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) - } + def translate(src: Column, matchingString: String, replaceString: String): Column = + Column.fn("translate", src, lit(matchingString), lit(replaceString)) /** * Trim the spaces from both ends for the specified string column. @@ -4393,16 +4219,14 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def trim(e: Column): Column = withExpr { StringTrim(e.expr) } + def trim(e: Column): Column = Column.fn("trim", e) /** * Trim the specified character from both ends for the specified string column. * @group string_funcs * @since 2.3.0 */ - def trim(e: Column, trimString: String): Column = withExpr { - StringTrim(e.expr, Literal(trimString)) - } + def trim(e: Column, trimString: String): Column = Column.fn("trim", lit(trimString), e) /** * Converts a string column to upper case. @@ -4410,7 +4234,7 @@ object functions { * @group string_funcs * @since 1.3.0 */ - def upper(e: Column): Column = withExpr { Upper(e.expr) } + def upper(e: Column): Column = Column.fn("upper", e) /** * Converts the input `e` to a binary value based on the supplied `format`. @@ -4421,9 +4245,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def to_binary(e: Column, format: Column): Column = withExpr { - new ToBinary(e.expr, format.expr) - } + def to_binary(e: Column, f: Column): Column = Column.fn("to_binary", e, f) /** * Converts the input `e` to a binary value based on the default format "hex". @@ -4432,9 +4254,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def to_binary(e: Column): Column = withExpr { - new ToBinary(e.expr) - } + def to_binary(e: Column): Column = Column.fn("to_binary", e) // scalastyle:off line.size.limit /** @@ -4469,7 +4289,7 @@ object functions { * @since 3.5.0 */ // scalastyle:on line.size.limit - def to_char(e: Column, format: Column): Column = call_function("to_char", e, format) + def to_char(e: Column, format: Column): Column = Column.fn("to_char", e, format) // scalastyle:off line.size.limit /** @@ -4504,7 +4324,7 @@ object functions { * @since 3.5.0 */ // scalastyle:on line.size.limit - def to_varchar(e: Column, format: Column): Column = call_function("to_varchar", e, format) + def to_varchar(e: Column, format: Column): Column = Column.fn("to_varchar", e, format) /** * Convert string 'e' to a number based on the string format 'format'. @@ -4529,9 +4349,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def to_number(e: Column, format: Column): Column = withExpr { - ToNumber(e.expr, format.expr) - } + def to_number(e: Column, format: Column): Column = Column.fn("to_number", e, format) /** * Replaces all occurrences of `search` with `replace`. @@ -4547,9 +4365,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def replace(src: Column, search: Column, replace: Column): Column = withExpr { - StringReplace(src.expr, search.expr, replace.expr) - } + def replace(src: Column, search: Column, replace: Column): Column = + Column.fn("replace", src, search, replace) /** * Replaces all occurrences of `search` with `replace`. @@ -4562,9 +4379,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def replace(src: Column, search: Column): Column = withExpr { - new StringReplace(src.expr, search.expr) - } + def replace(src: Column, search: Column): Column = Column.fn("replace", src, search) /** * Splits `str` by delimiter and return requested part of the split (1-based). @@ -4576,9 +4391,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def split_part(str: Column, delimiter: Column, partNum: Column): Column = withExpr { - SplitPart(str.expr, delimiter.expr, partNum.expr) - } + def split_part(str: Column, delimiter: Column, partNum: Column): Column = + Column.fn("split_part", str, delimiter, partNum) /** * Returns the substring of `str` that starts at `pos` and is of length `len`, @@ -4588,7 +4402,7 @@ object functions { * @since 3.5.0 */ def substr(str: Column, pos: Column, len: Column): Column = - call_function("substr", str, pos, len) + Column.fn("substr", str, pos, len) /** * Returns the substring of `str` that starts at `pos`, @@ -4597,8 +4411,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def substr(str: Column, pos: Column): Column = - call_function("substr", str, pos) + def substr(str: Column, pos: Column): Column = Column.fn("substr", str, pos) /** * Extracts a part from a URL. @@ -4606,9 +4419,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def parse_url(url: Column, partToExtract: Column, key: Column): Column = withExpr { - ParseUrl(Seq(url.expr, partToExtract.expr, key.expr)) - } + def parse_url(url: Column, partToExtract: Column, key: Column): Column = + Column.fn("parse_url", url, partToExtract, key) /** * Extracts a part from a URL. @@ -4616,9 +4428,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def parse_url(url: Column, partToExtract: Column): Column = withExpr { - ParseUrl(Seq(url.expr, partToExtract.expr)) - } + def parse_url(url: Column, partToExtract: Column): Column = + Column.fn("parse_url", url, partToExtract) /** * Formats the arguments in printf-style and returns the result as a string column. @@ -4627,7 +4438,7 @@ object functions { * @since 3.5.0 */ def printf(format: Column, arguments: Column*): Column = - call_function("printf", (format +: arguments): _*) + Column.fn("printf", (format +: arguments): _*) /** * Decodes a `str` in 'application/x-www-form-urlencoded' format @@ -4636,9 +4447,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def url_decode(str: Column): Column = withExpr { - UrlDecode(str.expr) - } + def url_decode(str: Column): Column = Column.fn("url_decode", str) /** * Translates a string into 'application/x-www-form-urlencoded' format @@ -4647,9 +4456,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def url_encode(str: Column): Column = withExpr { - UrlEncode(str.expr) - } + def url_encode(str: Column): Column = Column.fn("url_encode", str) /** * Returns the position of the first occurrence of `substr` in `str` after position `start`. @@ -4659,7 +4466,7 @@ object functions { * @since 3.5.0 */ def position(substr: Column, str: Column, start: Column): Column = - call_function("position", substr, str, start) + Column.fn("position", substr, str, start) /** * Returns the position of the first occurrence of `substr` in `str` after position `1`. @@ -4669,7 +4476,7 @@ object functions { * @since 3.5.0 */ def position(substr: Column, str: Column): Column = - call_function("position", substr, str) + Column.fn("position", substr, str) /** * Returns a boolean. The value is True if str ends with suffix. @@ -4679,7 +4486,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def endswith(str: Column, suffix: Column): Column = call_function("endswith", str, suffix) + def endswith(str: Column, suffix: Column): Column = + Column.fn("endswith", str, suffix) /** * Returns a boolean. The value is True if str starts with prefix. @@ -4689,7 +4497,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def startswith(str: Column, prefix: Column): Column = call_function("startswith", str, prefix) + def startswith(str: Column, prefix: Column): Column = + Column.fn("startswith", str, prefix) /** * Returns the ASCII character having the binary equivalent to `n`. @@ -4698,7 +4507,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def char(n: Column): Column = call_function("char", n) + def char(n: Column): Column = Column.fn("char", n) /** * Removes the leading and trailing space characters from `str`. @@ -4706,9 +4515,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def btrim(str: Column): Column = withExpr { - new StringTrimBoth(str.expr) - } + def btrim(str: Column): Column = Column.fn("btrim", str) /** * Remove the leading and trailing `trim` characters from `str`. @@ -4716,9 +4523,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def btrim(str: Column, trim: Column): Column = withExpr { - new StringTrimBoth(str.expr, trim.expr) - } + def btrim(str: Column, trim: Column): Column = Column.fn("btrim", str, trim) /** * This is a special version of `to_binary` that performs the same operation, but returns a NULL @@ -4727,9 +4532,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def try_to_binary(e: Column, format: Column): Column = withExpr { - new TryToBinary(e.expr, format.expr) - } + def try_to_binary(e: Column, f: Column): Column = Column.fn("try_to_binary", e, f) /** * This is a special version of `to_binary` that performs the same operation, but returns a NULL @@ -4738,9 +4541,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def try_to_binary(e: Column): Column = withExpr { - new TryToBinary(e.expr) - } + def try_to_binary(e: Column): Column = Column.fn("try_to_binary", e) /** * Convert string `e` to a number based on the string format `format`. Returns NULL if the @@ -4750,9 +4551,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def try_to_number(e: Column, format: Column): Column = withExpr { - TryToNumber(e.expr, format.expr) - } + def try_to_number(e: Column, format: Column): Column = Column.fn("try_to_number", e, format) /** * Returns the character length of string data or number of bytes of binary data. @@ -4762,7 +4561,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def char_length(str: Column): Column = call_function("char_length", str) + def char_length(str: Column): Column = Column.fn("char_length", str) /** * Returns the character length of string data or number of bytes of binary data. @@ -4772,7 +4571,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def character_length(str: Column): Column = call_function("character_length", str) + def character_length(str: Column): Column = Column.fn("character_length", str) /** * Returns the ASCII character having the binary equivalent to `n`. @@ -4781,9 +4580,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def chr(n: Column): Column = withExpr { - Chr(n.expr) - } + def chr(n: Column): Column = Column.fn("chr", n) /** * Returns a boolean. The value is True if right is found inside left. @@ -4793,7 +4590,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def contains(left: Column, right: Column): Column = call_function("contains", left, right) + def contains(left: Column, right: Column): Column = Column.fn("contains", left, right) /** * Returns the `n`-th input, e.g., returns `input2` when `n` is 2. @@ -4805,9 +4602,7 @@ object functions { * @since 3.5.0 */ @scala.annotation.varargs - def elt(inputs: Column*): Column = withExpr { - Elt(inputs.map(_.expr)) - } + def elt(inputs: Column*): Column = Column.fn("elt", inputs: _*) /** * Returns the index (1-based) of the given string (`str`) in the comma-delimited @@ -4817,9 +4612,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def find_in_set(str: Column, strArray: Column): Column = withExpr { - FindInSet(str.expr, strArray.expr) - } + def find_in_set(str: Column, strArray: Column): Column = Column.fn("find_in_set", str, strArray) /** * Returns true if str matches `pattern` with `escapeChar`, null if any arguments are null, @@ -4844,9 +4637,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def like(str: Column, pattern: Column): Column = withExpr { - new Like(str.expr, pattern.expr) - } + def like(str: Column, pattern: Column): Column = Column.fn("like", str, pattern) /** * Returns true if str matches `pattern` with `escapeChar` case-insensitively, null if any @@ -4871,9 +4662,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def ilike(str: Column, pattern: Column): Column = withExpr { - new ILike(str.expr, pattern.expr) - } + def ilike(str: Column, pattern: Column): Column = Column.fn("ilike", str, pattern) /** * Returns `str` with all characters changed to lowercase. @@ -4881,7 +4670,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def lcase(str: Column): Column = call_function("lcase", str) + def lcase(str: Column): Column = Column.fn("lcase", str) /** * Returns `str` with all characters changed to uppercase. @@ -4889,7 +4678,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def ucase(str: Column): Column = call_function("ucase", str) + def ucase(str: Column): Column = Column.fn("ucase", str) /** * Returns the leftmost `len`(`len` can be string type) characters from the string `str`, @@ -4898,9 +4687,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def left(str: Column, len: Column): Column = withExpr { - Left(str.expr, len.expr) - } + def left(str: Column, len: Column): Column = Column.fn("left", str, len) /** * Returns the rightmost `len`(`len` can be string type) characters from the string `str`, @@ -4909,9 +4696,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def right(str: Column, len: Column): Column = withExpr { - Right(str.expr, len.expr) - } + def right(str: Column, len: Column): Column = Column.fn("right", str, len) ////////////////////////////////////////////////////////////////////////////////////////////// // DateTime functions @@ -4940,9 +4725,8 @@ object functions { * @group datetime_funcs * @since 3.0.0 */ - def add_months(startDate: Column, numMonths: Column): Column = withExpr { - AddMonths(startDate.expr, numMonths.expr) - } + def add_months(startDate: Column, numMonths: Column): Column = + Column.fn("add_months", startDate, numMonths) /** * Returns the current date at the start of query evaluation as a date column. @@ -4951,7 +4735,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def curdate(): Column = call_function("curdate") + def curdate(): Column = Column.fn("curdate") /** * Returns the current date at the start of query evaluation as a date column. @@ -4960,7 +4744,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def current_date(): Column = withExpr { CurrentDate() } + def current_date(): Column = Column.fn("current_date") /** * Returns the current session local timezone. @@ -4968,7 +4752,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def current_timezone(): Column = withExpr { CurrentTimeZone() } + def current_timezone(): Column = Column.fn("current_timezone") /** * Returns the current timestamp at the start of query evaluation as a timestamp column. @@ -4977,7 +4761,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def current_timestamp(): Column = withExpr { CurrentTimestamp() } + def current_timestamp(): Column = Column.fn("current_timestamp") /** * Returns the current timestamp at the start of query evaluation. @@ -4985,7 +4769,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def now(): Column = withExpr { Now() } + def now(): Column = Column.fn("now") /** * Returns the current timestamp without time zone at the start of query evaluation @@ -4995,7 +4779,7 @@ object functions { * @group datetime_funcs * @since 3.3.0 */ - def localtimestamp(): Column = withExpr { LocalTimestamp() } + def localtimestamp(): Column = Column.fn("localtimestamp") /** * Converts a date/timestamp/string to a value of string in the format specified by the date @@ -5015,9 +4799,8 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def date_format(dateExpr: Column, format: String): Column = withExpr { - DateFormatClass(dateExpr.expr, Literal(format)) - } + def date_format(dateExpr: Column, format: String): Column = + Column.fn("date_format", dateExpr, lit(format)) /** * Returns the date that is `days` days after `start` @@ -5041,7 +4824,7 @@ object functions { * @group datetime_funcs * @since 3.0.0 */ - def date_add(start: Column, days: Column): Column = withExpr { DateAdd(start.expr, days.expr) } + def date_add(start: Column, days: Column): Column = Column.fn("date_add", start, days) /** * Returns the date that is `days` days after `start` @@ -5053,8 +4836,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def dateadd(start: Column, days: Column): Column = - call_function("dateadd", start, days) + def dateadd(start: Column, days: Column): Column = Column.fn("dateadd", start, days) /** * Returns the date that is `days` days before `start` @@ -5079,7 +4861,8 @@ object functions { * @group datetime_funcs * @since 3.0.0 */ - def date_sub(start: Column, days: Column): Column = withExpr { DateSub(start.expr, days.expr) } + def date_sub(start: Column, days: Column): Column = + Column.fn("date_sub", start, days) /** * Returns the number of days from `start` to `end`. @@ -5099,7 +4882,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def datediff(end: Column, start: Column): Column = withExpr { DateDiff(end.expr, start.expr) } + def datediff(end: Column, start: Column): Column = Column.fn("datediff", end, start) /** * Returns the number of days from `start` to `end`. @@ -5119,8 +4902,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def date_diff(end: Column, start: Column): Column = - call_function("date_diff", end, start) + def date_diff(end: Column, start: Column): Column = Column.fn("date_diff", end, start) /** * Create date from the number of `days` since 1970-01-01. @@ -5128,7 +4910,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def date_from_unix_date(days: Column): Column = withExpr { DateFromUnixDate(days.expr) } + def date_from_unix_date(days: Column): Column = Column.fn("date_from_unix_date", days) /** * Extracts the year as an integer from a given date/timestamp/string. @@ -5136,7 +4918,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def year(e: Column): Column = withExpr { Year(e.expr) } + def year(e: Column): Column = Column.fn("year", e) /** * Extracts the quarter as an integer from a given date/timestamp/string. @@ -5144,7 +4926,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def quarter(e: Column): Column = withExpr { Quarter(e.expr) } + def quarter(e: Column): Column = Column.fn("quarter", e) /** * Extracts the month as an integer from a given date/timestamp/string. @@ -5152,7 +4934,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def month(e: Column): Column = withExpr { Month(e.expr) } + def month(e: Column): Column = Column.fn("month", e) /** * Extracts the day of the week as an integer from a given date/timestamp/string. @@ -5161,7 +4943,7 @@ object functions { * @group datetime_funcs * @since 2.3.0 */ - def dayofweek(e: Column): Column = withExpr { DayOfWeek(e.expr) } + def dayofweek(e: Column): Column = Column.fn("dayofweek", e) /** * Extracts the day of the month as an integer from a given date/timestamp/string. @@ -5169,7 +4951,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def dayofmonth(e: Column): Column = withExpr { DayOfMonth(e.expr) } + def dayofmonth(e: Column): Column = Column.fn("dayofmonth", e) /** * Extracts the day of the month as an integer from a given date/timestamp/string. @@ -5177,7 +4959,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def day(e: Column): Column = call_function("day", e) + def day(e: Column): Column = Column.fn("day", e) /** * Extracts the day of the year as an integer from a given date/timestamp/string. @@ -5185,7 +4967,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def dayofyear(e: Column): Column = withExpr { DayOfYear(e.expr) } + def dayofyear(e: Column): Column = Column.fn("dayofyear", e) /** * Extracts the hours as an integer from a given date/timestamp/string. @@ -5193,7 +4975,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def hour(e: Column): Column = withExpr { Hour(e.expr) } + def hour(e: Column): Column = Column.fn("hour", e) /** * Extracts a part of the date/timestamp or interval source. @@ -5204,7 +4986,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def extract(field: Column, source: Column): Column = call_function("extract", field, source) + def extract(field: Column, source: Column): Column = { + Column.fn("extract", field, source) + } /** * Extracts a part of the date/timestamp or interval source. @@ -5216,8 +5000,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def date_part(field: Column, source: Column): Column = - call_function("date_part", field, source) + def date_part(field: Column, source: Column): Column = { + Column.fn("date_part", field, source) + } /** * Extracts a part of the date/timestamp or interval source. @@ -5229,8 +5014,9 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def datepart(field: Column, source: Column): Column = - call_function("datepart", field, source) + def datepart(field: Column, source: Column): Column = { + Column.fn("datepart", field, source) + } /** * Returns the last day of the month which the given date belongs to. @@ -5243,7 +5029,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def last_day(e: Column): Column = withExpr { LastDay(e.expr) } + def last_day(e: Column): Column = Column.fn("last_day", e) /** * Extracts the minutes as an integer from a given date/timestamp/string. @@ -5251,7 +5037,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def minute(e: Column): Column = withExpr { Minute(e.expr) } + def minute(e: Column): Column = Column.fn("minute", e) /** * Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). @@ -5259,16 +5045,15 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def weekday(e: Column): Column = withExpr { WeekDay(e.expr) } + def weekday(e: Column): Column = Column.fn("weekday", e) /** * @return A date created from year, month and day fields. * @group datetime_funcs * @since 3.3.0 */ - def make_date(year: Column, month: Column, day: Column): Column = withExpr { - MakeDate(year.expr, month.expr, day.expr) - } + def make_date(year: Column, month: Column, day: Column): Column = + Column.fn("make_date", year, month, day) /** * Returns number of months between dates `start` and `end`. @@ -5292,9 +5077,8 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def months_between(end: Column, start: Column): Column = withExpr { - new MonthsBetween(end.expr, start.expr) - } + def months_between(end: Column, start: Column): Column = + Column.fn("months_between", end, start) /** * Returns number of months between dates `end` and `start`. If `roundOff` is set to true, the @@ -5302,9 +5086,8 @@ object functions { * @group datetime_funcs * @since 2.4.0 */ - def months_between(end: Column, start: Column, roundOff: Boolean): Column = withExpr { - MonthsBetween(end.expr, start.expr, lit(roundOff).expr) - } + def months_between(end: Column, start: Column, roundOff: Boolean): Column = + Column.fn("months_between", end, start, lit(roundOff)) /** * Returns the first date which is later than the value of the `date` column that is on the @@ -5339,9 +5122,8 @@ object functions { * @group datetime_funcs * @since 3.2.0 */ - def next_day(date: Column, dayOfWeek: Column): Column = withExpr { - NextDay(date.expr, dayOfWeek.expr) - } + def next_day(date: Column, dayOfWeek: Column): Column = + Column.fn("next_day", date, dayOfWeek) /** * Extracts the seconds as an integer from a given date/timestamp/string. @@ -5349,7 +5131,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def second(e: Column): Column = withExpr { Second(e.expr) } + def second(e: Column): Column = Column.fn("second", e) /** * Extracts the week number as an integer from a given date/timestamp/string. @@ -5361,7 +5143,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def weekofyear(e: Column): Column = withExpr { WeekOfYear(e.expr) } + def weekofyear(e: Column): Column = Column.fn("weekofyear", e) /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -5374,9 +5156,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_unixtime(ut: Column): Column = withExpr { - FromUnixTime(ut.expr, Literal(TimestampFormatter.defaultPattern)) - } + def from_unixtime(ut: Column): Column = Column.fn("from_unixtime", ut) /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -5395,9 +5175,8 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_unixtime(ut: Column, f: String): Column = withExpr { - FromUnixTime(ut.expr, Literal(f)) - } + def from_unixtime(ut: Column, f: String): Column = + Column.fn("from_unixtime", ut, lit(f)) /** * Returns the current Unix timestamp (in seconds) as a long. @@ -5408,9 +5187,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(): Column = withExpr { - UnixTimestamp(CurrentTimestamp(), Literal(TimestampFormatter.defaultPattern)) - } + def unix_timestamp(): Column = unix_timestamp(current_timestamp()) /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), @@ -5422,9 +5199,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column): Column = withExpr { - UnixTimestamp(s.expr, Literal(TimestampFormatter.defaultPattern)) - } + def unix_timestamp(s: Column): Column = Column.fn("unix_timestamp", s) /** * Converts time string with given pattern to Unix timestamp (in seconds). @@ -5441,7 +5216,8 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column, p: String): Column = withExpr { UnixTimestamp(s.expr, Literal(p)) } + def unix_timestamp(s: Column, p: String): Column = + Column.fn("unix_timestamp", s, lit(p)) /** * Converts to a timestamp by casting rules to `TimestampType`. @@ -5452,9 +5228,7 @@ object functions { * @group datetime_funcs * @since 2.2.0 */ - def to_timestamp(s: Column): Column = withExpr { - new ParseToTimestamp(s.expr) - } + def to_timestamp(s: Column): Column = Column.fn("to_timestamp", s) /** * Converts time string with the given pattern to timestamp. @@ -5471,9 +5245,7 @@ object functions { * @group datetime_funcs * @since 2.2.0 */ - def to_timestamp(s: Column, fmt: String): Column = withExpr { - new ParseToTimestamp(s.expr, Literal(fmt)) - } + def to_timestamp(s: Column, fmt: String): Column = Column.fn("to_timestamp", s, lit(fmt)) /** * Parses the `s` with the `format` to a timestamp. The function always returns null on an @@ -5484,7 +5256,7 @@ object functions { * @since 3.5.0 */ def try_to_timestamp(s: Column, format: Column): Column = - call_function("try_to_timestamp", s, format) + Column.fn("try_to_timestamp", s, format) /** * Parses the `s` to a timestamp. The function always returns null on an invalid @@ -5494,8 +5266,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def try_to_timestamp(s: Column): Column = - call_function("try_to_timestamp", s) + def try_to_timestamp(s: Column): Column = Column.fn("try_to_timestamp", s) /** * Converts the column into `DateType` by casting rules to `DateType`. @@ -5503,7 +5274,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def to_date(e: Column): Column = withExpr { new ParseToDate(e.expr) } + def to_date(e: Column): Column = Column.fn("to_date", e) /** * Converts the column into a `DateType` with a specified format @@ -5520,9 +5291,7 @@ object functions { * @group datetime_funcs * @since 2.2.0 */ - def to_date(e: Column, fmt: String): Column = withExpr { - new ParseToDate(e.expr, Literal(fmt)) - } + def to_date(e: Column, fmt: String): Column = Column.fn("to_date", e, lit(fmt)) /** * Returns the number of days since 1970-01-01. @@ -5530,9 +5299,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def unix_date(e: Column): Column = withExpr { - UnixDate(e.expr) - } + def unix_date(e: Column): Column = Column.fn("unix_date", e) /** * Returns the number of microseconds since 1970-01-01 00:00:00 UTC. @@ -5540,9 +5307,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def unix_micros(e: Column): Column = withExpr { - UnixMicros(e.expr) - } + def unix_micros(e: Column): Column = Column.fn("unix_micros", e) /** * Returns the number of milliseconds since 1970-01-01 00:00:00 UTC. @@ -5551,9 +5316,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def unix_millis(e: Column): Column = withExpr { - UnixMillis(e.expr) - } + def unix_millis(e: Column): Column = Column.fn("unix_millis", e) /** * Returns the number of seconds since 1970-01-01 00:00:00 UTC. @@ -5562,9 +5325,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def unix_seconds(e: Column): Column = withExpr { - UnixSeconds(e.expr) - } + def unix_seconds(e: Column): Column = Column.fn("unix_seconds", e) /** * Returns date truncated to the unit specified by the format. @@ -5582,9 +5343,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def trunc(date: Column, format: String): Column = withExpr { - TruncDate(date.expr, Literal(format)) - } + def trunc(date: Column, format: String): Column = Column.fn("trunc", date, lit(format)) /** * Returns timestamp truncated to the unit specified by the format. @@ -5603,9 +5362,8 @@ object functions { * @group datetime_funcs * @since 2.3.0 */ - def date_trunc(format: String, timestamp: Column): Column = withExpr { - TruncTimestamp(Literal(format), timestamp.expr) - } + def date_trunc(format: String, timestamp: Column): Column = + Column.fn("date_trunc", lit(format), timestamp) /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders @@ -5625,9 +5383,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_utc_timestamp(ts: Column, tz: String): Column = withExpr { - FromUTCTimestamp(ts.expr, Literal(tz)) - } + def from_utc_timestamp(ts: Column, tz: String): Column = from_utc_timestamp(ts, lit(tz)) /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders @@ -5636,9 +5392,8 @@ object functions { * @group datetime_funcs * @since 2.4.0 */ - def from_utc_timestamp(ts: Column, tz: Column): Column = withExpr { - FromUTCTimestamp(ts.expr, tz.expr) - } + def from_utc_timestamp(ts: Column, tz: Column): Column = + Column.fn("from_utc_timestamp", ts, tz) /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time @@ -5658,9 +5413,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def to_utc_timestamp(ts: Column, tz: String): Column = withExpr { - ToUTCTimestamp(ts.expr, Literal(tz)) - } + def to_utc_timestamp(ts: Column, tz: String): Column = to_utc_timestamp(ts, lit(tz)) /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time @@ -5669,9 +5422,7 @@ object functions { * @group datetime_funcs * @since 2.4.0 */ - def to_utc_timestamp(ts: Column, tz: Column): Column = withExpr { - ToUTCTimestamp(ts.expr, tz.expr) - } + def to_utc_timestamp(ts: Column, tz: Column): Column = Column.fn("to_utc_timestamp", ts, tz) /** * Bucketize rows into one or more time windows given a timestamp specifying column. Window @@ -5722,12 +5473,8 @@ object functions { timeColumn: Column, windowDuration: String, slideDuration: String, - startTime: String): Column = { - withExpr { - TimeWindow(timeColumn.expr, windowDuration, slideDuration, startTime) - }.as("window") - } - + startTime: String): Column = + Column.fn("window", timeColumn, lit(windowDuration), lit(slideDuration), lit(startTime)) /** * Bucketize rows into one or more time windows given a timestamp specifying column. Window @@ -5824,9 +5571,7 @@ object functions { * @group datetime_funcs * @since 3.4.0 */ - def window_time(windowColumn: Column): Column = withExpr { - WindowTime(windowColumn.expr) - } + def window_time(windowColumn: Column): Column = Column.fn("window_time", windowColumn) /** * Generates session window given a timestamp specifying column. @@ -5889,11 +5634,8 @@ object functions { * @group datetime_funcs * @since 3.2.0 */ - def session_window(timeColumn: Column, gapDuration: Column): Column = { - withExpr { - SessionWindow(timeColumn.expr, gapDuration.expr) - }.as("session_window") - } + def session_window(timeColumn: Column, gapDuration: Column): Column = + Column.fn("session_window", timeColumn, gapDuration).as("session_window") /** * Converts the number of seconds from the Unix epoch (1970-01-01T00:00:00Z) @@ -5901,9 +5643,7 @@ object functions { * @group datetime_funcs * @since 3.1.0 */ - def timestamp_seconds(e: Column): Column = withExpr { - SecondsToTimestamp(e.expr) - } + def timestamp_seconds(e: Column): Column = Column.fn("timestamp_seconds", e) /** * Creates timestamp from the number of milliseconds since UTC epoch. @@ -5911,9 +5651,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def timestamp_millis(e: Column): Column = withExpr { - MillisToTimestamp(e.expr) - } + def timestamp_millis(e: Column): Column = Column.fn("timestamp_millis", e) /** * Creates timestamp from the number of microseconds since UTC epoch. @@ -5921,9 +5659,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def timestamp_micros(e: Column): Column = withExpr { - MicrosToTimestamp(e.expr) - } + def timestamp_micros(e: Column): Column = Column.fn("timestamp_micros", e) /** * Parses the `timestamp` expression with the `format` expression @@ -5933,7 +5669,7 @@ object functions { * @since 3.5.0 */ def to_timestamp_ltz(timestamp: Column, format: Column): Column = - call_function("to_timestamp_ltz", timestamp, format) + Column.fn("to_timestamp_ltz", timestamp, format) /** * Parses the `timestamp` expression with the default format to a timestamp without time zone. @@ -5943,7 +5679,7 @@ object functions { * @since 3.5.0 */ def to_timestamp_ltz(timestamp: Column): Column = - call_function("to_timestamp_ltz", timestamp) + Column.fn("to_timestamp_ltz", timestamp) /** * Parses the `timestamp_str` expression with the `format` expression @@ -5953,7 +5689,7 @@ object functions { * @since 3.5.0 */ def to_timestamp_ntz(timestamp: Column, format: Column): Column = - call_function("to_timestamp_ntz", timestamp, format) + Column.fn("to_timestamp_ntz", timestamp, format) /** * Parses the `timestamp` expression with the default format to a timestamp without time zone. @@ -5963,7 +5699,7 @@ object functions { * @since 3.5.0 */ def to_timestamp_ntz(timestamp: Column): Column = - call_function("to_timestamp_ntz", timestamp) + Column.fn("to_timestamp_ntz", timestamp) /** * Returns the UNIX timestamp of the given time. @@ -5971,9 +5707,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_unix_timestamp(e: Column, format: Column): Column = withExpr { - new ToUnixTimestamp(e.expr, format.expr) - } + def to_unix_timestamp(timeExp: Column, format: Column): Column = + Column.fn("to_unix_timestamp", timeExp, format) /** * Returns the UNIX timestamp of the given time. @@ -5981,9 +5716,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_unix_timestamp(e: Column): Column = withExpr { - new ToUnixTimestamp(e.expr) - } + def to_unix_timestamp(timeExp: Column): Column = + Column.fn("to_unix_timestamp", timeExp) ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions @@ -5994,9 +5728,8 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def array_contains(column: Column, value: Any): Column = withExpr { - ArrayContains(column.expr, lit(value).expr) - } + def array_contains(column: Column, value: Any): Column = + Column.fn("array_contains", column, lit(value)) /** * Returns an ARRAY containing all elements from the source ARRAY as well as the new element. @@ -6005,10 +5738,8 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_append(column: Column, element: Any): Column = withExpr { - ArrayAppend(column.expr, lit(element).expr) - } - + def array_append(column: Column, element: Any): Column = + Column.fn("array_append", column, lit(element)) /** * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both @@ -6017,9 +5748,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def arrays_overlap(a1: Column, a2: Column): Column = withExpr { - ArraysOverlap(a1.expr, a2.expr) - } + def arrays_overlap(a1: Column, a2: Column): Column = Column.fn("arrays_overlap", a1, a2) /** * Returns an array containing all the elements in `x` from index `start` (or starting from the @@ -6046,9 +5775,8 @@ object functions { * @group collection_funcs * @since 3.1.0 */ - def slice(x: Column, start: Column, length: Column): Column = withExpr { - Slice(x.expr, start.expr, length.expr) - } + def slice(x: Column, start: Column, length: Column): Column = + Column.fn("slice", x, start, length) /** * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with @@ -6056,18 +5784,16 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr { - ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement))) - } + def array_join(column: Column, delimiter: String, nullReplacement: String): Column = + Column.fn("array_join", column, lit(delimiter), lit(nullReplacement)) /** * Concatenates the elements of `column` using the `delimiter`. * @group collection_funcs * @since 2.4.0 */ - def array_join(column: Column, delimiter: String): Column = withExpr { - ArrayJoin(column.expr, Literal(delimiter), None) - } + def array_join(column: Column, delimiter: String): Column = + Column.fn("array_join", column, lit(delimiter)) /** * Concatenates multiple input columns together into a single column. @@ -6079,7 +5805,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } + def concat(exprs: Column*): Column = Column.fn("concat", exprs: _*) /** * Locates the position of the first occurrence of the value in the given array as long. @@ -6091,9 +5817,8 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_position(column: Column, value: Any): Column = withExpr { - ArrayPosition(column.expr, lit(value).expr) - } + def array_position(column: Column, value: Any): Column = + Column.fn("array_position", column, lit(value)) /** * Returns element of array at given index in value if column is array. Returns value for @@ -6102,9 +5827,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def element_at(column: Column, value: Any): Column = withExpr { - ElementAt(column.expr, lit(value).expr) - } + def element_at(column: Column, value: Any): Column = Column.fn("element_at", column, lit(value)) /** * (array, index) - Returns element of array at given (1-based) index. If Index is 0, Spark will @@ -6117,9 +5840,8 @@ object functions { * @group map_funcs * @since 3.5.0 */ - def try_element_at(column: Column, value: Column): Column = withExpr { - new TryElementAt(column.expr, value.expr) - } + def try_element_at(column: Column, value: Column): Column = + Column.fn("try_element_at", column, value) /** * Returns element of array at given (0-based) index. If the index points @@ -6128,9 +5850,7 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def get(column: Column, index: Column): Column = withExpr { - new Get(column.expr, index.expr) - } + def get(column: Column, index: Column): Column = Column.fn("get", column, index) /** * Sorts the input array in ascending order. The elements of the input array must be orderable. @@ -6140,7 +5860,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_sort(e: Column): Column = withExpr { new ArraySort(e.expr) } + def array_sort(e: Column): Column = Column.fn("array_sort", e) /** * Sorts the input array based on the given comparator function. The comparator will take two @@ -6151,9 +5871,8 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_sort(e: Column, comparator: (Column, Column) => Column): Column = withExpr { - new ArraySort(e.expr, createLambda(comparator)) - } + def array_sort(e: Column, comparator: (Column, Column) => Column): Column = + Column.fn("array_sort", e, createLambda(comparator)) /** * Remove all elements that equal to element from the given array. @@ -6161,9 +5880,8 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_remove(column: Column, element: Any): Column = withExpr { - ArrayRemove(column.expr, lit(element).expr) - } + def array_remove(column: Column, element: Any): Column = + Column.fn("array_remove", column, lit(element)) /** * Remove all null elements from the given array. @@ -6171,9 +5889,7 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - } + def array_compact(column: Column): Column = Column.fn("array_compact", column) /** * Returns an array containing value as well as all elements from array. The new element is @@ -6182,16 +5898,15 @@ object functions { * @group collection_funcs * @since 3.5.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - } + def array_prepend(column: Column, element: Any): Column = + Column.fn("array_prepend", column, lit(element)) /** * Removes duplicate values from the array. * @group collection_funcs * @since 2.4.0 */ - def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + def array_distinct(e: Column): Column = Column.fn("array_distinct", e) /** * Returns an array of the elements in the intersection of the given two arrays, @@ -6200,9 +5915,8 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_intersect(col1: Column, col2: Column): Column = withExpr { - ArrayIntersect(col1.expr, col2.expr) - } + def array_intersect(col1: Column, col2: Column): Column = + Column.fn("array_intersect", col1, col2) /** * Adds an item into a given array at a specified position @@ -6210,9 +5924,8 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_insert(arr: Column, pos: Column, value: Column): Column = withExpr { - new ArrayInsert(arr.expr, pos.expr, value.expr) - } + def array_insert(arr: Column, pos: Column, value: Column): Column = + Column.fn("array_insert", arr, pos, value) /** * Returns an array of the elements in the union of the given two arrays, without duplicates. @@ -6220,9 +5933,8 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_union(col1: Column, col2: Column): Column = withExpr { - ArrayUnion(col1.expr, col2.expr) - } + def array_union(col1: Column, col2: Column): Column = + Column.fn("array_union", col1, col2) /** * Returns an array of the elements in the first array but not in the second array, @@ -6231,24 +5943,23 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_except(col1: Column, col2: Column): Column = withExpr { - ArrayExcept(col1.expr, col2.expr) - } + def array_except(col1: Column, col2: Column): Column = + Column.fn("array_except", col1, col2) - private def createLambda(f: Column => Column) = { + private def createLambda(f: Column => Column) = Column { val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) val function = f(Column(x)).expr LambdaFunction(function, Seq(x)) } - private def createLambda(f: (Column, Column) => Column) = { + private def createLambda(f: (Column, Column) => Column) = Column { val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) val function = f(Column(x), Column(y)).expr LambdaFunction(function, Seq(x, y)) } - private def createLambda(f: (Column, Column, Column) => Column) = { + private def createLambda(f: (Column, Column, Column) => Column) = Column { val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) @@ -6269,9 +5980,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def transform(column: Column, f: Column => Column): Column = withExpr { - ArrayTransform(column.expr, createLambda(f)) - } + def transform(column: Column, f: Column => Column): Column = + Column.fn("transform", column, createLambda(f)) /** * Returns an array of elements after applying a transformation to each element @@ -6287,9 +5997,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def transform(column: Column, f: (Column, Column) => Column): Column = withExpr { - ArrayTransform(column.expr, createLambda(f)) - } + def transform(column: Column, f: (Column, Column) => Column): Column = + Column.fn("transform", column, createLambda(f)) /** * Returns whether a predicate holds for one or more elements in the array. @@ -6303,9 +6012,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def exists(column: Column, f: Column => Column): Column = withExpr { - ArrayExists(column.expr, createLambda(f)) - } + def exists(column: Column, f: Column => Column): Column = + Column.fn("exists", column, createLambda(f)) /** * Returns whether a predicate holds for every element in the array. @@ -6319,9 +6027,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def forall(column: Column, f: Column => Column): Column = withExpr { - ArrayForAll(column.expr, createLambda(f)) - } + def forall(column: Column, f: Column => Column): Column = + Column.fn("forall", column, createLambda(f)) /** * Returns an array of elements for which a predicate holds in a given array. @@ -6335,9 +6042,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def filter(column: Column, f: Column => Column): Column = withExpr { - ArrayFilter(column.expr, createLambda(f)) - } + def filter(column: Column, f: Column => Column): Column = + Column.fn("filter", column, createLambda(f)) /** * Returns an array of elements for which a predicate holds in a given array. @@ -6352,9 +6058,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def filter(column: Column, f: (Column, Column) => Column): Column = withExpr { - ArrayFilter(column.expr, createLambda(f)) - } + def filter(column: Column, f: (Column, Column) => Column): Column = + Column.fn("filter", column, createLambda(f)) /** * Applies a binary operator to an initial state and all elements in the array, @@ -6378,14 +6083,8 @@ object functions { expr: Column, initialValue: Column, merge: (Column, Column) => Column, - finish: Column => Column): Column = withExpr { - ArrayAggregate( - expr.expr, - initialValue.expr, - createLambda(merge), - createLambda(finish) - ) - } + finish: Column => Column): Column = + Column.fn("aggregate", expr, initialValue, createLambda(merge), createLambda(finish)) /** * Applies a binary operator to an initial state and all elements in the array, @@ -6426,7 +6125,8 @@ object functions { expr: Column, initialValue: Column, merge: (Column, Column) => Column, - finish: Column => Column): Column = aggregate(expr, initialValue, merge, finish) + finish: Column => Column): Column = + Column.fn("reduce", expr, initialValue, createLambda(merge), createLambda(finish)) /** * Applies a binary operator to an initial state and all elements in the array, @@ -6443,7 +6143,7 @@ object functions { * @since 3.5.0 */ def reduce(expr: Column, initialValue: Column, merge: (Column, Column) => Column): Column = - aggregate(expr, initialValue, merge, c => c) + reduce(expr, initialValue, merge, c => c) /** * Merge two given arrays, element-wise, into a single array using a function. @@ -6460,9 +6160,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def zip_with(left: Column, right: Column, f: (Column, Column) => Column): Column = withExpr { - ZipWith(left.expr, right.expr, createLambda(f)) - } + def zip_with(left: Column, right: Column, f: (Column, Column) => Column): Column = + Column.fn("zip_with", left, right, createLambda(f)) /** * Applies a function to every key-value pair in a map and returns @@ -6477,9 +6176,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def transform_keys(expr: Column, f: (Column, Column) => Column): Column = withExpr { - TransformKeys(expr.expr, createLambda(f)) - } + def transform_keys(expr: Column, f: (Column, Column) => Column): Column = + Column.fn("transform_keys", expr, createLambda(f)) /** * Applies a function to every key-value pair in a map and returns @@ -6495,9 +6193,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def transform_values(expr: Column, f: (Column, Column) => Column): Column = withExpr { - TransformValues(expr.expr, createLambda(f)) - } + def transform_values(expr: Column, f: (Column, Column) => Column): Column = + Column.fn("transform_values", expr, createLambda(f)) /** * Returns a map whose key-value pairs satisfy a predicate. @@ -6511,9 +6208,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def map_filter(expr: Column, f: (Column, Column) => Column): Column = withExpr { - MapFilter(expr.expr, createLambda(f)) - } + def map_filter(expr: Column, f: (Column, Column) => Column): Column = + Column.fn("map_filter", expr, createLambda(f)) /** * Merge two given maps, key-wise into a single map using a function. @@ -6528,12 +6224,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def map_zip_with( - left: Column, - right: Column, - f: (Column, Column, Column) => Column): Column = withExpr { - MapZipWith(left.expr, right.expr, createLambda(f)) - } + def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = + Column.fn("map_zip_with", left, right, createLambda(f)) /** * Creates a new row for each element in the given array or map column. @@ -6543,7 +6235,7 @@ object functions { * @group collection_funcs * @since 1.3.0 */ - def explode(e: Column): Column = withExpr { Explode(e.expr) } + def explode(e: Column): Column = Column.fn("explode", e) /** * Creates a new row for each element in the given array or map column. @@ -6554,7 +6246,7 @@ object functions { * @group collection_funcs * @since 2.2.0 */ - def explode_outer(e: Column): Column = withExpr { GeneratorOuter(Explode(e.expr)) } + def explode_outer(e: Column): Column = Column.fn("explode_outer", e) /** * Creates a new row for each element with position in the given array or map column. @@ -6564,7 +6256,7 @@ object functions { * @group collection_funcs * @since 2.1.0 */ - def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) } + def posexplode(e: Column): Column = Column.fn("posexplode", e) /** * Creates a new row for each element with position in the given array or map column. @@ -6575,7 +6267,7 @@ object functions { * @group collection_funcs * @since 2.2.0 */ - def posexplode_outer(e: Column): Column = withExpr { GeneratorOuter(PosExplode(e.expr)) } + def posexplode_outer(e: Column): Column = Column.fn("posexplode_outer", e) /** * Creates a new row for each element in the given array of structs. @@ -6583,7 +6275,7 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def inline(e: Column): Column = withExpr { Inline(e.expr) } + def inline(e: Column): Column = Column.fn("inline", e) /** * Creates a new row for each element in the given array of structs. @@ -6592,7 +6284,7 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def inline_outer(e: Column): Column = withExpr { GeneratorOuter(Inline(e.expr)) } + def inline_outer(e: Column): Column = Column.fn("inline_outer", e) /** * Extracts json object from a json string based on json path specified, and returns json string @@ -6601,9 +6293,8 @@ object functions { * @group collection_funcs * @since 1.6.0 */ - def get_json_object(e: Column, path: String): Column = withExpr { - GetJsonObject(e.expr, lit(path).expr) - } + def get_json_object(e: Column, path: String): Column = + Column.fn("get_json_object", e, lit(path)) /** * Creates a new row for a json column according to the given field names. @@ -6612,9 +6303,9 @@ object functions { * @since 1.6.0 */ @scala.annotation.varargs - def json_tuple(json: Column, fields: String*): Column = withExpr { + def json_tuple(json: Column, fields: String*): Column = { require(fields.nonEmpty, "at least 1 field name should be given.") - JsonTuple(json.expr +: fields.map(Literal.apply)) + Column.fn("json_tuple", json +: fields.map(lit): _*) } // scalastyle:off line.size.limit @@ -6657,8 +6348,8 @@ object functions { * @since 2.2.0 */ // scalastyle:on line.size.limit - def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - JsonToStructs(CharVarcharUtils.failIfHasCharVarchar(schema), options, e.expr) + def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = { + from_json(e, lit(schema.sql), options.iterator) } // scalastyle:off line.size.limit @@ -6817,7 +6508,33 @@ object functions { */ // scalastyle:on line.size.limit def from_json(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { - withExpr(new JsonToStructs(e.expr, schema.expr, options.asScala.toMap)) + from_json(e, schema, options.asScala.iterator) + } + + /** + * Invoke a function with an options map as its last argument. If there are no options, its + * column is dropped. + */ + private def fnWithOptions( + name: String, + options: Iterator[(String, String)], + arguments: Column*): Column = { + val augmentedArguments = if (options.hasNext) { + val flattenedKeyValueIterator = options.flatMap { case (k, v) => + Iterator(lit(k), lit(v)) + } + arguments :+ map(flattenedKeyValueIterator.toSeq: _*) + } else { + arguments + } + Column.fn(name, augmentedArguments: _*) + } + + private def from_json( + e: Column, + schema: Column, + options: Iterator[(String, String)]): Column = { + fnWithOptions("from_json", options, e, schema) } /** @@ -6838,7 +6555,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def schema_of_json(json: Column): Column = withExpr(new SchemaOfJson(json.expr)) + def schema_of_json(json: Column): Column = Column.fn("schema_of_json", json) // scalastyle:off line.size.limit /** @@ -6857,9 +6574,8 @@ object functions { * @since 3.0.0 */ // scalastyle:on line.size.limit - def schema_of_json(json: Column, options: java.util.Map[String, String]): Column = { - withExpr(SchemaOfJson(json.expr, options.asScala.toMap)) - } + def schema_of_json(json: Column, options: java.util.Map[String, String]): Column = + fnWithOptions("schema_of_json", options.asScala.iterator, json) /** * Returns the number of elements in the outermost JSON array. `NULL` is returned in case of @@ -6868,9 +6584,7 @@ object functions { * @group collection_funcs * @since 3.5.0 */ - def json_array_length(jsonArray: Column): Column = withExpr { - LengthOfJsonArray(jsonArray.expr) - } + def json_array_length(e: Column): Column = Column.fn("json_array_length", e) /** * Returns all the keys of the outermost JSON object as an array. If a valid JSON object is @@ -6880,9 +6594,7 @@ object functions { * @group collection_funcs * @since 3.5.0 */ - def json_object_keys(json: Column): Column = withExpr { - JsonObjectKeys(json.expr) - } + def json_object_keys(e: Column): Column = Column.fn("json_object_keys", e) // scalastyle:off line.size.limit /** @@ -6904,9 +6616,8 @@ object functions { * @since 2.1.0 */ // scalastyle:on line.size.limit - def to_json(e: Column, options: Map[String, String]): Column = withExpr { - StructsToJson(options, e.expr) - } + def to_json(e: Column, options: Map[String, String]): Column = + fnWithOptions("to_json", options.iterator, e) // scalastyle:off line.size.limit /** @@ -6954,9 +6665,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def mask(input: Column): Column = withExpr { - new Mask(input.expr) - } + def mask(input: Column): Column = Column.fn("mask", input) /** * Masks the given string value. The function replaces upper-case characters with specific @@ -6971,9 +6680,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def mask(input: Column, upperChar: Column): Column = withExpr { - new Mask(input.expr, upperChar.expr) - } + def mask(input: Column, upperChar: Column): Column = + Column.fn("mask", input, upperChar) /** * Masks the given string value. The function replaces upper-case and lower-case characters with @@ -6990,9 +6698,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def mask(input: Column, upperChar: Column, lowerChar: Column): Column = withExpr { - new Mask(input.expr, upperChar.expr, lowerChar.expr) - } + def mask(input: Column, upperChar: Column, lowerChar: Column): Column = + Column.fn("mask", input, upperChar, lowerChar) /** * Masks the given string value. The function replaces upper-case, lower-case characters and @@ -7011,11 +6718,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def mask(input: Column, upperChar: Column, lowerChar: Column, digitChar: Column): Column = { - withExpr { - new Mask(input.expr, upperChar.expr, lowerChar.expr, digitChar.expr) - } - } + def mask(input: Column, upperChar: Column, lowerChar: Column, digitChar: Column): Column = + Column.fn("mask", input, upperChar, lowerChar, digitChar) /** * Masks the given string value. This can be useful for creating copies of tables with sensitive @@ -7036,15 +6740,12 @@ object functions { * @since 3.5.0 */ def mask( - input: Column, - upperChar: Column, - lowerChar: Column, - digitChar: Column, - otherChar: Column): Column = { - withExpr { - Mask(input.expr, upperChar.expr, lowerChar.expr, digitChar.expr, otherChar.expr) - } - } + input: Column, + upperChar: Column, + lowerChar: Column, + digitChar: Column, + otherChar: Column): Column = + Column.fn("mask", input, upperChar, lowerChar, digitChar, otherChar) /** * Returns length of array or map. @@ -7056,7 +6757,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = withExpr { Size(e.expr) } + def size(e: Column): Column = Column.fn("size", e) /** * Returns length of array or map. This is an alias of `size` function. @@ -7068,7 +6769,7 @@ object functions { * @group collection_funcs * @since 3.5.0 */ - def cardinality(e: Column): Column = call_function("cardinality", e) + def cardinality(e: Column): Column = Column.fn("cardinality", e) /** * Sorts the input array for the given column in ascending order, @@ -7090,7 +6791,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + def sort_array(e: Column, asc: Boolean): Column = Column.fn("sort_array", e, lit(asc)) /** * Returns the minimum value in the array. NaN is greater than any non-NaN elements for @@ -7099,7 +6800,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) } + def array_min(e: Column): Column = Column.fn("array_min", e) /** * Returns the maximum value in the array. NaN is greater than any non-NaN elements for @@ -7108,7 +6809,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + def array_max(e: Column): Column = Column.fn("array_max", e) /** * Returns the total number of elements in the array. The function returns null for null input. @@ -7116,7 +6817,7 @@ object functions { * @group collection_funcs * @since 3.5.0 */ - def array_size(e: Column): Column = withExpr { ArraySize(e.expr) } + def array_size(e: Column): Column = Column.fn("array_size", e) /** * Aggregate function: returns a list of objects with duplicates. @@ -7126,7 +6827,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def array_agg(e: Column): Column = call_function("array_agg", e) + def array_agg(e: Column): Column = Column.fn("array_agg", e) /** * Returns a random permutation of the given array. @@ -7143,7 +6844,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + def reverse(e: Column): Column = Column.fn("reverse", e) /** * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than @@ -7151,7 +6852,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + def flatten(e: Column): Column = Column.fn("flatten", e) /** * Generate a sequence of integers from start to stop, incrementing by step. @@ -7159,9 +6860,8 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def sequence(start: Column, stop: Column, step: Column): Column = withExpr { - new Sequence(start.expr, stop.expr, step.expr) - } + def sequence(start: Column, stop: Column, step: Column): Column = + Column.fn("sequence", start, stop, step) /** * Generate a sequence of integers from start to stop, @@ -7170,9 +6870,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def sequence(start: Column, stop: Column): Column = withExpr { - new Sequence(start.expr, stop.expr) - } + def sequence(start: Column, stop: Column): Column = Column.fn("sequence", start, stop) /** * Creates an array containing the left argument repeated the number of times given by the @@ -7181,9 +6879,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_repeat(left: Column, right: Column): Column = withExpr { - ArrayRepeat(left.expr, right.expr) - } + def array_repeat(left: Column, right: Column): Column = Column.fn("array_repeat", left, right) /** * Creates an array containing the left argument repeated the number of times given by the @@ -7199,37 +6895,36 @@ object functions { * @group collection_funcs * @since 3.3.0 */ - def map_contains_key(column: Column, key: Any): Column = withExpr { - ArrayContains(MapKeys(column.expr), lit(key).expr) - } + def map_contains_key(column: Column, key: Any): Column = + Column.fn("map_contains_key", column, lit(key)) /** * Returns an unordered array containing the keys of the map. * @group collection_funcs * @since 2.3.0 */ - def map_keys(e: Column): Column = withExpr { MapKeys(e.expr) } + def map_keys(e: Column): Column = Column.fn("map_keys", e) /** * Returns an unordered array containing the values of the map. * @group collection_funcs * @since 2.3.0 */ - def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + def map_values(e: Column): Column = Column.fn("map_values", e) /** * Returns an unordered array of all entries in the given map. * @group collection_funcs * @since 3.0.0 */ - def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + def map_entries(e: Column): Column = Column.fn("map_entries", e) /** * Returns a map created from the given array of entries. * @group collection_funcs * @since 2.4.0 */ - def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) } + def map_from_entries(e: Column): Column = Column.fn("map_from_entries", e) /** * Returns a merged array of structs in which the N-th struct contains all N-th values of input @@ -7238,7 +6933,7 @@ object functions { * @since 2.4.0 */ @scala.annotation.varargs - def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } + def arrays_zip(e: Column*): Column = Column.fn("arrays_zip", e: _*) /** * Returns the union of all the given maps. @@ -7246,7 +6941,7 @@ object functions { * @since 2.4.0 */ @scala.annotation.varargs - def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } + def map_concat(cols: Column*): Column = Column.fn("map_concat", cols: _*) // scalastyle:off line.size.limit /** @@ -7266,10 +6961,8 @@ object functions { * @since 3.0.0 */ // scalastyle:on line.size.limit - def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { - val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - CsvToStructs(replaced, options, e.expr) - } + def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = + from_csv(e, lit(schema.toDDL), options.iterator) // scalastyle:off line.size.limit /** @@ -7289,9 +6982,11 @@ object functions { * @since 3.0.0 */ // scalastyle:on line.size.limit - def from_csv(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { - withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap)) - } + def from_csv(e: Column, schema: Column, options: java.util.Map[String, String]): Column = + from_csv(e, schema, options.asScala.iterator) + + private def from_csv(e: Column, schema: Column, options: Iterator[(String, String)]): Column = + fnWithOptions("from_csv", options, e, schema) /** * Parses a CSV string and infers its schema in DDL format. @@ -7311,7 +7006,7 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def schema_of_csv(csv: Column): Column = withExpr(new SchemaOfCsv(csv.expr)) + def schema_of_csv(csv: Column): Column = schema_of_csv(csv, Collections.emptyMap()) // scalastyle:off line.size.limit /** @@ -7330,9 +7025,8 @@ object functions { * @since 3.0.0 */ // scalastyle:on line.size.limit - def schema_of_csv(csv: Column, options: java.util.Map[String, String]): Column = { - withExpr(SchemaOfCsv(csv.expr, options.asScala.toMap)) - } + def schema_of_csv(csv: Column, options: java.util.Map[String, String]): Column = + fnWithOptions("schema_of_csv", options.asScala.iterator, csv) // scalastyle:off line.size.limit /** @@ -7351,9 +7045,8 @@ object functions { * @since 3.0.0 */ // scalastyle:on line.size.limit - def to_csv(e: Column, options: java.util.Map[String, String]): Column = withExpr { - StructsToCsv(options.asScala.toMap, e.expr) - } + def to_csv(e: Column, options: java.util.Map[String, String]): Column = + fnWithOptions("to_csv", options.asScala.iterator, e) /** * Converts a column containing a `StructType` into a CSV string with the specified schema. @@ -7367,7 +7060,6 @@ object functions { def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String].asJava) // scalastyle:off line.size.limit - /** * Parses a column containing a XML string into the data type corresponding to the specified schema. * Returns `null`, in the case of an unparseable string. @@ -7384,10 +7076,8 @@ object functions { * @since 4.0.0 */ // scalastyle:on line.size.limit - def from_xml(e: Column, schema: StructType, options: java.util.Map[String, String]): Column = { - withExpr(XmlToStructs(CharVarcharUtils.failIfHasCharVarchar(schema), - options.asScala.toMap, e.expr)) - } + def from_xml(e: Column, schema: StructType, options: java.util.Map[String, String]): Column = + from_xml(e, lit(CharVarcharUtils.failIfHasCharVarchar(schema).sql), options.asScala.toIterator) // scalastyle:off line.size.limit /** @@ -7408,10 +7098,8 @@ object functions { */ // scalastyle:on line.size.limit def from_xml(e: Column, schema: String, options: java.util.Map[String, String]): Column = { - val dataType = parseTypeWithFallback( - schema, - DataType.fromJson, - fallbackParser = DataType.fromDDL) + val dataType = + parseTypeWithFallback(schema, DataType.fromJson, fallbackParser = DataType.fromDDL) val structType = dataType match { case t: StructType => t case _ => throw DataTypeErrors.failedParsingStructTypeError(schema) @@ -7420,7 +7108,6 @@ object functions { } // scalastyle:off line.size.limit - /** * (Java-specific) Parses a column containing a XML string into a `StructType` * with the specified schema. Returns `null`, in the case of an unparseable string. @@ -7432,7 +7119,7 @@ object functions { */ // scalastyle:on line.size.limit def from_xml(e: Column, schema: Column): Column = { - from_xml(e, schema, Map.empty[String, String].asJava) + from_xml(e, schema, Iterator.empty) } // scalastyle:off line.size.limit @@ -7452,9 +7139,8 @@ object functions { * @since 4.0.0 */ // scalastyle:on line.size.limit - def from_xml(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { - withExpr(new XmlToStructs(e.expr, schema.expr, options.asScala.toMap)) - } + def from_xml(e: Column, schema: Column, options: java.util.Map[String, String]): Column = + from_xml(e, schema, options.asScala.iterator) /** * Parses a column containing a XML string into the data type @@ -7470,6 +7156,10 @@ object functions { def from_xml(e: Column, schema: StructType): Column = from_xml(e, schema, Map.empty[String, String].asJava) + private def from_xml(e: Column, schema: Column, options: Iterator[(String, String)]): Column = { + fnWithOptions("from_xml", options, e, schema) + } + /** * Parses a XML string and infers its schema in DDL format. * @@ -7539,9 +7229,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath(x: Column, p: Column): Column = withExpr { - XPathList(x.expr, p.expr) - } + def xpath(xml: Column, path: Column): Column = + Column.fn("xpath", xml, path) /** * Returns true if the XPath expression evaluates to true, or if a matching node is found. @@ -7549,9 +7238,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_boolean(x: Column, p: Column): Column = withExpr { - XPathBoolean(x.expr, p.expr) - } + def xpath_boolean(xml: Column, path: Column): Column = + Column.fn("xpath_boolean", xml, path) /** * Returns a double value, the value zero if no match is found, @@ -7560,9 +7248,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_double(x: Column, p: Column): Column = withExpr { - XPathDouble(x.expr, p.expr) - } + def xpath_double(xml: Column, path: Column): Column = + Column.fn("xpath_double", xml, path) /** * Returns a double value, the value zero if no match is found, @@ -7571,8 +7258,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_number(x: Column, p: Column): Column = - call_function("xpath_number", x, p) + def xpath_number(xml: Column, path: Column): Column = + Column.fn("xpath_number", xml, path) /** * Returns a float value, the value zero if no match is found, @@ -7581,9 +7268,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_float(x: Column, p: Column): Column = withExpr { - XPathFloat(x.expr, p.expr) - } + def xpath_float(xml: Column, path: Column): Column = + Column.fn("xpath_float", xml, path) /** * Returns an integer value, or the value zero if no match is found, @@ -7592,9 +7278,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_int(x: Column, p: Column): Column = withExpr { - XPathInt(x.expr, p.expr) - } + def xpath_int(xml: Column, path: Column): Column = + Column.fn("xpath_int", xml, path) /** * Returns a long integer value, or the value zero if no match is found, @@ -7603,9 +7288,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_long(x: Column, p: Column): Column = withExpr { - XPathLong(x.expr, p.expr) - } + def xpath_long(xml: Column, path: Column): Column = + Column.fn("xpath_long", xml, path) /** * Returns a short integer value, or the value zero if no match is found, @@ -7614,9 +7298,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_short(x: Column, p: Column): Column = withExpr { - XPathShort(x.expr, p.expr) - } + def xpath_short(xml: Column, path: Column): Column = + Column.fn("xpath_short", xml, path) /** * Returns the text contents of the first xml node that matches the XPath expression. @@ -7624,11 +7307,10 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_string(x: Column, p: Column): Column = withExpr { - XPathString(x.expr, p.expr) - } + def xpath_string(xml: Column, path: Column): Column = + Column.fn("xpath_string", xml, path) - /** + /** * A transform for timestamps to partition data into hours. * * @group partition_transforms @@ -7647,9 +7329,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def convert_timezone(sourceTz: Column, targetTz: Column, sourceTs: Column): Column = withExpr { - ConvertTimezone(sourceTz.expr, targetTz.expr, sourceTs.expr) - } + def convert_timezone(sourceTz: Column, targetTz: Column, sourceTs: Column): Column = + Column.fn("convert_timezone", sourceTz, targetTz, sourceTs) /** * Converts the timestamp without time zone `sourceTs` @@ -7660,9 +7341,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def convert_timezone(targetTz: Column, sourceTs: Column): Column = withExpr { - new ConvertTimezone(targetTz.expr, sourceTs.expr) - } + def convert_timezone(targetTz: Column, sourceTs: Column): Column = + Column.fn("convert_timezone", targetTz, sourceTs) /** * Make DayTimeIntervalType duration from days, hours, mins and secs. @@ -7670,9 +7350,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_dt_interval(days: Column, hours: Column, mins: Column, secs: Column): Column = withExpr { - MakeDTInterval(days.expr, hours.expr, mins.expr, secs.expr) - } + def make_dt_interval(days: Column, hours: Column, mins: Column, secs: Column): Column = + Column.fn("make_dt_interval", days, hours, mins, secs) /** * Make DayTimeIntervalType duration from days, hours and mins. @@ -7680,9 +7359,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_dt_interval(days: Column, hours: Column, mins: Column): Column = withExpr { - new MakeDTInterval(days.expr, hours.expr, mins.expr) - } + def make_dt_interval(days: Column, hours: Column, mins: Column): Column = + Column.fn("make_dt_interval", days, hours, mins) /** * Make DayTimeIntervalType duration from days and hours. @@ -7690,9 +7368,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_dt_interval(days: Column, hours: Column): Column = withExpr { - new MakeDTInterval(days.expr, hours.expr) - } + def make_dt_interval(days: Column, hours: Column): Column = + Column.fn("make_dt_interval", days, hours) /** * Make DayTimeIntervalType duration from days. @@ -7700,9 +7377,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_dt_interval(days: Column): Column = withExpr { - new MakeDTInterval(days.expr) - } + def make_dt_interval(days: Column): Column = + Column.fn("make_dt_interval", days) /** * Make DayTimeIntervalType duration. @@ -7710,9 +7386,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_dt_interval(): Column = withExpr { - new MakeDTInterval() - } + def make_dt_interval(): Column = + Column.fn("make_dt_interval") /** * Make interval from years, months, weeks, days, hours, mins and secs. @@ -7727,9 +7402,8 @@ object functions { days: Column, hours: Column, mins: Column, - secs: Column): Column = withExpr { - MakeInterval(years.expr, months.expr, weeks.expr, days.expr, hours.expr, mins.expr, secs.expr) - } + secs: Column): Column = + Column.fn("make_interval", years, months, weeks, days, hours, mins, secs) /** * Make interval from years, months, weeks, days, hours and mins. @@ -7743,9 +7417,8 @@ object functions { weeks: Column, days: Column, hours: Column, - mins: Column): Column = withExpr { - new MakeInterval(years.expr, months.expr, weeks.expr, days.expr, hours.expr, mins.expr) - } + mins: Column): Column = + Column.fn("make_interval", years, months, weeks, days, hours, mins) /** * Make interval from years, months, weeks, days and hours. @@ -7758,9 +7431,8 @@ object functions { months: Column, weeks: Column, days: Column, - hours: Column): Column = withExpr { - new MakeInterval(years.expr, months.expr, weeks.expr, days.expr, hours.expr) - } + hours: Column): Column = + Column.fn("make_interval", years, months, weeks, days, hours) /** * Make interval from years, months, weeks and days. @@ -7768,13 +7440,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_interval( - years: Column, - months: Column, - weeks: Column, - days: Column): Column = withExpr { - new MakeInterval(years.expr, months.expr, weeks.expr, days.expr) - } + def make_interval(years: Column, months: Column, weeks: Column, days: Column): Column = + Column.fn("make_interval", years, months, weeks, days) /** * Make interval from years, months and weeks. @@ -7782,9 +7449,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_interval(years: Column, months: Column, weeks: Column): Column = withExpr { - new MakeInterval(years.expr, months.expr, weeks.expr) - } + def make_interval(years: Column, months: Column, weeks: Column): Column = + Column.fn("make_interval", years, months, weeks) /** * Make interval from years and months. @@ -7792,9 +7458,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_interval(years: Column, months: Column): Column = withExpr { - new MakeInterval(years.expr, months.expr) - } + def make_interval(years: Column, months: Column): Column = + Column.fn("make_interval", years, months) /** * Make interval from years. @@ -7802,9 +7467,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_interval(years: Column): Column = withExpr { - new MakeInterval(years.expr) - } + def make_interval(years: Column): Column = + Column.fn("make_interval", years) /** * Make interval. @@ -7812,9 +7476,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_interval(): Column = withExpr { - new MakeInterval() - } + def make_interval(): Column = + Column.fn("make_interval") /** * Create timestamp from years, months, days, hours, mins, secs and timezone fields. The result @@ -7832,10 +7495,8 @@ object functions { hours: Column, mins: Column, secs: Column, - timezone: Column): Column = withExpr { - MakeTimestamp(years.expr, months.expr, days.expr, hours.expr, - mins.expr, secs.expr, Some(timezone.expr)) - } + timezone: Column): Column = + Column.fn("make_timestamp", years, months, days, hours, mins, secs, timezone) /** * Create timestamp from years, months, days, hours, mins and secs fields. The result data type @@ -7852,9 +7513,8 @@ object functions { days: Column, hours: Column, mins: Column, - secs: Column): Column = withExpr { - MakeTimestamp(years.expr, months.expr, days.expr, hours.expr, mins.expr, secs.expr) - } + secs: Column): Column = + Column.fn("make_timestamp", years, months, days, hours, mins, secs) /** * Create the current timestamp with local time zone from years, months, days, hours, mins, secs @@ -7872,8 +7532,7 @@ object functions { mins: Column, secs: Column, timezone: Column): Column = - call_function("make_timestamp_ltz", - years, months, days, hours, mins, secs, timezone) + Column.fn("make_timestamp_ltz", years, months, days, hours, mins, secs, timezone) /** * Create the current timestamp with local time zone from years, months, days, hours, mins and @@ -7890,8 +7549,7 @@ object functions { hours: Column, mins: Column, secs: Column): Column = - call_function("make_timestamp_ltz", - years, months, days, hours, mins, secs) + Column.fn("make_timestamp_ltz", years, months, days, hours, mins, secs) /** * Create local date-time from years, months, days, hours, mins, secs fields. If the @@ -7908,8 +7566,7 @@ object functions { hours: Column, mins: Column, secs: Column): Column = - call_function("make_timestamp_ntz", - years, months, days, hours, mins, secs) + Column.fn("make_timestamp_ntz", years, months, days, hours, mins, secs) /** * Make year-month interval from years, months. @@ -7917,9 +7574,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_ym_interval(years: Column, months: Column): Column = withExpr { - MakeYMInterval(years.expr, months.expr) - } + def make_ym_interval(years: Column, months: Column): Column = + Column.fn("make_ym_interval", years, months) /** * Make year-month interval from years. @@ -7927,9 +7583,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_ym_interval(years: Column): Column = withExpr { - new MakeYMInterval(years.expr) - } + def make_ym_interval(years: Column): Column = Column.fn("make_ym_interval", years) /** * Make year-month interval. @@ -7937,9 +7591,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def make_ym_interval(): Column = withExpr { - new MakeYMInterval() - } + def make_ym_interval(): Column = Column.fn("make_ym_interval") /** * A transform for any type that partitions by a hash of the input column. @@ -7976,8 +7628,7 @@ object functions { * @group predicates_funcs * @since 3.5.0 */ - def ifnull(col1: Column, col2: Column): Column = - call_function("ifnull", col1, col2) + def ifnull(col1: Column, col2: Column): Column = Column.fn("ifnull", col1, col2) /** * Returns true if `col` is not null, or false otherwise. @@ -7985,9 +7636,7 @@ object functions { * @group predicates_funcs * @since 3.5.0 */ - def isnotnull(col: Column): Column = withExpr { - IsNotNull(col.expr) - } + def isnotnull(col: Column): Column = Column.fn("isnotnull", col) /** * Returns same result as the EQUAL(=) operator for non-null operands, @@ -7996,9 +7645,7 @@ object functions { * @group predicates_funcs * @since 3.5.0 */ - def equal_null(col1: Column, col2: Column): Column = withExpr { - new EqualNull(col1.expr, col2.expr) - } + def equal_null(col1: Column, col2: Column): Column = Column.fn("equal_null", col1, col2) /** * Returns null if `col1` equals to `col2`, or `col1` otherwise. @@ -8006,9 +7653,7 @@ object functions { * @group predicates_funcs * @since 3.5.0 */ - def nullif(col1: Column, col2: Column): Column = withExpr { - new NullIf(col1.expr, col2.expr) - } + def nullif(col1: Column, col2: Column): Column = Column.fn("nullif", col1, col2) /** * Returns `col2` if `col1` is null, or `col1` otherwise. @@ -8016,9 +7661,7 @@ object functions { * @group predicates_funcs * @since 3.5.0 */ - def nvl(col1: Column, col2: Column): Column = withExpr { - new Nvl(col1.expr, col2.expr) - } + def nvl(col1: Column, col2: Column): Column = Column.fn("nvl", col1, col2) /** * Returns `col2` if `col1` is not null, or `col3` otherwise. @@ -8026,9 +7669,7 @@ object functions { * @group predicates_funcs * @since 3.5.0 */ - def nvl2(col1: Column, col2: Column, col3: Column): Column = withExpr { - new Nvl2(col1.expr, col2.expr, col3.expr) - } + def nvl2(col1: Column, col2: Column, col3: Column): Column = Column.fn("nvl2", col1, col2, col3) // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 9e8d77c53f366..0baded3323c68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -2566,7 +2566,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { intDf.select(assert_true($"a" > $"b")).collect() } assert(e3.getCause.isInstanceOf[RuntimeException]) - assert(e3.getCause.getMessage == "'('a > 'b)' is not true!") + assert(e3.getCause.getMessage.matches("'\\(a#\\d+ > b#\\d+\\)' is not true!")) } test("raise_error") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d78771a8f19bc..80862eec41e0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -849,7 +850,7 @@ class DataFrameAggregateSuite extends QueryTest assert(testData.groupBy(col("key")).toString.contains( "[grouping expressions: [key], value: [key: int, value: string], type: GroupBy]")) assert(testData.groupBy(current_date()).toString.contains( - "grouping expressions: [current_date(None)], value: [key: int, value: string], " + + "grouping expressions: ['current_date()], value: [key: int, value: string], " + "type: GroupBy]")) } @@ -1410,20 +1411,21 @@ class DataFrameAggregateSuite extends QueryTest Duration.ofSeconds(14)) :: Nil) assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) + val metadata = new MetadataBuilder().putString(AUTO_GENERATED_ALIAS, "true").build() assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), - StructField("sum(year-month)", YearMonthIntervalType()), - StructField("sum(year)", YearMonthIntervalType(YEAR)), - StructField("sum(month)", YearMonthIntervalType(MONTH)), - StructField("sum(day-second)", DayTimeIntervalType()), - StructField("sum(day-minute)", DayTimeIntervalType(DAY, MINUTE)), - StructField("sum(day-hour)", DayTimeIntervalType(DAY, HOUR)), - StructField("sum(day)", DayTimeIntervalType(DAY)), - StructField("sum(hour-second)", DayTimeIntervalType(HOUR, SECOND)), - StructField("sum(hour-minute)", DayTimeIntervalType(HOUR, MINUTE)), - StructField("sum(hour)", DayTimeIntervalType(HOUR)), - StructField("sum(minute-second)", DayTimeIntervalType(MINUTE, SECOND)), - StructField("sum(minute)", DayTimeIntervalType(MINUTE)), - StructField("sum(second)", DayTimeIntervalType(SECOND))))) + StructField("sum(year-month)", YearMonthIntervalType(), metadata = metadata), + StructField("sum(year)", YearMonthIntervalType(YEAR), metadata = metadata), + StructField("sum(month)", YearMonthIntervalType(MONTH), metadata = metadata), + StructField("sum(day-second)", DayTimeIntervalType(), metadata = metadata), + StructField("sum(day-minute)", DayTimeIntervalType(DAY, MINUTE), metadata = metadata), + StructField("sum(day-hour)", DayTimeIntervalType(DAY, HOUR), metadata = metadata), + StructField("sum(day)", DayTimeIntervalType(DAY), metadata = metadata), + StructField("sum(hour-second)", DayTimeIntervalType(HOUR, SECOND), metadata = metadata), + StructField("sum(hour-minute)", DayTimeIntervalType(HOUR, MINUTE), metadata = metadata), + StructField("sum(hour)", DayTimeIntervalType(HOUR), metadata = metadata), + StructField("sum(minute-second)", DayTimeIntervalType(MINUTE, SECOND), metadata = metadata), + StructField("sum(minute)", DayTimeIntervalType(MINUTE), metadata = metadata), + StructField("sum(second)", DayTimeIntervalType(SECOND), metadata = metadata)))) val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), (Period.ofMonths(10), Duration.ofDays(10))) @@ -1545,21 +1547,22 @@ class DataFrameAggregateSuite extends QueryTest Duration.ofMinutes(4).plusSeconds(20), Duration.ofSeconds(7)) :: Nil) assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) + val metadata = new MetadataBuilder().putString(AUTO_GENERATED_ALIAS, "true").build() assert(avgDF2.schema == StructType(Seq( StructField("class", IntegerType, false), - StructField("avg(year-month)", YearMonthIntervalType()), - StructField("avg(year)", YearMonthIntervalType()), - StructField("avg(month)", YearMonthIntervalType()), - StructField("avg(day-second)", DayTimeIntervalType()), - StructField("avg(day-minute)", DayTimeIntervalType()), - StructField("avg(day-hour)", DayTimeIntervalType()), - StructField("avg(day)", DayTimeIntervalType()), - StructField("avg(hour-second)", DayTimeIntervalType()), - StructField("avg(hour-minute)", DayTimeIntervalType()), - StructField("avg(hour)", DayTimeIntervalType()), - StructField("avg(minute-second)", DayTimeIntervalType()), - StructField("avg(minute)", DayTimeIntervalType()), - StructField("avg(second)", DayTimeIntervalType())))) + StructField("avg(year-month)", YearMonthIntervalType(), metadata = metadata), + StructField("avg(year)", YearMonthIntervalType(), metadata = metadata), + StructField("avg(month)", YearMonthIntervalType(), metadata = metadata), + StructField("avg(day-second)", DayTimeIntervalType(), metadata = metadata), + StructField("avg(day-minute)", DayTimeIntervalType(), metadata = metadata), + StructField("avg(day-hour)", DayTimeIntervalType(), metadata = metadata), + StructField("avg(day)", DayTimeIntervalType(), metadata = metadata), + StructField("avg(hour-second)", DayTimeIntervalType(), metadata = metadata), + StructField("avg(hour-minute)", DayTimeIntervalType(), metadata = metadata), + StructField("avg(hour)", DayTimeIntervalType(), metadata = metadata), + StructField("avg(minute-second)", DayTimeIntervalType(), metadata = metadata), + StructField("avg(minute)", DayTimeIntervalType(), metadata = metadata), + StructField("avg(second)", DayTimeIntervalType(), metadata = metadata)))) val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), (Period.ofMonths(10), Duration.ofDays(10))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7044b7c90c226..4020688bc3194 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -764,7 +764,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "sqlExpr" -> """"array_sort\(a, lambdafunction\(\(x_\d+ - y_\d+\), x_\d+, y_\d+\)\)"""", + "sqlExpr" -> """"array_sort\(a, lambdafunction\(`-`\(x_\d+, y_\d+\), x_\d+, y_\d+\)\)"""", "paramIndex" -> "1", "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", @@ -970,7 +970,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val qualifiedDF = df.as("foo") // Fields are UnresolvedAttribute - val zippedDF1 = qualifiedDF.select(arrays_zip($"foo.val1", $"foo.val2") as "zipped") + val zippedDF1 = + qualifiedDF.select(Column(ArraysZip(Seq($"foo.val1".expr, $"foo.val2".expr))) as "zipped") val maybeAlias1 = zippedDF1.queryExecution.logical.expressions.head assert(maybeAlias1.isInstanceOf[Alias]) val maybeArraysZip1 = maybeAlias1.children.head @@ -984,7 +985,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(fieldNames1.toSeq === Seq("val1", "val2")) // Fields are resolved NamedExpression - val zippedDF2 = df.select(arrays_zip(df("val1"), df("val2")) as "zipped") + val zippedDF2 = + df.select(Column(ArraysZip(Seq(df("val1").expr, df("val2").expr))) as "zipped") val maybeAlias2 = zippedDF2.queryExecution.logical.expressions.head assert(maybeAlias2.isInstanceOf[Alias]) val maybeArraysZip2 = maybeAlias2.children.head @@ -999,7 +1001,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(fieldNames2.toSeq === Seq("val1", "val2")) // Fields are unresolved NamedExpression - val zippedDF3 = df.select(arrays_zip($"val1" as "val3", $"val2" as "val4") as "zipped") + val zippedDF3 = df.select( + Column(ArraysZip(Seq(($"val1" as "val3").expr, ($"val2" as "val4").expr))) as "zipped") val maybeAlias3 = zippedDF3.queryExecution.logical.expressions.head assert(maybeAlias3.isInstanceOf[Alias]) val maybeArraysZip3 = maybeAlias3.children.head @@ -1013,7 +1016,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(fieldNames3.toSeq === Seq("val3", "val4")) // Fields are neither UnresolvedAttribute nor NamedExpression - val zippedDF4 = df.select(arrays_zip(array_sort($"val1"), array_sort($"val2")) as "zipped") + val zippedDF4 = df.select( + Column(ArraysZip(Seq(array_sort($"val1").expr, array_sort($"val2").expr))) as "zipped") val maybeAlias4 = zippedDF4.queryExecution.logical.expressions.head assert(maybeAlias4.isInstanceOf[Alias]) val maybeArraysZip4 = maybeAlias4.children.head @@ -3742,7 +3746,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", matchPVals = true, parameters = Map( - "sqlExpr" -> """"map_filter\(i, lambdafunction\(\(x_\d+ > y_\d+\), x_\d+, y_\d+\)\)"""", + "sqlExpr" -> """"map_filter\(i, lambdafunction\(`>`\(x_\d+, y_\d+\), x_\d+, y_\d+\)\)"""", "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", @@ -5227,7 +5231,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { matchPVals = true, parameters = Map( "sqlExpr" -> - """"transform_values\(x, lambdafunction\(\(x_\d+ \+ 1\), x_\d+, y_\d+\)\)"""", + """"transform_values\(x, lambdafunction\(`\+`\(x_\d+, 1\), x_\d+, y_\d+\)\)"""", "paramIndex" -> "1", "inputSql" -> "\"x\"", "inputType" -> "\"ARRAY\"", @@ -5950,7 +5954,7 @@ object DataFrameFunctionsSuite { case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback { override def nullable: Boolean = child.nullable override def dataType: DataType = child.dataType - override lazy val resolved = true + override lazy val resolved = child.resolved override def eval(input: InternalRow): Any = child.eval(input) override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpr = copy(child = newChild) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2eba9f1810982..805bb1ccc287d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -348,9 +348,19 @@ class DataFrameSuite extends QueryTest exception = intercept[AnalysisException] { df.select(explode($"*")) }, - errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", - parameters = Map("elem" -> "'*'", "prettyName" -> "expression `explode`") + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"explode(csv)\"", + "paramIndex" -> "1", + "inputSql"-> "\"csv\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "(\"ARRAY\" or \"MAP\")") ) + + val df2 = Seq(Array("1", "2"), Array("4"), Array("7", "8", "9")).toDF("csv") + checkAnswer( + df2.select(explode($"*")), + Row("1") :: Row("2") :: Row("4") :: Row("7") :: Row("8") :: Row("9") :: Nil) } test("explode on output of array-valued function") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index bc7a68732a168..86804ceed4f88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -363,7 +363,7 @@ object IntegratedUDFTestUtils extends SQLHelper { * casted_col.cast(df.schema["col"].dataType) * }}} */ - case class TestPythonUDF(name: String) extends TestUDF { + case class TestPythonUDF(name: String, returnType: Option[DataType] = None) extends TestUDF { private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( name = name, func = SimplePythonFunction( @@ -381,11 +381,14 @@ object IntegratedUDFTestUtils extends SQLHelper { override def builder(e: Seq[Expression]): Expression = { assert(e.length == 1, "Defined UDF only has one column") val expr = e.head - assert(expr.resolved, "column should be resolved to use the same type " + - "as input. Try df(name) or df.col(name)") + val rt = returnType.getOrElse { + assert(expr.resolved, "column should be resolved to use the same type " + + "as input. Try df(name) or df.col(name)") + expr.dataType + } val pythonUDF = new PythonUDFWithoutId( super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF]) - Cast(pythonUDF, expr.dataType) + Cast(pythonUDF, rt) } } @@ -705,7 +708,9 @@ object IntegratedUDFTestUtils extends SQLHelper { * casted_col.cast(df.schema["col"].dataType) * }}} */ - case class TestScalarPandasUDF(name: String) extends TestUDF { + case class TestScalarPandasUDF( + name: String, + returnType: Option[DataType] = None) extends TestUDF { private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( name = name, func = SimplePythonFunction( @@ -723,11 +728,14 @@ object IntegratedUDFTestUtils extends SQLHelper { override def builder(e: Seq[Expression]): Expression = { assert(e.length == 1, "Defined UDF only has one column") val expr = e.head - assert(expr.resolved, "column should be resolved to use the same type " + - "as input. Try df(name) or df.col(name)") + val rt = returnType.getOrElse { + assert(expr.resolved, "column should be resolved to use the same type " + + "as input. Try df(name) or df.col(name)") + expr.dataType + } val pythonUDF = new PythonUDFWithoutId( super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF]) - Cast(pythonUDF, expr.dataType) + Cast(pythonUDF, rt) } } @@ -826,7 +834,9 @@ object IntegratedUDFTestUtils extends SQLHelper { * casted_col.cast(df.schema("col").dataType) * }}} */ - class TestInternalScalaUDF(name: String) extends SparkUserDefinedFunction( + class TestInternalScalaUDF( + name: String, + returnType: Option[DataType] = None) extends SparkUserDefinedFunction( (input: Any) => if (input == null) { null } else { @@ -839,9 +849,12 @@ object IntegratedUDFTestUtils extends SQLHelper { override def apply(exprs: Column*): Column = { assert(exprs.length == 1, "Defined UDF only has one column") val expr = exprs.head.expr - assert(expr.resolved, "column should be resolved to use the same type " + - "as input. Try df(name) or df.col(name)") - Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType)) + val rt = returnType.getOrElse { + assert(expr.resolved, "column should be resolved to use the same type " + + "as input. Try df(name) or df.col(name)") + expr.dataType + } + Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), rt)) } override def withName(name: String): TestInternalScalaUDF = { @@ -851,8 +864,8 @@ object IntegratedUDFTestUtils extends SQLHelper { } } - case class TestScalaUDF(name: String) extends TestUDF { - private[IntegratedUDFTestUtils] lazy val udf = new TestInternalScalaUDF(name) + case class TestScalaUDF(name: String, returnType: Option[DataType] = None) extends TestUDF { + private[IntegratedUDFTestUtils] lazy val udf = new TestInternalScalaUDF(name, returnType) def apply(exprs: Column*): Column = udf(exprs: _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 7f938deaaa645..3f665d637748b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.{array, from_json, grouping, grouping_id, lit, struct, sum, udf} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{IntegerType, MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, IntegerType, MapType, StringType, StructField, StructType} import org.apache.spark.util.Utils case class StringLongClass(a: String, b: Long) @@ -169,7 +169,7 @@ class QueryCompilationErrorsSuite ).toDF("CustomerName", "CustomerID") val e = intercept[AnalysisException] { - val pythonTestUDF = TestPythonUDF(name = "python_udf") + val pythonTestUDF = TestPythonUDF(name = "python_udf", Some(BooleanType)) df1.join( df2, pythonTestUDF(df1("CustomerID") === df2("CustomerID")), "leftouter").collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index fb10e90b6ccea..d133270e2956c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -222,7 +222,7 @@ class QueryExecutionErrorsSuite exception = e2, errorClass = "UNSUPPORTED_FEATURE.PIVOT_TYPE", parameters = Map("value" -> "[dotnet,Dummies]", - "type" -> "\"STRUCT\""), + "type" -> "unknown"), sqlState = "0A000") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 588757931f0f9..d3538cf65a50a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -65,15 +65,15 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession { Row(6, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) - private lazy val singleConditionEQ = (left.col("a") === right.col("c")).expr + private lazy val singleConditionEQ = EqualTo(left.col("a").expr, right.col("c").expr) private lazy val composedConditionEQ = { - And((left.col("a") === right.col("c")).expr, + And(EqualTo(left.col("a").expr, right.col("c").expr), LessThan(left.col("b").expr, right.col("d").expr)) } private lazy val composedConditionNEQ = { - And((left.col("a") < right.col("c")).expr, + And(LessThan(left.col("a").expr, right.col("c").expr), LessThan(left.col("b").expr, right.col("d").expr)) } @@ -298,7 +298,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession { LeftAnti, left, rightUniqueKey, - Some((left.col("a") === rightUniqueKey.col("c") && left.col("b") < rightUniqueKey.col("d")) - .expr), + Some(And(EqualTo(left.col("a").expr, rightUniqueKey.col("c").expr), + LessThan(left.col("b").expr, rightUniqueKey.col("d").expr))), Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(null, null), Row(null, 5.0), Row(6, null))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index c496a5ae5d80e..c283d39425812 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{EqualNullSafe, EqualTo, Expression} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner @@ -227,7 +227,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { "inner join, one match per row", myUpperCaseData, myLowerCaseData, - () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + () => EqualTo(myUpperCaseData.col("N").expr, myLowerCaseData.col("n").expr), Seq( (1, "A", 1, "a"), (2, "B", 2, "b"), @@ -243,7 +243,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { "inner join, multiple matches", left, right, - () => (left.col("a") === right.col("a")).expr, + () => EqualTo(left.col("a").expr, right.col("a").expr), Seq( (1, 1, 1, 1), (1, 1, 1, 2), @@ -260,7 +260,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { "inner join, no matches", left, right, - () => (left.col("a") === right.col("a")).expr, + () => EqualTo(left.col("a").expr, right.col("a").expr), Seq.empty ) } @@ -272,7 +272,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { "inner join, null safe", left, right, - () => (left.col("b") <=> right.col("b")).expr, + () => EqualNullSafe(left.col("b").expr, right.col("b").expr), Seq( (1, 0, 1, 0), (2, null, 2, null) @@ -288,7 +288,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { "SPARK-15822 - test structs as keys", left, right, - () => (left.col("key") === right.col("key")).expr, + () => EqualTo(left.col("key").expr, right.col("key").expr), Seq( (Row(0, 0), "L0", Row(0, 0), "R0"), (Row(1, 1), "L1", Row(1, 1), "R1"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 4f78833abdb9f..962021604e717 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, LessThan} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ @@ -60,7 +60,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { )), new StructType().add("c", IntegerType).add("d", DoubleType)) private lazy val condition = { - And((left.col("a") === right.col("c")).expr, + And(EqualTo(left.col("a").expr, right.col("c").expr), LessThan(left.col("b").expr, right.col("d").expr)) } @@ -86,7 +86,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { )), new StructType().add("c", IntegerType).add("d", DoubleType)) private lazy val uniqueCondition = { - And((uniqueLeft.col("a") === uniqueRight.col("c")).expr, + And(EqualTo(uniqueLeft.col("a").expr, uniqueRight.col("c").expr), LessThan(uniqueLeft.col("b").expr, uniqueRight.col("d").expr)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index d86faec1a7bbd..84c23b728dca0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest} import org.apache.spark.sql.functions.count import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.LongType class PythonUDFSuite extends QueryTest with SharedSparkSession { import testImplicits._ import IntegratedUDFTestUtils._ - val scalaTestUDF = TestScalaUDF(name = "scalaUDF") - val pythonTestUDF = TestPythonUDF(name = "pyUDF") + val scalaTestUDF = TestScalaUDF(name = "scalaUDF", Some(LongType)) + val pythonTestUDF = TestPythonUDF(name = "pyUDF", Some(LongType)) lazy val base = Seq( (Some(1), Some(1)), (Some(1), Some(2)), (Some(2), Some(1)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala index 69b7154895341..26039b9185dbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala @@ -17,23 +17,17 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinConditionSplitPredicates import org.apache.spark.sql.types._ class StreamingSymmetricHashJoinHelperSuite extends StreamTest { - import org.apache.spark.sql.functions._ - val leftAttributeA = AttributeReference("a", IntegerType)() val leftAttributeB = AttributeReference("b", IntegerType)() val rightAttributeC = AttributeReference("c", IntegerType)() val rightAttributeD = AttributeReference("d", IntegerType)() - val leftColA = new Column(leftAttributeA) - val leftColB = new Column(leftAttributeB) - val rightColC = new Column(rightAttributeC) - val rightColD = new Column(rightAttributeD) val left = new LocalTableScanExec(Seq(leftAttributeA, leftAttributeB), Seq()) val right = new LocalTableScanExec(Seq(rightAttributeC, rightAttributeD), Seq()) @@ -49,7 +43,7 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest { test("only literals") { // Literal-only conjuncts end up on the left side because that's the first bucket they fit in. // There's no semantic reason they couldn't be in any bucket. - val predicate = (lit(1) < lit(5) && lit(6) < lit(7) && lit(0) === lit(-1)).expr + val predicate = Literal(1) < Literal(5) && Literal(6) < Literal(7) && Literal(0) === Literal(-1) val split = JoinConditionSplitPredicates(Some(predicate), left, right) assert(split.leftSideOnly.contains(predicate)) @@ -59,7 +53,8 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest { } test("only left") { - val predicate = (leftColA > lit(1) && leftColB > lit(5) && leftColA < leftColB).expr + val predicate = + leftAttributeA > Literal(1) && leftAttributeB > Literal(5) && leftAttributeA < leftAttributeB val split = JoinConditionSplitPredicates(Some(predicate), left, right) assert(split.leftSideOnly.contains(predicate)) @@ -69,7 +64,8 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest { } test("only right") { - val predicate = (rightColC > lit(1) && rightColD > lit(5) && rightColD < rightColC).expr + val predicate = rightAttributeC > Literal(1) && rightAttributeD > Literal(5) && + rightAttributeD < rightAttributeC val split = JoinConditionSplitPredicates(Some(predicate), left, right) assert(split.leftSideOnly.isEmpty) @@ -80,47 +76,55 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest { test("mixed conjuncts") { val predicate = - (leftColA > leftColB - && rightColC > rightColD - && leftColA === rightColC - && lit(1) === lit(1)).expr + (leftAttributeA > leftAttributeB + && rightAttributeC > rightAttributeD + && leftAttributeA === rightAttributeC + && Literal(1) === Literal(1)) val split = JoinConditionSplitPredicates(Some(predicate), left, right) - assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) - assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) - assert(split.bothSides.contains((leftColA === rightColC).expr)) + assert(split.leftSideOnly.contains( + leftAttributeA > leftAttributeB && Literal(1) === Literal(1))) + assert(split.rightSideOnly.contains( + rightAttributeC > rightAttributeD && Literal(1) === Literal(1))) + assert(split.bothSides.contains((leftAttributeA === rightAttributeC))) assert(split.full.contains(predicate)) } test("conjuncts after nondeterministic") { val predicate = - (rand(9) > lit(0) - && leftColA > leftColB - && rightColC > rightColD - && leftColA === rightColC - && lit(1) === lit(1)).expr + (rand(9) > Literal(0) + && leftAttributeA > leftAttributeB + && rightAttributeC > rightAttributeD + && leftAttributeA === rightAttributeC + && Literal(1) === Literal(1)) val split = JoinConditionSplitPredicates(Some(predicate), left, right) - assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) - assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) - assert(split.bothSides.contains((leftColA === rightColC && rand(9) > lit(0)).expr)) + assert(split.leftSideOnly.contains( + leftAttributeA > leftAttributeB && Literal(1) === Literal(1))) + assert(split.rightSideOnly.contains( + rightAttributeC > rightAttributeD && Literal(1) === Literal(1))) + assert(split.bothSides.contains( + leftAttributeA === rightAttributeC && rand(9).expr > Literal(0))) assert(split.full.contains(predicate)) } test("conjuncts before nondeterministic") { - val randCol = rand() + val randAttribute = rand(0) val predicate = - (leftColA > leftColB - && rightColC > rightColD - && leftColA === rightColC - && lit(1) === lit(1) - && randCol > lit(0)).expr + (leftAttributeA > leftAttributeB + && rightAttributeC > rightAttributeD + && leftAttributeA === rightAttributeC + && Literal(1) === Literal(1) + && randAttribute > Literal(0)) val split = JoinConditionSplitPredicates(Some(predicate), left, right) - assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) - assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) - assert(split.bothSides.contains((leftColA === rightColC && randCol > lit(0)).expr)) + assert(split.leftSideOnly.contains( + leftAttributeA > leftAttributeB && Literal(1) === Literal(1))) + assert(split.rightSideOnly.contains( + rightAttributeC > rightAttributeD && Literal(1) === Literal(1))) + assert(split.bothSides.contains( + leftAttributeA === rightAttributeC && randAttribute > Literal(0))) assert(split.full.contains(predicate)) } }