-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-7242][SQL][MLLIB] Frequent items for DataFrames #5799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3d82168
8279d4d
38e784d
482e741
3a5c177
0915e23
39b1bba
a6ec82c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.sql.execution.stat.FrequentItems | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Statistic functions for [[DataFrame]]s. | ||
| */ | ||
| @Experimental | ||
| final class DataFrameStatFunctions private[sql](df: DataFrame) { | ||
|
|
||
| /** | ||
| * Finding frequent items for columns, possibly with false positives. Using the | ||
| * frequent element count algorithm described in | ||
| * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. | ||
| * The `support` should be greater than 1e-4. | ||
| * | ||
| * @param cols the names of the columns to search frequent items in. | ||
| * @param support The minimum frequency for an item to be considered `frequent`. Should be greater | ||
| * than 1e-4. | ||
| * @return A Local DataFrame with the Array of frequent items for each column. | ||
| */ | ||
| def freqItems(cols: Array[String], support: Double): DataFrame = { | ||
| FrequentItems.singlePassFreqItems(df, cols, support) | ||
| } | ||
|
|
||
| /** | ||
| * Runs `freqItems` with a default `support` of 1%. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to just put the same javadoc, and say support = 0.01. Rather than saying run freqItems... |
||
| * | ||
| * @param cols the names of the columns to search frequent items in. | ||
| * @return A Local DataFrame with the Array of frequent items for each column. | ||
| */ | ||
| def freqItems(cols: Array[String]): DataFrame = { | ||
| FrequentItems.singlePassFreqItems(df, cols, 0.01) | ||
| } | ||
|
|
||
| /** | ||
| * Python friendly implementation for `freqItems` | ||
| */ | ||
| def freqItems(cols: List[String], support: Double): DataFrame = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can just use Seq here, since Python has helper functions that can convert List into Seq. |
||
| FrequentItems.singlePassFreqItems(df, cols, support) | ||
| } | ||
|
|
||
| /** | ||
| * Python friendly implementation for `freqItems` with a default `support` of 1%. | ||
| */ | ||
| def freqItems(cols: List[String]): DataFrame = { | ||
| FrequentItems.singlePassFreqItems(df, cols, 0.01) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution.stat | ||
|
|
||
| import scala.collection.mutable.{Map => MutableMap} | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.sql.{Column, DataFrame, Row} | ||
| import org.apache.spark.sql.catalyst.plans.logical.LocalRelation | ||
| import org.apache.spark.sql.types.{ArrayType, StructField, StructType} | ||
|
|
||
| private[sql] object FrequentItems extends Logging { | ||
|
|
||
| /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */ | ||
| private class FreqItemCounter(size: Int) extends Serializable { | ||
| val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long] | ||
|
|
||
| /** | ||
| * Add a new example to the counts if it exists, otherwise deduct the count | ||
| * from existing items. | ||
| */ | ||
| def add(key: Any, count: Long): this.type = { | ||
| if (baseMap.contains(key)) { | ||
| baseMap(key) += count | ||
| } else { | ||
| if (baseMap.size < size) { | ||
| baseMap += key -> count | ||
| } else { | ||
| // TODO: Make this more efficient... A flatMap? | ||
| baseMap.retain((k, v) => v > count) | ||
| baseMap.transform((k, v) => v - count) | ||
| } | ||
| } | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Merge two maps of counts. | ||
| * @param other The map containing the counts for that partition | ||
| */ | ||
| def merge(other: FreqItemCounter): this.type = { | ||
| other.baseMap.foreach { case (k, v) => | ||
| add(k, v) | ||
| } | ||
| this | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Finding frequent items for columns, possibly with false positives. Using the | ||
| * frequent element count algorithm described in | ||
| * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. | ||
| * The `support` should be greater than 1e-4. | ||
| * For Internal use only. | ||
| * | ||
| * @param df The input DataFrame | ||
| * @param cols the names of the columns to search frequent items in | ||
| * @param support The minimum frequency for an item to be considered `frequent`. Should be greater | ||
| * than 1e-4. | ||
| * @return A Local DataFrame with the Array of frequent items for each column. | ||
| */ | ||
| private[sql] def singlePassFreqItems( | ||
| df: DataFrame, | ||
| cols: Seq[String], | ||
| support: Double): DataFrame = { | ||
| require(support >= 1e-4, s"support ($support) must be greater than 1e-4.") | ||
| val numCols = cols.length | ||
| // number of max items to keep counts for | ||
| val sizeOfMap = (1 / support).toInt | ||
| val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap)) | ||
| val originalSchema = df.schema | ||
| val colInfo = cols.map { name => | ||
| val index = originalSchema.fieldIndex(name) | ||
| (name, originalSchema.fields(index).dataType) | ||
| } | ||
|
|
||
| val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)( | ||
| seqOp = (counts, row) => { | ||
| var i = 0 | ||
| while (i < numCols) { | ||
| val thisMap = counts(i) | ||
| val key = row.get(i) | ||
| thisMap.add(key, 1L) | ||
| i += 1 | ||
| } | ||
| counts | ||
| }, | ||
| combOp = (baseCounts, counts) => { | ||
| var i = 0 | ||
| while (i < numCols) { | ||
| baseCounts(i).merge(counts(i)) | ||
| i += 1 | ||
| } | ||
| baseCounts | ||
| } | ||
| ) | ||
| val justItems = freqItems.map(m => m.baseMap.keys.toSeq) | ||
| val resultRow = Row(justItems:_*) | ||
| // append frequent Items to the column name for easy debugging | ||
| val outputCols = colInfo.map { v => | ||
| StructField(v._1 + "_freqItems", ArrayType(v._2, false)) | ||
| } | ||
| val schema = StructType(outputCols).toAttributes | ||
| new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow))) | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import org.scalatest.FunSuite | ||
| import org.scalatest.Matchers._ | ||
|
|
||
| import org.apache.spark.sql.test.TestSQLContext | ||
| import org.apache.spark.sql.test.TestSQLContext.implicits._ | ||
|
|
||
| class DataFrameStatSuite extends FunSuite { | ||
|
|
||
| val sqlCtx = TestSQLContext | ||
|
|
||
| test("Frequent Items") { | ||
| def toLetter(i: Int): String = (i + 96).toChar.toString | ||
| val rows = Array.tabulate(1000) { i => | ||
| if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) | ||
| } | ||
| val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles") | ||
|
|
||
| val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) | ||
| val items = results.collect().head | ||
| items.getSeq[Int](0) should contain (1) | ||
| items.getSeq[String](1) should contain (toLetter(1)) | ||
|
|
||
| val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) | ||
| val items2 = singleColResults.collect().head | ||
| items2.getSeq[Double](0) should contain (-1.0) | ||
|
|
||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure you document the range of support allowed.