Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ class LogisticRegression @Since("1.2.0") (
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
)(seqOp, combOp, $(aggregationDepth))
}
instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count)
instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.min.toString)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not log the whole histogram ( each label -> its weightSum ).
Only log min/max weightSum seems useless and user even do not know they related to which label.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just proxy for balance in the dataset. We can log more, I just wanted to start by logging something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to add a utility for logging arrays and vectors (we can json encode them), but in the meantime I wanted to capture at least minimal information about the data balance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with not logging the full histogram here. There's a typo, where "highestLabelWeight" is actually logging the min (not max)


val histogram = labelSummarizer.histogram
val numInvalid = labelSummarizer.countInvalid
Expand Down Expand Up @@ -567,8 +570,8 @@ class LogisticRegression @Since("1.2.0") (
val isConstantLabel = histogram.count(_ != 0.0) == 1

if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) {
logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " +
s"will be zeros. Training is not needed.")
instr.logWarning(s"All labels are the same value and fitIntercept=true, so the " +
s"coefficients will be zeros. Training is not needed.")
val constantLabelIndex = Vectors.dense(histogram).argmax
val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures,
new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double],
Expand All @@ -581,7 +584,7 @@ class LogisticRegression @Since("1.2.0") (
(coefMatrix, interceptVec, Array.empty[Double])
} else {
if (!$(fitIntercept) && isConstantLabel) {
logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
instr.logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
s"dangerous ground, so the algorithm may not converge.")
}

Expand All @@ -590,7 +593,7 @@ class LogisticRegression @Since("1.2.0") (

if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
logWarning("Fitting LogisticRegressionModel without intercept on dataset with " +
instr.logWarning("Fitting LogisticRegressionModel without intercept on dataset with " +
"constant nonzero column, Spark MLlib outputs zero coefficients for constant " +
"nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.")
}
Expand Down Expand Up @@ -708,7 +711,7 @@ class LogisticRegression @Since("1.2.0") (
(_initialModel.interceptVector.size == numCoefficientSets) &&
(_initialModel.getFitIntercept == $(fitIntercept))
if (!modelIsValid) {
logWarning(s"Initial coefficients will be ignored! Its dimensions " +
instr.logWarning(s"Initial coefficients will be ignored! Its dimensions " +
s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " +
s"expected size ($numCoefficientSets, $numFeatures)")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.ml.util

import java.util.concurrent.atomic.AtomicLong
import java.util.UUID

import org.json4s._
import org.json4s.JsonDSL._
Expand All @@ -42,7 +42,7 @@ import org.apache.spark.sql.Dataset
private[spark] class Instrumentation[E <: Estimator[_]] private (
estimator: E, dataset: RDD[_]) extends Logging {

private val id = Instrumentation.counter.incrementAndGet()
private val id = UUID.randomUUID()
private val prefix = {
val className = estimator.getClass.getSimpleName
s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
Expand All @@ -56,12 +56,24 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
}

/**
* Logs a message with a prefix that uniquely identifies the training session.
* Logs a warning message with a prefix that uniquely identifies the training session.
*/
def log(msg: String): Unit = {
logInfo(prefix + msg)
override def logWarning(msg: => String): Unit = {
super.logWarning(prefix + msg)
}

/**
* Logs an info message with a prefix that uniquely identifies the training session.
*/
override def logInfo(msg: => String): Unit = {
super.logInfo(prefix + msg)
}

/**
* Alias for logInfo, see above.
*/
def log(msg: String): Unit = logInfo(msg)

/**
* Logs the value of the given parameters for the estimator being used in this session.
*/
Expand All @@ -77,11 +89,11 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
}

def logNumFeatures(num: Long): Unit = {
log(compact(render("numFeatures" -> num)))
logNamedValue(Instrumentation.loggerTags.numFeatures, num)
}

def logNumClasses(num: Long): Unit = {
log(compact(render("numClasses" -> num)))
logNamedValue(Instrumentation.loggerTags.numClasses, num)
}

/**
Expand All @@ -107,7 +119,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
* Some common methods for logging information about a training session.
*/
private[spark] object Instrumentation {
private val counter = new AtomicLong(0)

object loggerTags {
val numFeatures = "numFeatures"
val numClasses = "numClasses"
val numExamples = "numExamples"
}

/**
* Creates an instrumentation object for a training session.
Expand Down