Skip to content

Commit 5f8cd7f

Browse files
committed
[SPARK-47172] Addressing reviewer comments
1 parent 71cd10d commit 5f8cd7f

File tree

1 file changed

+100
-42
lines changed

1 file changed

+100
-42
lines changed

common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java

Lines changed: 100 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,33 +19,44 @@
1919

2020
import com.google.common.annotations.VisibleForTesting;
2121
import com.google.common.base.Preconditions;
22+
import com.google.common.primitives.Longs;
2223
import com.google.crypto.tink.subtle.AesGcmHkdfStreaming;
2324
import com.google.crypto.tink.subtle.StreamSegmentDecrypter;
2425
import com.google.crypto.tink.subtle.StreamSegmentEncrypter;
26+
import io.netty.buffer.ByteBuf;
2527
import io.netty.buffer.Unpooled;
2628
import io.netty.channel.*;
2729
import io.netty.util.ReferenceCounted;
2830
import org.apache.spark.network.util.AbstractFileRegion;
29-
import io.netty.buffer.ByteBuf;
3031

3132
import javax.crypto.spec.SecretKeySpec;
3233
import java.io.IOException;
3334
import java.nio.ByteBuffer;
3435
import java.nio.channels.ClosedChannelException;
3536
import java.nio.channels.WritableByteChannel;
3637
import java.security.GeneralSecurityException;
38+
import java.security.InvalidAlgorithmParameterException;
3739

