Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.ml.FrequentItems
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -1414,4 +1415,25 @@ class DataFrame private[sql](
val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
SerDeUtil.javaToPython(jrdd)
}

/////////////////////////////////////////////////////////////////////////////
// Statistic functions
/////////////////////////////////////////////////////////////////////////////

// scalastyle:off
object stat {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work in java?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it looks like df.stat$.MODULE$.freqItems(). I don't know how we can otherwise make it df.stat.freqItems in scala.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take a look at how we implemented na.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha! I like it

// scalastyle:on

/**
* Finding frequent items for columns, possibly with false positives. Using the algorithm
* described in `http://www.cs.umd.edu/~samir/498/karp.pdf`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should do proper citation rather than giving an url, since this url might disappear.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*
* @param cols the names of the columns to search frequent items in
* @param support The minimum frequency for an item to be considered `frequent`
* @return A Local DataFrame with the Array of frequent items for each column.
*/
def freqItems(cols: Array[String], support: Double): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in df we usually support List[String] and Seq[String]. This is one reason why we are using a separate name space.

FrequentItems.singlePassFreqItems(toDF(), cols, support)
}
}
}
124 changes: 124 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/ml/FrequentItems.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.ml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put this in execution.stat?

It's annoying to add a top level package because we have rules to specifically exclude existing packages.



import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{StructType, ArrayType, StructField}

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.Logging
import org.apache.spark.sql.{Row, DataFrame, functions}

private[sql] object FrequentItems extends Logging {

/**
* Merge two maps of counts. Subtracts the sum of `otherMap` from `baseMap`, and fills in
* any emptied slots with the most frequent of `otherMap`.
* @param baseMap The map containing the global counts
* @param otherMap The map containing the counts for that partition
* @param maxSize The maximum number of counts to keep in memory
*/
private def mergeCounts[A](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the implementation could be cleaner if we wrap MutableMap[A, Long] with a utility class:

class FreqItemCounter(size: k) {
  def add(any: Any, count: Long = 1L): this.type
  def merge(other: FreqItemCounter): this.type = {
    other.toSeq.foreach { case (k, c) =>
      add(k, c)
    }
  }
  def freqItems: Array[Any]
  def toSeq: Seq[(Any, Long)]
}

baseMap: MutableMap[A, Long],
otherMap: MutableMap[A, Long],
maxSize: Int): Unit = {
val otherSum = otherMap.foldLeft(0L) { case (sum, (k, v)) =>
if (!baseMap.contains(k)) sum + v else sum
}
baseMap.retain((k, v) => v > otherSum)
// sort in decreasing order, so that we will add the most frequent items first
val sorted = otherMap.toSeq.sortBy(-_._2)
var i = 0
val otherSize = sorted.length
while (i < otherSize && baseMap.size < maxSize) {
val keyVal = sorted(i)
baseMap += keyVal._1 -> keyVal._2
i += 1
}
}


/**
* Finding frequent items for columns, possibly with false positives. Using the algorithm
* described in `http://www.cs.umd.edu/~samir/498/karp.pdf`.
* For Internal use only.
*
* @param df The input DataFrame
* @param cols the names of the columns to search frequent items in
* @param support The minimum frequency for an item to be considered `frequent`
* @return A Local DataFrame with the Array of frequent items for each column.
*/
private[sql] def singlePassFreqItems(
df: DataFrame,
cols: Array[String],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If multiple columns are provided, shall we search the combination of them instead of each individually? For example, if I call

freqItems(Array("gender", "title"), 0.01)

I'm expecting the frequent combinations instead of each of them. The current implementation is more flexible because users can create a struct from multiple columns, and this allows to find frequent items on multiple columns in parallel. But I'm a little worried about what users expect when they call freqItems(Array("gender", "title")) @rxin

support: Double): DataFrame = {
val numCols = cols.length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the range of support. Warn if the it is too small (e.g., 1e-6).

// number of max items to keep counts for
val sizeOfMap = math.floor(1 / support).toInt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

math.floor is not necessary: (1.0 / support).toInt

val countMaps = Array.tabulate(numCols)(i => MutableMap.empty[Any, Long])
val originalSchema = df.schema
val colInfo = cols.map { name =>
val index = originalSchema.fieldIndex(name)
val dataType = originalSchema.fields(index)
(index, dataType.dataType)
}
val colIndices = colInfo.map(_._1)

val freqItems: Array[MutableMap[Any, Long]] = df.rdd.aggregate(countMaps)(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

df.select(cols).rdd.aggregate (then you don't need to skip elements)

seqOp = (counts, row) => {
var i = 0
colIndices.foreach { index =>
val thisMap = counts(i)
val key = row.get(index)
if (thisMap.contains(key)) {
thisMap(key) += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 -> 1L

} else {
if (thisMap.size < sizeOfMap) {
thisMap += key -> 1
} else {
// TODO: Make this more efficient... A flatMap?
thisMap.retain((k, v) => v > 1)
thisMap.transform((k, v) => v - 1)
}
}
i += 1
}
counts
},
combOp = (baseCounts, counts) => {
var i = 0
while (i < numCols) {
mergeCounts(baseCounts(i), counts(i), sizeOfMap)
i += 1
}
baseCounts
}
)
//
val justItems = freqItems.map(m => m.keys.toSeq)
val resultRow = Row(justItems:_*)
// append frequent Items to the column name for easy debugging
val outputCols = cols.zip(colInfo).map{ v =>
StructField(v._1 + "-freqItems", ArrayType(v._2._2, false))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-freqItems -> _freqItems

}
val schema = StructType(outputCols).toAttributes
new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.ml

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.scalatest.FunSuite

import org.apache.spark.sql.test.TestSQLContext

class FrequentItemsSuite extends FunSuite {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to .sql package, and call it DataFrameStatSuite?


val sqlCtx = TestSQLContext

test("Frequent Items") {
def toLetter(i: Int): String = (i + 96).toChar.toString
val rows = Array.tabulate(1000)(i => if (i % 3 == 0) (1, toLetter(1)) else (i, toLetter(i)))
val rowRdd = sqlCtx.sparkContext.parallelize(rows.map(v => Row(v._1, v._2)))
val schema = StructType(StructField("numbers", IntegerType, false) ::
StructField("letters", StringType, false) :: Nil)
val df = sqlCtx.createDataFrame(rowRdd, schema)

val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
assert(items.getSeq(0).contains(1),
"1 should be the frequent item for column 'numbers")
assert(items.getSeq(1).contains(toLetter(1)),
s"${toLetter(1)} should be the frequent item for column 'letters'")
}
}