-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24747][ML] Make Instrumentation class more flexible #21719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,8 +27,8 @@ import org.json4s.JsonDSL._ | |
| import org.json4s.jackson.JsonMethods._ | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.param.Param | ||
| import org.apache.spark.ml.{Estimator, Model, PipelineStage} | ||
| import org.apache.spark.ml.param.{Param, Params} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.Dataset | ||
| import org.apache.spark.util.Utils | ||
|
|
@@ -41,41 +41,40 @@ private[spark] class Instrumentation extends Logging { | |
|
|
||
| private val id = UUID.randomUUID() | ||
| private val shortId = id.toString.take(8) | ||
| private var prefix = s"$shortId:" | ||
| private val prefix = s"[$shortId] " | ||
|
|
||
| // TODO: update spark.ml to use new Instrumentation APIs and remove this constructor | ||
| var estimator: Estimator[_] = _ | ||
| var stage: Params = _ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend we either plan to remove "stage" or change "logPipelineStage" so it only allows setting "stage" once. If we go with the former, how about leaving a note to remove "stage" once spark.ml code is migrated to use the new logParams() method?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, the plan is to remove stage once we port switch over to the new APIs |
||
| private def this(estimator: Estimator[_], dataset: RDD[_]) = { | ||
| this() | ||
| logContext(estimator, dataset) | ||
| logPipelineStage(estimator) | ||
| logDataset(dataset) | ||
| } | ||
|
|
||
| /** | ||
| * Log info about the estimator and dataset being fit. | ||
| * | ||
| * @param estimator the estimator that is being fit | ||
| * @param dataset the training dataset | ||
| * Log some info about the pipeline stage being fit. | ||
| */ | ||
| def logContext(estimator: Estimator[_], dataset: RDD[_]): Unit = { | ||
| this.estimator = estimator | ||
| prefix = { | ||
| // estimator.getClass.getSimpleName can cause Malformed class name error, | ||
| // call safer `Utils.getSimpleName` instead | ||
| val className = Utils.getSimpleName(estimator.getClass) | ||
| s"$shortId-$className-${estimator.uid}-${dataset.hashCode()}:" | ||
| } | ||
|
|
||
| log(s"training: numPartitions=${dataset.partitions.length}" + | ||
| s" storageLevel=${dataset.getStorageLevel}") | ||
| def logPipelineStage(stage: PipelineStage): Unit = { | ||
| this.stage = stage | ||
| // estimator.getClass.getSimpleName can cause Malformed class name error, | ||
| // call safer `Utils.getSimpleName` instead | ||
| val className = Utils.getSimpleName(stage.getClass) | ||
| logInfo(s"Stage class: $className") | ||
| logInfo(s"Stage uid: ${stage.uid}") | ||
| } | ||
|
|
||
| /** | ||
| * Log info about the estimator and dataset being fit. | ||
| * | ||
| * @param e the estimator that is being fit | ||
| * @param dataset the training dataset | ||
| * Log some data about the dataset being fit. | ||
| */ | ||
| def logContext(e: Estimator[_], dataset: Dataset[_]): Unit = logContext(e, dataset.rdd) | ||
| def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd) | ||
|
|
||
| /** | ||
| * Log some data about the dataset being fit. | ||
| */ | ||
| def logDataset(dataset: RDD[_]): Unit = { | ||
| logInfo(s"training: numPartitions=${dataset.partitions.length}" + | ||
| s" storageLevel=${dataset.getStorageLevel}") | ||
| } | ||
|
|
||
| /** | ||
| * Logs a debug message with a prefix that uniquely identifies the training session. | ||
|
|
@@ -105,29 +104,25 @@ private[spark] class Instrumentation extends Logging { | |
| super.logInfo(prefix + msg) | ||
| } | ||
|
|
||
| /** | ||
| * Alias for logInfo, see above. | ||
| */ | ||
| def log(msg: String): Unit = logInfo(msg) | ||
|
|
||
| /** | ||
| * Logs the value of the given parameters for the estimator being used in this session. | ||
| */ | ||
| def logParams(estimator: Estimator[_], params: Param[_]*): Unit = { | ||
| def logParams(hasParams: Params, params: Param[_]*): Unit = { | ||
| val pairs: Seq[(String, JValue)] = for { | ||
| p <- params | ||
| value <- estimator.get(p) | ||
| value <- hasParams.get(p) | ||
| } yield { | ||
| val cast = p.asInstanceOf[Param[Any]] | ||
| p.name -> parse(cast.jsonEncode(value)) | ||
| } | ||
| log(compact(render(map2jvalue(pairs.toMap)))) | ||
| logInfo(compact(render(map2jvalue(pairs.toMap)))) | ||
| } | ||
|
|
||
| // TODO: remove this | ||
| def logParams(params: Param[_]*): Unit = { | ||
| require(estimator != null, "`logContext` must be called before `logParams`.") | ||
| logParams(estimator, params: _*) | ||
| require(stage != null, "`logStageParams` must be called before `logParams` (or an instance of" + | ||
| " Params must be provided explicitly).") | ||
| logParams(stage, params: _*) | ||
| } | ||
|
|
||
| def logNumFeatures(num: Long): Unit = { | ||
|
|
@@ -146,27 +141,27 @@ private[spark] class Instrumentation extends Logging { | |
| * Logs the value with customized name field. | ||
| */ | ||
| def logNamedValue(name: String, value: String): Unit = { | ||
| log(compact(render(name -> value))) | ||
| logInfo(compact(render(name -> value))) | ||
| } | ||
|
|
||
| def logNamedValue(name: String, value: Long): Unit = { | ||
| log(compact(render(name -> value))) | ||
| logInfo(compact(render(name -> value))) | ||
| } | ||
|
|
||
| def logNamedValue(name: String, value: Double): Unit = { | ||
| log(compact(render(name -> value))) | ||
| logInfo(compact(render(name -> value))) | ||
| } | ||
|
|
||
| def logNamedValue(name: String, value: Array[String]): Unit = { | ||
| log(compact(render(name -> compact(render(value.toSeq))))) | ||
| logInfo(compact(render(name -> compact(render(value.toSeq))))) | ||
| } | ||
|
|
||
| def logNamedValue(name: String, value: Array[Long]): Unit = { | ||
| log(compact(render(name -> compact(render(value.toSeq))))) | ||
| logInfo(compact(render(name -> compact(render(value.toSeq))))) | ||
| } | ||
|
|
||
| def logNamedValue(name: String, value: Array[Double]): Unit = { | ||
| log(compact(render(name -> compact(render(value.toSeq))))) | ||
| logInfo(compact(render(name -> compact(render(value.toSeq))))) | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -175,19 +170,19 @@ private[spark] class Instrumentation extends Logging { | |
| * Logs the successful completion of the training session. | ||
| */ | ||
| def logSuccess(model: Model[_]): Unit = { | ||
| log(s"training finished") | ||
| logInfo(s"training finished") | ||
| } | ||
|
|
||
| def logSuccess(): Unit = { | ||
| log("training finished") | ||
| logInfo("training finished") | ||
| } | ||
|
|
||
| /** | ||
| * Logs an exception raised during a training session. | ||
| */ | ||
| def logFailure(e: Throwable): Unit = { | ||
| val msg = e.getStackTrace.mkString("\n") | ||
| super.logInfo(msg) | ||
| super.logError(msg) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -222,13 +217,13 @@ private[spark] object Instrumentation { | |
|
|
||
| def instrumented[T](body: (Instrumentation => T)): T = { | ||
| val instr = new Instrumentation() | ||
| Try(body(new Instrumentation())) match { | ||
| Try(body(instr)) match { | ||
| case Failure(NonFatal(e)) => | ||
| instr.logFailure(e) | ||
| throw e | ||
| case Success(model) => | ||
| case Success(result) => | ||
| instr.logSuccess() | ||
| model | ||
| result | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -273,7 +268,7 @@ private[spark] object OptionalInstrumentation { | |
| */ | ||
| def create(instr: Instrumentation): OptionalInstrumentation = { | ||
| new OptionalInstrumentation(Some(instr), | ||
| instr.estimator.getClass.getName.stripSuffix("$")) | ||
| instr.stage.getClass.getName.stripSuffix("$")) | ||
| } | ||
|
|
||
| /** | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put import at top of file with the other imports (just to make imports easier to track).