Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,17 @@ class DataFrame private[sql](
*/
def na: DataFrameNaFunctions = new DataFrameNaFunctions(this)

/**
* Returns a [[DataFrameStatFunctions]] for working statistic functions support.
* {{{
* // Finding frequent items in column with name 'a'.
* df.stat.freqItems(Seq("a"))
* }}}
*
* @group dfops
*/
def stat: DataFrameStatFunctions = new DataFrameStatFunctions(this)

/**
* Cartesian join with another [[DataFrame]].
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat.FrequentItems

/**
* :: Experimental ::
* Statistic functions for [[DataFrame]]s.
*/
@Experimental
final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* The `support` should be greater than 1e-4.
*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure you document the range of support allowed.

* @param cols the names of the columns to search frequent items in.
* @param support The minimum frequency for an item to be considered `frequent`. Should be greater
* than 1e-4.
* @return A Local DataFrame with the Array of frequent items for each column.
*/
def freqItems(cols: Array[String], support: Double): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, support)
}

/**
* Runs `freqItems` with a default `support` of 1%.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to just put the same javadoc, and say support = 0.01. Rather than saying run freqItems...

*
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*/
def freqItems(cols: Array[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}

/**
* Python friendly implementation for `freqItems`
*/
def freqItems(cols: List[String], support: Double): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just use Seq here, since Python has helper functions that can convert List into Seq.

FrequentItems.singlePassFreqItems(df, cols, support)
}

/**
* Python friendly implementation for `freqItems` with a default `support` of 1%.
*/
def freqItems(cols: List[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.Logging
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{ArrayType, StructField, StructType}

private[sql] object FrequentItems extends Logging {

/** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */
private class FreqItemCounter(size: Int) extends Serializable {
val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long]

/**
* Add a new example to the counts if it exists, otherwise deduct the count
* from existing items.
*/
def add(key: Any, count: Long): this.type = {
if (baseMap.contains(key)) {
baseMap(key) += count
} else {
if (baseMap.size < size) {
baseMap += key -> count
} else {
// TODO: Make this more efficient... A flatMap?
baseMap.retain((k, v) => v > count)
baseMap.transform((k, v) => v - count)
}
}
this
}

/**
* Merge two maps of counts.
* @param other The map containing the counts for that partition
*/
def merge(other: FreqItemCounter): this.type = {
other.baseMap.foreach { case (k, v) =>
add(k, v)
}
this
}
}

/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* The `support` should be greater than 1e-4.
* For Internal use only.
*
* @param df The input DataFrame
* @param cols the names of the columns to search frequent items in
* @param support The minimum frequency for an item to be considered `frequent`. Should be greater
* than 1e-4.
* @return A Local DataFrame with the Array of frequent items for each column.
*/
private[sql] def singlePassFreqItems(
df: DataFrame,
cols: Seq[String],
support: Double): DataFrame = {
require(support >= 1e-4, s"support ($support) must be greater than 1e-4.")
val numCols = cols.length
// number of max items to keep counts for
val sizeOfMap = (1 / support).toInt
val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
val originalSchema = df.schema
val colInfo = cols.map { name =>
val index = originalSchema.fieldIndex(name)
(name, originalSchema.fields(index).dataType)
}

val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
while (i < numCols) {
val thisMap = counts(i)
val key = row.get(i)
thisMap.add(key, 1L)
i += 1
}
counts
},
combOp = (baseCounts, counts) => {
var i = 0
while (i < numCols) {
baseCounts(i).merge(counts(i))
i += 1
}
baseCounts
}
)
val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
val resultRow = Row(justItems:_*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
}
val schema = StructType(outputCols).toAttributes
new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.TestData$;
import org.apache.spark.sql.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.*;
Expand Down Expand Up @@ -178,5 +175,12 @@ public void testCreateDataFrameFromJavaBeans() {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
}


@Test
public void testFrequentItems() {
DataFrame df = context.table("testData2");
String[] cols = new String[]{"a"};
DataFrame results = df.stat().freqItems(cols, 0.2);
Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.scalatest.FunSuite
import org.scalatest.Matchers._

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._

class DataFrameStatSuite extends FunSuite {

val sqlCtx = TestSQLContext

test("Frequent Items") {
def toLetter(i: Int): String = (i + 96).toChar.toString
val rows = Array.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
}
val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles")

val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
items.getSeq[Int](0) should contain (1)
items.getSeq[String](1) should contain (toLetter(1))

val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)

}
}