Skip to content

Commit e9d5625

Browse files
committed
ByteType and ShortType pushdown to parquet
1 parent d54d8b8 commit e9d5625

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith:
4242
private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
4343
case BooleanType =>
4444
(n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
45+
case ByteType =>
46+
(n: String, v: Any) => FilterApi.eq(
47+
intColumn(n),
48+
Option(v).map(b => b.asInstanceOf[java.lang.Byte].toInt.asInstanceOf[Integer]).orNull)
49+
case ShortType =>
50+
(n: String, v: Any) => FilterApi.eq(
51+
intColumn(n),
52+
Option(v).map(b => b.asInstanceOf[java.lang.Short].toInt.asInstanceOf[Integer]).orNull)
4553
case IntegerType =>
4654
(n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer])
4755
case LongType =>
@@ -69,6 +77,14 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith:
6977
private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
7078
case BooleanType =>
7179
(n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
80+
case ByteType =>
81+
(n: String, v: Any) => FilterApi.notEq(
82+
intColumn(n),
83+
Option(v).map(b => b.asInstanceOf[java.lang.Byte].toInt.asInstanceOf[Integer]).orNull)
84+
case ShortType =>
85+
(n: String, v: Any) => FilterApi.notEq(
86+
intColumn(n),
87+
Option(v).map(b => b.asInstanceOf[java.lang.Short].toInt.asInstanceOf[Integer]).orNull)
7288
case IntegerType =>
7389
(n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer])
7490
case LongType =>
@@ -93,6 +109,14 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith:
93109
}
94110

95111
private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
112+
case ByteType =>
113+
(n: String, v: Any) => FilterApi.lt(
114+
intColumn(n),
115+
Option(v).map(b => b.asInstanceOf[java.lang.Byte].toInt.asInstanceOf[Integer]).orNull)
116+
case ShortType =>
117+
(n: String, v: Any) => FilterApi.lt(
118+
intColumn(n),
119+
Option(v).map(b => b.asInstanceOf[java.lang.Short].toInt.asInstanceOf[Integer]).orNull)
96120
case IntegerType =>
97121
(n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer])
98122
case LongType =>
@@ -116,6 +140,14 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith:
116140
}
117141

118142
private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
143+
case ByteType =>
144+
(n: String, v: Any) => FilterApi.ltEq(
145+
intColumn(n),
146+
Option(v).map(b => b.asInstanceOf[java.lang.Byte].toInt.asInstanceOf[Integer]).orNull)
147+
case ShortType =>
148+
(n: String, v: Any) => FilterApi.ltEq(
149+
intColumn(n),
150+
Option(v).map(b => b.asInstanceOf[java.lang.Short].toInt.asInstanceOf[Integer]).orNull)
119151
case IntegerType =>
120152
(n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer])
121153
case LongType =>
@@ -139,6 +171,14 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith:
139171
}
140172

141173
private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
174+
case ByteType =>
175+
(n: String, v: Any) => FilterApi.gt(
176+
intColumn(n),
177+
Option(v).map(b => b.asInstanceOf[java.lang.Byte].toInt.asInstanceOf[Integer]).orNull)
178+
case ShortType =>
179+
(n: String, v: Any) => FilterApi.gt(
180+
intColumn(n),
181+
Option(v).map(b => b.asInstanceOf[java.lang.Short].toInt.asInstanceOf[Integer]).orNull)
142182
case IntegerType =>
143183
(n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer])
144184
case LongType =>
@@ -162,6 +202,14 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith:
162202
}
163203

164204
private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
205+
case ByteType =>
206+
(n: String, v: Any) => FilterApi.gtEq(
207+
intColumn(n),
208+
Option(v).map(b => b.asInstanceOf[java.lang.Byte].toInt.asInstanceOf[Integer]).orNull)
209+
case ShortType =>
210+
(n: String, v: Any) => FilterApi.gtEq(
211+
intColumn(n),
212+
Option(v).map(b => b.asInstanceOf[java.lang.Short].toInt.asInstanceOf[Integer]).orNull)
165213
case IntegerType =>
166214
(n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer])
167215
case LongType =>

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,62 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
178178
}
179179
}
180180

