Skip to content

Commit 3fd0cd6

Browse files
panbingkungengliangwang
authored andcommitted
[SPARK-47598][CORE] MLLib: Migrate logError with variables to structured logging framework
### What changes were proposed in this pull request? The pr aims to migrate `logError` in module `MLLib` with variables to `structured logging framework`. ### Why are the changes needed? To enhance Apache Spark's logging system by implementing structured logging. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45837 from panbingkun/SPARK-47598. Authored-by: panbingkun <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent e3405c1 commit 3fd0cd6

File tree

11 files changed

+99
-42
lines changed

11 files changed

+99
-42
lines changed

common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ object LogKey extends Enumeration {
2828
val BLOCK_MANAGER_ID = Value
2929
val BROADCAST_ID = Value
3030
val BUCKET = Value
31+
val CATEGORICAL_FEATURES = Value
3132
val CLASS_LOADER = Value
3233
val CLASS_NAME = Value
3334
val COMMAND = Value
@@ -44,17 +45,22 @@ object LogKey extends Enumeration {
4445
val EXIT_CODE = Value
4546
val HOST = Value
4647
val JOB_ID = Value
48+
val LEARNING_RATE = Value
4749
val LINE = Value
4850
val LINE_NUM = Value
4951
val MASTER_URL = Value
5052
val MAX_ATTEMPTS = Value
53+
val MAX_CATEGORIES = Value
5154
val MAX_EXECUTOR_FAILURES = Value
5255
val MAX_SIZE = Value
5356
val MIN_SIZE = Value
57+
val NUM_ITERATIONS = Value
5458
val OLD_BLOCK_MANAGER_ID = Value
59+
val OPTIMIZER_CLASS_NAME = Value
5560
val PARTITION_ID = Value
5661
val PATH = Value
5762
val POD_ID = Value
63+
val RANGE = Value
5864
val REASON = Value
5965
val REMOTE_ADDRESS = Value
6066
val RETRY_COUNT = Value
@@ -63,6 +69,7 @@ object LogKey extends Enumeration {
6369
val SIZE = Value
6470
val STAGE_ID = Value
6571
val SUBMISSION_ID = Value
72+
val SUBSAMPLING_RATE = Value
6673
val TASK_ATTEMPT_ID = Value
6774
val TASK_ID = Value
6875
val TASK_NAME = Value

common/utils/src/main/scala/org/apache/spark/internal/Logging.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ trait Logging {
117117
}
118118
}
119119

120-
private def withLogContext(context: java.util.HashMap[String, String])(body: => Unit): Unit = {
120+
protected def withLogContext(context: java.util.HashMap[String, String])(body: => Unit): Unit = {
121121
val threadContext = CloseableThreadContext.putAll(context)
122122
try {
123123
body

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ import org.apache.hadoop.fs.Path
2525

2626
import org.apache.spark.SparkException
2727
import org.apache.spark.annotation.Since
28-
import org.apache.spark.internal.Logging
28+
import org.apache.spark.internal.{Logging, MDC}
29+
import org.apache.spark.internal.LogKey.{COUNT, RANGE}
2930
import org.apache.spark.ml.feature._
3031
import org.apache.spark.ml.linalg._
3132
import org.apache.spark.ml.optim.aggregator._
@@ -36,6 +37,7 @@ import org.apache.spark.ml.stat._
3637
import org.apache.spark.ml.util._
3738
import org.apache.spark.ml.util.DatasetUtils._
3839
import org.apache.spark.ml.util.Instrumentation.instrumented
40+
import org.apache.spark.mllib.util.MLUtils
3941
import org.apache.spark.rdd.RDD
4042
import org.apache.spark.sql._
4143
import org.apache.spark.storage.StorageLevel
@@ -220,10 +222,11 @@ class LinearSVC @Since("2.2.0") (
220222
instr.logNumFeatures(numFeatures)
221223

222224
if (numInvalid != 0) {
223-
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
224-
s"Found $numInvalid invalid labels."
225+
val msg = log"Classification labels should be in " +
226+
log"${MDC(RANGE, s"[0 to ${numClasses - 1}]")}. " +
227+
log"Found ${MDC(COUNT, numInvalid)} invalid labels."
225228
instr.logError(msg)
226-
throw new SparkException(msg)
229+
throw new SparkException(msg.message)
227230
}
228231

229232
val featuresStd = summarizer.std.toArray
@@ -249,9 +252,7 @@ class LinearSVC @Since("2.2.0") (
249252
regularization, optimizer)
250253

251254
if (rawCoefficients == null) {
252-
val msg = s"${optimizer.getClass.getName} failed."
253-
instr.logError(msg)
254-
throw new SparkException(msg)
255+
MLUtils.optimizerFailed(instr, optimizer.getClass)
255256
}
256257

257258
val coefficientArray = Array.tabulate(numFeatures) { i =>

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ import org.apache.hadoop.fs.Path
2727

2828
import org.apache.spark.SparkException
2929
import org.apache.spark.annotation.Since
30-
import org.apache.spark.internal.Logging
30+
import org.apache.spark.internal.{Logging, MDC}
31+
import org.apache.spark.internal.LogKey.{COUNT, RANGE}
3132
import org.apache.spark.ml.feature._
3233
import org.apache.spark.ml.impl.Utils
3334
import org.apache.spark.ml.linalg._
@@ -530,10 +531,11 @@ class LogisticRegression @Since("1.2.0") (
530531
}
531532

532533
if (numInvalid != 0) {
533-
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
534-
s"Found $numInvalid invalid labels."
534+
val msg = log"Classification labels should be in " +
535+
log"${MDC(RANGE, s"[0 to ${numClasses - 1}]")}. " +
536+
log"Found ${MDC(COUNT, numInvalid)} invalid labels."
535537
instr.logError(msg)
536-
throw new SparkException(msg)
538+
throw new SparkException(msg.message)
537539
}
538540

539541
instr.logNumClasses(numClasses)
@@ -634,9 +636,7 @@ class LogisticRegression @Since("1.2.0") (
634636
initialSolution.toArray, regularization, optimizer)
635637

636638
if (allCoefficients == null) {
637-
val msg = s"${optimizer.getClass.getName} failed."
638-
instr.logError(msg)
639-
throw new SparkException(msg)
639+
MLUtils.optimizerFailed(instr, optimizer.getClass)
640640
}
641641

642642
val allCoefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept,

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import breeze.linalg.{DenseVector => BDV}
2323
import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS}
2424
import org.apache.hadoop.fs.Path
2525

26-
import org.apache.spark.SparkException
2726
import org.apache.spark.annotation.Since
2827
import org.apache.spark.internal.Logging
2928
import org.apache.spark.ml.PredictorParams
@@ -271,9 +270,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
271270
optimizer, initialSolution)
272271

273272
if (rawCoefficients == null) {
274-
val msg = s"${optimizer.getClass.getName} failed."
275-
instr.logError(msg)
276-
throw new SparkException(msg)
273+
MLUtils.optimizerFailed(instr, optimizer.getClass)
277274
}
278275

279276
val coefficientArray = Array.tabulate(numFeatures) { i =>

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import breeze.stats.distributions.Rand.FixedSeed.randBasis
2525
import breeze.stats.distributions.StudentsT
2626
import org.apache.hadoop.fs.Path
2727

28-
import org.apache.spark.SparkException
2928
import org.apache.spark.annotation.Since
3029
import org.apache.spark.internal.Logging
3130
import org.apache.spark.ml.{PipelineStage, PredictorParams}
@@ -428,9 +427,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
428427
featuresMean, featuresStd, initialSolution, regularization, optimizer)
429428

430429
if (parameters == null) {
431-
val msg = s"${optimizer.getClass.getName} failed."
432-
instr.logError(msg)
433-
throw new SparkException(msg)
430+
MLUtils.optimizerFailed(instr, optimizer.getClass)
434431
}
435432

436433
val model = createModel(parameters, yMean, yStd, featuresMean, featuresStd)

mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.json4s._
2727
import org.json4s.JsonDSL._
2828
import org.json4s.jackson.JsonMethods._
2929

30-
import org.apache.spark.internal.Logging
30+
import org.apache.spark.internal.{LogEntry, Logging}
3131
import org.apache.spark.ml.{MLEvents, PipelineStage}
3232
import org.apache.spark.ml.param.{Param, Params}
3333
import org.apache.spark.rdd.RDD
@@ -84,20 +84,53 @@ private[spark] class Instrumentation private () extends Logging with MLEvents {
8484
super.logWarning(prefix + msg)
8585
}
8686

87+
/**
88+
* Logs a LogEntry which message with a prefix that uniquely identifies the training session.
89+
*/
90+
override def logWarning(entry: LogEntry): Unit = {
91+
if (log.isWarnEnabled) {
92+
withLogContext(entry.context) {
93+
log.warn(prefix + entry.message)
94+
}
95+
}
96+
}
97+
8798
/**
8899
* Logs a error message with a prefix that uniquely identifies the training session.
89100
*/
90101
override def logError(msg: => String): Unit = {
91102
super.logError(prefix + msg)
92103
}
93104

105+
/**
106+
* Logs a LogEntry which message with a prefix that uniquely identifies the training session.
107+
*/
108+
override def logError(entry: LogEntry): Unit = {
109+
if (log.isErrorEnabled) {
110+
withLogContext(entry.context) {
111+
log.error(prefix + entry.message)
112+
}
113+
}
114+
}
115+
94116
/**
95117
* Logs an info message with a prefix that uniquely identifies the training session.
96118
*/
97119
override def logInfo(msg: => String): Unit = {
98120
super.logInfo(prefix + msg)
99121
}
100122

123+
/**
124+
* Logs a LogEntry which message with a prefix that uniquely identifies the training session.
125+
*/
126+
override def logInfo(entry: LogEntry): Unit = {
127+
if (log.isInfoEnabled) {
128+
withLogContext(entry.context) {
129+
log.info(prefix + entry.message)
130+
}
131+
}
132+
}
133+
101134
/**
102135
* Logs the value of the given parameters for the estimator being used in this session.
103136
*/

mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.mllib.util
1919

2020
import org.apache.spark.annotation.Since
21-
import org.apache.spark.internal.Logging
21+
import org.apache.spark.internal.{Logging, MDC}
22+
import org.apache.spark.internal.LogKey.{COUNT, RANGE}
2223
import org.apache.spark.mllib.regression.LabeledPoint
2324
import org.apache.spark.rdd.RDD
2425

@@ -37,7 +38,8 @@ object DataValidators extends Logging {
3738
val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
3839
val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count()
3940
if (numInvalid != 0) {
40-
logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels")
41+
logError(log"Classification labels should be 0 or 1. " +
42+
log"Found ${MDC(COUNT, numInvalid)} invalid labels")
4143
}
4244
numInvalid == 0
4345
}
@@ -53,8 +55,9 @@ object DataValidators extends Logging {
5355
val numInvalid = data.filter(x =>
5456
x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count()
5557
if (numInvalid != 0) {
56-
logError("Classification labels should be in {0 to " + (k - 1) + "}. " +
57-
"Found " + numInvalid + " invalid labels")
58+
logError(log"Classification labels should be in " +
59+
log"${MDC(RANGE, s"[0 to ${k - 1}]")}. " +
60+
log"Found ${MDC(COUNT, numInvalid)} invalid labels")
5861
}
5962
numInvalid == 0
6063
}

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ import scala.reflect.ClassTag
2222

2323
import org.apache.spark.{SparkContext, SparkException}
2424
import org.apache.spark.annotation.Since
25-
import org.apache.spark.internal.Logging
25+
import org.apache.spark.internal.{Logging, MDC}
26+
import org.apache.spark.internal.LogKey.OPTIMIZER_CLASS_NAME
2627
import org.apache.spark.ml.linalg.{MatrixUDT => MLMatrixUDT, VectorUDT => MLVectorUDT}
28+
import org.apache.spark.ml.util.Instrumentation
2729
import org.apache.spark.mllib.linalg._
2830
import org.apache.spark.mllib.linalg.BLAS.dot
2931
import org.apache.spark.mllib.regression.LabeledPoint
@@ -593,4 +595,10 @@ object MLUtils extends Logging {
593595
math.log1p(math.exp(x))
594596
}
595597
}
598+
599+
def optimizerFailed(instr: Instrumentation, optimizerClass: Class[_]): Unit = {
600+
val msg = log"${MDC(OPTIMIZER_CLASS_NAME, optimizerClass.getName)} failed."
601+
instr.logError(msg)
602+
throw new SparkException(msg.message)
603+
}
596604
}

mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.SparkException
21-
import org.apache.spark.internal.Logging
21+
import org.apache.spark.internal.{Logging, MDC}
22+
import org.apache.spark.internal.LogKey.{CATEGORICAL_FEATURES, MAX_CATEGORIES}
2223
import org.apache.spark.ml.attribute._
2324
import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors}
2425
import org.apache.spark.ml.param.ParamsSuite
@@ -175,8 +176,10 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging {
175176
maxCategories: Int,
176177
categoricalFeatures: Set[Int]): Unit = {
177178
val collectedData = data.collect().map(_.getAs[Vector](0))
178-
val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," +
179-
s" categoricalFeatures=${categoricalFeatures.mkString(", ")}"
179+
180+
val errMsg = log"checkCategoryMaps failed for input with " +
181+
log"maxCategories=${MDC(MAX_CATEGORIES, maxCategories)} " +
182+
log"categoricalFeatures=${MDC(CATEGORICAL_FEATURES, categoricalFeatures.mkString(", "))}"
180183
try {
181184
val vectorIndexer = getIndexer.setMaxCategories(maxCategories)
182185
val model = vectorIndexer.fit(data)
@@ -210,8 +213,8 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging {
210213
assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString))
211214
assert(attr.isOrdinal.get === false)
212215
case _ =>
213-
throw new RuntimeException(errMsg + s". Categorical feature $feature failed" +
214-
s" metadata check. Found feature attribute: $featureAttr.")
216+
throw new RuntimeException(errMsg.message + s". Categorical feature $feature " +
217+
s"failed metadata check. Found feature attribute: $featureAttr.")
215218
}
216219
}
217220
// Check numerical feature metadata.
@@ -222,8 +225,8 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging {
222225
case attr: NumericAttribute =>
223226
assert(featureAttr.index.get === feature)
224227
case _ =>
225-
throw new RuntimeException(errMsg + s". Numerical feature $feature failed" +
226-
s" metadata check. Found feature attribute: $featureAttr.")
228+
throw new RuntimeException(errMsg.message + s". Numerical feature $feature " +
229+
s"failed metadata check. Found feature attribute: $featureAttr.")
227230
}
228231
}
229232
}

0 commit comments

Comments
 (0)