Skip to content

Commit 0a777cc

Browse files
committed
percentile
1 parent 24d892f commit 0a777cc

File tree

5 files changed

+240
-3
lines changed

5 files changed

+240
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ object FunctionRegistry {
250250
expression[Average]("mean"),
251251
expression[Min]("min"),
252252
expression[Skewness]("skewness"),
253+
expression[PercentileApprox]("percentile_approx"),
253254
expression[StddevSamp]("std"),
254255
expression[StddevSamp]("stddev"),
255256
expression[StddevPop]("stddev_pop"),
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import org.apache.spark.sql.AnalysisException
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.aggregate.PercentileApprox.{PercentileBuffer, PercentileUDT}
25+
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, QuantileSummaries, TypeUtils}
26+
import org.apache.spark.sql.catalyst.util.QuantileSummaries.defaultCompressThreshold
27+
import org.apache.spark.sql.types._
28+
29+
@ExpressionDescription(
30+
usage = """_FUNC_(col, p [, B]) - Returns an approximate pth percentile of a numeric column in the
31+
group. The B parameter (default: 1000) controls approximation accuracy at the cost of
32+
memory. Higher values yield better approximations.
33+
_FUNC_(col, array(p1 [, p2]...) [, B]) - Same as above, but accepts and returns an array of
34+
percentile values instead of a single one.
35+
""")
36+
case class PercentileApprox(
37+
child: Expression,
38+
percentiles: Array[Double],
39+
B: Int,
40+
dataType: DataType,
41+
mutableAggBufferOffset: Int = 0,
42+
inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[PercentileBuffer] {
43+
44+
private val relativeError = Math.max(1.0d / B, 0.001)
45+
46+
// "_FUNC_(col, p, B) / _FUNC_(col, array(p1, ...), B)"
47+
def this(child: Expression, percentiles: Expression, B: Expression) = {
48+
this(
49+
child = child,
50+
PercentileApprox.validatePercentiles(percentiles),
51+
PercentileApprox.validateB(B),
52+
PercentileApprox.getResultDataType(percentiles))
53+
}
54+
55+
// "_FUNC_(col, p) / _FUNC_(col, array(p1, ...))"
56+
def this(child: Expression, percentile: Expression) = {
57+
this(child, percentile, Literal(1000))
58+
}
59+
60+
61+
override def createAggregationBuffer(): PercentileBuffer = new PercentileBuffer(relativeError)
62+
63+
override def update(buffer: PercentileBuffer, input: InternalRow): Unit = {
64+
buffer.update(child.eval(input).asInstanceOf[Double])
65+
}
66+
67+
override def merge(buffer: PercentileBuffer, input: PercentileBuffer): Unit = {
68+
buffer.merge(input)
69+
}
70+
71+
override def eval(buffer: PercentileBuffer): Any = {
72+
val result = buffer.query(percentiles)
73+
dataType match {
74+
case _ if result.isEmpty => null
75+
case _: DoubleType => result(0)
76+
case _: ArrayType => new GenericArrayData(result)
77+
}
78+
}
79+
80+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): PercentileApprox =
81+
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
82+
83+
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): PercentileApprox =
84+
copy(inputAggBufferOffset = newInputAggBufferOffset)
85+
86+
override def children: Seq[Expression] = Seq(child)
87+
88+
// Returns null for empty inputs
89+
override def nullable: Boolean = true
90+
91+
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
92+
93+
private def childSql: String = dataType match {
94+
case _: DoubleType => s"${child.sql}, ${percentiles(0)}, $B"
95+
case _: ArrayType => s"${child.sql}, array(${percentiles.mkString(", ")}), $B"
96+
}
97+
98+
override def prettyName: String = "percentile_approx"
99+
100+
override def sql: String = s"$prettyName($childSql)"
101+
102+
override def sql(isDistinct: Boolean): String = {
103+
if (isDistinct) s"$prettyName(DISTINCT $childSql)" else s"$prettyName($childSql)"
104+
}
105+
106+
override def aggregationBufferClass: Class[PercentileBuffer] = classOf[PercentileBuffer]
107+
108+
override def serialize(buffer: PercentileBuffer): Array[Byte] = buffer.toBytes
109+
110+
override def deserialize(bytes: Array[Byte]): PercentileBuffer = {
111+
val row = new UnsafeRow(1)
112+
row.pointTo(bytes, bytes.length)
113+
new PercentileBuffer(PercentileApprox.fieldSerializer.fromRow(row))
114+
}
115+
}
116+
117+
object PercentileApprox {
118+
119+
/** Validates the B expression and extracts its value. */
120+
private def validateB(B: Expression): Int = B match {
121+
case Literal(i: Int, IntegerType) if i > 0 => i
122+
case _ => throw new AnalysisException("The B argument must be a positive integer literal")
123+
}
124+
125+
/** Validates the percentile(s) expression and extracts the percentile(s). */
126+
private def validatePercentiles(percentileExpression: Expression): Array[Double] = {
127+
val percentiles = (percentileExpression.dataType, percentileExpression.eval()) match {
128+
case (num: NumericType, v) => Array(TypeUtils.getNumeric(num).toDouble(v))
129+
case (ArrayType(elementType: NumericType, false), v: ArrayData) if v.numElements() > 0 =>
130+
v.array.map(TypeUtils.getNumeric(elementType).toDouble)
131+
case _ =>
132+
throw new AnalysisException("The percentile(s) must be a double or double array literal.")
133+
}
134+
135+
percentiles.foreach { percentile =>
136+
if (percentile < 0D || percentile > 1D) {
137+
throw new AnalysisException("the percentile(s) must be >= 0 and <= 1.0")
138+
}
139+
}
140+
percentiles
141+
}
142+
143+
def getResultDataType(percentileExpression: Expression): DataType = {
144+
percentileExpression.dataType match {
145+
case num: NumericType => DoubleType
146+
case _ => ArrayType(DoubleType, containsNull = false)
147+
}
148+
}
149+
150+
// serializer for QuantileSummaries
151+
private val fieldSerializer = ExpressionEncoder[QuantileSummaries].resolveAndBind()
152+
153+
class PercentileBuffer(private var summary: QuantileSummaries) {
154+
155+
def this(relativeError: Double) = {
156+
this(new QuantileSummaries(defaultCompressThreshold, relativeError))
157+
}
158+
159+
def merge(input: PercentileBuffer): Unit = summary = summary.merge(input.summary)
160+
161+
def update(input: Double): Unit = summary = summary.insert(input)
162+
163+
def query(percentiles: Array[Double]): Array[Double] = {
164+
if (summary.count == 0 || percentiles.length == 0) {
165+
Array.empty[Double]
166+
} else {
167+
val result = new Array[Double](percentiles.length)
168+
var i = 0
169+
while (i < percentiles.length) {
170+
result(i) = summary.query(percentiles(i))
171+
i += 1
172+
}
173+
result
174+
}
175+
}
176+
177+
def toBytes: Array[Byte] = {
178+
summary = summary.compress()
179+
fieldSerializer.toRow(summary).asInstanceOf[UnsafeRow].getBytes
180+
}
181+
}
182+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats
4141
* @param count the count of all the elements *inserted in the sampled buffer*
4242
* (excluding the head buffer)
4343
*/
44-
class QuantileSummaries(
44+
case class QuantileSummaries(
4545
val compressThreshold: Int,
4646
val relativeError: Double,
4747
val sampled: Array[Stats] = Array.empty,

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ private[sql] class HiveSessionCatalog(
237237
private val hiveFunctions = Seq(
238238
"hash",
239239
"histogram_numeric",
240-
"percentile",
241-
"percentile_approx"
240+
"percentile"
242241
)
243242
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,61 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
728728
Row(0, null, 1, 1, null, 0) :: Nil)
729729
}
730730

731+
test("percentile_approx") {
732+
val table = "percentile_approx_nums"
733+
(1 to 1000).toDF("num").createOrReplaceTempView(table)
734+
735+
// with default B
736+
checkAnswer(
737+
spark.sql(
738+
s"""
739+
|SELECT
740+
| percentile_approx(num, 0.1),
741+
| percentile_approx(num, 0.2d),
742+
| percentile_approx(num, array(0.3)),
743+
| percentile_approx(num, array(0.4d)),
744+
| percentile_approx(num, array(0.5, 0.6, 0.7)),
745+
| percentile_approx(num, array(0.8d, 0.9d, 1.0d))
746+
|FROM $table
747+
""".stripMargin),
748+
Row(100.0, 200.0, Seq(300.0), Seq(400.0),
749+
Seq(500.0, 600.0, 700.0),
750+
Seq(800.0, 900.0, 1000.0)) :: Nil) // results are stable
751+
752+
// with B specified
753+
checkAnswer(
754+
spark.sql(
755+
s"""
756+
|SELECT
757+
| percentile_approx(num, 0.1, 10),
758+
| percentile_approx(num, 0.2d, 10),
759+
| percentile_approx(num, array(0.3), 10),
760+
| percentile_approx(num, array(0.4d), 10),
761+
| percentile_approx(num, array(0.5, 0.6, 0.7), 10),
762+
| percentile_approx(num, array(0.8d, 0.9d, 1.0d), 10)
763+
|FROM $table
764+
""".stripMargin),
765+
Row(1.0, 156.0, Seq(296.0), Seq(413.0),
766+
Seq(413.0, 510.0, 659.0),
767+
Seq(715.0, 1000.0, 1000.0)) :: Nil) // results are stable
768+
}
769+
770+
test("percentile_approx with empty inputs") {
771+
val table = "percentile_approx_nums"
772+
(1 to 1000).toDF("num").createOrReplaceTempView(table)
773+
774+
checkAnswer(
775+
spark.sql(
776+
s"""
777+
|SELECT
778+
| percentile_approx(num, 0.5),
779+
| percentile_approx(num, array(0.5d))
780+
|FROM $table
781+
|WHERE num > 2000
782+
""".stripMargin),
783+
Row(null, null) :: Nil)
784+
}
785+
731786
test("pearson correlation") {
732787
val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
733788
val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)

0 commit comments

Comments
 (0)