Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[SPARK-47172] Updating GcmTransportCipher to use buffered streaming
  • Loading branch information
sweisdb committed May 15, 2024
commit b9baf2277c4d17a37aa5ffe91ba36fa9fd58f834
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ private TransportCipher generateTransportCipher(
return new GcmTransportCipher(sessionKey);
} else {
throw new IllegalArgumentException(
"Unsupported cipher transformation: " + conf.cipherTransformation());
String.format("Unsupported cipher mode: %s. %s and %s are supported.",
conf.cipherTransformation(), CIPHER_ALGORITHM, LEGACY_CIPHER_ALGORITHM));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package org.apache.spark.network.crypto;

import com.google.common.annotations.VisibleForTesting;
import com.google.crypto.tink.subtle.AesGcmJce;
import com.google.common.base.Preconditions;
import com.google.crypto.tink.subtle.AesGcmHkdfStreaming;
import com.google.crypto.tink.subtle.StreamSegmentEncrypter;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.util.ReferenceCounted;
import org.apache.spark.network.util.AbstractFileRegion;
import io.netty.buffer.ByteBuf;

Expand All @@ -33,80 +36,284 @@

public class GcmTransportCipher implements TransportCipher {
private static final byte[] DEFAULT_AAD = new byte[0];

private static final int LENGTH_HEADER_BYTES = 8;
@VisibleForTesting
static final int CIPHERTEXT_BUFFER_SIZE = 1024;
private final SecretKeySpec aesKey;

public GcmTransportCipher(SecretKeySpec aesKey) {
this.aesKey = aesKey;
}

@VisibleForTesting
EncryptionHandler getEncryptionHandler() {
EncryptionHandler getEncryptionHandler() throws GeneralSecurityException {
return new EncryptionHandler();
}

@VisibleForTesting
DecryptionHandler getDecryptionHandler() {
DecryptionHandler getDecryptionHandler() throws GeneralSecurityException {
return new DecryptionHandler();
}

public void addToChannel(Channel ch) {
public void addToChannel(Channel ch) throws GeneralSecurityException {
ch.pipeline()
.addFirst("GcmTransportEncryption", getEncryptionHandler())
.addFirst("GcmTransportDecryption", getDecryptionHandler());
}

@VisibleForTesting
class EncryptionHandler extends ChannelOutboundHandlerAdapter {
private final ByteBuffer plaintextBuffer;
private final ByteBuffer ciphertextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;

EncryptionHandler() throws GeneralSecurityException {
aesGcmHkdfStreaming = new AesGcmHkdfStreaming(
aesKey.getEncoded(),
"HmacSha256",
aesKey.getEncoded().length,
CIPHERTEXT_BUFFER_SIZE,
0);
plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
}

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
ByteBuffer inputBuffer;
int bytesToRead;
if (msg instanceof ByteBuf byteBuf) {
bytesToRead = byteBuf.readableBytes();
// This is allocating a buffer that is the size of the input
inputBuffer = ByteBuffer.allocate(bytesToRead);
// This will block while copying
while (inputBuffer.position() < bytesToRead) {
byteBuf.readBytes(inputBuffer);
GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage(
aesGcmHkdfStreaming,
msg,
plaintextBuffer,
ciphertextBuffer);
ctx.write(encryptedMessage, promise);
}
}

static class GcmEncryptedMessage extends AbstractFileRegion {
private final Object plaintextMessage;
private final ByteBuffer plaintextBuffer;
private final ByteBuffer ciphertextBuffer;
private final long bytesToRead;
private long bytesRead = 0;
private final StreamSegmentEncrypter encrypter;
private boolean headerWritten = false;
private long transferred = 0;
private final long encryptedCount;

GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming,
Object plaintextMessage,
ByteBuffer plaintextBuffer,
ByteBuffer ciphertextBuffer) throws GeneralSecurityException {
Preconditions.checkArgument(
plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion,
"Unrecognized message type: %s", plaintextMessage.getClass().getName());
this.plaintextMessage = plaintextMessage;
this.bytesToRead = getReadableBytes();
this.plaintextBuffer = plaintextBuffer;
this.plaintextBuffer.clear();
this.ciphertextBuffer = ciphertextBuffer;
this.ciphertextBuffer.clear();
this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(DEFAULT_AAD);
this.encryptedCount =
LENGTH_HEADER_BYTES + aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead);
}

@Override
public long position() {
return 0;
}

@Override
public long transferred() {
return transferred;
}

@Override
public long count() {
return encryptedCount;
}

@Override
public long transferTo(WritableByteChannel target, long position) throws IOException {
Preconditions.checkArgument(position == transferred(),
"Invalid position.");
int transferredThisCall = 0;
// The format of the output is:
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
if (!headerWritten) {
ByteBuffer expectedLength = ByteBuffer
.allocate(LENGTH_HEADER_BYTES)
.putLong(encryptedCount)
.flip();
target.write(expectedLength);
int headerWritten = LENGTH_HEADER_BYTES + target.write(encrypter.getHeader());
transferredThisCall += headerWritten;
this.transferred += headerWritten;
this.headerWritten = true;
}

while (bytesRead < bytesToRead) {
long readableBytes = getReadableBytes();
boolean lastSegment = readableBytes <= plaintextBuffer.capacity();
plaintextBuffer.clear();
int readLimit =
(int) Math.min(readableBytes, plaintextBuffer.capacity());
plaintextBuffer.limit(readLimit);
if (plaintextMessage instanceof ByteBuf byteBuf) {
byteBuf.readBytes(plaintextBuffer);
long inputBytesRead = readableBytes - byteBuf.readableBytes();
bytesRead += inputBytesRead;
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
} else if (plaintextMessage instanceof FileRegion fileRegion) {

ByteBufferWriteableChannel plaintextChannel =
new ByteBufferWriteableChannel(plaintextBuffer);
long transferred =
fileRegion.transferTo(plaintextChannel, fileRegion.transferred());
bytesRead += transferred;
if (transferred == 0) {
// File regions may return 0 if they are not ready to transfer
// more data. In that case, we'll return with the expectation
// that this transferTo() is called again.
return transferredThisCall;
}
}
} else if (msg instanceof AbstractFileRegion fileRegion) {
bytesToRead = (int) fileRegion.count();
// This is allocating a buffer that is the size of the input
inputBuffer = ByteBuffer.allocate(bytesToRead);
ByteBufferWriteableChannel writeableChannel =
new ByteBufferWriteableChannel(inputBuffer);
long transferred = 0;
// This will block while copying
while (transferred < bytesToRead) {
transferred +=
fileRegion.transferTo(writeableChannel, fileRegion.transferred());
plaintextBuffer.flip();
ciphertextBuffer.clear();
try {
encrypter.encryptSegment(plaintextBuffer, lastSegment, ciphertextBuffer);
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
}
ciphertextBuffer.flip();
int outputRemaining = ciphertextBuffer.remaining();
while (ciphertextBuffer.hasRemaining()) {
target.write(ciphertextBuffer);
}
transferredThisCall += outputRemaining;
transferred += outputRemaining;
}
return transferredThisCall;
}

private long getReadableBytes() {
if (plaintextMessage instanceof ByteBuf byteBuf) {
return byteBuf.readableBytes();
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
return fileRegion.count() - fileRegion.transferred();
} else {
throw new IllegalArgumentException("Unsupported message type: " + msg.getClass());
throw new IllegalArgumentException("Unsupported message type: " +
plaintextMessage.getClass().getName());
}
AesGcmJce cipher = new AesGcmJce(aesKey.getEncoded());
byte[] encrypted = cipher.encrypt(inputBuffer.array(), DEFAULT_AAD);
ByteBuf wrappedEncrypted = Unpooled.wrappedBuffer(encrypted);
ctx.write(wrappedEncrypted, promise);
}

@Override
protected void deallocate() {
if (plaintextMessage instanceof ReferenceCounted referenceCounted) {
referenceCounted.release();
}
plaintextBuffer.clear();
ciphertextBuffer.clear();
}
}

@VisibleForTesting
class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final ByteBuffer ciphertextBuffer;
private final ByteBuffer plaintextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
private final StreamSegmentDecrypter decrypter;
private boolean decrypterInit = false;
private int segmentNumber = 0;
private long expectedLength = -1;
private long ciphertextRead = 0;

DecryptionHandler() throws GeneralSecurityException {
aesGcmHkdfStreaming = new AesGcmHkdfStreaming(
aesKey.getEncoded(),
"HmacSha256",
aesKey.getEncoded().length,
CIPHERTEXT_BUFFER_SIZE,
0);
plaintextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
ciphertextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter();
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
throws GeneralSecurityException {
if (msg instanceof ByteBuf byteBuf) {
// This is allocating a buffer that is the size of the input
byte[] encrypted = new byte[byteBuf.readableBytes()];
byteBuf.readBytes(encrypted);
AesGcmJce cipher = new AesGcmJce(aesKey.getEncoded());
byte[] decrypted = cipher.decrypt(encrypted, DEFAULT_AAD);
ByteBuf wrappedDecrypted = Unpooled.wrappedBuffer(decrypted);
ctx.fireChannelRead(wrappedDecrypted);
} else {
throw new IllegalArgumentException("Unsupported message type: " + msg.getClass());
Preconditions.checkArgument(ciphertextMessage instanceof ByteBuf,
"Unrecognized message type: %s",
ciphertextMessage.getClass().getName());
ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage;
// The format of the output is:
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
try {
while (ciphertextNettyBuf.readableBytes() > 0) {
// Check if the expected ciphertext length has been read.
if (expectedLength < 0 &&
ciphertextNettyBuf.readableBytes() >= LENGTH_HEADER_BYTES) {
expectedLength = ciphertextNettyBuf.readLong();
if (expectedLength < 0) {
throw new IllegalStateException("Invalid expected ciphertext length.");
}
ciphertextRead += LENGTH_HEADER_BYTES;
}
int headerLength = aesGcmHkdfStreaming.getHeaderLength();
// Check if the ciphertext header has been read. This contains
// the IV and other internal metadata.
if (!decrypterInit &&
ciphertextNettyBuf.readableBytes() >= headerLength) {
ByteBuffer headerBuffer = ByteBuffer.allocate(headerLength);
ciphertextNettyBuf.readBytes(headerBuffer);
headerBuffer.flip();
decrypter.init(headerBuffer, DEFAULT_AAD);
decrypterInit = true;
ciphertextRead += headerLength;
}
// This may occur if there weren't enough readable bytes to read the header.
if (!decrypterInit) {
return;
}
// This may occur if the expected length is just the header.
if (expectedLength == ciphertextRead) {
return;
}
ciphertextBuffer.clear();
// Read the ciphertext into the local buffer
int readableBytes = Integer.min(
ciphertextNettyBuf.readableBytes(),
ciphertextBuffer.remaining());
if (readableBytes == 0) {
return;
}
// The smallest ciphertext size is 16 bytes for the auth tag
ciphertextBuffer.limit(readableBytes);
ciphertextNettyBuf.readBytes(ciphertextBuffer);
ciphertextRead += readableBytes;
// Check if this is the last segment
boolean lastSegment = false;
if (ciphertextRead == expectedLength) {
lastSegment = true;
} else if (ciphertextRead > expectedLength) {
throw new IllegalStateException("Read more ciphertext than expected.");
}
plaintextBuffer.clear();
ciphertextBuffer.flip();

decrypter.decryptSegment(
ciphertextBuffer,
segmentNumber,
lastSegment,
plaintextBuffer);
segmentNumber++;
plaintextBuffer.flip();
ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer));
}
} finally {
ciphertextNettyBuf.release();
}
}
}
Expand All @@ -125,8 +332,9 @@ public int write(ByteBuffer src) throws IOException {
throw new ClosedChannelException();
}
int bytesToWrite = Math.min(src.remaining(), destination.remaining());
// Destination buffer is full
if (bytesToWrite == 0) {
return 0; // Destination buffer is full
return 0;
}
ByteBuffer temp = src.slice().limit(bytesToWrite);
destination.put(temp);
Expand Down
Loading