Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ object LogKey extends Enumeration {
val MIN_SIZE = Value
val REMOTE_ADDRESS = Value
val POD_ID = Value
val NUM_ITERATIONS = Value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's sort the keys.

val LEARNING_RATE = Value
val SUBSAMPLING_RATE = Value

val MAX_CATEGORIES = Value
val CATEGORICAL_FEATURES = Value

val RANGE_CLASSIFICATION_LABELS = Value
val NUM_CLASSIFICATION_LABELS = Value
val OPTIMIZER_CLASS_NAME = Value

type LogKey = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ case class MessageWithContext(message: String, context: java.util.HashMap[String
resultMap.putAll(mdc.context)
MessageWithContext(message + mdc.message, resultMap)
}

override def toString: String = message
}

/**
Expand Down Expand Up @@ -117,7 +119,7 @@ trait Logging {
}
}

private def withLogContext(context: java.util.HashMap[String, String])(body: => Unit): Unit = {
protected def withLogContext(context: java.util.HashMap[String, String])(body: => Unit): Unit = {
Copy link
Contributor Author

@panbingkun panbingkun Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change private to protected
because some class extends Logging and override the method logInfo, logWarn, 'logError', eg:
mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala

val threadContext = CloseableThreadContext.putAll(context)
try {
body
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ class MDCSuite
val log = log"This is a log, exitcode ${MDC(EXIT_CODE, 10086)}"
assert(log.message === "This is a log, exitcode 10086")
assert(log.context === Map("exit_code" -> "10086").asJava)
assert(log.toString === "This is a log, exitcode 10086")
}

test("custom object as MDC value") {
val cov = CustomObjectValue("spark", 10086)
val log = log"This is a log, exitcode ${MDC(EXIT_CODE, cov)}"
assert(log.message === "This is a log, exitcode CustomObjectValue: spark, 10086")
assert(log.context === Map("exit_code" -> "CustomObjectValue: spark, 10086").asJava)
assert(log.toString === "This is a log, exitcode CustomObjectValue: spark, 10086")
}

case class CustomObjectValue(key: String, value: Int) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKey.{NUM_CLASSIFICATION_LABELS, OPTIMIZER_CLASS_NAME, RANGE_CLASSIFICATION_LABELS}
import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.optim.aggregator._
Expand Down Expand Up @@ -220,10 +221,11 @@ class LinearSVC @Since("2.2.0") (
instr.logNumFeatures(numFeatures)

if (numInvalid != 0) {
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
s"Found $numInvalid invalid labels."
val msg = log"Classification labels should be in " +
log"${MDC(RANGE_CLASSIFICATION_LABELS, s"[0 to ${numClasses - 1}]")}. " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking about making the log keys generic. How about making it RANGE here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah

log"Found ${MDC(NUM_CLASSIFICATION_LABELS, numInvalid)} invalid labels."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking about making the log keys generic. How about making it COUNT here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah

instr.logError(msg)
throw new SparkException(msg)
throw new SparkException(msg.message)
}

val featuresStd = summarizer.std.toArray
Expand All @@ -249,9 +251,9 @@ class LinearSVC @Since("2.2.0") (
regularization, optimizer)

if (rawCoefficients == null) {
val msg = s"${optimizer.getClass.getName} failed."
val msg = log"${MDC(OPTIMIZER_CLASS_NAME, optimizer.getClass.getName)} failed."
instr.logError(msg)
throw new SparkException(msg)
throw new SparkException(msg.message)
}

val coefficientArray = Array.tabulate(numFeatures) { i =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKey.{NUM_CLASSIFICATION_LABELS, OPTIMIZER_CLASS_NAME, RANGE_CLASSIFICATION_LABELS}
import org.apache.spark.ml.feature._
import org.apache.spark.ml.impl.Utils
import org.apache.spark.ml.linalg._
Expand Down Expand Up @@ -530,10 +531,11 @@ class LogisticRegression @Since("1.2.0") (
}

if (numInvalid != 0) {
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
s"Found $numInvalid invalid labels."
val msg = log"Classification labels should be in " +
log"${MDC(RANGE_CLASSIFICATION_LABELS, s"[0 to ${numClasses - 1}]")}. " +
log"Found ${MDC(NUM_CLASSIFICATION_LABELS, numInvalid)} invalid labels."
instr.logError(msg)
throw new SparkException(msg)
throw new SparkException(msg.message)
}

instr.logNumClasses(numClasses)
Expand Down Expand Up @@ -634,9 +636,9 @@ class LogisticRegression @Since("1.2.0") (
initialSolution.toArray, regularization, optimizer)

if (allCoefficients == null) {
val msg = s"${optimizer.getClass.getName} failed."
val msg = log"${MDC(OPTIMIZER_CLASS_NAME, optimizer.getClass.getName)} failed."
instr.logError(msg)
throw new SparkException(msg)
throw new SparkException(msg.message)
}

val allCoefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKey.OPTIMIZER_CLASS_NAME
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg._
Expand Down Expand Up @@ -271,9 +272,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
optimizer, initialSolution)

if (rawCoefficients == null) {
val msg = s"${optimizer.getClass.getName} failed."
val msg = log"${MDC(OPTIMIZER_CLASS_NAME, optimizer.getClass.getName)} failed."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple duplicated codes in the changes of this PR. Let's create a method to reduce duplications.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, very good suggestion

instr.logError(msg)
throw new SparkException(msg)
throw new SparkException(msg.message)
}

val coefficientArray = Array.tabulate(numFeatures) { i =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKey.OPTIMIZER_CLASS_NAME
import org.apache.spark.ml.{PipelineStage, PredictorParams}
import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
Expand Down Expand Up @@ -428,9 +429,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
featuresMean, featuresStd, initialSolution, regularization, optimizer)

if (parameters == null) {
val msg = s"${optimizer.getClass.getName} failed."
val msg = log"${MDC(OPTIMIZER_CLASS_NAME, optimizer.getClass.getName)} failed."
instr.logError(msg)
throw new SparkException(msg)
throw new SparkException(msg.message)
}

val model = createModel(parameters, yMean, yStd, featuresMean, featuresStd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.internal.Logging
import org.apache.spark.internal.{LogEntry, Logging, MessageWithContext}
import org.apache.spark.ml.{MLEvents, PipelineStage}
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -84,20 +84,53 @@ private[spark] class Instrumentation private () extends Logging with MLEvents {
super.logWarning(prefix + msg)
}

/**
* Logs a LogEntry which message with a prefix that uniquely identifies the training session.
*/
override def logWarning(entry: LogEntry): Unit = {
if (log.isWarnEnabled) {
withLogContext(entry.context) {
log.warn(prefix + entry.message)
}
}
}

/**
* Logs a error message with a prefix that uniquely identifies the training session.
*/
override def logError(msg: => String): Unit = {
super.logError(prefix + msg)
}

/**
* Logs a LogEntry which message with a prefix that uniquely identifies the training session.
*/
override def logError(entry: LogEntry): Unit = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can write it as follows:

override def logError(entry: LogEntry): Unit = {
   super.logError(MessageWithContext(prefix + entry.message, entry.context))
}

But it seems that the efficiency is not as high as mentioned above.

if (log.isErrorEnabled) {
withLogContext(entry.context) {
log.error(prefix + entry.message)
}
}
}

