Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions src/main/java/io/nats/client/impl/MessageQueue.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@

import io.nats.client.NatsSystemClock;

import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Predicate;

import static io.nats.client.support.NatsConstants.*;

Expand All @@ -33,8 +34,9 @@ class MessageQueue {
protected static final int DRAINING = 2;
protected static final long MIN_OFFER_TIMEOUT_NANOS = 100 * NANOS_PER_MILLI;

protected final AtomicLong length;
protected final AtomicLong sizeInBytes;
protected final AtomicLong length;
protected final AtomicBoolean filtered;
protected final AtomicInteger running;
protected final boolean singleReaderMode;
protected final LinkedBlockingQueue<NatsMessage> queue;
Expand All @@ -45,12 +47,18 @@ class MessageQueue {
protected final long offerTimeoutNanos;
protected final Duration requestCleanupInterval;

static class MarkerMessage extends ProtocolMessage {
MarkerMessage(String mark) {
super(mark.getBytes(StandardCharsets.ISO_8859_1), false);
}
}

// SPECIAL MARKER MESSAGES
// A simple == is used to resolve if any message is exactly the static pill object in question
// ----------
// 1. Poison pill is a graphic, but common term for an item that breaks loops or stop something.
// In this class the poison pill is used to break out of timed waits on the blocking queue.
protected static final NatsMessage POISON_PILL = new NatsMessage("_poison", null, EMPTY_BODY);
protected static final MarkerMessage POISON_PILL = new MarkerMessage("_poison");

MessageQueue(boolean singleReaderMode, Duration requestCleanupInterval) {
this(singleReaderMode, -1, false, requestCleanupInterval, null);
Expand Down Expand Up @@ -81,6 +89,7 @@ class MessageQueue {
this.running = new AtomicInteger(RUNNING);
sizeInBytes = new AtomicLong(0);
length = new AtomicLong(0);
filtered = new AtomicBoolean(true);
this.offerLockNanos = requestCleanupInterval.toNanos();
this.offerTimeoutNanos = Math.max(MIN_OFFER_TIMEOUT_NANOS, requestCleanupInterval.toMillis() * NANOS_PER_MILLI * 95 / 100) ;

Expand All @@ -97,9 +106,11 @@ class MessageQueue {
void drainTo(MessageQueue target) {
editLock.lock();
try {
queue.drainTo(target.queue);
target.length.set(length.getAndSet(0));
this.queue.drainTo(target.queue);
target.sizeInBytes.set(sizeInBytes.getAndSet(0));
target.length.set(length.getAndSet(0));
target.filtered.set(false);
this.filtered.set(true);
} finally {
editLock.unlock();
}
Expand Down Expand Up @@ -178,6 +189,7 @@ boolean push(NatsMessage msg, boolean internal) {
}
sizeInBytes.getAndAdd(msg.getSizeInBytes());
length.incrementAndGet();
filtered.set(false);
return true;

}
Expand Down Expand Up @@ -206,11 +218,11 @@ void poisonTheQueue() {
}

/**
* Marking the queue, like POISON, is a message we don't want to count.
* Marking the queue, like poisonTheQueue, is a message we don't want to count.
* Intended to only be used with an unbounded queue. Use at your own risk.
* @param msg the mark
*/
void markTheQueue(NatsMessage msg) {
void markTheQueue(MarkerMessage msg) {
queue.offer(msg);
}

Expand Down Expand Up @@ -250,7 +262,7 @@ NatsMessage pop(Duration timeout) throws InterruptedException {
}

sizeInBytes.getAndAdd(-msg.getSizeInBytes());
length.decrementAndGet();
filtered.set(length.decrementAndGet() == 0);

return msg;
}
Expand Down Expand Up @@ -286,7 +298,7 @@ NatsMessage accumulate(long maxBytesToAccumulate, long maxMessagesToAccumulate,

if (maxMessagesToAccumulate <= 1 || size >= maxBytesToAccumulate) {
sizeInBytes.addAndGet(-size);
length.decrementAndGet();
filtered.set(length.decrementAndGet() == 0);
return msg;
}

Expand Down Expand Up @@ -320,7 +332,7 @@ NatsMessage accumulate(long maxBytesToAccumulate, long maxMessagesToAccumulate,
}

sizeInBytes.addAndGet(-size);
length.addAndGet(-accumulated);
filtered.set(length.addAndGet(-accumulated) == 0);

return msg;
}
Expand All @@ -338,24 +350,32 @@ long sizeInBytes() {
return sizeInBytes.get();
}

void filter(Predicate<NatsMessage> p) {
void filterOnStop() {
editLock.lock();
try {
if (this.isRunning()) {
throw new IllegalStateException("Filter is only supported when the queue is paused");
}
ArrayList<NatsMessage> newQueue = new ArrayList<>();
NatsMessage cursor = this.queue.poll();
while (cursor != null) {
if (!p.test(cursor)) {
newQueue.add(cursor);
} else {
sizeInBytes.addAndGet(-cursor.getSizeInBytes());
length.decrementAndGet();
if (!filtered.get()) {
long removed = 0;
long removedBytes = 0;
ArrayList<NatsMessage> newQueue = new ArrayList<>();
NatsMessage cursor = this.queue.poll();
while (cursor != null) {
if (cursor.isProtocolFilterOnStop()) {
removedBytes += cursor.getSizeInBytes();
removed++;
}
else {
newQueue.add(cursor);
}
cursor = this.queue.poll();
}
cursor = this.queue.poll();
this.queue.addAll(newQueue);
sizeInBytes.addAndGet(-removedBytes);
length.addAndGet(-removed);
filtered.set(true);
}
this.queue.addAll(newQueue);
} finally {
editLock.unlock();
}
Expand All @@ -367,6 +387,7 @@ void clear() {
this.queue.clear();
length.set(0);
sizeInBytes.set(0);
filtered.set(true);
} finally {
editLock.unlock();
}
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/io/nats/client/impl/NatsConnectionWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
import java.util.concurrent.locks.ReentrantLock;

import static io.nats.client.support.BuilderBase.bufferAllocSize;
import static io.nats.client.support.NatsConstants.*;
import static io.nats.client.support.NatsConstants.CR;
import static io.nats.client.support.NatsConstants.LF;

class NatsConnectionWriter implements Runnable {
enum Mode {
Expand Down Expand Up @@ -114,7 +115,7 @@ Future<Boolean> stop() {
try {
this.normalOutgoing.pause();
this.reconnectOutgoing.pause();
this.normalOutgoing.filter(NatsMessage::isProtocolFilterOnStop);
this.normalOutgoing.filterOnStop();
}
finally {
this.startStopLock.unlock();
Expand All @@ -127,7 +128,7 @@ boolean isRunning() {
return running.get();
}

private static final NatsMessage END_RECONNECT = new NatsMessage("_end", null, EMPTY_BODY);
private static final MessageQueue.MarkerMessage END_RECONNECT = new MessageQueue.MarkerMessage("_end_reconnect");

void sendMessageBatch(NatsMessage msg, DataPort dataPort, StatisticsCollector stats) throws IOException {
writerLock.lock();
Expand Down
110 changes: 44 additions & 66 deletions src/test/java/io/nats/client/impl/MessageQueueTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@

import org.junit.jupiter.api.Test;

import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -29,9 +27,9 @@
public class MessageQueueTests {
static final Duration REQUEST_CLEANUP_INTERVAL = Duration.ofSeconds(5);
static final byte[] PING = "PING".getBytes();
static final byte[] ONE = "one".getBytes();
static final byte[] TWO = "two".getBytes();
static final byte[] THREE = "three".getBytes();
static final byte[] AAA = "aaa".getBytes();
static final byte[] BBB = "bbb".getBytes();
static final byte[] CCC = "ccc".getBytes();

@Test
public void testEmptyPop() throws InterruptedException {
Expand Down Expand Up @@ -457,9 +455,9 @@ public void testLength() throws InterruptedException {
@Test
public void testSizeInBytes() throws InterruptedException {
MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL);
NatsMessage msg1 = new ProtocolMessage(ONE);
NatsMessage msg2 = new ProtocolMessage(TWO);
NatsMessage msg3 = new ProtocolMessage(THREE);
NatsMessage msg1 = new ProtocolMessage(AAA);
NatsMessage msg2 = new ProtocolMessage(BBB);
NatsMessage msg3 = new ProtocolMessage(CCC);
long expected = 0;

q.push(msg1);
Expand Down Expand Up @@ -548,75 +546,55 @@ public void testDrainTo() {
}

@Test
public void testFilterTail() throws InterruptedException {
MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL);
NatsMessage msg1 = new ProtocolMessage(ONE);
NatsMessage msg2 = new ProtocolMessage(TWO);
NatsMessage msg3 = new ProtocolMessage(THREE);
byte[] expected = "one".getBytes(StandardCharsets.UTF_8);

q.push(msg1);
q.push(msg2);
q.push(msg3);

long before = q.sizeInBytes();
q.pause();
q.filter((msg) -> Arrays.equals(expected, msg.getProtocolBytes()));
q.resume();
long after = q.sizeInBytes();
public void testFilterFirstIn() throws InterruptedException {
_testFiltered(1);
}

assertEquals(2,q.length());
assertEquals(before, after + expected.length + 2);
assertEquals(q.popNow(), msg2);
assertEquals(q.popNow(), msg3);
@Test
public void testFilterLastIn() throws InterruptedException {
_testFiltered(3);
}

@Test
public void testFilterHead() throws InterruptedException {
MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL);
NatsMessage msg1 = new ProtocolMessage(ONE);
NatsMessage msg2 = new ProtocolMessage(TWO);
NatsMessage msg3 = new ProtocolMessage(THREE);
byte[] expected = "three".getBytes(StandardCharsets.UTF_8);
public void testFilterMiddle() throws InterruptedException {
_testFiltered(2);
}

private static void _testFiltered(int filtered) throws InterruptedException {
NatsMessage msg1 = new ProtocolMessage(AAA, filtered == 1);
NatsMessage msg2 = new ProtocolMessage(BBB, filtered == 2);
NatsMessage msg3 = new ProtocolMessage(CCC, filtered == 3);

MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL);
q.push(msg1);
q.push(msg2);
q.push(msg3);

long before = q.sizeInBytes();
q.pause();
q.filter((msg) -> Arrays.equals(expected, msg.getProtocolBytes()));
q.filterOnStop();
q.resume();
long after = q.sizeInBytes();

assertEquals(2,q.length());
assertEquals(before, after + expected.length + 2);
assertEquals(q.popNow(), msg1);
assertEquals(q.popNow(), msg2);
}

@Test
public void testFilterMiddle() throws InterruptedException {
MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL);
NatsMessage msg1 = new ProtocolMessage(ONE);
NatsMessage msg2 = new ProtocolMessage(TWO);
NatsMessage msg3 = new ProtocolMessage(THREE);
byte[] expected = "two".getBytes(StandardCharsets.UTF_8);

q.push(msg1);
q.push(msg2);
q.push(msg3);
assertEquals(2, q.length());
assertEquals(before, after + 3 + 2);

long before = q.sizeInBytes();
q.pause();
q.filter((msg) -> Arrays.equals(expected, msg.getProtocolBytes()));
q.filterOnStop();
q.resume();
long after = q.sizeInBytes();

assertEquals(2,q.length());
assertEquals(before, after + expected.length + 2);
assertEquals(q.popNow(), msg1);
assertEquals(q.popNow(), msg3);
assertEquals(2, q.length());
assertEquals(before, after + 3 + 2);

if (filtered != 1) {
assertEquals(q.popNow(), msg1);
}
if (filtered != 2) {
assertEquals(q.popNow(), msg2);
}
if (filtered != 3) {
assertEquals(q.popNow(), msg3);
}
}

@Test
Expand All @@ -631,17 +609,17 @@ public void testPausedAccumulate() throws InterruptedException {
public void testThrowOnFilterIfRunning() {
assertThrows(IllegalStateException.class, () -> {
MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL);
q.filter((msg) -> true);
q.filterOnStop();
fail();
});
}

@Test
public void testExceptionWhenQueueIsFull() {
MessageQueue q = new MessageQueue(true, 2, false, REQUEST_CLEANUP_INTERVAL);
NatsMessage msg1 = new ProtocolMessage(ONE);
NatsMessage msg2 = new ProtocolMessage(TWO);
NatsMessage msg3 = new ProtocolMessage(THREE);
NatsMessage msg1 = new ProtocolMessage(AAA);
NatsMessage msg2 = new ProtocolMessage(BBB);
NatsMessage msg3 = new ProtocolMessage(CCC);

assertTrue(q.push(msg1));
assertTrue(q.push(msg2));
Expand All @@ -656,9 +634,9 @@ public void testExceptionWhenQueueIsFull() {
@Test
public void testDiscardMessageWhenQueueFull() {
MessageQueue q = new MessageQueue(true, 2, true, REQUEST_CLEANUP_INTERVAL);
NatsMessage msg1 = new ProtocolMessage(ONE);
NatsMessage msg2 = new ProtocolMessage(TWO);
NatsMessage msg3 = new ProtocolMessage(THREE);
NatsMessage msg1 = new ProtocolMessage(AAA);
NatsMessage msg2 = new ProtocolMessage(BBB);
NatsMessage msg3 = new ProtocolMessage(CCC);

assertTrue(q.push(msg1));
assertTrue(q.push(msg2));
Expand Down
Loading