Skip to content

Commit d781483

Browse files
committed
Make PowertoolsLogging thread-safe.
1 parent 12a0d43 commit d781483

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

powertools-logging/src/main/java/software/amazon/lambda/powertools/logging/PowertoolsLogging.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@
6464
*/
6565
public final class PowertoolsLogging {
6666
private static final Logger LOG = LoggerFactory.getLogger(PowertoolsLogging.class);
67-
private static final Random SAMPLER = new Random();
68-
private static boolean hasBeenInitialized = false;
67+
private static final ThreadLocal<Random> SAMPLER = ThreadLocal.withInitial(Random::new);
68+
private static volatile boolean hasBeenInitialized = false;
6969

7070
static {
7171
initializeLogLevel();
@@ -186,7 +186,7 @@ public static void initializeLogging(Context context, double samplingRate, Strin
186186
coldStartDone();
187187
}
188188
hasBeenInitialized = true;
189-
189+
190190
addLambdaContextToLoggingContext(context);
191191
setLogLevelBasedOnSamplingRate(samplingRate);
192192
getXrayTraceId().ifPresent(xRayTraceId -> MDC.put(FUNCTION_TRACE_ID.getName(), xRayTraceId));
@@ -219,7 +219,7 @@ private static void setLogLevelBasedOnSamplingRate(double samplingRate) {
219219
return;
220220
}
221221

222-
float sample = SAMPLER.nextFloat();
222+
float sample = SAMPLER.get().nextFloat();
223223
if (effectiveSamplingRate > sample) {
224224
LoggingManager loggingManager = LoggingManagerRegistry.getLoggingManager();
225225
loggingManager.setLogLevel(Level.DEBUG);
@@ -272,5 +272,6 @@ public static void clearState(boolean clearMdcState) {
272272
MDC.clear();
273273
}
274274
clearBuffer();
275+
SAMPLER.remove();
275276
}
276277
}

powertools-logging/src/test/java/software/amazon/lambda/powertools/logging/PowertoolsLoggingTest.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,57 @@ void clearState_withoutMdcClear_shouldOnlyClearBuffer() {
365365
assertThat(testManager.isBufferCleared()).isTrue();
366366
}
367367

368+
@Test
369+
void initializeLogging_concurrentCalls_shouldBeThreadSafe() throws InterruptedException {
370+
// GIVEN
371+
int threadCount = 10;
372+
Thread[] threads = new Thread[threadCount];
373+
String[] samplingRates = new String[threadCount];
374+
boolean[] success = new boolean[threadCount];
375+
376+
// WHEN - Multiple threads call initializeLogging with alternating sampling rates
377+
for (int i = 0; i < threadCount; i++) {
378+
final int threadIndex = i;
379+
final double samplingRate = (i % 2 == 0) ? 1.0 : 0.0; // Alternate between 1.0 and 0.0
380+
381+
threads[i] = new Thread(() -> {
382+
try {
383+
PowertoolsLogging.initializeLogging(context, samplingRate);
384+
385+
// Capture the sampling rate set in MDC (thread-local)
386+
samplingRates[threadIndex] = MDC.get(PowertoolsLoggedFields.SAMPLING_RATE.getName());
387+
success[threadIndex] = true;
388+
389+
// Clean up thread-local state
390+
PowertoolsLogging.clearState(true);
391+
} catch (Exception e) {
392+
success[threadIndex] = false;
393+
}
394+
});
395+
}
396+
397+
// Start all threads
398+
for (Thread thread : threads) {
399+
thread.start();
400+
}
401+
402+
// Wait for all threads to complete
403+
for (Thread thread : threads) {
404+
thread.join();
405+
}
406+
407+
// THEN - All threads should complete successfully
408+
for (boolean result : success) {
409+
assertThat(result).isTrue();
410+
}
411+
412+
// THEN - Each thread should have its own sampling rate in MDC
413+
for (int i = 0; i < threadCount; i++) {
414+
String expectedRate = (i % 2 == 0) ? "1.0" : "0.0";
415+
assertThat(samplingRates[i]).as("Thread %d should have sampling rate %s", i, expectedRate).isEqualTo(expectedRate);
416+
}
417+
}
418+
368419
private void reinitializeLogLevel() {
369420
try {
370421
Method initializeLogLevel = PowertoolsLogging.class.getDeclaredMethod("initializeLogLevel");

0 commit comments

Comments
 (0)