Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ class LogisticRegression @Since("1.2.0") (

protected[spark] def train(
dataset: Dataset[_],
handlePersistence: Boolean): LogisticRegressionModel = {
handlePersistence: Boolean): LogisticRegressionModel = Instrumentation.instrumented { instr =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid line too wide, we might want to import instrumented and save "Instrumentation" from this line.

val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
Expand All @@ -500,7 +500,7 @@ class LogisticRegression @Since("1.2.0") (

if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

val instr = Instrumentation.create(this, dataset)
instr.logContext(this, dataset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't log anything. I think we should auto-generate prefix and keep it as a constant. So logs would appear as:

[PREFIX]: instrumentation started
[PREFIX]: using estimator logReg-abc128
[PREFIX]: using dataset some hashcode
[PREFIX]: param maxIter=10
[PREFIX]: ...
[PREFIX]: run succeeded/failed
[PREFIX]: instrumentation ended

We can generate 8 random chars as the PREFIX. This is sufficient for correlate metrics from the same run. The issue with making it mutable is that we do not have a way to guarantee logContext is always called.

So I would suggest replacing logContext with the following:

  • logEstimator or logPipelineStage
  • logDataset

Btw, we can by default log call site. It provides more info for instrumentation, not necessary in this PR.

instr.logParams(regParam, elasticNetParam, standardization, threshold,
maxIter, tol, fitIntercept)

Expand Down Expand Up @@ -905,8 +905,6 @@ class LogisticRegression @Since("1.2.0") (
objectiveHistory)
}
model.setSummary(Some(logRegSummary))
instr.logSuccess(model)
model
}

@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private[spark] object RandomForest extends Logging {
numTrees: Int,
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation[_]],
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
/**
* Instrumentation logging for tuning params including the inner estimator and evaluator info.
*/
protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = {
protected def logTuningParams(instrumentation: Instrumentation): Unit = {
instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName)
instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName)
instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length)
Expand Down
100 changes: 73 additions & 27 deletions mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.ml.util

import java.util.UUID

import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}
import scala.util.control.NonFatal

