Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,10 @@ class LogisticRegression @Since("1.2.0") (
train(dataset, handlePersistence)
}

import Instrumentation.instrumented
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put import at top of file with the other imports (just to make imports easier to track).

protected[spark] def train(
dataset: Dataset[_],
handlePersistence: Boolean): LogisticRegressionModel = {
handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr =>
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
Expand All @@ -500,7 +501,8 @@ class LogisticRegression @Since("1.2.0") (

if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

val instr = Instrumentation.create(this, dataset)
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(regParam, elasticNetParam, standardization, threshold,
maxIter, tol, fitIntercept)

Expand Down Expand Up @@ -905,8 +907,6 @@ class LogisticRegression @Since("1.2.0") (
objectiveHistory)
}
model.setSummary(Some(logRegSummary))
instr.logSuccess(model)
model
}

@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private[spark] object RandomForest extends Logging {
numTrees: Int,
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation[_]],
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
/**
* Instrumentation logging for tuning params including the inner estimator and evaluator info.
*/
protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = {
protected def logTuningParams(instrumentation: Instrumentation): Unit = {
instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName)
instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName)
instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length)
Expand Down
127 changes: 84 additions & 43 deletions mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,60 @@ package org.apache.spark.ml.util

import java.util.UUID

import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}
import scala.util.control.NonFatal

import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.util.Utils

