diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index ea2c9e9caf15..12d9cd4fb1d3 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -24,6 +24,8 @@ license: | ## Upgrading from Spark SQL 3.2 to 3.3 + - Since Spark 3.3, `DayTimeIntervalType` in Spark SQL is mapped to Arrow's `Duration` type in `ArrowWriter` and `ArrowColumnVector` developer APIs. Previously, `DayTimeIntervalType` was mapped to Arrow's `Interval` type which does not match with the types of other languages Spark SQL maps. For example, `DayTimeIntervalType` is mapped to `java.time.Duration` in Java. + - Since Spark 3.3, the functions `lpad` and `rpad` have been overloaded to support byte sequences. When the first argument is a byte sequence, the optional padding pattern must also be a byte sequence and the result is a BINARY value. The default padding pattern in this case is the zero byte. - Since Spark 3.3, Spark turns a non-nullable schema into nullable for API `DataFrameReader.schema(schema: StructType).json(jsonDataset: Dataset[String])` and `DataFrameReader.schema(schema: StructType).csv(csvDataset: Dataset[String])` when the schema is specified by the user and contains non-nullable fields. diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 85b0176d74d4..6f55db290a6f 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -21,8 +21,8 @@ from pyspark.rdd import _load_from_socket # type: ignore[attr-defined] from pyspark.sql.pandas.serializers import ArrowCollectSerializer -from pyspark.sql.types import IntegralType from pyspark.sql.types import ( + IntegralType, ByteType, ShortType, IntegerType, @@ -33,6 +33,7 @@ MapType, TimestampType, TimestampNTZType, + DayTimeIntervalType, StructType, DataType, ) @@ -85,6 +86,7 @@ def toPandas(self) -> "PandasDataFrameLike": import numpy as np import pandas as pd + from pandas.core.dtypes.common import is_timedelta64_dtype timezone = self.sql_ctx._conf.sessionLocalTimeZone() # type: ignore[attr-defined] @@ -225,7 +227,10 @@ def toPandas(self) -> "PandasDataFrameLike": else: series = pdf[column_name] - if t is not None: + # No need to cast for non-empty series for timedelta. The type is already correct. + should_check_timedelta = is_timedelta64_dtype(t) and len(pdf) == 0 + + if (t is not None and not is_timedelta64_dtype(t)) or should_check_timedelta: series = series.astype(t, copy=False) # `insert` API makes copy of data, we only do it for Series of duplicate column names. @@ -278,6 +283,8 @@ def _to_corrected_pandas_type(dt: DataType) -> Optional[Type]: return np.datetime64 elif type(dt) == TimestampNTZType: return np.datetime64 + elif type(dt) == DayTimeIntervalType: + return np.timedelta64 else: return None @@ -424,13 +431,14 @@ def _convert_from_pandas( list list of records """ + import pandas as pd from pyspark.sql import SparkSession assert isinstance(self, SparkSession) if timezone is not None: from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local - from pandas.core.dtypes.common import is_datetime64tz_dtype + from pandas.core.dtypes.common import is_datetime64tz_dtype, is_timedelta64_dtype copied = False if isinstance(schema, StructType): @@ -459,6 +467,19 @@ def _convert_from_pandas( copied = True pdf[column] = s + for column, series in pdf.iteritems(): + if is_timedelta64_dtype(series): + if not copied: + pdf = pdf.copy() + copied = True + # Explicitly set the timedelta as object so the output of numpy records can + # hold the timedelta instances as are. Otherwise, it converts to the internal + # numeric values. + ser = pdf[column] + pdf[column] = pd.Series( + ser.dt.to_pytimedelta(), index=ser.index, dtype="object", name=ser.name + ) + # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 5ad83001ac59..4ff011d1947c 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -35,6 +35,7 @@ DateType, TimestampType, TimestampNTZType, + DayTimeIntervalType, ArrayType, MapType, StructType, @@ -81,6 +82,8 @@ def to_arrow_type(dt: DataType) -> "pa.DataType": arrow_type = pa.timestamp("us", tz="UTC") elif type(dt) == TimestampNTZType: arrow_type = pa.timestamp("us", tz=None) + elif type(dt) == DayTimeIntervalType: + arrow_type = pa.duration("us") elif type(dt) == ArrayType: if type(dt.elementType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) @@ -153,6 +156,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da spark_type = TimestampNTZType() elif types.is_timestamp(at): spark_type = TimestampType() + elif types.is_duration(at): + spark_type = DayTimeIntervalType() elif types.is_list(at): if types.is_timestamp(at.value_type): raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index a3ce7c24f05b..0c690257c9c9 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -26,7 +26,7 @@ from pyspark import SparkContext, SparkConf from pyspark.sql import Row, SparkSession -from pyspark.sql.functions import rand, udf +from pyspark.sql.functions import rand, udf, assert_true, lit from pyspark.sql.types import ( StructType, StringType, @@ -241,6 +241,18 @@ def test_create_data_frame_to_pandas_timestamp_ntz(self): assert_frame_equal(origin, pdf) assert_frame_equal(pdf, pdf_arrow) + def test_create_data_frame_to_pandas_day_time_internal(self): + # SPARK-37279: Test DayTimeInterval in createDataFrame and toPandas + origin = pd.DataFrame({"a": [datetime.timedelta(microseconds=123)]}) + df = self.spark.createDataFrame(origin) + df.select( + assert_true(lit("INTERVAL '0 00:00:00.000123' DAY TO SECOND") == df.a.cast("string")) + ).collect() + + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + assert_frame_equal(origin, pdf) + assert_frame_equal(pdf, pdf_arrow) + def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index a1966e5f2121..75301edc8d46 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -37,6 +37,7 @@ TimestampType, TimestampNTZType, FloatType, + DayTimeIntervalType, ) from pyspark.sql.utils import AnalysisException, IllegalArgumentException from pyspark.testing.sqlutils import ( @@ -678,7 +679,7 @@ def test_cache(self): ) def _to_pandas(self): - from datetime import datetime, date + from datetime import datetime, date, timedelta schema = ( StructType() @@ -689,6 +690,7 @@ def _to_pandas(self): .add("dt", DateType()) .add("ts", TimestampType()) .add("ts_ntz", TimestampNTZType()) + .add("dt_interval", DayTimeIntervalType()) ) data = [ ( @@ -699,8 +701,9 @@ def _to_pandas(self): date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1), datetime(1969, 1, 1, 1, 1, 1), + timedelta(days=1), ), - (2, "foo", True, 5.0, None, None, None), + (2, "foo", True, 5.0, None, None, None, None), ( 3, "bar", @@ -709,6 +712,7 @@ def _to_pandas(self): date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3), datetime(2012, 3, 3, 3, 3, 3), + timedelta(hours=-1, milliseconds=421), ), ( 4, @@ -718,6 +722,7 @@ def _to_pandas(self): date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4), datetime(2100, 4, 4, 4, 4, 4), + timedelta(microseconds=123), ), ] df = self.spark.createDataFrame(data, schema) @@ -736,6 +741,7 @@ def test_to_pandas(self): self.assertEqual(types[4], np.object) # datetime.date self.assertEqual(types[5], "datetime64[ns]") self.assertEqual(types[6], "datetime64[ns]") + self.assertEqual(types[7], "timedelta64[ns]") @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_with_duplicated_column_names(self): @@ -808,7 +814,8 @@ def test_to_pandas_from_empty_dataframe(self): CAST(1 AS BOOLEAN) AS boolean, CAST('foo' AS STRING) AS string, CAST('2019-01-01' AS TIMESTAMP) AS timestamp, - CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz + CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz, + INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval """ dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes @@ -830,7 +837,8 @@ def test_to_pandas_from_null_dataframe(self): CAST(NULL AS BOOLEAN) AS boolean, CAST(NULL AS STRING) AS string, CAST(NULL AS TIMESTAMP) AS timestamp, - CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz + CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz, + INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval """ pdf = self.spark.sql(sql).toPandas() types = pdf.dtypes @@ -844,6 +852,7 @@ def test_to_pandas_from_null_dataframe(self): self.assertEqual(types[7], np.object) self.assertTrue(np.can_cast(np.datetime64, types[8])) self.assertTrue(np.can_cast(np.datetime64, types[9])) + self.assertTrue(np.can_cast(np.timedelta64, types[10])) @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_from_mixed_dataframe(self): @@ -861,9 +870,10 @@ def test_to_pandas_from_mixed_dataframe(self): CAST(col7 AS BOOLEAN) AS boolean, CAST(col8 AS STRING) AS string, timestamp_seconds(col9) AS timestamp, - timestamp_seconds(col10) AS timestamp_ntz - FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1), - (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL) + timestamp_seconds(col10) AS timestamp_ntz, + INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval + FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), + (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL) """ pdf_with_some_nulls = self.spark.sql(sql).toPandas() pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas() @@ -937,6 +947,15 @@ def test_create_dataframe_from_pandas_with_dst(self): os.environ["TZ"] = orig_env_tz time.tzset() + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore + def test_create_dataframe_from_pandas_with_day_time_interval(self): + # SPARK-37277: Test DayTimeIntervalType in createDataFrame without Arrow. + import pandas as pd + from datetime import timedelta + + df = self.spark.createDataFrame(pd.DataFrame({"a": [timedelta(microseconds=123)]})) + self.assertEqual(df.toPandas().a.iloc[0], timedelta(microseconds=123)) + def test_repr_behaviors(self): import re diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py index 49a5ea6e2296..be80d7a56260 100644 --- a/python/pyspark/sql/tests/test_pandas_udf.py +++ b/python/pyspark/sql/tests/test_pandas_udf.py @@ -19,8 +19,8 @@ import datetime from typing import cast -from pyspark.sql.functions import udf, pandas_udf, PandasUDFType -from pyspark.sql.types import DoubleType, StructType, StructField, LongType +from pyspark.sql.functions import udf, pandas_udf, PandasUDFType, assert_true, lit +from pyspark.sql.types import DoubleType, StructType, StructField, LongType, DayTimeIntervalType from pyspark.sql.utils import ParseException, PythonException from pyspark.rdd import PythonEvalType from pyspark.testing.sqlutils import ( @@ -272,6 +272,25 @@ def noop(s): self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz") self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0)) + def test_pandas_udf_day_time_interval_type(self): + # SPARK-37277: Test DayTimeIntervalType in pandas UDF + import pandas as pd + + @pandas_udf(DayTimeIntervalType(DayTimeIntervalType.DAY, DayTimeIntervalType.SECOND)) + def noop(s: pd.Series) -> pd.Series: + assert s.iloc[0] == datetime.timedelta(microseconds=123) + return s + + df = self.spark.createDataFrame( + [(datetime.timedelta(microseconds=123),)], schema="td interval day to second" + ).select(noop("td").alias("td")) + + df.select( + assert_true(lit("INTERVAL '0 00:00:00.000123' DAY TO SECOND") == df.td.cast("string")) + ).collect() + self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second") + self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123)) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_udf import * # noqa: F401 diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 08137018ae6b..9aee1050370d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -19,19 +19,14 @@ import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; -import org.apache.arrow.vector.holders.NullableIntervalDayHolder; import org.apache.arrow.vector.holders.NullableVarCharHolder; import org.apache.spark.sql.util.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY; -import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS; - /** - * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not - * supported. + * A column vector backed by Apache Arrow. */ public final class ArrowColumnVector extends ColumnVector { @@ -180,8 +175,8 @@ public ArrowColumnVector(ValueVector vector) { accessor = new NullAccessor((NullVector) vector); } else if (vector instanceof IntervalYearVector) { accessor = new IntervalYearAccessor((IntervalYearVector) vector); - } else if (vector instanceof IntervalDayVector) { - accessor = new IntervalDayAccessor((IntervalDayVector) vector); + } else if (vector instanceof DurationVector) { + accessor = new DurationAccessor((DurationVector) vector); } else { throw new UnsupportedOperationException(); } @@ -549,21 +544,18 @@ int getInt(int rowId) { } } - private static class IntervalDayAccessor extends ArrowVectorAccessor { + private static class DurationAccessor extends ArrowVectorAccessor { - private final IntervalDayVector accessor; - private final NullableIntervalDayHolder intervalDayHolder = new NullableIntervalDayHolder(); + private final DurationVector accessor; - IntervalDayAccessor(IntervalDayVector vector) { + DurationAccessor(DurationVector vector) { super(vector); this.accessor = vector; } @Override - long getLong(int rowId) { - accessor.get(rowId, intervalDayHolder); - return Math.addExact(Math.multiplyExact(intervalDayHolder.days, MICROS_PER_DAY), - intervalDayHolder.milliseconds * MICROS_PER_MILLIS); + final long getLong(int rowId) { + return DurationVector.get(accessor.getDataBuffer(), rowId); } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 4065d2354c8f..4254c045ca6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -57,7 +57,7 @@ private[sql] object ArrowUtils { new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) case NullType => ArrowType.Null.INSTANCE case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) - case _: DayTimeIntervalType => new ArrowType.Interval(IntervalUnit.DAY_TIME) + case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) case _ => throw QueryExecutionErrors.unsupportedDataTypeError(dt.catalogString) } @@ -81,7 +81,7 @@ private[sql] object ArrowUtils { case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case ArrowType.Null.INSTANCE => NullType case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType() - case di: ArrowType.Interval if di.getUnit == IntervalUnit.DAY_TIME => DayTimeIntervalType() + case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() case _ => throw QueryExecutionErrors.unsupportedDataTypeError(dt.toString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index c216d92aa432..7abca5f0e332 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -24,7 +24,6 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters -import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_MILLIS} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils @@ -77,7 +76,7 @@ object ArrowWriter { new StructWriter(vector, children.toArray) case (NullType, vector: NullVector) => new NullWriter(vector) case (_: YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector) - case (_: DayTimeIntervalType, vector: IntervalDayVector) => new IntervalDayWriter(vector) + case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) case (dt, _) => throw QueryExecutionErrors.unsupportedDataTypeError(dt.catalogString) } @@ -422,16 +421,13 @@ private[arrow] class IntervalYearWriter(val valueVector: IntervalYearVector) } } -private[arrow] class IntervalDayWriter(val valueVector: IntervalDayVector) +private[arrow] class DurationWriter(val valueVector: DurationVector) extends ArrowFieldWriter { override def setNull(): Unit = { valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - val totalMicroseconds = input.getLong(ordinal) - val days = totalMicroseconds / MICROS_PER_DAY - val millis = (totalMicroseconds % MICROS_PER_DAY) / MICROS_PER_MILLIS - valueVector.set(count, days.toInt, millis.toInt) + valueVector.set(count, input.getLong(ordinal)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index f980a84727ea..a88f423ae01f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.arrow -import org.apache.arrow.vector.IntervalDayVector - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util._ @@ -90,30 +88,6 @@ class ArrowWriterSuite extends SparkFunSuite { Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L), (Long.MinValue + 808L)))) } - test("long overflow for DayTimeIntervalType") - { - val schema = new StructType().add("value", DayTimeIntervalType(), nullable = true) - val writer = ArrowWriter.create(schema, null) - val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val valueVector = writer.root.getFieldVectors().get(0).asInstanceOf[IntervalDayVector] - - valueVector.set(0, 106751992, 0) - valueVector.set(1, 106751991, Int.MaxValue) - - // first long overflow for test Math.multiplyExact() - val msg = intercept[java.lang.ArithmeticException] { - reader.getLong(0) - }.getMessage - assert(msg.equals("long overflow")) - - // second long overflow for test Math.addExact() - val msg1 = intercept[java.lang.ArithmeticException] { - reader.getLong(1) - }.getMessage - assert(msg1.equals("long overflow")) - writer.root.close() - } - test("get multiple") { def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { val avroDatatype = dt match {