Skip to content

Commit 0c8147f

Browse files
committed
Merge remote-tracking branch 'apache/master' into light-analysis-exception
2 parents d2d5e73 + d97a4e2 commit 0c8147f

File tree

41 files changed

+1042
-267
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1042
-267
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ class SparkSession private[sql] (
126126
private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = {
127127
newDataset(encoder) { builder =>
128128
if (data.nonEmpty) {
129-
val timeZoneId = conf.get("spark.sql.session.timeZone")
130129
val arrowData = ArrowSerializer.serialize(data, encoder, allocator, timeZoneId)
131130
if (arrowData.size() <= conf.get("spark.sql.session.localRelationCacheThreshold").toInt) {
132131
builder.getLocalRelationBuilder
@@ -529,9 +528,11 @@ class SparkSession private[sql] (
529528
client.semanticHash(plan).getSemanticHash.getResult
530529
}
531530

531+
private[sql] def timeZoneId: String = conf.get("spark.sql.session.timeZone")
532+
532533
private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
533534
val value = client.execute(plan)
534-
val result = new SparkResult(value, allocator, encoder)
535+
val result = new SparkResult(value, allocator, encoder, timeZoneId)
535536
cleaner.register(result)
536537
result
537538
}

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ import org.apache.spark.sql.util.ArrowUtils
3636
private[sql] class SparkResult[T](
3737
responses: java.util.Iterator[proto.ExecutePlanResponse],
3838
allocator: BufferAllocator,
39-
encoder: AgnosticEncoder[T])
39+
encoder: AgnosticEncoder[T],
40+
timeZoneId: String)
4041
extends AutoCloseable
4142
with Cleanable { self =>
4243

@@ -213,7 +214,8 @@ private[sql] class SparkResult[T](
213214
new ConcatenatingArrowStreamReader(
214215
allocator,
215216
Iterator.single(new ResultMessageIterator(destructive)),
216-
destructive))
217+
destructive),
218+
timeZoneId)
217219
}
218220
}
219221

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala

