Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c78a347
WIP
hvanhovell Jul 25, 2023
0eb42c1
Fix conflicts
hvanhovell Jul 26, 2023
f648bef
Finish move of classes
hvanhovell Jul 26, 2023
be5b320
Fix legacy policy
hvanhovell Jul 26, 2023
c9fc20c
Remove ToJsonUtil
hvanhovell Jul 26, 2023
0b1b803
Add UDTUtilsImpl
hvanhovell Jul 26, 2023
783a1af
Fix stuff
hvanhovell Jul 26, 2023
4fa9124
Merge with master
hvanhovell Jul 26, 2023
4fc0e6d
Simplification & Fixes
hvanhovell Jul 26, 2023
65ea542
Disconnect Arrow encoders from Catalyst
hvanhovell Jul 26, 2023
b70413b
Disconnect Arrow encoders from Catalyst
hvanhovell Jul 26, 2023
67d1592
Fix proto functions
hvanhovell Jul 26, 2023
340304d
Fix proto functions
hvanhovell Jul 26, 2023
8ca1d2b
fix proto for real
hvanhovell Jul 26, 2023
4845456
Make tests compile!
hvanhovell Jul 26, 2023
42d8f0f
Capstone
hvanhovell Jul 26, 2023
69b8868
Fix compilation
hvanhovell Jul 26, 2023
1927189
Style
hvanhovell Jul 26, 2023
a4f3050
Fix formatting
hvanhovell Jul 27, 2023
ad78a8d
Merge branch 'SPARK-44538' into SPARK-41400-v1
hvanhovell Jul 27, 2023
9f10a8c
Merge remote-tracking branch 'apache/master' into SPARK-41400-v1
hvanhovell Jul 27, 2023
cc981b4
Integrate AnalyisException
hvanhovell Jul 27, 2023
bdd4346
Merge with master
hvanhovell Jul 27, 2023
2b8879d
Put back DateTimeUtils and enrich SparkDateTimeUtils.
hvanhovell Jul 27, 2023
d6a57a0
Merge remote-tracking branch 'apache/master' into SPARK-44538
hvanhovell Jul 27, 2023
e2b0dc4
Merge branch 'SPARK-44538' into SPARK-41400-v1
hvanhovell Jul 27, 2023
fcda03c
Put back string to/from bytebuffer
hvanhovell Jul 27, 2023
912f892
Undo change to RebaseDateTimeSuite
hvanhovell Jul 27, 2023
eeaa409
Undo change
hvanhovell Jul 27, 2023
30f3ce4
style...
hvanhovell Jul 28, 2023
6b53488
Merge branch 'SPARK-44538' into SPARK-41400-v1
hvanhovell Jul 28, 2023
470fed4
Merge remote-tracking branch 'apache/master' into SPARK-41400-v1
hvanhovell Jul 28, 2023
c81d847
Fix docs
hvanhovell Jul 28, 2023
064bc40
Merge remote-tracking branch 'apache/master' into SPARK-41400-v1
hvanhovell Jul 28, 2023
ace8fed
Weird Bug Fix
hvanhovell Jul 28, 2023
5c5aa60
Better fix
hvanhovell Jul 28, 2023
b7bc5c4
Fix MiMa
hvanhovell Jul 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Disconnect Arrow encoders from Catalyst
  • Loading branch information
hvanhovell committed Jul 26, 2023
commit 65ea54280d01c4133d4fa18c6ca6c7178579e9e0
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.errors.{ExecutionErrors, QueryCompilationErrors}
import org.apache.spark.sql.errors.{ExecutionErrors, CompilationErrors}
import org.apache.spark.sql.types.Decimal

