Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ class LogisticRegression @Since("1.2.0") (
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
)(seqOp, combOp, $(aggregationDepth))
}
instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count)
instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)

val histogram = labelSummarizer.histogram
val numInvalid = labelSummarizer.countInvalid
Expand Down Expand Up @@ -560,15 +563,15 @@ class LogisticRegression @Since("1.2.0") (
if (numInvalid != 0) {
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
s"Found $numInvalid invalid labels."
logError(msg)
instr.logError(msg)
throw new SparkException(msg)
}

val isConstantLabel = histogram.count(_ != 0.0) == 1

if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) {
logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " +
s"will be zeros. Training is not needed.")
instr.logWarning(s"All labels are the same value and fitIntercept=true, so the " +
s"coefficients will be zeros. Training is not needed.")
val constantLabelIndex = Vectors.dense(histogram).argmax
val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures,
new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double],
Expand All @@ -581,7 +584,7 @@ class LogisticRegression @Since("1.2.0") (
(coefMatrix, interceptVec, Array.empty[Double])
} else {
if (!$(fitIntercept) && isConstantLabel) {
logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
instr.logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
s"dangerous ground, so the algorithm may not converge.")
}

Expand All @@ -590,7 +593,7 @@ class LogisticRegression @Since("1.2.0") (

if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
logWarning("Fitting LogisticRegressionModel without intercept on dataset with " +
instr.logWarning("Fitting LogisticRegressionModel without intercept on dataset with " +
"constant nonzero column, Spark MLlib outputs zero coefficients for constant " +
"nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.")
}
Expand Down Expand Up @@ -708,7 +711,7 @@ class LogisticRegression @Since("1.2.0") (
(_initialModel.interceptVector.size == numCoefficientSets) &&
(_initialModel.getFitIntercept == $(fitIntercept))
if (!modelIsValid) {
logWarning(s"Initial coefficients will be ignored! Its dimensions " +
instr.logWarning(s"Initial coefficients will be ignored! Its dimensions " +
s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " +
s"expected size ($numCoefficientSets, $numFeatures)")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.ml.util

import java.util.concurrent.atomic.AtomicLong
import java.util.UUID

import org.json4s._
import org.json4s.JsonDSL._
Expand All @@ -42,7 +42,7 @@ import org.apache.spark.sql.Dataset
private[spark] class Instrumentation[E <: Estimator[_]] private (
estimator: E, dataset: RDD[_]) extends Logging {

private val id = Instrumentation.counter.incrementAndGet()
private val id = UUID.randomUUID()
private val prefix = {
val className = estimator.getClass.getSimpleName
s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
Expand All @@ -56,12 +56,31 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
}

/**
* Logs a message with a prefix that uniquely identifies the training session.
* Logs a warning message with a prefix that uniquely identifies the training session.
*/
def log(msg: String): Unit = {
logInfo(prefix + msg)
override def logWarning(msg: => String): Unit = {
super.logWarning(prefix + msg)
}

/**
* Logs a error message with a prefix that uniquely identifies the training session.
*/
override def logError(msg: => String): Unit = {
super.logError(prefix + msg)
}

/**
* Logs an info message with a prefix that uniquely identifies the training session.
*/
override def logInfo(msg: => String): Unit = {
super.logInfo(prefix + msg)
}

/**
* Alias for logInfo, see above.
*/
def log(msg: String): Unit = logInfo(msg)

/**
* Logs the value of the given parameters for the estimator being used in this session.
*/
Expand All @@ -77,11 +96,11 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
}

def logNumFeatures(num: Long): Unit = {
log(compact(render("numFeatures" -> num)))
logNamedValue(Instrumentation.loggerTags.numFeatures, num)
}

def logNumClasses(num: Long): Unit = {
log(compact(render("numClasses" -> num)))
logNamedValue(Instrumentation.loggerTags.numClasses, num)
}

/**
Expand All @@ -107,7 +126,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
* Some common methods for logging information about a training session.
*/
private[spark] object Instrumentation {
private val counter = new AtomicLong(0)

object loggerTags {
val numFeatures = "numFeatures"
val numClasses = "numClasses"
val numExamples = "numExamples"
}

/**
* Creates an instrumentation object for a training session.
Expand Down