Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.Utils

object StatFunctions extends Logging {

Expand Down Expand Up @@ -188,47 +187,14 @@ object StatFunctions extends Logging {

/** Generate a table of frequencies for the elements of two columns. */
def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
val tableName = s"${col1}_$col2"
val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt)
if (counts.length == 1e6.toInt) {
logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
"the pairs. Please try reducing the amount of distinct items in your columns.")
}
def cleanElement(element: Any): String = {
if (element == null) "null" else element.toString
}
// get the distinct sorted values of column 2, so that we can make them the column names
val distinctCol2: Map[Any, Int] =
Utils.toMapWithIndex(counts.map(e => cleanElement(e.get(1))).distinct.sorted)
val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize")
val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
val countsRow = new GenericInternalRow(columnSize + 1)
rows.foreach { (row: Row) =>
// row.get(0) is column 1
// row.get(1) is column 2
// row.get(2) is the frequency
val columnIndex = distinctCol2(cleanElement(row.get(1)))
countsRow.setLong(columnIndex + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
countsRow.update(0, UTF8String.fromString(cleanElement(col1Item)))
countsRow
}.toSeq
// Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept
// special keywords and `.`, wrap the column names in ``.
def cleanColumnName(name: String): String = {
name.replace("`", "")
}
// In the map, the column names (._1) are not ordered by the index (._2). This was the bug in
// SPARK-8681. We need to explicitly sort by the column index and assign the column names.
val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r =>
StructField(cleanColumnName(r._1.toString), LongType)
}
val schema = StructType(StructField(tableName, StringType) +: headerNames)

Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
df.groupBy(
when(isnull(col(col1)), "null")
.otherwise(col(col1).cast("string"))
.as(s"${col1}_$col2")
).pivot(
when(isnull(col(col2)), "null")
.otherwise(regexp_replace(col(col2).cast("string"), "`", ""))
).count().na.fill(0L)
}

/** Calculate selected summary statistics for a dataset */
Expand Down