/**
Expand Down Expand Up @@ -436,13 +436,13 @@ object ArrowDeserializers {
val key = toKey(field.getName)
val old = lookup.put(key, field)
if (old.isDefined) {
throw QueryCompilationErrors.ambiguousColumnOrFieldError(
throw CompilationErrors.ambiguousColumnOrFieldError(
field.getName :: Nil,
fields.count(f => toKey(f.getName) == key))
}
}
name => {
lookup.getOrElse(toKey(name), throw QueryCompilationErrors.columnNotFoundError(name))
lookup.getOrElse(toKey(name), throw CompilationErrors.columnNotFoundError(name))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffse
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector}
import org.apache.arrow.vector.util.Text

import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.util.{DateFormatter, IntervalUtils, StringUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkIntervalUtils, StringUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, Decimal, YearMonthIntervalType}
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, Decimal, UpCastRule, YearMonthIntervalType}
import org.apache.spark.sql.util.ArrowUtils

/**
Expand Down Expand Up @@ -69,7 +68,7 @@ object ArrowVectorReader {
vector: FieldVector,
timeZoneId: String): ArrowVectorReader = {
val vectorDataType = ArrowUtils.fromArrowType(vector.getField.getType)
if (!Cast.canUpCast(vectorDataType, targetDataType)) {
if (!UpCastRule.canUpCast(vectorDataType, targetDataType)) {
throw new RuntimeException(
s"Reading '$targetDataType' values from a ${vector.getClass} instance is not supported.")
}
Expand Down Expand Up @@ -200,8 +199,8 @@ private[arrow] class DurationVectorReader(v: DurationVector)
extends TypedArrowVectorReader[DurationVector](v) {
override def getDuration(i: Int): Duration = vector.getObject(i)
override def getString(i: Int): String = {
IntervalUtils.toDayTimeIntervalString(
IntervalUtils.durationToMicros(getDuration(i)),
SparkIntervalUtils.toDayTimeIntervalString(
SparkIntervalUtils.durationToMicros(getDuration(i)),
ANSI_STYLE,
DayTimeIntervalType.DEFAULT.startField,
DayTimeIntervalType.DEFAULT.endField)
Expand All @@ -212,7 +211,7 @@ private[arrow] class IntervalYearVectorReader(v: IntervalYearVector)
extends TypedArrowVectorReader[IntervalYearVector](v) {
override def getPeriod(i: Int): Period = vector.getObject(i).normalized()
override def getString(i: Int): String = {
IntervalUtils.toYearMonthIntervalString(
SparkIntervalUtils.toYearMonthIntervalString(
vector.get(i),
ANSI_STYLE,
YearMonthIntervalType.DEFAULT.startField,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,64 @@
*/
package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.util.DateTimeConstants.{DAYS_PER_WEEK, MICROS_PER_HOUR, MICROS_PER_MINUTE, MICROS_PER_SECOND, MONTHS_PER_YEAR, NANOS_PER_MICROS, NANOS_PER_SECOND}
import java.time.Duration

import scala.collection.mutable

import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle}
import org.apache.spark.sql.types.{DayTimeIntervalType => DT, YearMonthIntervalType => YM}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

