|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.sql.catalyst.expressions.aggregate |
| 19 | + |
| 20 | +import org.apache.spark.sql.AnalysisException |
| 21 | +import org.apache.spark.sql.catalyst.InternalRow |
| 22 | +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder |
| 23 | +import org.apache.spark.sql.catalyst.expressions._ |
| 24 | +import org.apache.spark.sql.catalyst.expressions.aggregate.PercentileApprox.{PercentileBuffer, PercentileUDT} |
| 25 | +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, QuantileSummaries, TypeUtils} |
| 26 | +import org.apache.spark.sql.catalyst.util.QuantileSummaries.defaultCompressThreshold |
| 27 | +import org.apache.spark.sql.types._ |
| 28 | + |
| 29 | +@ExpressionDescription( |
| 30 | + usage = """_FUNC_(col, p [, B]) - Returns an approximate pth percentile of a numeric column in the |
| 31 | + group. The B parameter (default: 1000) controls approximation accuracy at the cost of |
| 32 | + memory. Higher values yield better approximations. |
| 33 | + _FUNC_(col, array(p1 [, p2]...) [, B]) - Same as above, but accepts and returns an array of |
| 34 | + percentile values instead of a single one. |
| 35 | + """) |
| 36 | +case class PercentileApprox( |
| 37 | + child: Expression, |
| 38 | + percentiles: Array[Double], |
| 39 | + B: Int, |
| 40 | + dataType: DataType, |
| 41 | + mutableAggBufferOffset: Int = 0, |
| 42 | + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[PercentileBuffer] { |
| 43 | + |
| 44 | + private val relativeError = Math.max(1.0d / B, 0.001) |
| 45 | + |
| 46 | + // "_FUNC_(col, p, B) / _FUNC_(col, array(p1, ...), B)" |
| 47 | + def this(child: Expression, percentiles: Expression, B: Expression) = { |
| 48 | + this( |
| 49 | + child = child, |
| 50 | + PercentileApprox.validatePercentiles(percentiles), |
| 51 | + PercentileApprox.validateB(B), |
| 52 | + PercentileApprox.getResultDataType(percentiles)) |
| 53 | + } |
| 54 | + |
| 55 | + // "_FUNC_(col, p) / _FUNC_(col, array(p1, ...))" |
| 56 | + def this(child: Expression, percentile: Expression) = { |
| 57 | + this(child, percentile, Literal(1000)) |
| 58 | + } |
| 59 | + |
| 60 | + |
| 61 | + override def createAggregationBuffer(): PercentileBuffer = new PercentileBuffer(relativeError) |
| 62 | + |
| 63 | + override def update(buffer: PercentileBuffer, input: InternalRow): Unit = { |
| 64 | + buffer.update(child.eval(input).asInstanceOf[Double]) |
| 65 | + } |
| 66 | + |
| 67 | + override def merge(buffer: PercentileBuffer, input: PercentileBuffer): Unit = { |
| 68 | + buffer.merge(input) |
| 69 | + } |
| 70 | + |
| 71 | + override def eval(buffer: PercentileBuffer): Any = { |
| 72 | + val result = buffer.query(percentiles) |
| 73 | + dataType match { |
| 74 | + case _ if result.isEmpty => null |
| 75 | + case _: DoubleType => result(0) |
| 76 | + case _: ArrayType => new GenericArrayData(result) |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): PercentileApprox = |
| 81 | + copy(mutableAggBufferOffset = newMutableAggBufferOffset) |
| 82 | + |
| 83 | + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): PercentileApprox = |
| 84 | + copy(inputAggBufferOffset = newInputAggBufferOffset) |
| 85 | + |
| 86 | + override def children: Seq[Expression] = Seq(child) |
| 87 | + |
| 88 | + // Returns null for empty inputs |
| 89 | + override def nullable: Boolean = true |
| 90 | + |
| 91 | + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) |
| 92 | + |
| 93 | + private def childSql: String = dataType match { |
| 94 | + case _: DoubleType => s"${child.sql}, ${percentiles(0)}, $B" |
| 95 | + case _: ArrayType => s"${child.sql}, array(${percentiles.mkString(", ")}), $B" |
| 96 | + } |
| 97 | + |
| 98 | + override def prettyName: String = "percentile_approx" |
| 99 | + |
| 100 | + override def sql: String = s"$prettyName($childSql)" |
| 101 | + |
| 102 | + override def sql(isDistinct: Boolean): String = { |
| 103 | + if (isDistinct) s"$prettyName(DISTINCT $childSql)" else s"$prettyName($childSql)" |
| 104 | + } |
| 105 | + |
| 106 | + override def aggregationBufferClass: Class[PercentileBuffer] = classOf[PercentileBuffer] |
| 107 | + |
| 108 | + override def serialize(buffer: PercentileBuffer): Array[Byte] = buffer.toBytes |
| 109 | + |
| 110 | + override def deserialize(bytes: Array[Byte]): PercentileBuffer = { |
| 111 | + val row = new UnsafeRow(1) |
| 112 | + row.pointTo(bytes, bytes.length) |
| 113 | + new PercentileBuffer(PercentileApprox.fieldSerializer.fromRow(row)) |
| 114 | + } |
| 115 | +} |
| 116 | + |
| 117 | +object PercentileApprox { |
| 118 | + |
| 119 | + /** Validates the B expression and extracts its value. */ |
| 120 | + private def validateB(B: Expression): Int = B match { |
| 121 | + case Literal(i: Int, IntegerType) if i > 0 => i |
| 122 | + case _ => throw new AnalysisException("The B argument must be a positive integer literal") |
| 123 | + } |
| 124 | + |
| 125 | + /** Validates the percentile(s) expression and extracts the percentile(s). */ |
| 126 | + private def validatePercentiles(percentileExpression: Expression): Array[Double] = { |
| 127 | + val percentiles = (percentileExpression.dataType, percentileExpression.eval()) match { |
| 128 | + case (num: NumericType, v) => Array(TypeUtils.getNumeric(num).toDouble(v)) |
| 129 | + case (ArrayType(elementType: NumericType, false), v: ArrayData) if v.numElements() > 0 => |
| 130 | + v.array.map(TypeUtils.getNumeric(elementType).toDouble) |
| 131 | + case _ => |
| 132 | + throw new AnalysisException("The percentile(s) must be a double or double array literal.") |
| 133 | + } |
| 134 | + |
| 135 | + percentiles.foreach { percentile => |
| 136 | + if (percentile < 0D || percentile > 1D) { |
| 137 | + throw new AnalysisException("the percentile(s) must be >= 0 and <= 1.0") |
| 138 | + } |
| 139 | + } |
| 140 | + percentiles |
| 141 | + } |
| 142 | + |
| 143 | + def getResultDataType(percentileExpression: Expression): DataType = { |
| 144 | + percentileExpression.dataType match { |
| 145 | + case num: NumericType => DoubleType |
| 146 | + case _ => ArrayType(DoubleType, containsNull = false) |
| 147 | + } |
| 148 | + } |
| 149 | + |
| 150 | + // serializer for QuantileSummaries |
| 151 | + private val fieldSerializer = ExpressionEncoder[QuantileSummaries].resolveAndBind() |
| 152 | + |
| 153 | + class PercentileBuffer(private var summary: QuantileSummaries) { |
| 154 | + |
| 155 | + def this(relativeError: Double) = { |
| 156 | + this(new QuantileSummaries(defaultCompressThreshold, relativeError)) |
| 157 | + } |
| 158 | + |
| 159 | + def merge(input: PercentileBuffer): Unit = summary = summary.merge(input.summary) |
| 160 | + |
| 161 | + def update(input: Double): Unit = summary = summary.insert(input) |
| 162 | + |
| 163 | + def query(percentiles: Array[Double]): Array[Double] = { |
| 164 | + if (summary.count == 0 || percentiles.length == 0) { |
| 165 | + Array.empty[Double] |
| 166 | + } else { |
| 167 | + val result = new Array[Double](percentiles.length) |
| 168 | + var i = 0 |
| 169 | + while (i < percentiles.length) { |
| 170 | + result(i) = summary.query(percentiles(i)) |
| 171 | + i += 1 |
| 172 | + } |
| 173 | + result |
| 174 | + } |
| 175 | + } |
| 176 | + |
| 177 | + def toBytes: Array[Byte] = { |
| 178 | + summary = summary.compress() |
| 179 | + fieldSerializer.toRow(summary).asInstanceOf[UnsafeRow].getBytes |
| 180 | + } |
| 181 | + } |
| 182 | +} |
0 commit comments