diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 06ca37bc7514..109df6335e69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -35,6 +35,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer @@ -490,7 +491,7 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train( dataset: Dataset[_], - handlePersistence: Boolean): LogisticRegressionModel = { + handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr => val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -500,7 +501,8 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, dataset) + instr.logPipelineStage(this) + instr.logDataset(dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -905,8 +907,6 @@ class LogisticRegression @Since("1.2.0") ( objectiveHistory) } model.setSummary(Some(logRegSummary)) - instr.logSuccess(model) - model } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 905870178e54..bb3f3a015c71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -91,7 +91,7 @@ private[spark] object RandomForest extends Logging { numTrees: Int, featureSubsetStrategy: String, seed: Long, - instr: Option[Instrumentation[_]], + instr: Option[Instrumentation], prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 363304ef1014..135828815504 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -80,7 +80,7 @@ private[ml] trait ValidatorParams extends HasSeed with Params { /** * Instrumentation logging for tuning params including the inner estimator and evaluator info. */ - protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = { + protected def logTuningParams(instrumentation: Instrumentation): Unit = { instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName) instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 11f46eb9e435..2e43a9ef49ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -19,15 +19,16 @@ package org.apache.spark.ml.util import java.util.UUID -import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.util.Utils @@ -35,29 +36,44 @@ import org.apache.spark.util.Utils /** * A small wrapper that defines a training session for an estimator, and some methods to log * useful information during this session. - * - * A new instance is expected to be created within fit(). - * - * @param estimator the estimator that is being fit - * @param dataset the training dataset - * @tparam E the type of the estimator */ -private[spark] class Instrumentation[E <: Estimator[_]] private ( - val estimator: E, - val dataset: RDD[_]) extends Logging { +private[spark] class Instrumentation extends Logging { private val id = UUID.randomUUID() - private val prefix = { + private val shortId = id.toString.take(8) + private val prefix = s"[$shortId] " + + // TODO: remove stage + var stage: Params = _ + // TODO: update spark.ml to use new Instrumentation APIs and remove this constructor + private def this(estimator: Estimator[_], dataset: RDD[_]) = { + this() + logPipelineStage(estimator) + logDataset(dataset) + } + + /** + * Log some info about the pipeline stage being fit. + */ + def logPipelineStage(stage: PipelineStage): Unit = { + this.stage = stage // estimator.getClass.getSimpleName can cause Malformed class name error, // call safer `Utils.getSimpleName` instead - val className = Utils.getSimpleName(estimator.getClass) - s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " + val className = Utils.getSimpleName(stage.getClass) + logInfo(s"Stage class: $className") + logInfo(s"Stage uid: ${stage.uid}") } - init() + /** + * Log some data about the dataset being fit. + */ + def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd) - private def init(): Unit = { - log(s"training: numPartitions=${dataset.partitions.length}" + + /** + * Log some data about the dataset being fit. + */ + def logDataset(dataset: RDD[_]): Unit = { + logInfo(s"training: numPartitions=${dataset.partitions.length}" + s" storageLevel=${dataset.getStorageLevel}") } @@ -89,23 +105,25 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( super.logInfo(prefix + msg) } - /** - * Alias for logInfo, see above. - */ - def log(msg: String): Unit = logInfo(msg) - /** * Logs the value of the given parameters for the estimator being used in this session. */ - def logParams(params: Param[_]*): Unit = { + def logParams(hasParams: Params, params: Param[_]*): Unit = { val pairs: Seq[(String, JValue)] = for { p <- params - value <- estimator.get(p) + value <- hasParams.get(p) } yield { val cast = p.asInstanceOf[Param[Any]] p.name -> parse(cast.jsonEncode(value)) } - log(compact(render(map2jvalue(pairs.toMap)))) + logInfo(compact(render(map2jvalue(pairs.toMap)))) + } + + // TODO: remove this + def logParams(params: Param[_]*): Unit = { + require(stage != null, "`logStageParams` must be called before `logParams` (or an instance of" + + " Params must be provided explicitly).") + logParams(stage, params: _*) } def logNumFeatures(num: Long): Unit = { @@ -124,35 +142,48 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( * Logs the value with customized name field. */ def logNamedValue(name: String, value: String): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Long): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Double): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Array[String]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } def logNamedValue(name: String, value: Array[Long]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } def logNamedValue(name: String, value: Array[Double]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } + // TODO: Remove this (possibly replace with logModel?) /** * Logs the successful completion of the training session. */ def logSuccess(model: Model[_]): Unit = { - log(s"training finished") + logInfo(s"training finished") + } + + def logSuccess(): Unit = { + logInfo("training finished") + } + + /** + * Logs an exception raised during a training session. + */ + def logFailure(e: Throwable): Unit = { + val msg = e.getStackTrace.mkString("\n") + super.logError(msg) } } @@ -169,22 +200,33 @@ private[spark] object Instrumentation { val varianceOfLabels = "varianceOfLabels" } + // TODO: Remove these /** * Creates an instrumentation object for a training session. */ - def create[E <: Estimator[_]]( - estimator: E, dataset: Dataset[_]): Instrumentation[E] = { - create[E](estimator, dataset.rdd) + def create(estimator: Estimator[_], dataset: Dataset[_]): Instrumentation = { + create(estimator, dataset.rdd) } /** * Creates an instrumentation object for a training session. */ - def create[E <: Estimator[_]]( - estimator: E, dataset: RDD[_]): Instrumentation[E] = { - new Instrumentation[E](estimator, dataset) + def create(estimator: Estimator[_], dataset: RDD[_]): Instrumentation = { + new Instrumentation(estimator, dataset) + } + // end remove + + def instrumented[T](body: (Instrumentation => T)): T = { + val instr = new Instrumentation() + Try(body(instr)) match { + case Failure(NonFatal(e)) => + instr.logFailure(e) + throw e + case Success(result) => + instr.logSuccess() + result + } } - } /** @@ -193,7 +235,7 @@ private[spark] object Instrumentation { * will log via it, otherwise will log via common logger. */ private[spark] class OptionalInstrumentation private( - val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], + val instrumentation: Option[Instrumentation], val className: String) extends Logging { protected override def logName: String = className @@ -225,9 +267,9 @@ private[spark] object OptionalInstrumentation { /** * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. */ - def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { + def create(instr: Instrumentation): OptionalInstrumentation = { new OptionalInstrumentation(Some(instr), - instr.estimator.getClass.getName.stripSuffix("$")) + instr.stage.getClass.getName.stripSuffix("$")) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index b5b1be349049..d77299cc821e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -235,7 +235,7 @@ class KMeans private ( private[spark] def run( data: RDD[Vector], - instr: Option[Instrumentation[NewKMeans]]): KMeansModel = { + instr: Option[Instrumentation]): KMeansModel = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" @@ -264,7 +264,7 @@ class KMeans private ( */ private def runAlgorithm( data: RDD[VectorWithNorm], - instr: Option[Instrumentation[NewKMeans]]): KMeansModel = { + instr: Option[Instrumentation]): KMeansModel = { val sc = data.sparkContext