Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-29783][SQL] Support SQL Standard output style for interval type
  • Loading branch information
yaooqinn committed Nov 7, 2019
commit 88418e05f335db2ead7377824689774f32a449a7
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}

object Cast {
Expand Down Expand Up @@ -280,6 +280,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit

// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
case CalendarIntervalType if ansiEnabled => buildCast[CalendarInterval](_,
i => UTF8String.fromString(IntervalUtils.toSqlStandardString(i)))
case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes)
case DateType => buildCast[Int](_, d => UTF8String.fromString(dateFormatter.format(d)))
case TimestampType => buildCast[Long](_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
s"TIMESTAMP('${formatter.format(v)}')"
case (v: Array[Byte], BinaryType) => s"X'${DatatypeConverter.printHexBinary(v)}'"
case (v: CalendarInterval, CalendarIntervalType) if SQLConf.get.ansiEnabled =>
IntervalUtils.toSqlStandardString(v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be something that can be parsed, I think we need to output something like INTERVAL'1 year 2 days'

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support SQL standard input for intervals? I missed that, may cause user behavior change. But If we stay multi-unit style here, would there be conflicting between literals and other exprs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed back , thanks

case _ => value.toString
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.util

import java.math.BigDecimal
import java.util.concurrent.TimeUnit

import scala.util.control.NonFatal
Expand Down Expand Up @@ -388,4 +389,41 @@ object IntervalUtils {
def divide(interval: CalendarInterval, num: Double): CalendarInterval = {
fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
}

def toSqlStandardString(interval: CalendarInterval): String = {
val yearMonthPart = if (interval.months != 0) {
interval.months / 12 + "-" + math.abs(interval.months) % 12
} else {
""
}

val dayPart = if (interval.days!= 0) interval.days.toString else ""

val timePart = if (interval.microseconds != 0) {
val sb = new StringBuilder()
var rest = interval.microseconds
sb.append(rest / MICROS_PER_HOUR)
sb.append(':')
rest = math.abs(rest % MICROS_PER_HOUR)
val minutes = rest / MICROS_PER_MINUTE;
if (minutes < 10) {
sb.append(0)
}
sb.append(minutes)
sb.append(':')
rest %= MICROS_PER_MINUTE
val db = BigDecimal.valueOf(rest, 6)
if (db.compareTo(new BigDecimal(10)) < 0) {
sb.append(0)
}
val s = db.stripTrailingZeros().toPlainString
sb.append(s)
sb.toString()
} else {
""
}

val intervalList = Seq(yearMonthPart, dayPart, timePart).filter(_.nonEmpty)
if (intervalList.nonEmpty) intervalList.mkString(" ") else "0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, a single 0 is also SQL standard?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

postgres=# set IntervalStyle=sql_standard;
SET
postgres=# select interval '0';
 interval
----------
 0
(1 row)

postgres=# set IntervalStyle=postgres;
SET
postgres=# select interval '0';
 interval
----------
 00:00:00
(1 row)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ package org.apache.spark.sql.catalyst.util
import java.util.concurrent.TimeUnit

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{MICROS_PER_MILLIS, MICROS_PER_SECOND}
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.unsafe.types.CalendarInterval.{MICROS_PER_HOUR, MICROS_PER_MINUTE}

class IntervalUtilsSuite extends SparkFunSuite {

Expand Down Expand Up @@ -225,4 +226,27 @@ class IntervalUtilsSuite extends SparkFunSuite {
assert(e.getMessage.contains("overflow"))
}
}

test("to ansi sql standard string") {
val i1 = new CalendarInterval(0, 0, 0)
assert(IntervalUtils.toSqlStandardString(i1) === "0")
val i2 = new CalendarInterval(34, 0, 0)
assert(IntervalUtils.toSqlStandardString(i2) === "2-10")
val i3 = new CalendarInterval(-34, 0, 0)
assert(IntervalUtils.toSqlStandardString(i3) === "-2-10")
val i4 = new CalendarInterval(0, 31, 0)
assert(IntervalUtils.toSqlStandardString(i4) === "31")
val i5 = new CalendarInterval(0, -31, 0)
assert(IntervalUtils.toSqlStandardString(i5) === "-31")
val i6 = new CalendarInterval(0, 0, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123)
assert(IntervalUtils.toSqlStandardString(i6) === "3:13:00.000123")
val i7 = new CalendarInterval(0, 0, -3 * MICROS_PER_HOUR - 13 * MICROS_PER_MINUTE - 123)
assert(IntervalUtils.toSqlStandardString(i7) === "-3:13:00.000123")
val i8 = new CalendarInterval(-34, 31, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123)
assert(IntervalUtils.toSqlStandardString(i8) === "-2-10 31 3:13:00.000123")
val i9 = new CalendarInterval(0, 0, -3000 * MICROS_PER_HOUR)
assert(IntervalUtils.toSqlStandardString(i9) === "-3000:00:00")


}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, IntervalUtils, TimestampFormatter}
import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

/**
* Runs a query returning the result in Hive compatible form.
Expand Down Expand Up @@ -80,6 +81,7 @@ object HiveResult {
private lazy val zoneId = DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)
private lazy val dateFormatter = DateFormatter(zoneId)
private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId)
private lazy val ansiEnabled = SQLConf.get.ansiEnabled

/** Hive outputs fields of structs slightly differently than top level attributes. */
private def toHiveStructString(a: (Any, DataType)): String = a match {
Expand All @@ -97,6 +99,8 @@ object HiveResult {
case (null, _) => "null"
case (s: String, StringType) => "\"" + s + "\""
case (decimal, DecimalType()) => decimal.toString
case (interval: CalendarInterval, CalendarIntervalType) if ansiEnabled =>
IntervalUtils.toSqlStandardString(interval)
case (interval, CalendarIntervalType) => interval.toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}
Expand All @@ -120,6 +124,8 @@ object HiveResult {
DateTimeUtils.timestampToString(timestampFormatter, DateTimeUtils.fromJavaTimestamp(t))
case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8)
case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal)
case (interval: CalendarInterval, CalendarIntervalType) if ansiEnabled =>
IntervalUtils.toSqlStandardString(interval)
case (interval, CalendarIntervalType) => interval.toString
case (other, _ : UserDefinedType[_]) => other.toString
case (other, tpe) if primitiveTypes.contains(tpe) => other.toString
Expand Down
Loading