/**
* A small wrapper that defines a training session for an estimator, and some methods to log
* useful information during this session.
*
* A new instance is expected to be created within fit().
*
* @param estimator the estimator that is being fit
* @param dataset the training dataset
* @tparam E the type of the estimator
*/
private[spark] class Instrumentation[E <: Estimator[_]] private (
val estimator: E,
val dataset: RDD[_]) extends Logging {
private[spark] class Instrumentation extends Logging {

private val id = UUID.randomUUID()
private val prefix = {
private val shortId = id.toString.take(8)
private val prefix = s"[$shortId] "

// TODO: update spark.ml to use new Instrumentation APIs and remove this constructor
var stage: Params = _
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend we either plan to remove "stage" or change "logPipelineStage" so it only allows setting "stage" once. If we go with the former, how about leaving a note to remove "stage" once spark.ml code is migrated to use the new logParams() method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the plan is to remove stage once we port switch over to the new APIs

private def this(estimator: Estimator[_], dataset: RDD[_]) = {
this()
logPipelineStage(estimator)
logDataset(dataset)
}

/**
* Log some info about the pipeline stage being fit.
*/
def logPipelineStage(stage: PipelineStage): Unit = {
this.stage = stage
// estimator.getClass.getSimpleName can cause Malformed class name error,
// call safer `Utils.getSimpleName` instead
val className = Utils.getSimpleName(estimator.getClass)
s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
val className = Utils.getSimpleName(stage.getClass)
logInfo(s"Stage class: $className")
logInfo(s"Stage uid: ${stage.uid}")
}

init()
/**
* Log some data about the dataset being fit.
*/
def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd)

private def init(): Unit = {
log(s"training: numPartitions=${dataset.partitions.length}" +
/**
* Log some data about the dataset being fit.
*/
def logDataset(dataset: RDD[_]): Unit = {
logInfo(s"training: numPartitions=${dataset.partitions.length}" +
s" storageLevel=${dataset.getStorageLevel}")
}

Expand Down Expand Up @@ -89,23 +104,25 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
super.logInfo(prefix + msg)
}

/**
* Alias for logInfo, see above.
*/
def log(msg: String): Unit = logInfo(msg)

/**
* Logs the value of the given parameters for the estimator being used in this session.
*/
def logParams(params: Param[_]*): Unit = {
def logParams(hasParams: Params, params: Param[_]*): Unit = {
val pairs: Seq[(String, JValue)] = for {
p <- params
value <- estimator.get(p)
value <- hasParams.get(p)
} yield {
val cast = p.asInstanceOf[Param[Any]]
p.name -> parse(cast.jsonEncode(value))
}
log(compact(render(map2jvalue(pairs.toMap))))
logInfo(compact(render(map2jvalue(pairs.toMap))))
}

// TODO: remove this
def logParams(params: Param[_]*): Unit = {
require(stage != null, "`logStageParams` must be called before `logParams` (or an instance of" +
" Params must be provided explicitly).")
logParams(stage, params: _*)
}

def logNumFeatures(num: Long): Unit = {
Expand All @@ -124,35 +141,48 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
* Logs the value with customized name field.
*/
def logNamedValue(name: String, value: String): Unit = {
log(compact(render(name -> value)))
logInfo(compact(render(name -> value)))
}

def logNamedValue(name: String, value: Long): Unit = {
log(compact(render(name -> value)))
logInfo(compact(render(name -> value)))
}

def logNamedValue(name: String, value: Double): Unit = {
log(compact(render(name -> value)))
logInfo(compact(render(name -> value)))
}

def logNamedValue(name: String, value: Array[String]): Unit = {
log(compact(render(name -> compact(render(value.toSeq)))))
logInfo(compact(render(name -> compact(render(value.toSeq)))))
}

def logNamedValue(name: String, value: Array[Long]): Unit = {
log(compact(render(name -> compact(render(value.toSeq)))))
logInfo(compact(render(name -> compact(render(value.toSeq)))))
}

def logNamedValue(name: String, value: Array[Double]): Unit = {
log(compact(render(name -> compact(render(value.toSeq)))))
logInfo(compact(render(name -> compact(render(value.toSeq)))))
}


// TODO: Remove this (possibly replace with logModel?)
/**
* Logs the successful completion of the training session.
*/
def logSuccess(model: Model[_]): Unit = {
log(s"training finished")
logInfo(s"training finished")
}

def logSuccess(): Unit = {
logInfo("training finished")
}

/**
* Logs an exception raised during a training session.
*/
def logFailure(e: Throwable): Unit = {
val msg = e.getStackTrace.mkString("\n")
super.logError(msg)
}
}

Expand All @@ -169,22 +199,33 @@ private[spark] object Instrumentation {
val varianceOfLabels = "varianceOfLabels"
}

// TODO: Remove these
/**
* Creates an instrumentation object for a training session.
*/
def create[E <: Estimator[_]](
estimator: E, dataset: Dataset[_]): Instrumentation[E] = {
create[E](estimator, dataset.rdd)
def create(estimator: Estimator[_], dataset: Dataset[_]): Instrumentation = {
create(estimator, dataset.rdd)
}

/**
* Creates an instrumentation object for a training session.
*/
def create[E <: Estimator[_]](
estimator: E, dataset: RDD[_]): Instrumentation[E] = {
new Instrumentation[E](estimator, dataset)
def create(estimator: Estimator[_], dataset: RDD[_]): Instrumentation = {
new Instrumentation(estimator, dataset)
}
// end remove

def instrumented[T](body: (Instrumentation => T)): T = {
val instr = new Instrumentation()
Try(body(instr)) match {
case Failure(NonFatal(e)) =>
instr.logFailure(e)
throw e
case Success(result) =>
instr.logSuccess()
result
}
}

}

/**
Expand All @@ -193,7 +234,7 @@ private[spark] object Instrumentation {
* will log via it, otherwise will log via common logger.
*/
private[spark] class OptionalInstrumentation private(
val instrumentation: Option[Instrumentation[_ <: Estimator[_]]],
val instrumentation: Option[Instrumentation],
val className: String) extends Logging {

protected override def logName: String = className
Expand Down Expand Up @@ -225,9 +266,9 @@ private[spark] object OptionalInstrumentation {
/**
* Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object.
*/
def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = {
def create(instr: Instrumentation): OptionalInstrumentation = {
new OptionalInstrumentation(Some(instr),
instr.estimator.getClass.getName.stripSuffix("$"))
instr.stage.getClass.getName.stripSuffix("$"))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class KMeans private (

private[spark] def run(
data: RDD[Vector],
instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
instr: Option[Instrumentation]): KMeansModel = {

if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
Expand Down Expand Up @@ -264,7 +264,7 @@ class KMeans private (
*/
private def runAlgorithm(
data: RDD[VectorWithNorm],
instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
instr: Option[Instrumentation]): KMeansModel = {

val sc = data.sparkContext

Expand Down