181+
test(s"filter pushdown - ByteType") {
182+
withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df =>
183+
assert(df.schema.head.dataType === ByteType)
184+
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
185+
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
186+
187+
checkFilterPredicate('_1 === 1.toByte, classOf[Eq[_]], 1)
188+
checkFilterPredicate('_1 <=> 1.toByte, classOf[Eq[_]], 1)
189+
checkFilterPredicate('_1 =!= 1.toByte, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
190+
191+
checkFilterPredicate('_1 < 2.toByte, classOf[Lt[_]], 1)
192+
checkFilterPredicate('_1 > 3.toByte, classOf[Gt[_]], 4)
193+
checkFilterPredicate('_1 <= 1.toByte, classOf[LtEq[_]], 1)
194+
checkFilterPredicate('_1 >= 4.toByte, classOf[GtEq[_]], 4)
195+
196+
checkFilterPredicate(Literal(1.toByte) === '_1, classOf[Eq[_]], 1)
197+
checkFilterPredicate(Literal(1.toByte) <=> '_1, classOf[Eq[_]], 1)
198+
checkFilterPredicate(Literal(2.toByte) > '_1, classOf[Lt[_]], 1)
199+
checkFilterPredicate(Literal(3.toByte) < '_1, classOf[Gt[_]], 4)
200+
checkFilterPredicate(Literal(1.toByte) >= '_1, classOf[LtEq[_]], 1)
201+
checkFilterPredicate(Literal(4.toByte) <= '_1, classOf[GtEq[_]], 4)
202+
203+
checkFilterPredicate(!('_1 < 4.toByte), classOf[GtEq[_]], 4)
204+
checkFilterPredicate('_1 < 2.toByte || '_1 > 3.toByte,
205+
classOf[Operators.Or], Seq(Row(1), Row(4)))
206+
}
207+
}
208+
209+
test(s"filter pushdown - ShortType") {
210+
withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df =>
211+
assert(df.schema.head.dataType === ShortType)
212+
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
213+
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
214+
215+
checkFilterPredicate('_1 === 1.toShort, classOf[Eq[_]], 1)
216+
checkFilterPredicate('_1 <=> 1.toShort, classOf[Eq[_]], 1)
217+
checkFilterPredicate('_1 =!= 1.toShort, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
218+
219+
checkFilterPredicate('_1 < 2.toShort, classOf[Lt[_]], 1)
220+
checkFilterPredicate('_1 > 3.toShort, classOf[Gt[_]], 4)
221+
checkFilterPredicate('_1 <= 1.toShort, classOf[LtEq[_]], 1)
222+
checkFilterPredicate('_1 >= 4.toShort, classOf[GtEq[_]], 4)
223+
224+
checkFilterPredicate(Literal(1.toShort) === '_1, classOf[Eq[_]], 1)
225+
checkFilterPredicate(Literal(1.toShort) <=> '_1, classOf[Eq[_]], 1)
226+
checkFilterPredicate(Literal(2.toShort) > '_1, classOf[Lt[_]], 1)
227+
checkFilterPredicate(Literal(3.toShort) < '_1, classOf[Gt[_]], 4)
228+
checkFilterPredicate(Literal(1.toShort) >= '_1, classOf[LtEq[_]], 1)
229+
checkFilterPredicate(Literal(4.toShort) <= '_1, classOf[GtEq[_]], 4)
230+
231+
checkFilterPredicate(!('_1 < 4.toShort), classOf[GtEq[_]], 4)
232+
checkFilterPredicate('_1 < 2.toShort || '_1 > 3.toShort,
233+
classOf[Operators.Or], Seq(Row(1), Row(4)))
234+
}
235+
}
236+
181237
test("filter pushdown - integer") {
182238
withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df =>
183239
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])

0 commit comments

Comments
 (0)