Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add general util functions in DatasetUtils and SchemaUtils
  • Loading branch information
lu-wang-dl committed May 3, 2018
commit 877c126ff493e43edb5a8bcf33e7dd1fe59503b0
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,16 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.sql.types.{IntegerType, StructType}


/**
Expand Down Expand Up @@ -69,24 +67,13 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter
@Since("2.0.0")
def getMinDivisibleClusterSize: Double = $(minDivisibleClusterSize)

/**
* Validates the input schema.
* @param schema input schema
*/
private[clustering] def validateSchema(schema: StructType): Unit = {
val typeCandidates = List( new VectorUDT,
new ArrayType(DoubleType, false),
new ArrayType(FloatType, false))

SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
}
/**
* Validates and transforms the input schema.
* @param schema input schema
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateSchema(schema)
SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
}
Expand Down Expand Up @@ -144,11 +131,8 @@ class BisectingKMeansModel private[ml] (
*/
@Since("2.0.0")
def computeCost(dataset: Dataset[_]): Double = {
validateSchema(dataset.schema)
val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol))
.rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol)
val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
parentModel.computeCost(data)
}

Expand Down Expand Up @@ -275,10 +259,7 @@ class BisectingKMeans @Since("2.0.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
transformSchema(dataset.schema, logging = true)
val rdd: RDD[OldVector] = dataset
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)

val instr = Instrumentation.create(this, rdd)
instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatr
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.sql.types.{IntegerType, StructType}


/**
Expand All @@ -56,26 +56,14 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
@Since("2.0.0")
def getK: Int = $(k)

/**
* Validates the input schema.
* @param schema input schema
*/
private[clustering] def validateSchema(schema: StructType): Unit = {
val typeCandidates = List( new VectorUDT,
new ArrayType(DoubleType, false),
new ArrayType(FloatType, false))

SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
}

/**
* Validates and transforms the input schema.
*
* @param schema input schema
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateSchema(schema)
SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT)
}
Expand Down
29 changes: 6 additions & 23 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
Expand All @@ -34,7 +34,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion

Expand Down Expand Up @@ -86,24 +86,13 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
def getInitSteps: Int = $(initSteps)

/**
* Validates the input schema.
* @param schema input schema
*/
private[clustering] def validateSchema(schema: StructType): Unit = {
val typeCandidates = List( new VectorUDT,
new ArrayType(DoubleType, false),
new ArrayType(FloatType, false))

SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
}
/**
* Validates and transforms the input schema.
* @param schema input schema
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateSchema(schema)
SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
}
Expand Down Expand Up @@ -160,11 +149,8 @@ class KMeansModel private[ml] (
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
@Since("2.0.0")
def computeCost(dataset: Dataset[_]): Double = {
validateSchema(dataset.schema)
val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol))
.rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol)
val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
parentModel.computeCost(data)
}

Expand Down Expand Up @@ -350,10 +336,7 @@ class KMeans @Since("1.5.0") (
transformSchema(dataset.schema, logging = true)

val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances: RDD[OldVector] = dataset
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)

if (handlePersistence) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
Expand Down
14 changes: 1 addition & 13 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -311,18 +311,6 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
@Since("2.0.0")
def getKeepLastCheckpoint: Boolean = $(keepLastCheckpoint)

/**
* Validates the input schema.
* @param schema input schema
*/
private[clustering] def validateSchema(schema: StructType): Unit = {
val typeCandidates = List( new VectorUDT,
new ArrayType(DoubleType, false),
new ArrayType(FloatType, false))

SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
}

/**
* Validates and transforms the input schema.
*
Expand Down Expand Up @@ -357,7 +345,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
s" must be >= 1. Found value: $getTopicConcentration")
}
}
validateSchema(schema)
SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}

Expand Down
13 changes: 11 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.ml.util

import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
import org.apache.spark.sql.{Column, Dataset}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType}

Expand Down Expand Up @@ -60,4 +62,11 @@ private[spark] object DatasetUtils {
throw new IllegalArgumentException(s"$other column cannot be cast to Vector")
}
}

def columnToOldVector(dataset: Dataset[_], colName: String): RDD[OldVector] = {
dataset.select(columnToVector(dataset, colName))
.rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
}
}
16 changes: 15 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.ml.util

import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType}
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.types._


/**
Expand Down Expand Up @@ -101,4 +102,17 @@ private[spark] object SchemaUtils {
require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.")
StructType(schema.fields :+ col)
}

/**
* Check whether the given column in the schema is one of the supporting vector type: Vector,
* Array[Dloat]. Array[Double]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Float

* @param schema input schema
* @param colName column name
*/
def validateVectorCompatibleColumn(schema: StructType, colName: String): Unit = {
val typeCandidates = List( new VectorUDT,
new ArrayType(DoubleType, false),
new ArrayType(FloatType, false))
checkColumnTypes(schema, colName, typeCandidates)
}
}