Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
MQTT binding done?
  • Loading branch information
bmaidics committed Nov 28, 2023
commit c39621df9c3fa294445078dbf86376d45a153785
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -99,6 +100,7 @@
import org.agrona.collections.Int2IntHashMap;
import org.agrona.collections.Int2ObjectHashMap;
import org.agrona.collections.IntArrayList;
import org.agrona.collections.IntArrayQueue;
import org.agrona.collections.Long2ObjectHashMap;
import org.agrona.collections.MutableBoolean;
import org.agrona.collections.Object2IntHashMap;
Expand Down Expand Up @@ -239,6 +241,7 @@ public final class MqttServerFactory implements MqttStreamFactory

private final BeginFW beginRO = new BeginFW();
private final DataFW dataRO = new DataFW();
private final FlushFW flushRO = new FlushFW();
private final EndFW endRO = new EndFW();
private final AbortFW abortRO = new AbortFW();
private final WindowFW windowRO = new WindowFW();
Expand All @@ -254,6 +257,7 @@ public final class MqttServerFactory implements MqttStreamFactory
private final FlushFW.Builder flushRW = new FlushFW.Builder();

private final MqttDataExFW mqttSubscribeDataExRO = new MqttDataExFW();
private final MqttFlushExFW mqttSubscribeFlushExRO = new MqttFlushExFW();
private final MqttResetExFW mqttResetExRO = new MqttResetExFW();
private final MqttBeginExFW mqttBeginExRO = new MqttBeginExFW();