trait SparkIntervalUtils {
protected val MAX_DAY: Long = Long.MaxValue / MICROS_PER_DAY
protected val MAX_HOUR: Long = Long.MaxValue / MICROS_PER_HOUR
protected val MAX_MINUTE: Long = Long.MaxValue / MICROS_PER_MINUTE
protected val MAX_SECOND: Long = Long.MaxValue / MICROS_PER_SECOND
protected val MIN_SECOND: Long = Long.MinValue / MICROS_PER_SECOND

// The amount of seconds that can cause overflow in the conversion to microseconds
private final val minDurationSeconds = Math.floorDiv(Long.MinValue, MICROS_PER_SECOND)

/**
* Converts this duration to the total length in microseconds.
* <p>
* If this duration is too large to fit in a [[Long]] microseconds, then an
* exception is thrown.
* <p>
* If this duration has greater than microsecond precision, then the conversion
* will drop any excess precision information as though the amount in nanoseconds
* was subject to integer division by one thousand.
*
* @return The total length of the duration in microseconds
* @throws ArithmeticException If numeric overflow occurs
*/
def durationToMicros(duration: Duration): Long = {
durationToMicros(duration, DT.SECOND)
}

def durationToMicros(duration: Duration, endField: Byte): Long = {
val seconds = duration.getSeconds
val micros = if (seconds == minDurationSeconds) {
val microsInSeconds = (minDurationSeconds + 1) * MICROS_PER_SECOND
val nanoAdjustment = duration.getNano
assert(0 <= nanoAdjustment && nanoAdjustment < NANOS_PER_SECOND,
"Duration.getNano() must return the adjustment to the seconds field " +
"in the range from 0 to 999999999 nanoseconds, inclusive.")
Math.addExact(microsInSeconds, (nanoAdjustment - NANOS_PER_SECOND) / NANOS_PER_MICROS)
} else {
val microsInSeconds = Math.multiplyExact(seconds, MICROS_PER_SECOND)
Math.addExact(microsInSeconds, duration.getNano / NANOS_PER_MICROS)
}

endField match {
case DT.DAY => micros - micros % MICROS_PER_DAY
case DT.HOUR => micros - micros % MICROS_PER_HOUR
case DT.MINUTE => micros - micros % MICROS_PER_MINUTE
case DT.SECOND => micros
}
}

/**
* Converts a string to [[CalendarInterval]] case-insensitively.
*
Expand Down Expand Up @@ -226,22 +280,176 @@ trait SparkIntervalUtils {
result
}

/**
* Converts an year-month interval as a number of months to its textual representation
* which conforms to the ANSI SQL standard.
*
* @param months The number of months, positive or negative
* @param style The style of textual representation of the interval
* @param startField The start field (YEAR or MONTH) which the interval comprises of.
* @param endField The end field (YEAR or MONTH) which the interval comprises of.
* @return Year-month interval string
*/
def toYearMonthIntervalString(
months: Int,
style: IntervalStyle,
startField: Byte,
endField: Byte): String = {
var sign = ""
var absMonths: Long = months
if (months < 0) {
sign = "-"
absMonths = -absMonths
}
val year = s"$sign${absMonths / MONTHS_PER_YEAR}"
val yearAndMonth = s"$year-${absMonths % MONTHS_PER_YEAR}"
style match {
case ANSI_STYLE =>
val formatBuilder = new StringBuilder("INTERVAL '")
if (startField == endField) {
startField match {
case YM.YEAR => formatBuilder.append(s"$year' YEAR")
case YM.MONTH => formatBuilder.append(s"$months' MONTH")
}
} else {
formatBuilder.append(s"$yearAndMonth' YEAR TO MONTH")
}
formatBuilder.toString
case HIVE_STYLE => s"$yearAndMonth"
}
}

/**
* Converts a day-time interval as a number of microseconds to its textual representation
* which conforms to the ANSI SQL standard.
*
* @param micros The number of microseconds, positive or negative
* @param style The style of textual representation of the interval
* @param startField The start field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of.
* @param endField The end field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of.
* @return Day-time interval string
*/
def toDayTimeIntervalString(
micros: Long,
style: IntervalStyle,
startField: Byte,
endField: Byte): String = {
var sign = ""
var rest = micros
// scalastyle:off caselocale
val from = DT.fieldToString(startField).toUpperCase
val to = DT.fieldToString(endField).toUpperCase
// scalastyle:on caselocale
val prefix = "INTERVAL '"
val postfix = s"' ${if (startField == endField) from else s"$from TO $to"}"

if (micros < 0) {
if (micros == Long.MinValue) {
// Especial handling of minimum `Long` value because negate op overflows `Long`.
// seconds = 106751991 * (24 * 60 * 60) + 4 * 60 * 60 + 54 = 9223372036854
// microseconds = -9223372036854000000L-775808 == Long.MinValue
val baseStr = "-106751991 04:00:54.775808000"
val minIntervalString = style match {
case ANSI_STYLE =>
val firstStr = startField match {
case DT.DAY => s"-$MAX_DAY"
case DT.HOUR => s"-$MAX_HOUR"
case DT.MINUTE => s"-$MAX_MINUTE"
case DT.SECOND => s"-$MAX_SECOND.775808"
}
val followingStr = if (startField == endField) {
""
} else {
val substrStart = startField match {
case DT.DAY => 10
case DT.HOUR => 13
case DT.MINUTE => 16
}
val substrEnd = endField match {
case DT.HOUR => 13
case DT.MINUTE => 16
case DT.SECOND => 26
}
baseStr.substring(substrStart, substrEnd)
}

s"$prefix$firstStr$followingStr$postfix"
case HIVE_STYLE => baseStr
}
return minIntervalString
} else {
sign = "-"
rest = -rest
}
}
val intervalString = style match {
case ANSI_STYLE =>
val formatBuilder = new mutable.StringBuilder(sign)
val formatArgs = new mutable.ArrayBuffer[Long]()
startField match {
case DT.DAY =>
formatBuilder.append(rest / MICROS_PER_DAY)
rest %= MICROS_PER_DAY
case DT.HOUR =>
formatBuilder.append("%02d")
formatArgs.append(rest / MICROS_PER_HOUR)
rest %= MICROS_PER_HOUR
case DT.MINUTE =>
formatBuilder.append("%02d")
formatArgs.append(rest / MICROS_PER_MINUTE)
rest %= MICROS_PER_MINUTE
case DT.SECOND =>
val leadZero = if (rest < 10 * MICROS_PER_SECOND) "0" else ""
formatBuilder.append(s"$leadZero" +
s"${java.math.BigDecimal.valueOf(rest, 6).stripTrailingZeros.toPlainString}")
}

if (startField < DT.HOUR && DT.HOUR <= endField) {
formatBuilder.append(" %02d")
formatArgs.append(rest / MICROS_PER_HOUR)
rest %= MICROS_PER_HOUR
}
if (startField < DT.MINUTE && DT.MINUTE <= endField) {
formatBuilder.append(":%02d")
formatArgs.append(rest / MICROS_PER_MINUTE)
rest %= MICROS_PER_MINUTE
}
if (startField < DT.SECOND && DT.SECOND <= endField) {
val leadZero = if (rest < 10 * MICROS_PER_SECOND) "0" else ""
formatBuilder.append(
s":$leadZero${java.math.BigDecimal.valueOf(rest, 6).stripTrailingZeros.toPlainString}")
}
s"$prefix${formatBuilder.toString.format(formatArgs.toSeq: _*)}$postfix"
case HIVE_STYLE =>
val secondsWithFraction = rest % MICROS_PER_MINUTE
rest /= MICROS_PER_MINUTE
val minutes = rest % MINUTES_PER_HOUR
rest /= MINUTES_PER_HOUR
val hours = rest % HOURS_PER_DAY
val days = rest / HOURS_PER_DAY
val seconds = secondsWithFraction / MICROS_PER_SECOND
val nanos = (secondsWithFraction % MICROS_PER_SECOND) * NANOS_PER_MICROS
f"$sign$days $hours%02d:$minutes%02d:$seconds%02d.$nanos%09d"
}
intervalString
}

protected def unitToUtf8(unit: String): UTF8String = {
UTF8String.fromString(unit)
}

protected val intervalStr = unitToUtf8("interval")
protected val intervalStr: UTF8String = unitToUtf8("interval")

protected val yearStr = unitToUtf8("year")
protected val monthStr = unitToUtf8("month")
protected val weekStr = unitToUtf8("week")
protected val dayStr = unitToUtf8("day")
protected val hourStr = unitToUtf8("hour")
protected val minuteStr = unitToUtf8("minute")
protected val secondStr = unitToUtf8("second")
protected val millisStr = unitToUtf8("millisecond")
protected val microsStr = unitToUtf8("microsecond")
protected val nanosStr = unitToUtf8("nanosecond")
protected val yearStr: UTF8String = unitToUtf8("year")
protected val monthStr: UTF8String = unitToUtf8("month")
protected val weekStr: UTF8String = unitToUtf8("week")
protected val dayStr: UTF8String = unitToUtf8("day")
protected val hourStr: UTF8String = unitToUtf8("hour")
protected val minuteStr: UTF8String = unitToUtf8("minute")
protected val secondStr: UTF8String = unitToUtf8("second")
protected val millisStr: UTF8String = unitToUtf8("millisecond")
protected val microsStr: UTF8String = unitToUtf8("microsecond")
protected val nanosStr: UTF8String = unitToUtf8("nanosecond")


private object ParseState extends Enumeration {
Expand All @@ -261,3 +469,9 @@ trait SparkIntervalUtils {
}

object SparkIntervalUtils extends SparkIntervalUtils

// The style of textual representation of intervals
object IntervalStringStyles extends Enumeration {
type IntervalStyle = Value
val ANSI_STYLE, HIVE_STYLE = Value
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.errors

import org.apache.spark.SparkException
import org.apache.spark.sql.internal.SqlApiConf

// TODO integrate this with QueryCompilationErrors.
// For this we need the AnalysisException work to land.
private[sql] trait CompilationErrors extends DataTypeErrorsBase {
def ambiguousColumnOrFieldError(
name: Seq[String], numMatches: Int): Throwable = {
new SparkException(
errorClass = "AMBIGUOUS_COLUMN_OR_FIELD",
messageParameters = Map(
"name" -> toSQLId(name),
"n" -> numMatches.toString),
cause = null)
}

def columnNotFoundError(colName: String): Throwable = {
new SparkException(
errorClass = "COLUMN_NOT_FOUND",
messageParameters = Map(
"colName" -> toSQLId(colName),
"caseSensitiveConfig" -> toSQLConf(SqlApiConf.CASE_SENSITIVE_KEY)),
cause = null)
}
}

object CompilationErrors extends CompilationErrors
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ private[sql] object SqlApiConf {
// Shared keys.
val ANSI_ENABLED_KEY: String = "spark.sql.ansi.enabled"
val LEGACY_TIME_PARSER_POLICY_KEY: String = "spark.sql.legacy.timeParserPolicy"
val CASE_SENSITIVE_KEY: String = "spark.sql.caseSensitive"

/**
* Defines a getter that returns the [[SqlApiConf]] within scope.
Expand Down
Loading