Lines changed: 121 additions & 100 deletions
Large diffs are not rendered by default.
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connect.client.arrow
18+
19+
import java.math.{BigDecimal => JBigDecimal}
20+
import java.sql.{Date, Timestamp}
21+
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffset}
22+
23+
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector}
24+
import org.apache.arrow.vector.util.Text
25+
26+
import org.apache.spark.sql.catalyst.expressions.Cast
27+
import org.apache.spark.sql.catalyst.util.{DateFormatter, IntervalUtils, StringUtils, TimestampFormatter}
28+
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
29+
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
30+
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
31+
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, Decimal, YearMonthIntervalType}
32+
import org.apache.spark.sql.util.ArrowUtils
33+
34+
/**
35+
* Base class for reading leaf values from an arrow vector. This reader has read methods for all
36+
* leaf data types supported by the encoder framework. A subclass should always implement one of
37+
* the read methods. If upcasting is allowed for the given vector, then all allowed read methods
38+
* must be implemented.
39+
*/
40+
private[arrow] abstract class ArrowVectorReader {
41+
def isNull(i: Int): Boolean
42+
def getBoolean(i: Int): Boolean = unsupported()
43+
def getByte(i: Int): Byte = unsupported()
44+
def getShort(i: Int): Short = unsupported()
45+
def getInt(i: Int): Int = unsupported()
46+
def getLong(i: Int): Long = unsupported()
47+
def getFloat(i: Int): Float = unsupported()
48+
def getDouble(i: Int): Double = unsupported()
49+
def getString(i: Int): String = unsupported()
50+
def getBytes(i: Int): Array[Byte] = unsupported()
51+
def getJavaDecimal(i: Int): JBigDecimal = unsupported()
52+
def getJavaBigInt(i: Int): java.math.BigInteger = getJavaDecimal(i).toBigInteger
53+
def getScalaDecimal(i: Int): BigDecimal = BigDecimal(getJavaDecimal(i))
54+
def getScalaBigInt(i: Int): BigInt = BigInt(getJavaBigInt(i))
55+
def getDecimal(i: Int): Decimal = Decimal(getJavaDecimal(i))
56+
def getPeriod(i: Int): java.time.Period = unsupported()
57+
def getDuration(i: Int): java.time.Duration = unsupported()
58+
def getDate(i: Int): java.sql.Date = unsupported()
59+
def getTimestamp(i: Int): java.sql.Timestamp = unsupported()
60+
def getInstant(i: Int): java.time.Instant = unsupported()
61+
def getLocalDate(i: Int): java.time.LocalDate = unsupported()
62+
def getLocalDateTime(i: Int): java.time.LocalDateTime = unsupported()
63+
private def unsupported(): Nothing = throw new UnsupportedOperationException()
64+
}
65+
66+
object ArrowVectorReader {
67+
def apply(
68+
targetDataType: DataType,
69+
vector: FieldVector,
70+
timeZoneId: String): ArrowVectorReader = {
71+
val vectorDataType = ArrowUtils.fromArrowType(vector.getField.getType)
72+
if (!Cast.canUpCast(vectorDataType, targetDataType)) {
73+
throw new RuntimeException(
74+
s"Reading '$targetDataType' values from a ${vector.getClass} instance is not supported.")
75+
}
76+
vector match {
77+
case v: BitVector => new BitVectorReader(v)
78+
case v: TinyIntVector => new TinyIntVectorReader(v)
79+
case v: SmallIntVector => new SmallIntVectorReader(v)
80+
case v: IntVector => new IntVectorReader(v)
81+
case v: BigIntVector => new BigIntVectorReader(v)
82+
case v: Float4Vector => new Float4VectorReader(v)
83+
case v: Float8Vector => new Float8VectorReader(v)
84+
case v: DecimalVector => new DecimalVectorReader(v)
85+
case v: VarCharVector => new VarCharVectorReader(v)
86+
case v: VarBinaryVector => new VarBinaryVectorReader(v)
87+
case v: DurationVector => new DurationVectorReader(v)
88+
case v: IntervalYearVector => new IntervalYearVectorReader(v)
89+
case v: DateDayVector => new DateDayVectorReader(v, timeZoneId)
90+
case v: TimeStampMicroTZVector => new TimeStampMicroTZVectorReader(v)
91+
case v: TimeStampMicroVector => new TimeStampMicroVectorReader(v, timeZoneId)
92+
case _: NullVector => NullVectorReader
93+
case _ => throw new RuntimeException("Unsupported Vector Type: " + vector.getClass)
94+
}
95+
}
96+
}
97+
98+
private[arrow] object NullVectorReader extends ArrowVectorReader {
99+
override def isNull(i: Int): Boolean = true
100+
}
101+
102+
private[arrow] abstract class TypedArrowVectorReader[E <: FieldVector](val vector: E)
103+
extends ArrowVectorReader {
104+
override def isNull(i: Int): Boolean = vector.isNull(i)
105+
}
106+
107+
private[arrow] class BitVectorReader(v: BitVector) extends TypedArrowVectorReader[BitVector](v) {
108+
override def getBoolean(i: Int): Boolean = vector.get(i) > 0
109+
override def getString(i: Int): String = String.valueOf(getBoolean(i))
110+
}
111+
112+
private[arrow] class TinyIntVectorReader(v: TinyIntVector)
113+
extends TypedArrowVectorReader[TinyIntVector](v) {
114+
override def getByte(i: Int): Byte = vector.get(i)
115+
override def getShort(i: Int): Short = getByte(i)
116+
override def getInt(i: Int): Int = getByte(i)
117+
override def getLong(i: Int): Long = getByte(i)
118+
override def getFloat(i: Int): Float = getByte(i)
119+
override def getDouble(i: Int): Double = getByte(i)
120+
override def getString(i: Int): String = String.valueOf(getByte(i))
121+
override def getJavaDecimal(i: Int): JBigDecimal = JBigDecimal.valueOf(getByte(i))
122+
}
123+
124+
private[arrow] class SmallIntVectorReader(v: SmallIntVector)
125+
extends TypedArrowVectorReader[SmallIntVector](v) {
126+
override def getShort(i: Int): Short = vector.get(i)
127+
override def getInt(i: Int): Int = getShort(i)
128+
override def getLong(i: Int): Long = getShort(i)
129+
override def getFloat(i: Int): Float = getShort(i)
130+
override def getDouble(i: Int): Double = getShort(i)
131+
override def getString(i: Int): String = String.valueOf(getShort(i))
132+
override def getJavaDecimal(i: Int): JBigDecimal = JBigDecimal.valueOf(getShort(i))
133+
}
134+
135+
private[arrow] class IntVectorReader(v: IntVector) extends TypedArrowVectorReader[IntVector](v) {
136+
override def getInt(i: Int): Int = vector.get(i)
137+
override def getLong(i: Int): Long = getInt(i)
138+
override def getFloat(i: Int): Float = getInt(i)
139+
override def getDouble(i: Int): Double = getInt(i)
140+
override def getString(i: Int): String = String.valueOf(getInt(i))
141+
override def getJavaDecimal(i: Int): JBigDecimal = JBigDecimal.valueOf(getInt(i))
142+
}
143+
144+
private[arrow] class BigIntVectorReader(v: BigIntVector)
145+
extends TypedArrowVectorReader[BigIntVector](v) {
146+
override def getLong(i: Int): Long = vector.get(i)
147+
override def getFloat(i: Int): Float = getLong(i)
148+
override def getDouble(i: Int): Double = getLong(i)
149+
override def getString(i: Int): String = String.valueOf(getLong(i))
150+
override def getJavaDecimal(i: Int): JBigDecimal = JBigDecimal.valueOf(getLong(i))
151+
override def getTimestamp(i: Int): Timestamp = toJavaTimestamp(getLong(i) * MICROS_PER_SECOND)
152+
override def getInstant(i: Int): Instant = microsToInstant(getLong(i))
153+
}
154+
155+
private[arrow] class Float4VectorReader(v: Float4Vector)
156+
extends TypedArrowVectorReader[Float4Vector](v) {
157+
override def getFloat(i: Int): Float = vector.get(i)
158+
override def getDouble(i: Int): Double = getFloat(i)
159+
override def getString(i: Int): String = String.valueOf(getFloat(i))
160+
override def getJavaDecimal(i: Int): JBigDecimal = JBigDecimal.valueOf(getFloat(i))
161+
}
162+
163+
private[arrow] class Float8VectorReader(v: Float8Vector)
164+
extends TypedArrowVectorReader[Float8Vector](v) {
165+
override def getDouble(i: Int): Double = vector.get(i)
166+
override def getString(i: Int): String = String.valueOf(getDouble(i))
167+
override def getJavaDecimal(i: Int): JBigDecimal = JBigDecimal.valueOf(getDouble(i))
168+
}
169+
170+
private[arrow] class DecimalVectorReader(v: DecimalVector)
171+
extends TypedArrowVectorReader[DecimalVector](v) {
172+
override def getByte(i: Int): Byte = getJavaDecimal(i).byteValueExact()
173+
override def getShort(i: Int): Short = getJavaDecimal(i).shortValueExact()
174+
override def getInt(i: Int): Int = getJavaDecimal(i).intValueExact()
175+
override def getLong(i: Int): Long = getJavaDecimal(i).longValueExact()
176+
override def getFloat(i: Int): Float = getJavaDecimal(i).floatValue()
177+
override def getDouble(i: Int): Double = getJavaDecimal(i).doubleValue()
178+
override def getJavaDecimal(i: Int): JBigDecimal = vector.getObject(i)
179+
override def getString(i: Int): String = getJavaDecimal(i).toPlainString
180+
}
181+
182+
private[arrow] class VarCharVectorReader(v: VarCharVector)
183+
extends TypedArrowVectorReader[VarCharVector](v) {
184+
// This is currently a bit heavy on allocations:
185+
// - byte array created in VarCharVector.get
186+
// - CharBuffer created CharSetEncoder
187+
// - char array in String
188+
// By using direct buffers and reusing the char buffer
189+
// we could get rid of the first two allocations.
190+
override def getString(i: Int): String = Text.decode(vector.get(i))
191+
}
192+
193+
private[arrow] class VarBinaryVectorReader(v: VarBinaryVector)
194+
extends TypedArrowVectorReader[VarBinaryVector](v) {
195+
override def getBytes(i: Int): Array[Byte] = vector.get(i)
196+
override def getString(i: Int): String = StringUtils.getHexString(getBytes(i))
197+
}
198+
199+
private[arrow] class DurationVectorReader(v: DurationVector)
200+
extends TypedArrowVectorReader[DurationVector](v) {
201+
override def getDuration(i: Int): Duration = vector.getObject(i)
202+
override def getString(i: Int): String = {
203+
IntervalUtils.toDayTimeIntervalString(
204+
IntervalUtils.durationToMicros(getDuration(i)),
205+
ANSI_STYLE,
206+
DayTimeIntervalType.DEFAULT.startField,
207+
DayTimeIntervalType.DEFAULT.endField)
208+
}
209+
}
210+
211+
private[arrow] class IntervalYearVectorReader(v: IntervalYearVector)
212+
extends TypedArrowVectorReader[IntervalYearVector](v) {
213+
override def getPeriod(i: Int): Period = vector.getObject(i).normalized()
214+
override def getString(i: Int): String = {
215+
IntervalUtils.toYearMonthIntervalString(
216+
vector.get(i),
217+
ANSI_STYLE,
218+
YearMonthIntervalType.DEFAULT.startField,
219+
YearMonthIntervalType.DEFAULT.endField)
220+
}
221+
}
222+
223+
private[arrow] class DateDayVectorReader(v: DateDayVector, timeZoneId: String)
224+
extends TypedArrowVectorReader[DateDayVector](v) {
225+
private val zone = getZoneId(timeZoneId)
226+
private lazy val formatter = DateFormatter()
227+
private def days(i: Int): Int = vector.get(i)
228+
private def micros(i: Int): Long = daysToMicros(days(i), zone)
229+
override def getDate(i: Int): Date = toJavaDate(days(i))
230+
override def getLocalDate(i: Int): LocalDate = daysToLocalDate(days(i))
231+
override def getTimestamp(i: Int): Timestamp = toJavaTimestamp(micros(i))
232+
override def getInstant(i: Int): Instant = microsToInstant(micros(i))
233+
override def getLocalDateTime(i: Int): LocalDateTime = microsToLocalDateTime(micros(i))
234+
override def getString(i: Int): String = formatter.format(getLocalDate(i))
235+
}
236+
237+
private[arrow] class TimeStampMicroTZVectorReader(v: TimeStampMicroTZVector)
238+
extends TypedArrowVectorReader[TimeStampMicroTZVector](v) {
239+
private val zone = getZoneId(v.getTimeZone)
240+
private lazy val formatter = TimestampFormatter.getFractionFormatter(zone)
241+
private def utcMicros(i: Int): Long = convertTz(vector.get(i), zone, ZoneOffset.UTC)
242+
override def getLong(i: Int): Long = Math.floorDiv(vector.get(i), MICROS_PER_SECOND)
243+
override def getTimestamp(i: Int): Timestamp = toJavaTimestamp(vector.get(i))
244+
override def getInstant(i: Int): Instant = microsToInstant(vector.get(i))
245+
override def getLocalDateTime(i: Int): LocalDateTime = microsToLocalDateTime(utcMicros(i))
246+
override def getString(i: Int): String = formatter.format(vector.get(i))
247+
}
248+
249+
private[arrow] class TimeStampMicroVectorReader(v: TimeStampMicroVector, timeZoneId: String)
250+
extends TypedArrowVectorReader[TimeStampMicroVector](v) {
251+
private val zone = getZoneId(timeZoneId)
252+
private lazy val formatter = TimestampFormatter.getFractionFormatter(ZoneOffset.UTC)
253+
private def tzMicros(i: Int): Long = convertTz(utcMicros(i), ZoneOffset.UTC, zone)
254+
private def utcMicros(i: Int): Long = vector.get(i)
255+
override def getTimestamp(i: Int): Timestamp = toJavaTimestamp(tzMicros(i))
256+
override def getInstant(i: Int): Instant = microsToInstant(tzMicros(i))
257+
override def getLocalDateTime(i: Int): LocalDateTime = microsToLocalDateTime(utcMicros(i))
258+
override def getString(i: Int): String = formatter.format(utcMicros(i))
259+
}

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7923,15 +7923,19 @@ object functions {
79237923
def call_udf(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*)
79247924

79257925
/**
7926-
* Call a builtin or temp function.
7926+
* Call a SQL function.
79277927
*
79287928
* @param funcName
7929-
* function name
7929+
* function name that follows the SQL identifier syntax (can be quoted, can be qualified)
79307930
* @param cols
79317931
* the expression parameters of function
79327932
* @since 3.5.0
79337933
*/
79347934
@scala.annotation.varargs
7935-
def call_function(funcName: String, cols: Column*): Column = Column.fn(funcName, cols: _*)
7935+
def call_function(funcName: String, cols: Column*): Column = Column { builder =>
7936+
builder.getCallFunctionBuilder
7937+
.setFunctionName(funcName)
7938+
.addAllArguments(cols.map(_.expr).asJava)
7939+
}
79367940

79377941
}

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
570570
(col("id") / lit(10.0d)).as("b"),
571571
col("id"),
572572
lit("world").as("d"),
573-
// TODO SPARK-44449 make this int again when upcasting is in.
574-
(col("id") % 2).cast("double").as("a"))
573+
(col("id") % 2).as("a"))
575574