Expand Down Expand Up @@ -1053,7 +1057,7 @@ private int decodePublish(

if (canPublish && (reserved != 0 || payloadSize == 0))
{
server.onDecodePublish(traceId, authorization, reserved, payload);
server.onDecodePublish(traceId, authorization, reserved, packetId, payload);
server.decodeablePacketBytes = 0;
server.decoder = decodePacketType;
progress = publishLimit;
Expand Down Expand Up @@ -1526,7 +1530,6 @@ private final class MqttServer
private final GuardHandler guard;
private final Function<String, String> credentials;
private final MqttConnectProperty authField;
private final IntArrayList unreleasedPacketIds;

private MqttSessionStream session;

Expand Down Expand Up @@ -1581,9 +1584,14 @@ private final class MqttServer
private int state;
private long sessionId;
private int decodableRemainingBytes;
//TODO: use packetId+qos hash instead of maintaining 2 maps?
private final Int2ObjectHashMap<MqttSubscribeStream> qos1Subscribes;
private final Int2ObjectHashMap<MqttSubscribeStream> qos2Subscribes;
private final IntArrayList unreleasedPacketIds;
private final LinkedHashMap<Long, Integer> unAckedReceivedQos1PacketIds;
private final LinkedHashMap<Long, Integer> unAckedReceivedQos2PacketIds;
private final IntArrayQueue unAckedDeliveredPacketIds;
private final Int2IntHashMap deferredAckedPacketIdsWithQos;


private MqttServer(
Function<String, String> credentials,
Expand Down Expand Up @@ -1611,6 +1619,10 @@ private MqttServer(
this.subscribePacketIds = new Int2IntHashMap(-1);
this.unsubscribePacketIds = new Object2IntHashMap<>(-1);
this.unreleasedPacketIds = new IntArrayList();
this.unAckedReceivedQos1PacketIds = new LinkedHashMap<>();
this.unAckedReceivedQos2PacketIds = new LinkedHashMap<>();
this.deferredAckedPacketIdsWithQos = new Int2IntHashMap(-1);
this.unAckedDeliveredPacketIds = new IntArrayQueue();
this.qos1Subscribes = new Int2ObjectHashMap<>();
this.qos2Subscribes = new Int2ObjectHashMap<>();
this.guard = resolveGuard(options, resolveId);
Expand Down Expand Up @@ -2219,6 +2231,7 @@ private void onDecodePublish(
long traceId,
long authorization,
int reserved,
int packetId,
OctetsFW payload)
{
int reasonCode = SUCCESS;
Expand Down Expand Up @@ -2267,19 +2280,14 @@ else if (mqttPublishHeaderRO.retained && !retainAvailable(capabilities))
final MqttDataExFW dataEx = builder.build();
if (stream != null)
{
stream.doPublishData(traceId, reserved, payload, dataEx);
stream.doPublishData(traceId, reserved, packetId, payload, dataEx);
}
}
doSignalKeepAliveTimeout(traceId);

if (mqttPublishHeaderRO.qos == 1)
{
doEncodePuback(traceId, authorization, mqttPublishHeaderRO.packetId);
}
else if (mqttPublishHeaderRO.qos == 2)
else
{
doEncodePubrec(traceId, authorization, mqttPublishHeaderRO.packetId);
doEncodePubrec(traceId, authorization, packetId);
}
doSignalKeepAliveTimeout(traceId);
}
}

Expand All @@ -2293,12 +2301,20 @@ private int onDecodePuback(
{
final int packetId = puback.packetId();

qos1Subscribes.remove(packetId).doSubscribeWindow(traceId, encodeSlotOffset, encodeBudgetMax);
if (unAckedDeliveredPacketIds.peekInt() == packetId)
{
unAckedDeliveredPacketIds.pollInt();
qos1Subscribes.remove(packetId).doSubscribeWindow(traceId, encodeSlotOffset, encodeBudgetMax);
acknowledgeDeliveredPackets(traceId, authorization);
}
else
{
deferredAckedPacketIdsWithQos.put(packetId, 1);
}

progress = puback.limit();
return progress;
}

private int onDecodePubrec(
long traceId,
long authorization,
Expand All @@ -2309,13 +2325,43 @@ private int onDecodePubrec(
{
final int packetId = pubrec.packetId();

qos2Subscribes.get(packetId).doSubscribeFlush(traceId, 0, packetId);
doEncodePubrel(traceId, authorization, packetId);
if (unAckedDeliveredPacketIds.peekInt() == packetId)
{
unAckedDeliveredPacketIds.pollInt();
qos2Subscribes.get(packetId).doSubscribeFlush(traceId, 0, packetId);
acknowledgeDeliveredPackets(traceId, authorization);
}
else
{
deferredAckedPacketIdsWithQos.put(packetId, 2);
}

progress = pubrec.limit();
return progress;
}

private void acknowledgeDeliveredPackets(
long traceId,
long authorization)
{

while (!unAckedDeliveredPacketIds.isEmpty() &&
deferredAckedPacketIdsWithQos.containsKey(unAckedDeliveredPacketIds.peek()))
{
final int packetId = unAckedDeliveredPacketIds.pollInt();
final int qos = deferredAckedPacketIdsWithQos.remove(packetId);

if (qos == 1)
{
qos1Subscribes.remove(packetId).doSubscribeWindow(traceId, encodeSlotOffset, encodeBudgetMax);
}
else
{
qos2Subscribes.get(packetId).doSubscribeFlush(traceId, 0, packetId);
}
}
}

private int onDecodePubrel(
long traceId,
long authorization,
Expand Down Expand Up @@ -3043,18 +3089,20 @@ private void doEncodePublish(
}
else
{
final int packetId = subscribeDataEx.subscribe().packetId();
final MqttPublishQosFW publish =
mqttPublishQosRW.wrap(writeBuffer, DataFW.FIELD_OFFSET_PAYLOAD, writeBuffer.capacity())
.typeAndFlags(publishNetworkTypeAndFlags)
.remainingLength(5 + topicNameLength + propertiesSize.get() + payloadSize + deferred)
.topicName(topicName)
.packetId(subscribeDataEx.subscribe().packetId())
.packetId(packetId)
.properties(p -> p.length(propertiesSize0)
.value(propertyBuffer, 0, propertiesSize0))
.payload(payload)
.build();

doNetworkData(traceId, authorization, 0L, publish);
unAckedDeliveredPacketIds.add(packetId);
}
}
else
Expand Down Expand Up @@ -4162,6 +4210,7 @@ private class MqttPublishStream
private long publishExpiresId = NO_CANCEL_ID;
private long publishExpiresAt;


MqttPublishStream(
long originId,
long routedId,
Expand Down Expand Up @@ -4203,8 +4252,9 @@ private void doPublishBegin(
private void doPublishData(
long traceId,
int reserved,
int packetId,
OctetsFW payload,
Flyweight extension)
MqttDataExFW mqttData)
{
assert MqttState.initialOpening(state);

Expand All @@ -4215,11 +4265,22 @@ private void doPublishData(
assert reserved >= length + initialPad;

doData(application, originId, routedId, initialId, initialSeq, initialAck, initialMax,
traceId, sessionId, budgetId, reserved, buffer, offset, length, extension);
traceId, sessionId, budgetId, reserved, buffer, offset, length, mqttData);

initialSeq += reserved;
assert initialSeq <= initialAck + initialMax;

final int qos = mqttData.publish().qos();

if (qos == 1)
{
unAckedReceivedQos1PacketIds.put(initialSeq, packetId);
}
else if (qos == 2)
{
unAckedReceivedQos2PacketIds.put(initialSeq, packetId);
}

doSignalPublishExpiration(traceId);
}

Expand Down Expand Up @@ -4369,6 +4430,40 @@ else if (decodePublisherKey == topicKey)
{
decodeNetwork(traceId);
}

acknowledgePublishPackets(acknowledge, traceId, authorization);
}

private void acknowledgePublishPackets(
long acknowledge,
long traceId,
long authorization)
{
for (Map.Entry<Long, Integer> e : unAckedReceivedQos1PacketIds.entrySet())
{
if (e.getKey() <= acknowledge)
{
doEncodePuback(traceId, authorization, e.getValue());
unAckedReceivedQos1PacketIds.remove(e.getKey());
}
else
{
break;
}
}

for (Map.Entry<Long, Integer> e : unAckedReceivedQos2PacketIds.entrySet())
{
if (e.getKey() <= acknowledge)
{
doEncodePubrec(traceId, authorization, e.getValue());
unAckedReceivedQos2PacketIds.remove(e.getKey());
}
else
{
break;
}
}
}

private void onPublishReset(
Expand Down Expand Up @@ -4739,6 +4834,10 @@ private void onSubscribe(
final DataFW data = dataRO.wrap(buffer, index, index + length);
onSubscribeData(data);
break;
case FlushFW.TYPE_ID:
final FlushFW flush = flushRO.wrap(buffer, index, index + length);
onSubscribeFlush(flush);
break;
case EndFW.TYPE_ID:
final EndFW end = endRO.wrap(buffer, index, index + length);
onSubscribeEnd(end);
Expand Down Expand Up @@ -4834,6 +4933,29 @@ else if (qos == 2)
}
}

private void onSubscribeFlush(
FlushFW flush)
{
final long sequence = flush.sequence();
final long acknowledge = flush.acknowledge();
final long traceId = flush.traceId();
final long authorization = flush.authorization();
final OctetsFW extension = flush.extension();

assert acknowledge <= sequence;
assert sequence >= replySeq;

replySeq = sequence;

assert replyAck <= replySeq;

final MqttFlushExFW subscribeFlushEx = extension.get(mqttSubscribeFlushExRO::tryWrap);
final int packetId = subscribeFlushEx.subscribe().packetId();

doEncodePubrel(traceId, authorization, packetId);
qos2Subscribes.put(packetId, this);
}


private void onSubscribeReset(
ResetFW reset)
Expand Down
Loading