import org.json4s._
import org.json4s.JsonDSL._
Expand All @@ -35,32 +36,47 @@ import org.apache.spark.util.Utils
/**
* A small wrapper that defines a training session for an estimator, and some methods to log
* useful information during this session.
*
* A new instance is expected to be created within fit().
*
* @param estimator the estimator that is being fit
* @param dataset the training dataset
* @tparam E the type of the estimator
*/
private[spark] class Instrumentation[E <: Estimator[_]] private (
val estimator: E,
val dataset: RDD[_]) extends Logging {
private[spark] class Instrumentation extends Logging {

private val id = UUID.randomUUID()
private val prefix = {
// estimator.getClass.getSimpleName can cause Malformed class name error,
// call safer `Utils.getSimpleName` instead
val className = Utils.getSimpleName(estimator.getClass)
s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
private val shortId = id.toString.take(8)
private var prefix = s"$shortId:"

// TODO: update spark.ml to use new Instrumentation APIs and remove this constructor
var estimator: Estimator[_] = _
private def this(estimator: Estimator[_], dataset: RDD[_]) = {
this()
logContext(estimator, dataset)
}

init()
/**
* Log info about the estimator and dataset being fit.
*
* @param estimator the estimator that is being fit
* @param dataset the training dataset
*/
def logContext(estimator: Estimator[_], dataset: RDD[_]): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comment above

this.estimator = estimator
prefix = {
// estimator.getClass.getSimpleName can cause Malformed class name error,
// call safer `Utils.getSimpleName` instead
val className = Utils.getSimpleName(estimator.getClass)
s"$shortId-$className-${estimator.uid}-${dataset.hashCode()}:"
}

private def init(): Unit = {
log(s"training: numPartitions=${dataset.partitions.length}" +
s" storageLevel=${dataset.getStorageLevel}")
}

/**
* Log info about the estimator and dataset being fit.
*
* @param e the estimator that is being fit
* @param dataset the training dataset
*/
def logContext(e: Estimator[_], dataset: Dataset[_]): Unit = logContext(e, dataset.rdd)

/**
* Logs a debug message with a prefix that uniquely identifies the training session.
*/
Expand Down Expand Up @@ -97,7 +113,7 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
/**
* Logs the value of the given parameters for the estimator being used in this session.
*/
def logParams(params: Param[_]*): Unit = {
def logParams(estimator: Estimator[_], params: Param[_]*): Unit = {
val pairs: Seq[(String, JValue)] = for {
p <- params
value <- estimator.get(p)
Expand All @@ -108,6 +124,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
log(compact(render(map2jvalue(pairs.toMap))))
}

// TODO: remove this
def logParams(params: Param[_]*): Unit = {
require(estimator != null, "`logContext` must be called before `logParams`.")
logParams(estimator, params: _*)
}

def logNumFeatures(num: Long): Unit = {
logNamedValue(Instrumentation.loggerTags.numFeatures, num)
}
Expand Down Expand Up @@ -148,12 +170,25 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
}


// TODO: Remove this (possibly replace with logModel?)
/**
* Logs the successful completion of the training session.
*/
def logSuccess(model: Model[_]): Unit = {
log(s"training finished")
}

def logSuccess(): Unit = {
log("training finished")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have this log alias. I was wondering which log level it uses. Just use logInfo and remove log(.

}

/**
* Logs an exception raised during a training session.
*/
def logFailure(e: Throwable): Unit = {
val msg = e.getStackTrace.mkString("\n")
super.logInfo(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failures should go to ERROR level.

}
}

/**
Expand All @@ -169,22 +204,33 @@ private[spark] object Instrumentation {
val varianceOfLabels = "varianceOfLabels"
}

// TODO: Remove these
/**
* Creates an instrumentation object for a training session.
*/
def create[E <: Estimator[_]](
estimator: E, dataset: Dataset[_]): Instrumentation[E] = {
create[E](estimator, dataset.rdd)
def create(estimator: Estimator[_], dataset: Dataset[_]): Instrumentation = {
create(estimator, dataset.rdd)
}

/**
* Creates an instrumentation object for a training session.
*/
def create[E <: Estimator[_]](
estimator: E, dataset: RDD[_]): Instrumentation[E] = {
new Instrumentation[E](estimator, dataset)
def create(estimator: Estimator[_], dataset: RDD[_]): Instrumentation = {
new Instrumentation(estimator, dataset)
}
// end remove

def instrumented[T](body: (Instrumentation => T)): T = {
val instr = new Instrumentation()
Try(body(new Instrumentation())) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use already constructed instr

case Failure(NonFatal(e)) =>
instr.logFailure(e)
throw e
case Success(model) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model -> result, it doesn't need to be a model

instr.logSuccess()
model
}
}

}

/**
Expand All @@ -193,7 +239,7 @@ private[spark] object Instrumentation {
* will log via it, otherwise will log via common logger.
*/
private[spark] class OptionalInstrumentation private(
val instrumentation: Option[Instrumentation[_ <: Estimator[_]]],
val instrumentation: Option[Instrumentation],
val className: String) extends Logging {

protected override def logName: String = className
Expand Down Expand Up @@ -225,7 +271,7 @@ private[spark] object OptionalInstrumentation {
/**
* Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object.
*/
def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = {
def create(instr: Instrumentation): OptionalInstrumentation = {
new OptionalInstrumentation(Some(instr),
instr.estimator.getClass.getName.stripSuffix("$"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class KMeans private (

private[spark] def run(
data: RDD[Vector],
instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
instr: Option[Instrumentation]): KMeansModel = {

if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
Expand Down Expand Up @@ -264,7 +264,7 @@ class KMeans private (
*/
private def runAlgorithm(
data: RDD[VectorWithNorm],
instr: Option[Instrumentation[NewKMeans]]): KMeansModel = {
instr: Option[Instrumentation]): KMeansModel = {

val sc = data.sparkContext

Expand Down