576575
private def validateMyTypeResult(result: Array[MyType]): Unit = {
577576
result.zipWithIndex.foreach { case (MyType(id, a, b), i) =>
@@ -818,11 +817,10 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
818817
}
819818

820819
test("toJSON") {
821-
// TODO SPARK-44449 make this int again when upcasting is in.
822820
val expected = Array(
823-
"""{"b":0.0,"id":0,"d":"world","a":0.0}""",
824-
"""{"b":0.1,"id":1,"d":"world","a":1.0}""",
825-
"""{"b":0.2,"id":2,"d":"world","a":0.0}""")
821+
"""{"b":0.0,"id":0,"d":"world","a":0}""",
822+
"""{"b":0.1,"id":1,"d":"world","a":1}""",
823+
"""{"b":0.2,"id":2,"d":"world","a":0}""")
826824
val result = spark
827825
.range(3)
828826
.select(generateMyTypeColumns: _*)

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
483483
val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1))
484484
.toDF("key", "seq", "value")
485485
val grouped = ds.groupBy($"value").as[String, (String, Int, Int)]
486-
// TODO SPARK-44449 make this string again when upcasting is in.
487-
val keys = grouped.keyAs[Int].keys.sort($"value")
488-
checkDataset(keys, 1, 2, 10, 20)
486+
val keys = grouped.keyAs[String].keys.sort($"value")
487+
checkDataset(keys, "1", "2", "10", "20")
489488
}
490489

491490
test("flatMapGroupsWithState") {

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,4 +239,14 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
239239
val output = runCommandsInShell(input)
240240
assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output)
241241
}
242+
243+
test("call_function") {
244+
val input = """
245+
|val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
246+
|spark.udf.register("simpleUDF", (v: Int) => v * v)
247+
|df.select($"id", call_function("simpleUDF", $"value")).collect()
248+
""".stripMargin
249+
val output = runCommandsInShell(input)
250+
assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output)
251+
}
242252
}

0 commit comments

Comments
 (0)