-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24747][ML] Make Instrumentation class more flexible #21719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,45 +19,60 @@ package org.apache.spark.ml.util | |
|
|
||
| import java.util.UUID | ||
|
|
||
| import scala.reflect.ClassTag | ||
| import scala.util.{Failure, Success, Try} | ||
| import scala.util.control.NonFatal | ||
|
|
||
| import org.json4s._ | ||
| import org.json4s.JsonDSL._ | ||
| import org.json4s.jackson.JsonMethods._ | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.param.Param | ||
| import org.apache.spark.ml.{Estimator, Model, PipelineStage} | ||
| import org.apache.spark.ml.param.{Param, Params} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.Dataset | ||
| import org.apache.spark.util.Utils | ||
|
|
||
| /** | ||
| * A small wrapper that defines a training session for an estimator, and some methods to log | ||
| * useful information during this session. | ||
| * | ||
| * A new instance is expected to be created within fit(). | ||
| * | ||
| * @param estimator the estimator that is being fit | ||
| * @param dataset the training dataset | ||
| * @tparam E the type of the estimator | ||
| */ | ||
| private[spark] class Instrumentation[E <: Estimator[_]] private ( | ||
| val estimator: E, | ||
| val dataset: RDD[_]) extends Logging { | ||
| private[spark] class Instrumentation extends Logging { | ||
|
|
||
| private val id = UUID.randomUUID() | ||
| private val prefix = { | ||
| private val shortId = id.toString.take(8) | ||
| private val prefix = s"[$shortId] " | ||
|
|
||
| // TODO: update spark.ml to use new Instrumentation APIs and remove this constructor | ||
| var stage: Params = _ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend we either plan to remove "stage" or change "logPipelineStage" so it only allows setting "stage" once. If we go with the former, how about leaving a note to remove "stage" once spark.ml code is migrated to use the new logParams() method?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, the plan is to remove stage once we port switch over to the new APIs |
||
| private def this(estimator: Estimator[_], dataset: RDD[_]) = { | ||
| this() | ||
| logPipelineStage(estimator) | ||
| logDataset(dataset) | ||
| } | ||
|
|
||
| /** | ||
| * Log some info about the pipeline stage being fit. | ||
| */ | ||
| def logPipelineStage(stage: PipelineStage): Unit = { | ||
| this.stage = stage | ||
| // estimator.getClass.getSimpleName can cause Malformed class name error, | ||
| // call safer `Utils.getSimpleName` instead | ||
| val className = Utils.getSimpleName(estimator.getClass) | ||
| s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " | ||
| val className = Utils.getSimpleName(stage.getClass) | ||
| logInfo(s"Stage class: $className") | ||
| logInfo(s"Stage uid: ${stage.uid}") | ||
| } | ||
|
|
||
| init() | ||
| /** | ||
| * Log some data about the dataset being fit. | ||
| */ | ||
| def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd) | ||
|
|
||
| private def init(): Unit = { | ||
| log(s"training: numPartitions=${dataset.partitions.length}" + | ||
| /** | ||
| * Log some data about the dataset being fit. | ||
| */ | ||
| def logDataset(dataset: RDD[_]): Unit = { | ||
| logInfo(s"training: numPartitions=${dataset.partitions.length}" + | ||
| s" storageLevel=${dataset.getStorageLevel}") | ||
| } | ||
|
|
||
|
|
@@ -89,23 +104,25 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( | |
| super.logInfo(prefix + msg) | ||
| } | ||
|
|
||
| /** | ||
| * Alias for logInfo, see above. | ||
| */ | ||
| def log(msg: String): Unit = logInfo(msg) | ||
|
|
||
| /** | ||
| * Logs the value of the given parameters for the estimator being used in this session. | ||
| */ | ||
| def logParams(params: Param[_]*): Unit = { | ||
| def logParams(hasParams: Params, params: Param[_]*): Unit = { | ||
| val pairs: Seq[(String, JValue)] = for { | ||
| p <- params | ||
| value <- estimator.get(p) | ||
| value <- hasParams.get(p) | ||
| } yield { | ||
| val cast = p.asInstanceOf[Param[Any]] | ||
| p.name -> parse(cast.jsonEncode(value)) | ||
| } | ||
| log(compact(render(map2jvalue(pairs.toMap)))) | ||
| logInfo(compact(render(map2jvalue(pairs.toMap)))) | ||
| } | ||
|
|
||
| // TODO: remove this | ||
| def logParams(params: Param[_]*): Unit = { | ||
| require(stage != null, "`logStageParams` must be called before `logParams` (or an instance of" + | ||
| " Params must be provided explicitly).") | ||
| logParams(stage, params: _*) | ||
| } | ||
|
|
||
| def logNumFeatures(num: Long): Unit = { | ||
|
|
@@ -124,35 +141,48 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( | |
| * Logs the value with customized name field. | ||
| */ | ||
| def logNamedValue(name: String, value: String): Unit = { | ||
| log(compact(render(name -> value))) | ||
| logInfo(compact(render(name -> value))) | ||
| } | ||
|
|
||
| def logNamedValue(name: String, value: Long): Unit = { | ||
| log(compact(render(name -> value))) | ||
| logInfo(compact(render(name -> value))) | ||
| } | ||
|
|
||
| def logNamedValue(name: String, value: Double): Unit = { | ||
| log(compact(render(name -> value))) | ||
| logInfo(compact(render(name -> value))) | ||
| } | ||
|
|
||
| def logNamedValue(name: String, value: Array[String]): Unit = { | ||
| log(compact(render(name -> compact(render(value.toSeq))))) | ||
| logInfo(compact(render(name -> compact(render(value.toSeq))))) | ||
| } | ||
|
|
||
| def logNamedValue(name: String, value: Array[Long]): Unit = { | ||
| log(compact(render(name -> compact(render(value.toSeq))))) | ||
| logInfo(compact(render(name -> compact(render(value.toSeq))))) | ||
| } | ||
|
|
||
| def logNamedValue(name: String, value: Array[Double]): Unit = { | ||
| log(compact(render(name -> compact(render(value.toSeq))))) | ||
| logInfo(compact(render(name -> compact(render(value.toSeq))))) | ||
| } | ||
|
|
||
|
|
||
| // TODO: Remove this (possibly replace with logModel?) | ||
| /** | ||
| * Logs the successful completion of the training session. | ||
| */ | ||
| def logSuccess(model: Model[_]): Unit = { | ||
| log(s"training finished") | ||
| logInfo(s"training finished") | ||
| } | ||
|
|
||
| def logSuccess(): Unit = { | ||
| logInfo("training finished") | ||
| } | ||
|
|
||
| /** | ||
| * Logs an exception raised during a training session. | ||
| */ | ||
| def logFailure(e: Throwable): Unit = { | ||
| val msg = e.getStackTrace.mkString("\n") | ||
| super.logError(msg) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -169,22 +199,33 @@ private[spark] object Instrumentation { | |
| val varianceOfLabels = "varianceOfLabels" | ||
| } | ||
|
|
||
| // TODO: Remove these | ||
| /** | ||
| * Creates an instrumentation object for a training session. | ||
| */ | ||
| def create[E <: Estimator[_]]( | ||
| estimator: E, dataset: Dataset[_]): Instrumentation[E] = { | ||
| create[E](estimator, dataset.rdd) | ||
| def create(estimator: Estimator[_], dataset: Dataset[_]): Instrumentation = { | ||
| create(estimator, dataset.rdd) | ||
| } | ||
|
|
||
| /** | ||
| * Creates an instrumentation object for a training session. | ||
| */ | ||
| def create[E <: Estimator[_]]( | ||
| estimator: E, dataset: RDD[_]): Instrumentation[E] = { | ||
| new Instrumentation[E](estimator, dataset) | ||
| def create(estimator: Estimator[_], dataset: RDD[_]): Instrumentation = { | ||
| new Instrumentation(estimator, dataset) | ||
| } | ||
| // end remove | ||
|
|
||
| def instrumented[T](body: (Instrumentation => T)): T = { | ||
| val instr = new Instrumentation() | ||
| Try(body(instr)) match { | ||
| case Failure(NonFatal(e)) => | ||
| instr.logFailure(e) | ||
| throw e | ||
| case Success(result) => | ||
| instr.logSuccess() | ||
| result | ||
| } | ||
| } | ||
|
|
||
| } | ||
|
|
||
| /** | ||
|
|
@@ -193,7 +234,7 @@ private[spark] object Instrumentation { | |
| * will log via it, otherwise will log via common logger. | ||
| */ | ||
| private[spark] class OptionalInstrumentation private( | ||
| val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], | ||
| val instrumentation: Option[Instrumentation], | ||
| val className: String) extends Logging { | ||
|
|
||
| protected override def logName: String = className | ||
|
|
@@ -225,9 +266,9 @@ private[spark] object OptionalInstrumentation { | |
| /** | ||
| * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. | ||
| */ | ||
| def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { | ||
| def create(instr: Instrumentation): OptionalInstrumentation = { | ||
| new OptionalInstrumentation(Some(instr), | ||
| instr.estimator.getClass.getName.stripSuffix("$")) | ||
| instr.stage.getClass.getName.stripSuffix("$")) | ||
| } | ||
|
|
||
| /** | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put import at top of file with the other imports (just to make imports easier to track).