-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-7242][SQL][MLLIB] Frequent items for DataFrames #5799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
3d82168
8279d4d
38e784d
482e741
3a5c177
0915e23
39b1bba
a6ec82c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ | |
| import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} | ||
| import org.apache.spark.sql.jdbc.JDBCWriteDetails | ||
| import org.apache.spark.sql.json.JsonRDD | ||
| import org.apache.spark.sql.ml.FrequentItems | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} | ||
| import org.apache.spark.util.Utils | ||
|
|
@@ -1414,4 +1415,25 @@ class DataFrame private[sql]( | |
| val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() | ||
| SerDeUtil.javaToPython(jrdd) | ||
| } | ||
|
|
||
| ///////////////////////////////////////////////////////////////////////////// | ||
| // Statistic functions | ||
| ///////////////////////////////////////////////////////////////////////////// | ||
|
|
||
| // scalastyle:off | ||
| object stat { | ||
| // scalastyle:on | ||
|
|
||
| /** | ||
| * Finding frequent items for columns, possibly with false positives. Using the algorithm | ||
| * described in `http://www.cs.umd.edu/~samir/498/karp.pdf`. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should do proper citation rather than giving an url, since this url might disappear.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use DOI link: http://dx.doi.org/10.1145/762471.762473 |
||
| * | ||
| * @param cols the names of the columns to search frequent items in | ||
| * @param support The minimum frequency for an item to be considered `frequent` | ||
| * @return A Local DataFrame with the Array of frequent items for each column. | ||
| */ | ||
| def freqItems(cols: Array[String], support: Double): DataFrame = { | ||
|
||
| FrequentItems.singlePassFreqItems(toDF(), cols, support) | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.ml | ||
|
||
|
|
||
|
|
||
| import org.apache.spark.sql.catalyst.plans.logical.LocalRelation | ||
| import org.apache.spark.sql.types.{StructType, ArrayType, StructField} | ||
|
|
||
| import scala.collection.mutable.{Map => MutableMap} | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.sql.{Row, DataFrame, functions} | ||
|
|
||
| private[sql] object FrequentItems extends Logging { | ||
|
|
||
| /** | ||
| * Merge two maps of counts. Subtracts the sum of `otherMap` from `baseMap`, and fills in | ||
| * any emptied slots with the most frequent of `otherMap`. | ||
| * @param baseMap The map containing the global counts | ||
| * @param otherMap The map containing the counts for that partition | ||
| * @param maxSize The maximum number of counts to keep in memory | ||
| */ | ||
| private def mergeCounts[A]( | ||
|
||
| baseMap: MutableMap[A, Long], | ||
| otherMap: MutableMap[A, Long], | ||
| maxSize: Int): Unit = { | ||
| val otherSum = otherMap.foldLeft(0L) { case (sum, (k, v)) => | ||
| if (!baseMap.contains(k)) sum + v else sum | ||
| } | ||
| baseMap.retain((k, v) => v > otherSum) | ||
| // sort in decreasing order, so that we will add the most frequent items first | ||
| val sorted = otherMap.toSeq.sortBy(-_._2) | ||
| var i = 0 | ||
| val otherSize = sorted.length | ||
| while (i < otherSize && baseMap.size < maxSize) { | ||
| val keyVal = sorted(i) | ||
| baseMap += keyVal._1 -> keyVal._2 | ||
| i += 1 | ||
| } | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Finding frequent items for columns, possibly with false positives. Using the algorithm | ||
| * described in `http://www.cs.umd.edu/~samir/498/karp.pdf`. | ||
| * For Internal use only. | ||
| * | ||
| * @param df The input DataFrame | ||
| * @param cols the names of the columns to search frequent items in | ||
| * @param support The minimum frequency for an item to be considered `frequent` | ||
| * @return A Local DataFrame with the Array of frequent items for each column. | ||
| */ | ||
| private[sql] def singlePassFreqItems( | ||
| df: DataFrame, | ||
| cols: Array[String], | ||
|
||
| support: Double): DataFrame = { | ||
| val numCols = cols.length | ||
|
||
| // number of max items to keep counts for | ||
| val sizeOfMap = math.floor(1 / support).toInt | ||
|
||
| val countMaps = Array.tabulate(numCols)(i => MutableMap.empty[Any, Long]) | ||
| val originalSchema = df.schema | ||
| val colInfo = cols.map { name => | ||
| val index = originalSchema.fieldIndex(name) | ||
| val dataType = originalSchema.fields(index) | ||
| (index, dataType.dataType) | ||
| } | ||
| val colIndices = colInfo.map(_._1) | ||
|
|
||
| val freqItems: Array[MutableMap[Any, Long]] = df.rdd.aggregate(countMaps)( | ||
|
||
| seqOp = (counts, row) => { | ||
| var i = 0 | ||
| colIndices.foreach { index => | ||
| val thisMap = counts(i) | ||
| val key = row.get(index) | ||
| if (thisMap.contains(key)) { | ||
| thisMap(key) += 1 | ||
|
||
| } else { | ||
| if (thisMap.size < sizeOfMap) { | ||
| thisMap += key -> 1 | ||
| } else { | ||
| // TODO: Make this more efficient... A flatMap? | ||
| thisMap.retain((k, v) => v > 1) | ||
| thisMap.transform((k, v) => v - 1) | ||
| } | ||
| } | ||
| i += 1 | ||
| } | ||
| counts | ||
| }, | ||
| combOp = (baseCounts, counts) => { | ||
| var i = 0 | ||
| while (i < numCols) { | ||
| mergeCounts(baseCounts(i), counts(i), sizeOfMap) | ||
| i += 1 | ||
| } | ||
| baseCounts | ||
| } | ||
| ) | ||
| // | ||
| val justItems = freqItems.map(m => m.keys.toSeq) | ||
| val resultRow = Row(justItems:_*) | ||
| // append frequent Items to the column name for easy debugging | ||
| val outputCols = cols.zip(colInfo).map{ v => | ||
| StructField(v._1 + "-freqItems", ArrayType(v._2._2, false)) | ||
|
||
| } | ||
| val schema = StructType(outputCols).toAttributes | ||
| new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow))) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.ml | ||
|
|
||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.types._ | ||
| import org.scalatest.FunSuite | ||
|
|
||
| import org.apache.spark.sql.test.TestSQLContext | ||
|
|
||
| class FrequentItemsSuite extends FunSuite { | ||
|
||
|
|
||
| val sqlCtx = TestSQLContext | ||
|
|
||
| test("Frequent Items") { | ||
| def toLetter(i: Int): String = (i + 96).toChar.toString | ||
| val rows = Array.tabulate(1000)(i => if (i % 3 == 0) (1, toLetter(1)) else (i, toLetter(i))) | ||
| val rowRdd = sqlCtx.sparkContext.parallelize(rows.map(v => Row(v._1, v._2))) | ||
| val schema = StructType(StructField("numbers", IntegerType, false) :: | ||
| StructField("letters", StringType, false) :: Nil) | ||
| val df = sqlCtx.createDataFrame(rowRdd, schema) | ||
|
|
||
| val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) | ||
| val items = results.collect().head | ||
| assert(items.getSeq(0).contains(1), | ||
| "1 should be the frequent item for column 'numbers") | ||
| assert(items.getSeq(1).contains(toLetter(1)), | ||
| s"${toLetter(1)} should be the frequent item for column 'letters'") | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work in java?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it looks like
df.stat$.MODULE$.freqItems(). I don't know how we can otherwise make itdf.stat.freqItemsin scala.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
take a look at how we implemented na.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aha! I like it