Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
PR feedback.
  • Loading branch information
MrBago committed Jul 12, 2018
commit b98d77284f2397daef2623c012d34ff50a25f498
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,10 @@ class LogisticRegression @Since("1.2.0") (
train(dataset, handlePersistence)
}

import Instrumentation.instrumented
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put import at top of file with the other imports (just to make imports easier to track).

protected[spark] def train(
dataset: Dataset[_],
handlePersistence: Boolean): LogisticRegressionModel = Instrumentation.instrumented { instr =>
handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr =>
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
Expand All @@ -500,7 +501,8 @@ class LogisticRegression @Since("1.2.0") (

if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

instr.logContext(this, dataset)
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(regParam, elasticNetParam, standardization, threshold,
maxIter, tol, fitIntercept)

Expand Down
91 changes: 43 additions & 48 deletions mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.util.Utils
Expand All @@ -41,41 +41,40 @@ private[spark] class Instrumentation extends Logging {

private val id = UUID.randomUUID()
private val shortId = id.toString.take(8)
private var prefix = s"$shortId:"
private val prefix = s"[$shortId] "

// TODO: update spark.ml to use new Instrumentation APIs and remove this constructor
var estimator: Estimator[_] = _
var stage: Params = _
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend we either plan to remove "stage" or change "logPipelineStage" so it only allows setting "stage" once. If we go with the former, how about leaving a note to remove "stage" once spark.ml code is migrated to use the new logParams() method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the plan is to remove stage once we port switch over to the new APIs

private def this(estimator: Estimator[_], dataset: RDD[_]) = {
this()
logContext(estimator, dataset)
logPipelineStage(estimator)
logDataset(dataset)
}

/**
* Log info about the estimator and dataset being fit.
*
* @param estimator the estimator that is being fit
* @param dataset the training dataset
* Log some info about the pipeline stage being fit.
*/
def logContext(estimator: Estimator[_], dataset: RDD[_]): Unit = {
this.estimator = estimator
prefix = {
// estimator.getClass.getSimpleName can cause Malformed class name error,
// call safer `Utils.getSimpleName` instead
val className = Utils.getSimpleName(estimator.getClass)
s"$shortId-$className-${estimator.uid}-${dataset.hashCode()}:"
}

log(s"training: numPartitions=${dataset.partitions.length}" +
s" storageLevel=${dataset.getStorageLevel}")
def logPipelineStage(stage: PipelineStage): Unit = {
this.stage = stage
// estimator.getClass.getSimpleName can cause Malformed class name error,
// call safer `Utils.getSimpleName` instead
val className = Utils.getSimpleName(stage.getClass)
logInfo(s"Stage class: $className")
logInfo(s"Stage uid: ${stage.uid}")
}

/**
* Log info about the estimator and dataset being fit.
*
* @param e the estimator that is being fit
* @param dataset the training dataset
* Log some data about the dataset being fit.
*/
def logContext(e: Estimator[_], dataset: Dataset[_]): Unit = logContext(e, dataset.rdd)
def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd)

/**
* Log some data about the dataset being fit.
*/
def logDataset(dataset: RDD[_]): Unit = {
logInfo(s"training: numPartitions=${dataset.partitions.length}" +
s" storageLevel=${dataset.getStorageLevel}")
}

/**
* Logs a debug message with a prefix that uniquely identifies the training session.
Expand Down Expand Up @@ -105,29 +104,25 @@ private[spark] class Instrumentation extends Logging {
super.logInfo(prefix + msg)
}

/**
* Alias for logInfo, see above.
*/
def log(msg: String): Unit = logInfo(msg)

/**
* Logs the value of the given parameters for the estimator being used in this session.
*/
def logParams(estimator: Estimator[_], params: Param[_]*): Unit = {
def logParams(hasParams: Params, params: Param[_]*): Unit = {
val pairs: Seq[(String, JValue)] = for {
p <- params
value <- estimator.get(p)
value <- hasParams.get(p)
} yield {
val cast = p.asInstanceOf[Param[Any]]
p.name -> parse(cast.jsonEncode(value))
}
log(compact(render(map2jvalue(pairs.toMap))))
logInfo(compact(render(map2jvalue(pairs.toMap))))
}

// TODO: remove this
def logParams(params: Param[_]*): Unit = {
require(estimator != null, "`logContext` must be called before `logParams`.")
logParams(estimator, params: _*)
require(stage != null, "`logStageParams` must be called before `logParams` (or an instance of" +
" Params must be provided explicitly).")
logParams(stage, params: _*)
}

def logNumFeatures(num: Long): Unit = {
Expand All @@ -146,27 +141,27 @@ private[spark] class Instrumentation extends Logging {
* Logs the value with customized name field.
*/
def logNamedValue(name: String, value: String): Unit = {
log(compact(render(name -> value)))
logInfo(compact(render(name -> value)))
}

def logNamedValue(name: String, value: Long): Unit = {
log(compact(render(name -> value)))
logInfo(compact(render(name -> value)))
}

def logNamedValue(name: String, value: Double): Unit = {
log(compact(render(name -> value)))
logInfo(compact(render(name -> value)))
}

def logNamedValue(name: String, value: Array[String]): Unit = {
log(compact(render(name -> compact(render(value.toSeq)))))
logInfo(compact(render(name -> compact(render(value.toSeq)))))
}

def logNamedValue(name: String, value: Array[Long]): Unit = {
log(compact(render(name -> compact(render(value.toSeq)))))
logInfo(compact(render(name -> compact(render(value.toSeq)))))
}

def logNamedValue(name: String, value: Array[Double]): Unit = {
log(compact(render(name -> compact(render(value.toSeq)))))
logInfo(compact(render(name -> compact(render(value.toSeq)))))
}


Expand All @@ -175,19 +170,19 @@ private[spark] class Instrumentation extends Logging {
* Logs the successful completion of the training session.
*/
def logSuccess(model: Model[_]): Unit = {
log(s"training finished")
logInfo(s"training finished")
}

def logSuccess(): Unit = {
log("training finished")
logInfo("training finished")
}

/**
* Logs an exception raised during a training session.
*/
def logFailure(e: Throwable): Unit = {
val msg = e.getStackTrace.mkString("\n")
super.logInfo(msg)
super.logError(msg)
}
}

Expand Down Expand Up @@ -222,13 +217,13 @@ private[spark] object Instrumentation {

def instrumented[T](body: (Instrumentation => T)): T = {
val instr = new Instrumentation()
Try(body(new Instrumentation())) match {
Try(body(instr)) match {
case Failure(NonFatal(e)) =>
instr.logFailure(e)
throw e
case Success(model) =>
case Success(result) =>
instr.logSuccess()
model
result
}
}
}
Expand Down Expand Up @@ -273,7 +268,7 @@ private[spark] object OptionalInstrumentation {
*/
def create(instr: Instrumentation): OptionalInstrumentation = {
new OptionalInstrumentation(Some(instr),
instr.estimator.getClass.getName.stripSuffix("$"))
instr.stage.getClass.getName.stripSuffix("$"))
}

/**
Expand Down