3840
public class GcmTransportCipher implements TransportCipher {
39-
private static final byte[] DEFAULT_AAD = new byte[0];
41+
private static final String HKDF_ALG = "HmacSha256";
4042
private static final int LENGTH_HEADER_BYTES = 8;
4143
@VisibleForTesting
42-
static final int CIPHERTEXT_BUFFER_SIZE = 1024;
44+
static final int CIPHERTEXT_BUFFER_SIZE = 32 * 1024; // 32KB
4345
private final SecretKeySpec aesKey;
4446

4547
public GcmTransportCipher(SecretKeySpec aesKey) {
4648
this.aesKey = aesKey;
4749
}
4850

51+
AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws InvalidAlgorithmParameterException {
52+
return new AesGcmHkdfStreaming(
53+
aesKey.getEncoded(),
54+
HKDF_ALG,
55+
aesKey.getEncoded().length,
56+
CIPHERTEXT_BUFFER_SIZE,
57+
0);
58+
}
59+
4960
@VisibleForTesting
5061
EncryptionHandler getEncryptionHandler() throws GeneralSecurityException {
5162
return new EncryptionHandler();
@@ -68,13 +79,8 @@ class EncryptionHandler extends ChannelOutboundHandlerAdapter {
6879
private final ByteBuffer ciphertextBuffer;
6980
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
7081

71-
EncryptionHandler() throws GeneralSecurityException {
72-
aesGcmHkdfStreaming = new AesGcmHkdfStreaming(
73-
aesKey.getEncoded(),
74-
"HmacSha256",
75-
aesKey.getEncoded().length,
76-
CIPHERTEXT_BUFFER_SIZE,
77-
0);
82+
EncryptionHandler() throws InvalidAlgorithmParameterException {
83+
aesGcmHkdfStreaming = getAesGcmHkdfStreaming();
7884
plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
7985
ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
8086
}
@@ -95,10 +101,10 @@ static class GcmEncryptedMessage extends AbstractFileRegion {
95101
private final Object plaintextMessage;
96102
private final ByteBuffer plaintextBuffer;
97103
private final ByteBuffer ciphertextBuffer;
104+
private final ByteBuffer headerByteBuffer;
98105
private final long bytesToRead;
99106
private long bytesRead = 0;
100107
private final StreamSegmentEncrypter encrypter;
101-
private boolean headerWritten = false;
102108
private long transferred = 0;
103109
private final long encryptedCount;
104110

@@ -110,14 +116,30 @@ static class GcmEncryptedMessage extends AbstractFileRegion {
110116
plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion,
111117
"Unrecognized message type: %s", plaintextMessage.getClass().getName());
112118
this.plaintextMessage = plaintextMessage;
113-
this.bytesToRead = getReadableBytes();
114119
this.plaintextBuffer = plaintextBuffer;
115-
this.plaintextBuffer.clear();
116120
this.ciphertextBuffer = ciphertextBuffer;
117-
this.ciphertextBuffer.clear();
118-
this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(DEFAULT_AAD);
121+
// If the ciphertext buffer cannot be fully written the target, transferTo may
122+
// return with it containing some unwritten data. The initial call we'll explicitly
123+
// set its limit to 0 to indicate the first call to transferTo.
124+
this.ciphertextBuffer.limit(0);
125+
126+
this.bytesToRead = getReadableBytes();
119127
this.encryptedCount =
120128
LENGTH_HEADER_BYTES + aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead);
129+
byte[] lengthAad = Longs.toByteArray(encryptedCount);
130+
this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(lengthAad);
131+
this.headerByteBuffer = createHeaderByteBuffer();
132+
}
133+
134+
// The format of the output is:
135+
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
136+
private ByteBuffer createHeaderByteBuffer() {
137+
ByteBuffer encrypterHeader = encrypter.getHeader();
138+
return ByteBuffer
139+
.allocate(encrypterHeader.remaining() + LENGTH_HEADER_BYTES)
140+
.putLong(encryptedCount)
141+
.put(encrypterHeader)
142+
.flip();
121143
}
122144

123145
@Override
@@ -135,25 +157,61 @@ public long count() {
135157
return encryptedCount;
136158
}
137159

160+
@Override
161+
public GcmEncryptedMessage touch(Object o) {
162+
super.touch(o);
163+
if (plaintextMessage instanceof ByteBuf byteBuf) {
164+
byteBuf.touch(o);
165+
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
166+
fileRegion.touch(o);
167+
}
168+
return this;
169+
}
170+
171+
@Override
172+
public GcmEncryptedMessage retain(int increment) {
173+
super.retain(increment);
174+
if (plaintextMessage instanceof ByteBuf byteBuf) {
175+
byteBuf.retain(increment);
176+
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
177+
fileRegion.retain(increment);
178+
}
179+
return this;
180+
}
181+
182+
@Override
183+
public boolean release(int decrement) {
184+
if (plaintextMessage instanceof ByteBuf byteBuf) {
185+
byteBuf.release(decrement);
186+
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
187+
fileRegion.release(decrement);
188+
}
189+
return super.release(decrement);
190+
}
191+
138192
@Override
139193
public long transferTo(WritableByteChannel target, long position) throws IOException {
140194
Preconditions.checkArgument(position == transferred(),
141195
"Invalid position.");
142196
int transferredThisCall = 0;
143-
// The format of the output is:
144-
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
145-
if (!headerWritten) {
146-
ByteBuffer expectedLength = ByteBuffer
147-
.allocate(LENGTH_HEADER_BYTES)
148-
.putLong(encryptedCount)
149-
.flip();
150-
target.write(expectedLength);
151-
int headerWritten = LENGTH_HEADER_BYTES + target.write(encrypter.getHeader());
152-
transferredThisCall += headerWritten;
153-
this.transferred += headerWritten;
154-
this.headerWritten = true;
197+
// If the header has is not empty, try to write it out to the target.
198+
if (headerByteBuffer.hasRemaining()) {
199+
int written = target.write(headerByteBuffer);
200+
transferredThisCall += written;
201+
this.transferred += written;
202+
if (headerByteBuffer.hasRemaining()) {
203+
return written;
204+
}
205+
}
206+
// If the ciphertext buffer is not empty, try to write it to the target.
207+
if (ciphertextBuffer.hasRemaining()) {
208+
int written = target.write(ciphertextBuffer);
209+
transferredThisCall += written;
210+
this.transferred += written;
211+
if (ciphertextBuffer.hasRemaining()) {
212+
return transferredThisCall;
213+
}
155214
}
156-
157215
while (bytesRead < bytesToRead) {
158216
long readableBytes = getReadableBytes();
159217
boolean lastSegment = readableBytes <= plaintextBuffer.capacity();
@@ -186,12 +244,14 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep
186244
throw new RuntimeException(e);
187245
}
188246
ciphertextBuffer.flip();
189-
int outputRemaining = ciphertextBuffer.remaining();
190-
while (ciphertextBuffer.hasRemaining()) {
191-
target.write(ciphertextBuffer);
247+
int written = target.write(ciphertextBuffer);
248+
transferredThisCall += written;
249+
this.transferred += written;
250+
if (ciphertextBuffer.hasRemaining()) {
251+
// In this case, upon calling transferTo again, it will try to write the
252+
// remaining ciphertext buffer in the conditional before this loop.
253+
return transferredThisCall;
192254
}
193-
transferredThisCall += outputRemaining;
194-
transferred += outputRemaining;
195255
}
196256
return transferredThisCall;
197257
}
@@ -229,12 +289,7 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
229289
private long ciphertextRead = 0;
230290

231291
DecryptionHandler() throws GeneralSecurityException {
232-
aesGcmHkdfStreaming = new AesGcmHkdfStreaming(
233-
aesKey.getEncoded(),
234-
"HmacSha256",
235-
aesKey.getEncoded().length,
236-
CIPHERTEXT_BUFFER_SIZE,
237-
0);
292+
aesGcmHkdfStreaming = getAesGcmHkdfStreaming();
238293
plaintextBuffer =
239294
ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
240295
ciphertextBuffer =
@@ -270,7 +325,8 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
270325
ByteBuffer headerBuffer = ByteBuffer.allocate(headerLength);
271326
ciphertextNettyBuf.readBytes(headerBuffer);
272327
headerBuffer.flip();
273-
decrypter.init(headerBuffer, DEFAULT_AAD);
328+
byte[] lengthAad = Longs.toByteArray(expectedLength);
329+
decrypter.init(headerBuffer, lengthAad);
274330
decrypterInit = true;
275331
ciphertextRead += headerLength;
276332
}
@@ -290,10 +346,12 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
290346
if (readableBytes == 0) {
291347
return;
292348
}
349+
int expectedRemaining = (int) (expectedLength - ciphertextRead);
350+
int bytesToRead = Integer.min(readableBytes, expectedRemaining);
293351
// The smallest ciphertext size is 16 bytes for the auth tag
294-
ciphertextBuffer.limit(readableBytes);
352+
ciphertextBuffer.limit(bytesToRead);
295353
ciphertextNettyBuf.readBytes(ciphertextBuffer);
296-
ciphertextRead += readableBytes;
354+
ciphertextRead += bytesToRead;
297355
// Check if this is the last segment
298356
boolean lastSegment = false;
299357
if (ciphertextRead == expectedLength) {

0 commit comments

Comments
 (0)