/**
* Logs an info message with a prefix that uniquely identifies the training session.
*/
override def logInfo(msg: => String): Unit = {
super.logInfo(prefix + msg)
}

/**
* Logs a LogEntry which message with a prefix that uniquely identifies the training session.
*/
override def logInfo(entry: LogEntry): Unit = {
if (log.isInfoEnabled) {
withLogContext(entry.context) {
log.info(prefix + entry.message)
}
}
}

/**
* Logs the value of the given parameters for the estimator being used in this session.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.mllib.util

import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKey.{NUM_CLASSIFICATION_LABELS, RANGE_CLASSIFICATION_LABELS}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD

Expand All @@ -37,7 +38,8 @@ object DataValidators extends Logging {
val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count()
if (numInvalid != 0) {
logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels")
logError(log"Classification labels should be 0 or 1. " +
log"Found ${MDC(NUM_CLASSIFICATION_LABELS, numInvalid)} invalid labels")
}
numInvalid == 0
}
Expand All @@ -53,8 +55,9 @@ object DataValidators extends Logging {
val numInvalid = data.filter(x =>
x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count()
if (numInvalid != 0) {
logError("Classification labels should be in {0 to " + (k - 1) + "}. " +
"Found " + numInvalid + " invalid labels")
logError(log"Classification labels should be in " +
log"${MDC(RANGE_CLASSIFICATION_LABELS, s"[0 to ${k - 1}]")}. " +
log"Found ${MDC(NUM_CLASSIFICATION_LABELS, numInvalid)} invalid labels")
}
numInvalid == 0
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKey.{CATEGORICAL_FEATURES, MAX_CATEGORIES}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
Expand Down Expand Up @@ -175,8 +176,10 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging {
maxCategories: Int,
categoricalFeatures: Set[Int]): Unit = {
val collectedData = data.collect().map(_.getAs[Vector](0))
val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," +
s" categoricalFeatures=${categoricalFeatures.mkString(", ")}"

val errMsg = log"checkCategoryMaps failed for input with " +
log"maxCategories=${MDC(MAX_CATEGORIES, maxCategories)} " +
log"categoricalFeatures=${MDC(CATEGORICAL_FEATURES, categoricalFeatures.mkString(", "))}"
try {
val vectorIndexer = getIndexer.setMaxCategories(maxCategories)
val model = vectorIndexer.fit(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.mllib.tree

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.LogKey.{LEARNING_RATE, NUM_ITERATIONS, SUBSAMPLING_RATE}
import org.apache.spark.internal.MDC
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
Expand Down Expand Up @@ -51,8 +53,9 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
gbt, GradientBoostedTreesSuite.data.toImmutableArraySeq, 0.06)
} catch {
case e: java.lang.AssertionError =>
logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
s" subsamplingRate=$subsamplingRate")
logError(log"FAILED for numIterations=${MDC(NUM_ITERATIONS, numIterations)}, " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create a method to reduce duplicated code in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

log"learningRate=${MDC(LEARNING_RATE, learningRate)}, " +
log"subsamplingRate=${MDC(SUBSAMPLING_RATE, subsamplingRate)}")
throw e
}

Expand Down Expand Up @@ -82,8 +85,9 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
gbt, GradientBoostedTreesSuite.data.toImmutableArraySeq, 0.85, "mae")
} catch {
case e: java.lang.AssertionError =>
logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
s" subsamplingRate=$subsamplingRate")
logError(log"FAILED for numIterations=${MDC(NUM_ITERATIONS, numIterations)}, " +
log"learningRate=${MDC(LEARNING_RATE, learningRate)}, " +
log"subsamplingRate=${MDC(SUBSAMPLING_RATE, subsamplingRate)}")
throw e
}

Expand Down Expand Up @@ -114,8 +118,9 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
gbt, GradientBoostedTreesSuite.data.toImmutableArraySeq, 0.9)
} catch {
case e: java.lang.AssertionError =>
logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
s" subsamplingRate=$subsamplingRate")
logError(log"FAILED for numIterations=${MDC(NUM_ITERATIONS, numIterations)}, " +
log"learningRate=${MDC(LEARNING_RATE, learningRate)}, " +
log"subsamplingRate=${MDC(SUBSAMPLING_RATE, subsamplingRate)}")
throw e
}

Expand Down