diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java index cb68cfb5a0e8..8449a774a404 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -45,6 +45,8 @@ class AuthEngine implements Closeable { public static final byte[] INPUT_IV_INFO = "inputIv".getBytes(UTF_8); public static final byte[] OUTPUT_IV_INFO = "outputIv".getBytes(UTF_8); private static final String MAC_ALGORITHM = "HMACSHA256"; + private static final String LEGACY_CIPHER_ALGORITHM = "AES/CTR/NoPadding"; + private static final String CIPHER_ALGORITHM = "AES/GCM/NoPadding"; private static final int AES_GCM_KEY_SIZE_BYTES = 16; private static final byte[] EMPTY_TRANSCRIPT = new byte[0]; private static final int UNSAFE_SKIP_HKDF_VERSION = 1; @@ -227,12 +229,19 @@ private TransportCipher generateTransportCipher( OUTPUT_IV_INFO, // This is the HKDF info field used to differentiate IV values AES_GCM_KEY_SIZE_BYTES); SecretKeySpec sessionKey = new SecretKeySpec(derivedKey, "AES"); - return new TransportCipher( - cryptoConf, - conf.cipherTransformation(), - sessionKey, - isClient ? clientIv : serverIv, // If it's the client, use the client IV first - isClient ? serverIv : clientIv); + if (LEGACY_CIPHER_ALGORITHM.equalsIgnoreCase(conf.cipherTransformation())) { + return new CtrTransportCipher( + cryptoConf, + sessionKey, + isClient ? clientIv : serverIv, // If it's the client, use the client IV first + isClient ? serverIv : clientIv); + } else if (CIPHER_ALGORITHM.equalsIgnoreCase(conf.cipherTransformation())) { + return new GcmTransportCipher(sessionKey); + } else { + throw new IllegalArgumentException( + String.format("Unsupported cipher mode: %s. %s and %s are supported.", + conf.cipherTransformation(), CIPHER_ALGORITHM, LEGACY_CIPHER_ALGORITHM)); + } } private byte[] getTranscript(AuthMessage... encryptedPublicKeys) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java new file mode 100644 index 000000000000..85b893751b39 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.security.GeneralSecurityException; +import java.util.Properties; +import javax.crypto.spec.SecretKeySpec; +import javax.crypto.spec.IvParameterSpec; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import org.apache.commons.crypto.stream.CryptoInputStream; +import org.apache.commons.crypto.stream.CryptoOutputStream; + +import org.apache.spark.network.util.AbstractFileRegion; +import org.apache.spark.network.util.ByteArrayReadableChannel; +import org.apache.spark.network.util.ByteArrayWritableChannel; + +/** + * Cipher for encryption and decryption. + */ +public class CtrTransportCipher implements TransportCipher { + @VisibleForTesting + static final String ENCRYPTION_HANDLER_NAME = "CtrTransportEncryption"; + private static final String DECRYPTION_HANDLER_NAME = "CtrTransportDecryption"; + @VisibleForTesting + static final int STREAM_BUFFER_SIZE = 1024 * 32; + + private final Properties conf; + private static final String CIPHER_ALGORITHM = "AES/CTR/NoPadding"; + private final SecretKeySpec key; + private final byte[] inIv; + private final byte[] outIv; + + public CtrTransportCipher( + Properties conf, + SecretKeySpec key, + byte[] inIv, + byte[] outIv) { + this.conf = conf; + this.key = key; + this.inIv = inIv; + this.outIv = outIv; + } + + /* + * This method is for testing purposes only. + */ + @VisibleForTesting + public String getKeyId() throws GeneralSecurityException { + return TransportCipherUtil.getKeyId(key); + } + + @VisibleForTesting + SecretKeySpec getKey() { + return key; + } + + /** The IV for the input channel (i.e. output channel of the remote side). */ + public byte[] getInputIv() { + return inIv; + } + + /** The IV for the output channel (i.e. input channel of the remote side). */ + public byte[] getOutputIv() { + return outIv; + } + + @VisibleForTesting + CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { + return new CryptoOutputStream(CIPHER_ALGORITHM, conf, ch, key, new IvParameterSpec(outIv)); + } + + @VisibleForTesting + CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { + return new CryptoInputStream(CIPHER_ALGORITHM, conf, ch, key, new IvParameterSpec(inIv)); + } + + /** + * Add handlers to channel. + * + * @param ch the channel for adding handlers + * @throws IOException + */ + public void addToChannel(Channel ch) throws IOException { + ch.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this)) + .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this)); + } + + @VisibleForTesting + static class EncryptionHandler extends ChannelOutboundHandlerAdapter { + private final ByteArrayWritableChannel byteEncChannel; + private final CryptoOutputStream cos; + private final ByteArrayWritableChannel byteRawChannel; + private boolean isCipherValid; + + EncryptionHandler(CtrTransportCipher cipher) throws IOException { + byteEncChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); + cos = cipher.createOutputStream(byteEncChannel); + byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); + isCipherValid = true; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + ctx.write(createEncryptedMessage(msg), promise); + } + + @VisibleForTesting + EncryptedMessage createEncryptedMessage(Object msg) { + return new EncryptedMessage(this, cos, msg, byteEncChannel, byteRawChannel); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + try { + if (isCipherValid) { + cos.close(); + } + } finally { + super.close(ctx, promise); + } + } + + /** + * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with the underlying cipher + * after an error occurs. + */ + void reportError() { + this.isCipherValid = false; + } + + boolean isCipherValid() { + return isCipherValid; + } + } + + private static class DecryptionHandler extends ChannelInboundHandlerAdapter { + private final CryptoInputStream cis; + private final ByteArrayReadableChannel byteChannel; + private boolean isCipherValid; + + DecryptionHandler(CtrTransportCipher cipher) throws IOException { + byteChannel = new ByteArrayReadableChannel(); + cis = cipher.createInputStream(byteChannel); + isCipherValid = true; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + ByteBuf buffer = (ByteBuf) data; + + try { + if (!isCipherValid) { + throw new IOException("Cipher is in invalid state."); + } + byte[] decryptedData = new byte[buffer.readableBytes()]; + byteChannel.feedData(buffer); + + int offset = 0; + while (offset < decryptedData.length) { + // SPARK-25535: workaround for CRYPTO-141. + try { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } catch (InternalError ie) { + isCipherValid = false; + throw ie; + } + } + + ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); + } finally { + buffer.release(); + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + // We do the closing of the stream / channel in handlerRemoved(...) as + // this method will be called in all cases: + // + // - when the Channel becomes inactive + // - when the handler is removed from the ChannelPipeline + try { + if (isCipherValid) { + cis.close(); + } + } finally { + super.handlerRemoved(ctx); + } + } + } + + @VisibleForTesting + static class EncryptedMessage extends AbstractFileRegion { + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + private final CryptoOutputStream cos; + private final EncryptionHandler handler; + private final long count; + private long transferred; + + // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has + // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data + // from upper handler, another is used to store encrypted data. + private final ByteArrayWritableChannel byteEncChannel; + private final ByteArrayWritableChannel byteRawChannel; + + private ByteBuffer currentEncrypted; + + EncryptedMessage( + EncryptionHandler handler, + CryptoOutputStream cos, + Object msg, + ByteArrayWritableChannel byteEncChannel, + ByteArrayWritableChannel byteRawChannel) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.handler = handler; + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.transferred = 0; + this.cos = cos; + this.byteEncChannel = byteEncChannel; + this.byteRawChannel = byteRawChannel; + this.count = isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long count() { + return count; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return transferred; + } + + @Override + public EncryptedMessage touch(Object o) { + super.touch(o); + if (region != null) { + region.touch(o); + } + if (buf != null) { + buf.touch(o); + } + return this; + } + + @Override + public EncryptedMessage retain(int increment) { + super.retain(increment); + if (region != null) { + region.retain(increment); + } + if (buf != null) { + buf.retain(increment); + } + return this; + } + + @Override + public boolean release(int decrement) { + if (region != null) { + region.release(decrement); + } + if (buf != null) { + buf.release(decrement); + } + return super.release(decrement); + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + Preconditions.checkArgument(position == transferred(), "Invalid position."); + + if (transferred == count) { + return 0; + } + + long totalBytesWritten = 0L; + do { + if (currentEncrypted == null) { + encryptMore(); + } + + long remaining = currentEncrypted.remaining(); + if (remaining == 0) { + // Just for safety to avoid endless loop. It usually won't happen, but since the + // underlying `region.transferTo` is allowed to transfer 0 bytes, we should handle it for + // safety. + currentEncrypted = null; + byteEncChannel.reset(); + return totalBytesWritten; + } + + long bytesWritten = target.write(currentEncrypted); + totalBytesWritten += bytesWritten; + transferred += bytesWritten; + if (bytesWritten < remaining) { + // break as the underlying buffer in "target" is full + break; + } + currentEncrypted = null; + byteEncChannel.reset(); + } while (transferred < count); + + return totalBytesWritten; + } + + private void encryptMore() throws IOException { + if (!handler.isCipherValid()) { + throw new IOException("Cipher is in invalid state."); + } + byteRawChannel.reset(); + + if (isByteBuf) { + int copied = byteRawChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteRawChannel, region.transferred()); + } + + try { + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + } catch (InternalError ie) { + handler.reportError(); + throw ie; + } + + currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), + 0, byteEncChannel.length()); + } + + @Override + protected void deallocate() { + byteRawChannel.reset(); + byteEncChannel.reset(); + if (region != null) { + region.release(); + } + if (buf != null) { + buf.release(); + } + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java new file mode 100644 index 000000000000..c3540838bef0 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Longs; +import com.google.crypto.tink.subtle.*; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import io.netty.util.ReferenceCounted; +import org.apache.spark.network.util.AbstractFileRegion; +import org.apache.spark.network.util.ByteBufferWriteableChannel; + +import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.security.GeneralSecurityException; +import java.security.InvalidAlgorithmParameterException; + +public class GcmTransportCipher implements TransportCipher { + private static final String HKDF_ALG = "HmacSha256"; + private static final int LENGTH_HEADER_BYTES = 8; + @VisibleForTesting + static final int CIPHERTEXT_BUFFER_SIZE = 32 * 1024; // 32KB + private final SecretKeySpec aesKey; + + public GcmTransportCipher(SecretKeySpec aesKey) { + this.aesKey = aesKey; + } + + AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws InvalidAlgorithmParameterException { + return new AesGcmHkdfStreaming( + aesKey.getEncoded(), + HKDF_ALG, + aesKey.getEncoded().length, + CIPHERTEXT_BUFFER_SIZE, + 0); + } + + /* + * This method is for testing purposes only. + */ + @VisibleForTesting + public String getKeyId() throws GeneralSecurityException { + return TransportCipherUtil.getKeyId(aesKey); + } + + @VisibleForTesting + EncryptionHandler getEncryptionHandler() throws GeneralSecurityException { + return new EncryptionHandler(); + } + + @VisibleForTesting + DecryptionHandler getDecryptionHandler() throws GeneralSecurityException { + return new DecryptionHandler(); + } + + public void addToChannel(Channel ch) throws GeneralSecurityException { + ch.pipeline() + .addFirst("GcmTransportEncryption", getEncryptionHandler()) + .addFirst("GcmTransportDecryption", getDecryptionHandler()); + } + + @VisibleForTesting + class EncryptionHandler extends ChannelOutboundHandlerAdapter { + private final ByteBuffer plaintextBuffer; + private final ByteBuffer ciphertextBuffer; + private final AesGcmHkdfStreaming aesGcmHkdfStreaming; + + EncryptionHandler() throws InvalidAlgorithmParameterException { + aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); + plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize()); + ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage( + aesGcmHkdfStreaming, + msg, + plaintextBuffer, + ciphertextBuffer); + ctx.write(encryptedMessage, promise); + } + } + + static class GcmEncryptedMessage extends AbstractFileRegion { + private final Object plaintextMessage; + private final ByteBuffer plaintextBuffer; + private final ByteBuffer ciphertextBuffer; + private final ByteBuffer headerByteBuffer; + private final long bytesToRead; + private long bytesRead = 0; + private final StreamSegmentEncrypter encrypter; + private long transferred = 0; + private final long encryptedCount; + + GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming, + Object plaintextMessage, + ByteBuffer plaintextBuffer, + ByteBuffer ciphertextBuffer) throws GeneralSecurityException { + Preconditions.checkArgument( + plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion, + "Unrecognized message type: %s", plaintextMessage.getClass().getName()); + this.plaintextMessage = plaintextMessage; + this.plaintextBuffer = plaintextBuffer; + this.ciphertextBuffer = ciphertextBuffer; + // If the ciphertext buffer cannot be fully written the target, transferTo may + // return with it containing some unwritten data. The initial call we'll explicitly + // set its limit to 0 to indicate the first call to transferTo. + this.ciphertextBuffer.limit(0); + + this.bytesToRead = getReadableBytes(); + this.encryptedCount = + LENGTH_HEADER_BYTES + aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead); + byte[] lengthAad = Longs.toByteArray(encryptedCount); + this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(lengthAad); + this.headerByteBuffer = createHeaderByteBuffer(); + } + + // The format of the output is: + // [8 byte length][Internal IV and header][Ciphertext][Auth Tag] + private ByteBuffer createHeaderByteBuffer() { + ByteBuffer encrypterHeader = encrypter.getHeader(); + return ByteBuffer + .allocate(encrypterHeader.remaining() + LENGTH_HEADER_BYTES) + .putLong(encryptedCount) + .put(encrypterHeader) + .flip(); + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return transferred; + } + + @Override + public long count() { + return encryptedCount; + } + + @Override + public GcmEncryptedMessage touch(Object o) { + super.touch(o); + if (plaintextMessage instanceof ByteBuf byteBuf) { + byteBuf.touch(o); + } else if (plaintextMessage instanceof FileRegion fileRegion) { + fileRegion.touch(o); + } + return this; + } + + @Override + public GcmEncryptedMessage retain(int increment) { + super.retain(increment); + if (plaintextMessage instanceof ByteBuf byteBuf) { + byteBuf.retain(increment); + } else if (plaintextMessage instanceof FileRegion fileRegion) { + fileRegion.retain(increment); + } + return this; + } + + @Override + public boolean release(int decrement) { + if (plaintextMessage instanceof ByteBuf byteBuf) { + byteBuf.release(decrement); + } else if (plaintextMessage instanceof FileRegion fileRegion) { + fileRegion.release(decrement); + } + return super.release(decrement); + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + int transferredThisCall = 0; + // If the header has is not empty, try to write it out to the target. + if (headerByteBuffer.hasRemaining()) { + int written = target.write(headerByteBuffer); + transferredThisCall += written; + this.transferred += written; + if (headerByteBuffer.hasRemaining()) { + return written; + } + } + // If the ciphertext buffer is not empty, try to write it to the target. + if (ciphertextBuffer.hasRemaining()) { + int written = target.write(ciphertextBuffer); + transferredThisCall += written; + this.transferred += written; + if (ciphertextBuffer.hasRemaining()) { + return transferredThisCall; + } + } + while (bytesRead < bytesToRead) { + long readableBytes = getReadableBytes(); + int readLimit = + (int) Math.min(readableBytes, plaintextBuffer.remaining()); + if (plaintextMessage instanceof ByteBuf byteBuf) { + Preconditions.checkState(0 == plaintextBuffer.position()); + plaintextBuffer.limit(readLimit); + byteBuf.readBytes(plaintextBuffer); + Preconditions.checkState(readLimit == plaintextBuffer.position()); + } else if (plaintextMessage instanceof FileRegion fileRegion) { + ByteBufferWriteableChannel plaintextChannel = + new ByteBufferWriteableChannel(plaintextBuffer); + long plaintextRead = + fileRegion.transferTo(plaintextChannel, fileRegion.transferred()); + if (plaintextRead < readLimit) { + // If we do not read a full plaintext buffer or all the available + // readable bytes, return what was transferred this call. + return transferredThisCall; + } + } + boolean lastSegment = getReadableBytes() == 0; + plaintextBuffer.flip(); + bytesRead += plaintextBuffer.remaining(); + ciphertextBuffer.clear(); + try { + encrypter.encryptSegment(plaintextBuffer, lastSegment, ciphertextBuffer); + } catch (GeneralSecurityException e) { + throw new IllegalStateException("GeneralSecurityException from encrypter", e); + } + plaintextBuffer.clear(); + ciphertextBuffer.flip(); + int written = target.write(ciphertextBuffer); + transferredThisCall += written; + this.transferred += written; + if (ciphertextBuffer.hasRemaining()) { + // In this case, upon calling transferTo again, it will try to write the + // remaining ciphertext buffer in the conditional before this loop. + return transferredThisCall; + } + } + return transferredThisCall; + } + + private long getReadableBytes() { + if (plaintextMessage instanceof ByteBuf byteBuf) { + return byteBuf.readableBytes(); + } else if (plaintextMessage instanceof FileRegion fileRegion) { + return fileRegion.count() - fileRegion.transferred(); + } else { + throw new IllegalArgumentException("Unsupported message type: " + + plaintextMessage.getClass().getName()); + } + } + + @Override + protected void deallocate() { + if (plaintextMessage instanceof ReferenceCounted referenceCounted) { + referenceCounted.release(); + } + plaintextBuffer.clear(); + ciphertextBuffer.clear(); + } + } + + @VisibleForTesting + class DecryptionHandler extends ChannelInboundHandlerAdapter { + private final ByteBuffer expectedLengthBuffer; + private final ByteBuffer headerBuffer; + private final ByteBuffer ciphertextBuffer; + private final AesGcmHkdfStreaming aesGcmHkdfStreaming; + private final StreamSegmentDecrypter decrypter; + private final int plaintextSegmentSize; + private boolean decrypterInit = false; + private boolean completed = false; + private int segmentNumber = 0; + private long expectedLength = -1; + private long ciphertextRead = 0; + + DecryptionHandler() throws GeneralSecurityException { + aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); + expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES); + headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength()); + ciphertextBuffer = + ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); + decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); + plaintextSegmentSize = aesGcmHkdfStreaming.getPlaintextSegmentSize(); + } + + private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { + if (expectedLength < 0) { + ciphertextNettyBuf.readBytes(expectedLengthBuffer); + if (expectedLengthBuffer.hasRemaining()) { + // We did not read enough bytes to initialize the expected length. + return false; + } + expectedLengthBuffer.flip(); + expectedLength = expectedLengthBuffer.getLong(); + if (expectedLength < 0) { + throw new IllegalStateException("Invalid expected ciphertext length."); + } + ciphertextRead += LENGTH_HEADER_BYTES; + } + return true; + } + + private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) + throws GeneralSecurityException { + // Check if the ciphertext header has been read. This contains + // the IV and other internal metadata. + if (!decrypterInit) { + ciphertextNettyBuf.readBytes(headerBuffer); + if (headerBuffer.hasRemaining()) { + // We did not read enough bytes to initialize the header. + return false; + } + headerBuffer.flip(); + byte[] lengthAad = Longs.toByteArray(expectedLength); + decrypter.init(headerBuffer, lengthAad); + decrypterInit = true; + ciphertextRead += aesGcmHkdfStreaming.getHeaderLength(); + if (expectedLength == ciphertextRead) { + // If the expected length is just the header, the ciphertext is 0 length. + completed = true; + } + } + return true; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) + throws GeneralSecurityException { + Preconditions.checkArgument(ciphertextMessage instanceof ByteBuf, + "Unrecognized message type: %s", + ciphertextMessage.getClass().getName()); + ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage; + // The format of the output is: + // [8 byte length][Internal IV and header][Ciphertext][Auth Tag] + try { + if (!initalizeExpectedLength(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize the expected length. + return; + } + if (!initalizeDecrypter(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize a header, needed to + // initialize a decrypter. + return; + } + int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + while (nettyBufReadableBytes > 0 && !completed) { + // Read the ciphertext into the local buffer + int readableBytes = Integer.min( + nettyBufReadableBytes, + ciphertextBuffer.remaining()); + int expectedRemaining = (int) (expectedLength - ciphertextRead); + int bytesToRead = Integer.min(readableBytes, expectedRemaining); + // The smallest ciphertext size is 16 bytes for the auth tag + ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead); + ciphertextNettyBuf.readBytes(ciphertextBuffer); + ciphertextRead += bytesToRead; + // Check if this is the last segment + if (ciphertextRead == expectedLength) { + completed = true; + } else if (ciphertextRead > expectedLength) { + throw new IllegalStateException("Read more ciphertext than expected."); + } + // If the ciphertext buffer is full, or this is the last segment, + // then decrypt it and fire a read. + if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { + ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); + ciphertextBuffer.flip(); + decrypter.decryptSegment( + ciphertextBuffer, + segmentNumber, + completed, + plaintextBuffer); + segmentNumber++; + // Clear the ciphertext buffer because it's been read + ciphertextBuffer.clear(); + plaintextBuffer.flip(); + ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer)); + } else { + // Set the ciphertext buffer up to read the next chunk + ciphertextBuffer.limit(ciphertextBuffer.capacity()); + } + nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + } + } finally { + ciphertextNettyBuf.release(); + } + } + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index b507f911fe11..355c55272018 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -17,362 +17,32 @@ package org.apache.spark.network.crypto; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ReadableByteChannel; -import java.nio.channels.WritableByteChannel; -import java.util.Properties; -import javax.crypto.spec.SecretKeySpec; -import javax.crypto.spec.IvParameterSpec; - import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.*; -import org.apache.commons.crypto.stream.CryptoInputStream; -import org.apache.commons.crypto.stream.CryptoOutputStream; - -import org.apache.spark.network.util.AbstractFileRegion; -import org.apache.spark.network.util.ByteArrayReadableChannel; -import org.apache.spark.network.util.ByteArrayWritableChannel; - -/** - * Cipher for encryption and decryption. - */ -public class TransportCipher { - @VisibleForTesting - static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption"; - private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption"; - @VisibleForTesting - static final int STREAM_BUFFER_SIZE = 1024 * 32; - - private final Properties conf; - private final String cipher; - private final SecretKeySpec key; - private final byte[] inIv; - private final byte[] outIv; - - public TransportCipher( - Properties conf, - String cipher, - SecretKeySpec key, - byte[] inIv, - byte[] outIv) { - this.conf = conf; - this.cipher = cipher; - this.key = key; - this.inIv = inIv; - this.outIv = outIv; - } - - public String getCipherTransformation() { - return cipher; - } - - @VisibleForTesting - SecretKeySpec getKey() { - return key; - } - - /** The IV for the input channel (i.e. output channel of the remote side). */ - public byte[] getInputIv() { - return inIv; - } - - /** The IV for the output channel (i.e. input channel of the remote side). */ - public byte[] getOutputIv() { - return outIv; - } - - @VisibleForTesting - CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { - return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv)); - } - - @VisibleForTesting - CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { - return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv)); - } - - /** - * Add handlers to channel. - * - * @param ch the channel for adding handlers - * @throws IOException - */ - public void addToChannel(Channel ch) throws IOException { - ch.pipeline() - .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this)) - .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this)); - } - - @VisibleForTesting - static class EncryptionHandler extends ChannelOutboundHandlerAdapter { - private final ByteArrayWritableChannel byteEncChannel; - private final CryptoOutputStream cos; - private final ByteArrayWritableChannel byteRawChannel; - private boolean isCipherValid; - - EncryptionHandler(TransportCipher cipher) throws IOException { - byteEncChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); - cos = cipher.createOutputStream(byteEncChannel); - byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); - isCipherValid = true; - } +import com.google.crypto.tink.subtle.Hex; +import com.google.crypto.tink.subtle.Hkdf; +import io.netty.channel.Channel; - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) - throws Exception { - ctx.write(createEncryptedMessage(msg), promise); - } - - @VisibleForTesting - EncryptedMessage createEncryptedMessage(Object msg) { - return new EncryptedMessage(this, cos, msg, byteEncChannel, byteRawChannel); - } +import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; - @Override - public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - try { - if (isCipherValid) { - cos.close(); - } - } finally { - super.close(ctx, promise); - } - } +interface TransportCipher { + String getKeyId() throws GeneralSecurityException; + void addToChannel(Channel channel) throws IOException, GeneralSecurityException; +} - /** - * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with the underlying cipher - * after an error occurs. +class TransportCipherUtil { + /* + * This method is used for testing to verify key derivation. */ - void reportError() { - this.isCipherValid = false; - } - - boolean isCipherValid() { - return isCipherValid; - } - } - - private static class DecryptionHandler extends ChannelInboundHandlerAdapter { - private final CryptoInputStream cis; - private final ByteArrayReadableChannel byteChannel; - private boolean isCipherValid; - - DecryptionHandler(TransportCipher cipher) throws IOException { - byteChannel = new ByteArrayReadableChannel(); - cis = cipher.createInputStream(byteChannel); - isCipherValid = true; - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { - ByteBuf buffer = (ByteBuf) data; - - try { - if (!isCipherValid) { - throw new IOException("Cipher is in invalid state."); - } - byte[] decryptedData = new byte[buffer.readableBytes()]; - byteChannel.feedData(buffer); - - int offset = 0; - while (offset < decryptedData.length) { - // SPARK-25535: workaround for CRYPTO-141. - try { - offset += cis.read(decryptedData, offset, decryptedData.length - offset); - } catch (InternalError ie) { - isCipherValid = false; - throw ie; - } - } - - ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); - } finally { - buffer.release(); - } - } - - @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - // We do the closing of the stream / channel in handlerRemoved(...) as - // this method will be called in all cases: - // - // - when the Channel becomes inactive - // - when the handler is removed from the ChannelPipeline - try { - if (isCipherValid) { - cis.close(); - } - } finally { - super.handlerRemoved(ctx); - } - } - } - - @VisibleForTesting - static class EncryptedMessage extends AbstractFileRegion { - private final boolean isByteBuf; - private final ByteBuf buf; - private final FileRegion region; - private final CryptoOutputStream cos; - private final EncryptionHandler handler; - private final long count; - private long transferred; - - // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has - // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data - // from upper handler, another is used to store encrypted data. - private final ByteArrayWritableChannel byteEncChannel; - private final ByteArrayWritableChannel byteRawChannel; - - private ByteBuffer currentEncrypted; - - EncryptedMessage( - EncryptionHandler handler, - CryptoOutputStream cos, - Object msg, - ByteArrayWritableChannel byteEncChannel, - ByteArrayWritableChannel byteRawChannel) { - Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, - "Unrecognized message type: %s", msg.getClass().getName()); - this.handler = handler; - this.isByteBuf = msg instanceof ByteBuf; - this.buf = isByteBuf ? (ByteBuf) msg : null; - this.region = isByteBuf ? null : (FileRegion) msg; - this.transferred = 0; - this.cos = cos; - this.byteEncChannel = byteEncChannel; - this.byteRawChannel = byteRawChannel; - this.count = isByteBuf ? buf.readableBytes() : region.count(); - } - - @Override - public long count() { - return count; - } - - @Override - public long position() { - return 0; - } - - @Override - public long transferred() { - return transferred; - } - - @Override - public EncryptedMessage touch(Object o) { - super.touch(o); - if (region != null) { - region.touch(o); - } - if (buf != null) { - buf.touch(o); - } - return this; - } - - @Override - public EncryptedMessage retain(int increment) { - super.retain(increment); - if (region != null) { - region.retain(increment); - } - if (buf != null) { - buf.retain(increment); - } - return this; - } - - @Override - public boolean release(int decrement) { - if (region != null) { - region.release(decrement); - } - if (buf != null) { - buf.release(decrement); - } - return super.release(decrement); - } - - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - Preconditions.checkArgument(position == transferred(), "Invalid position."); - - if (transferred == count) { - return 0; - } - - long totalBytesWritten = 0L; - do { - if (currentEncrypted == null) { - encryptMore(); - } - - long remaining = currentEncrypted.remaining(); - if (remaining == 0) { - // Just for safety to avoid endless loop. It usually won't happen, but since the - // underlying `region.transferTo` is allowed to transfer 0 bytes, we should handle it for - // safety. - currentEncrypted = null; - byteEncChannel.reset(); - return totalBytesWritten; - } - - long bytesWritten = target.write(currentEncrypted); - totalBytesWritten += bytesWritten; - transferred += bytesWritten; - if (bytesWritten < remaining) { - // break as the underlying buffer in "target" is full - break; - } - currentEncrypted = null; - byteEncChannel.reset(); - } while (transferred < count); - - return totalBytesWritten; - } - - private void encryptMore() throws IOException { - if (!handler.isCipherValid()) { - throw new IOException("Cipher is in invalid state."); - } - byteRawChannel.reset(); - - if (isByteBuf) { - int copied = byteRawChannel.write(buf.nioBuffer()); - buf.skipBytes(copied); - } else { - region.transferTo(byteRawChannel, region.transferred()); - } - - try { - cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); - cos.flush(); - } catch (InternalError ie) { - handler.reportError(); - throw ie; - } - - currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), - 0, byteEncChannel.length()); - } - - @Override - protected void deallocate() { - byteRawChannel.reset(); - byteEncChannel.reset(); - if (region != null) { - region.release(); - } - if (buf != null) { - buf.release(); - } + @VisibleForTesting + static String getKeyId(SecretKeySpec key) throws GeneralSecurityException { + byte[] keyIdBytes = Hkdf.computeHkdf("HmacSha256", + key.getEncoded(), + null, + "keyID".getBytes(StandardCharsets.UTF_8), + 32); + return Hex.encode(keyIdBytes); } - } - } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java new file mode 100644 index 000000000000..b20240cfcaa6 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.WritableByteChannel; + +public class ByteBufferWriteableChannel implements WritableByteChannel { + private final ByteBuffer destination; + private boolean open; + + public ByteBufferWriteableChannel(ByteBuffer destination) { + this.destination = destination; + this.open = true; + } + + @Override + public int write(ByteBuffer src) throws IOException { + if (!isOpen()) { + throw new ClosedChannelException(); + } + int bytesToWrite = Math.min(src.remaining(), destination.remaining()); + // Destination buffer is full + if (bytesToWrite == 0) { + return 0; + } + ByteBuffer temp = src.slice().limit(bytesToWrite); + destination.put(temp); + src.position(src.position() + bytesToWrite); + return bytesToWrite; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public void close() { + open = false; + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java index e9846be20c9b..628de9e78033 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -18,75 +18,76 @@ package org.apache.spark.network.crypto; import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; import java.security.GeneralSecurityException; -import java.util.Collections; -import java.util.Random; +import java.util.Map; +import com.google.common.collect.ImmutableMap; import com.google.crypto.tink.subtle.Hex; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.FileRegion; -import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.ConfigProvider; -import org.apache.spark.network.util.MapConfigProvider; -import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.util.*; + import static org.junit.jupiter.api.Assertions.*; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; -import static org.mockito.Mockito.*; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -public class AuthEngineSuite { - private static final String clientPrivate = - "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186"; - private static final String clientChallengeHex = - "fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9" + - "3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d" + - "65f8c426e18ff380f6"; - private static final String serverResponseHex = - "fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4" + - "e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738" + - "08ecad08b46b5ee3ff"; - private static final String derivedKey = "2d6e7a9048c8265c33a8f3747bfcc84c"; +abstract class AuthEngineSuite { + static final String clientPrivate = + "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186"; + static final String clientChallengeHex = + "fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9" + + "3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d" + + "65f8c426e18ff380f6"; + static final String serverResponseHex = + "fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4" + + "e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738" + + "08ecad08b46b5ee3ff"; + static final String derivedKeyId = + "de04fd52d71040ed9d260579dacfdf4f5695f991ce8ddb1dde05a7335880906e"; // This key would have been derived for version 1.0 protocol that did not run a final HKDF round. - private static final String unsafeDerivedKey = - "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31"; - - private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2"; - private static final String outputIv = "a72709baf00785cad6329ce09f631f71"; - private static TransportConf conf; - - @BeforeAll - public static void setUp() { - ConfigProvider v2Provider = new MapConfigProvider(Collections.singletonMap( - "spark.network.crypto.authEngineVersion", "2")); - conf = new TransportConf("rpc", v2Provider); + static final String unsafeDerivedKey = + "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31"; + static TransportConf conf; + + static TransportConf getConf(int authEngineVerison, boolean useCtr) { + String authEngineVersion = (authEngineVerison == 1) ? "1" : "2"; + String mode = useCtr ? "AES/CTR/NoPadding" : "AES/GCM/NoPadding"; + Map confMap = ImmutableMap.of( + "spark.network.crypto.enabled", "true", + "spark.network.crypto.authEngineVersion", authEngineVersion, + "spark.network.crypto.cipher", mode + ); + ConfigProvider v2Provider = new MapConfigProvider(confMap); + return new TransportConf("rpc", v2Provider); } @Test public void testAuthEngine() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); AuthMessage serverResponse = server.response(clientChallenge); client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher serverCipher = server.sessionCipher(); TransportCipher clientCipher = client.sessionCipher(); + assertEquals(clientCipher.getKeyId(), serverCipher.getKeyId()); + } + } - assertArrayEquals(serverCipher.getInputIv(), clientCipher.getOutputIv()); - assertArrayEquals(serverCipher.getOutputIv(), clientCipher.getInputIv()); - assertEquals(serverCipher.getKey(), clientCipher.getKey()); + @Test + public void testFixedChallengeResponse() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + assertEquals(client.sessionCipher().getKeyId(), derivedKeyId); } } @Test public void testCorruptChallengeAppId() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -98,7 +99,6 @@ public void testCorruptChallengeAppId() throws Exception { @Test public void testCorruptChallengeSalt() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -109,7 +109,6 @@ public void testCorruptChallengeSalt() throws Exception { @Test public void testCorruptChallengeCiphertext() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -120,7 +119,6 @@ public void testCorruptChallengeCiphertext() throws Exception { @Test public void testCorruptResponseAppId() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -134,20 +132,18 @@ public void testCorruptResponseAppId() throws Exception { @Test public void testCorruptResponseSalt() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); AuthMessage serverResponse = server.response(clientChallenge); serverResponse.salt()[0] ^= 1; assertThrows(GeneralSecurityException.class, - () -> client.deriveSessionCipher(clientChallenge, serverResponse)); + () -> client.deriveSessionCipher(clientChallenge, serverResponse)); } } @Test public void testCorruptServerCiphertext() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -169,45 +165,6 @@ public void testFixedChallenge() throws Exception { } } - @Test - public void testFixedChallengeResponse() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { - byte[] clientPrivateKey = Hex.decode(clientPrivate); - client.setClientPrivateKey(clientPrivateKey); - AuthMessage clientChallenge = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); - AuthMessage serverResponse = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); - // Verify that the client will accept an old transcript. - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = client.sessionCipher(); - assertEquals(Hex.encode(clientCipher.getKey().getEncoded()), derivedKey); - assertEquals(Hex.encode(clientCipher.getInputIv()), inputIv); - assertEquals(Hex.encode(clientCipher.getOutputIv()), outputIv); - } - } - - @Test - public void testFixedChallengeResponseUnsafeVersion() throws Exception { - ConfigProvider v1Provider = new MapConfigProvider(Collections.singletonMap( - "spark.network.crypto.authEngineVersion", "1")); - TransportConf v1Conf = new TransportConf("rpc", v1Provider); - try (AuthEngine client = new AuthEngine("appId", "secret", v1Conf)) { - byte[] clientPrivateKey = Hex.decode(clientPrivate); - client.setClientPrivateKey(clientPrivateKey); - AuthMessage clientChallenge = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); - AuthMessage serverResponse = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); - // Verify that the client will accept an old transcript. - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = client.sessionCipher(); - assertEquals(Hex.encode(clientCipher.getKey().getEncoded()), unsafeDerivedKey); - assertEquals(Hex.encode(clientCipher.getInputIv()), inputIv); - assertEquals(Hex.encode(clientCipher.getOutputIv()), outputIv); - } - } - @Test public void testMismatchedSecret() throws Exception { try (AuthEngine client = new AuthEngine("appId", "secret", conf); @@ -216,70 +173,4 @@ public void testMismatchedSecret() throws Exception { assertThrows(GeneralSecurityException.class, () -> server.response(clientChallenge)); } } - - @Test - public void testEncryptedMessage() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); - AuthEngine server = new AuthEngine("appId", "secret", conf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - - TransportCipher cipher = server.sessionCipher(); - TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher); - - byte[] data = new byte[TransportCipher.STREAM_BUFFER_SIZE + 1]; - new Random().nextBytes(data); - ByteBuf buf = Unpooled.wrappedBuffer(data); - - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); - TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf); - while (emsg.transferred() < emsg.count()) { - emsg.transferTo(channel, emsg.transferred()); - } - assertEquals(data.length, channel.length()); - } - } - - @Test - public void testEncryptedMessageWhenTransferringZeroBytes() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); - AuthEngine server = new AuthEngine("appId", "secret", conf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - - TransportCipher cipher = server.sessionCipher(); - TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher); - - int testDataLength = 4; - FileRegion region = mock(FileRegion.class); - when(region.count()).thenReturn((long) testDataLength); - // Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one. - when(region.transferTo(any(), anyLong())).thenAnswer(new Answer() { - - private boolean firstTime = true; - - @Override - public Long answer(InvocationOnMock invocationOnMock) throws Throwable { - if (firstTime) { - firstTime = false; - return 0L; - } else { - WritableByteChannel channel = invocationOnMock.getArgument(0); - channel.write(ByteBuffer.wrap(new byte[testDataLength])); - return (long) testDataLength; - } - } - }); - - TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region); - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength); - // "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes. - assertEquals(0L, emsg.transferTo(channel, emsg.transferred())); - assertEquals(testDataLength, emsg.transferTo(channel, emsg.transferred())); - assertEquals(emsg.transferred(), emsg.count()); - assertEquals(4, channel.length()); - } - } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 90f6c874a6c8..cb5929f7c65b 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -49,7 +49,7 @@ public class AuthIntegrationSuite { private AuthTestCtx ctx; @AfterEach - public void cleanUp() throws Exception { + public void cleanUp() { if (ctx != null) { ctx.close(); } @@ -57,8 +57,8 @@ public void cleanUp() throws Exception { } @Test - public void testNewAuth() throws Exception { - ctx = new AuthTestCtx(); + public void testNewCtrAuth() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/CTR/NoPadding"); ctx.createServer("secret"); ctx.createClient("secret"); @@ -68,8 +68,28 @@ public void testNewAuth() throws Exception { } @Test - public void testAuthFailure() throws Exception { - ctx = new AuthTestCtx(); + public void testNewGcmAuth() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/GCM/NoPadding"); + ctx.createServer("secret"); + ctx.createClient("secret"); + ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + assertEquals("Pong", JavaUtils.bytesToString(reply)); + assertNull(ctx.authRpcHandler.saslHandler); + } + + @Test + public void testCtrAuthFailure() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/CTR/NoPadding"); + ctx.createServer("server"); + + assertThrows(Exception.class, () -> ctx.createClient("client")); + assertFalse(ctx.authRpcHandler.isAuthenticated()); + assertFalse(ctx.serverChannel.isActive()); + } + + @Test + public void testGcmAuthFailure() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/GCM/NoPadding"); ctx.createServer("server"); assertThrows(Exception.class, () -> ctx.createClient("client")); @@ -100,7 +120,7 @@ public void testSaslClientFallback() throws Exception { } @Test - public void testAuthReplay() throws Exception { + public void testCtrAuthReplay() throws Exception { // This test covers the case where an attacker replays a challenge message sniffed from the // network, but doesn't know the actual secret. The server should close the connection as // soon as a message is sent after authentication is performed. This is emulated by removing @@ -110,16 +130,16 @@ public void testAuthReplay() throws Exception { ctx.createClient("secret"); assertNotNull(ctx.client.getChannel().pipeline() - .remove(TransportCipher.ENCRYPTION_HANDLER_NAME)); + .remove(CtrTransportCipher.ENCRYPTION_HANDLER_NAME)); assertThrows(Exception.class, () -> ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000)); assertTrue(ctx.authRpcHandler.isAuthenticated()); } @Test - public void testLargeMessageEncryption() throws Exception { + public void testLargeCtrMessageEncryption() throws Exception { // Use a big length to create a message that cannot be put into the encryption buffer completely - final int testErrorMessageLength = TransportCipher.STREAM_BUFFER_SIZE; + final int testErrorMessageLength = CtrTransportCipher.STREAM_BUFFER_SIZE; ctx = new AuthTestCtx(new RpcHandler() { @Override public void receive( @@ -157,6 +177,23 @@ public void testValidMergedBlockMetaReqHandler() throws Exception { assertNotNull(ctx.authRpcHandler.getMergedBlockMetaReqHandler()); } + private static class DummyRpcHandler extends RpcHandler { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + String messageString = JavaUtils.bytesToString(message); + assertEquals("Ping", messageString); + callback.onSuccess(JavaUtils.stringToBytes("Pong")); + } + + @Override + public StreamManager getStreamManager() { + return null; + } + } + private static class AuthTestCtx { private final String appId = "testAppId"; @@ -169,25 +206,17 @@ private static class AuthTestCtx { volatile AuthRpcHandler authRpcHandler; AuthTestCtx() throws Exception { - this(new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - assertEquals("Ping", JavaUtils.bytesToString(message)); - callback.onSuccess(JavaUtils.stringToBytes("Pong")); - } - - @Override - public StreamManager getStreamManager() { - return null; - } - }); + this(new DummyRpcHandler()); } AuthTestCtx(RpcHandler rpcHandler) throws Exception { - Map testConf = ImmutableMap.of("spark.network.crypto.enabled", "true"); + this(rpcHandler, "AES/CTR/NoPadding"); + } + + AuthTestCtx(RpcHandler rpcHandler, String mode) throws Exception { + Map testConf = ImmutableMap.of( + "spark.network.crypto.enabled", "true", + "spark.network.crypto.cipher", mode); this.conf = new TransportConf("rpc", new MapConfigProvider(testConf)); this.ctx = new TransportContext(conf, rpcHandler); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java new file mode 100644 index 000000000000..c353ee392ff4 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import com.google.crypto.tink.subtle.Hex; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +public class CtrAuthEngineSuite extends AuthEngineSuite { + private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2"; + private static final String outputIv = "a72709baf00785cad6329ce09f631f71"; + + @BeforeAll + public static void setUp() { + conf = getConf(2, true); + } + + @Test + public void testAuthEngine() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher serverCipher = server.sessionCipher(); + TransportCipher clientCipher = client.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + assert(serverCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrClient = (CtrTransportCipher) clientCipher; + CtrTransportCipher ctrServer = (CtrTransportCipher) serverCipher; + assertArrayEquals(ctrServer.getInputIv(), ctrClient.getOutputIv()); + assertArrayEquals(ctrServer.getOutputIv(), ctrClient.getInputIv()); + assertEquals(ctrServer.getKey(), ctrClient.getKey()); + } + } + + @Test + public void testCtrFixedChallengeIvResponse() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = client.sessionCipher(); + assertEquals(clientCipher.getKeyId(), derivedKeyId); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv); + assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv); + } + } + + @Test + public void testFixedChallengeResponseUnsafeVersion() throws Exception { + TransportConf v1Conf = getConf(1, true); + try (AuthEngine client = new AuthEngine("appId", "secret", v1Conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = client.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + assertEquals(Hex.encode(ctrTransportCipher.getKey().getEncoded()), unsafeDerivedKey); + assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv); + assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv); + } + } + + @Test + public void testCtrEncryptedMessage() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher clientCipher = server.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + CtrTransportCipher.EncryptionHandler handler = + new CtrTransportCipher.EncryptionHandler(ctrTransportCipher); + + byte[] data = new byte[CtrTransportCipher.STREAM_BUFFER_SIZE + 1]; + new Random().nextBytes(data); + ByteBuf buf = Unpooled.wrappedBuffer(data); + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); + CtrTransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf); + while (emsg.transferred() < emsg.count()) { + emsg.transferTo(channel, emsg.transferred()); + } + assertEquals(data.length, channel.length()); + } + } + + @Test + public void testCtrEncryptedMessageWhenTransferringZeroBytes() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + CtrTransportCipher.EncryptionHandler handler = + new CtrTransportCipher.EncryptionHandler(ctrTransportCipher); + int testDataLength = 4; + FileRegion region = mock(FileRegion.class); + when(region.count()).thenReturn((long) testDataLength); + // Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one. + when(region.transferTo(any(), anyLong())).thenAnswer(new Answer() { + + private boolean firstTime = true; + + @Override + public Long answer(InvocationOnMock invocationOnMock) throws Throwable { + if (firstTime) { + firstTime = false; + return 0L; + } else { + WritableByteChannel channel = invocationOnMock.getArgument(0); + channel.write(ByteBuffer.wrap(new byte[testDataLength])); + return (long) testDataLength; + } + } + }); + + CtrTransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region); + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength); + // "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes. + assertEquals(0L, emsg.transferTo(channel, emsg.transferred())); + assertEquals(testDataLength, emsg.transferTo(channel, emsg.transferred())); + assertEquals(emsg.transferred(), emsg.count()); + assertEquals(4, channel.length()); + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java new file mode 100644 index 000000000000..20efb8d57dcb --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import org.apache.spark.network.util.*; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import javax.crypto.AEADBadTagException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +public class GcmAuthEngineSuite extends AuthEngineSuite { + + @BeforeAll + public static void setUp() { + // Uses GCM mode + conf = getConf(2, false); + } + + @Test + public void testGcmEncryptedMessage() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + byte[] data = new byte[plaintextSegmentSize + (plaintextSegmentSize / 2)]; + // Just writing some bytes. + data[0] = 'a'; + data[data.length / 2] = 'b'; + data[data.length - 10] = 'c'; + ByteBuf buf = Unpooled.wrappedBuffer(data); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, ciphertext); + verify(ctx, times(2)) + .fireChannelRead(captorPlaintext.capture()); + ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); + assertEquals(plaintextSegmentSize/2, + lastPlaintextSegment.readableBytes()); + assertEquals('c', + lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + } + } + + static class FakeRegion extends AbstractFileRegion { + private final ByteBuffer[] source; + private int sourcePosition; + private final long count; + + FakeRegion(ByteBuffer... source) { + this.source = source; + sourcePosition = 0; + count = remaining(); + } + + private long remaining() { + long remaining = 0; + for (ByteBuffer buffer : source) { + remaining += buffer.remaining(); + } + return remaining; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return count - remaining(); + } + + @Override + public long count() { + return count; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + if (sourcePosition < source.length) { + ByteBuffer currentBuffer = source[sourcePosition]; + long written = target.write(currentBuffer); + if (!currentBuffer.hasRemaining()) { + sourcePosition++; + } + return written; + } else { + return 0; + } + } + + @Override + protected void deallocate() { + } + } + + private static ByteBuffer getTestByteBuf(int size, byte fill) { + byte[] data = new byte[size]; + Arrays.fill(data, fill); + return ByteBuffer.wrap(data); + } + + @Test + public void testGcmEncryptedMessageFileRegion() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int halfSegmentSize = plaintextSegmentSize / 2; + int totalSize = plaintextSegmentSize + halfSegmentSize; + + // Set up some fragmented segments to test + ByteBuffer halfSegment = getTestByteBuf(halfSegmentSize, (byte) 'a'); + int smallFragmentSize = 128; + ByteBuffer smallFragment = getTestByteBuf(smallFragmentSize, (byte) 'b'); + int remainderSize = totalSize - halfSegmentSize - smallFragmentSize; + ByteBuffer remainder = getTestByteBuf(remainderSize, (byte) 'c'); + FakeRegion fakeRegion = new FakeRegion(halfSegment, smallFragment, remainder); + assertEquals(totalSize, fakeRegion.count()); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, fakeRegion, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + + // We'll simulate the FileRegion only transferring half a segment. + // The encrypted message should buffer the partial segment plaintext. + long ciphertextTransferred = 0; + while (ciphertextTransferred < encrypted.count()) { + long chunkTransferred = encrypted.transferTo(channel, 0); + ciphertextTransferred += chunkTransferred; + } + assertEquals(encrypted.count(), ciphertextTransferred); + + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, ciphertext); + verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); + ByteBuf plaintext = captorPlaintext.getValue(); + // We expect this to be the last partial plaintext segment + int expectedLength = totalSize % plaintextSegmentSize; + assertEquals(expectedLength, plaintext.readableBytes()); + // This will be the "remainder" segment that is filled to 'c' + assertEquals('c', plaintext.getByte(0)); + } + } + + + @Test + public void testGcmUnalignedDecryption() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int plaintextSize = plaintextSegmentSize + (plaintextSegmentSize / 2); + byte[] data = new byte[plaintextSize]; + Arrays.fill(data, (byte) 'x'); + ByteBuf buf = Unpooled.wrappedBuffer(data); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Split up the ciphertext into some different sized chunks + int firstChunkSize = plaintextSize / 2; + ByteBuf mockCiphertext = spy(ciphertext); + when(mockCiphertext.readableBytes()) + .thenReturn(firstChunkSize, firstChunkSize).thenCallRealMethod(); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, mockCiphertext); + verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); + ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); + assertEquals(plaintextSegmentSize/2, + lastPlaintextSegment.readableBytes()); + assertEquals('x', + lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + } + } + + @Test + public void testCorruptGcmEncryptedMessage() throws Exception { + TransportConf gcmConf = getConf(2, false); + + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher clientCipher = server.sessionCipher(); + assert (clientCipher instanceof GcmTransportCipher); + + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + byte[] zeroData = new byte[1024 * 32]; + // Just writing some bytes. + ByteBuf buf = Unpooled.wrappedBuffer(zeroData); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + byte b = ciphertext.getByte(100); + // Inverting the bits of the 100th bit + ciphertext.setByte(100, ~b & 0xFF); + assertThrows(AEADBadTagException.class, () -> decryptionHandler.channelRead(ctx, ciphertext)); + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java index da62d3b2de31..8977f29034fe 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java @@ -41,10 +41,10 @@ public class TransportCipherSuite { @Test - public void testBufferNotLeaksOnInternalError() throws IOException { + public void testCtrBufferNotLeaksOnInternalError() throws IOException { String algorithm = "TestAlgorithm"; TransportConf conf = new TransportConf("Test", MapConfigProvider.EMPTY); - TransportCipher cipher = new TransportCipher(conf.cryptoConf(), conf.cipherTransformation(), + CtrTransportCipher cipher = new CtrTransportCipher(conf.cryptoConf(), new SecretKeySpec(new byte[256], algorithm), new byte[0], new byte[0]) { @Override diff --git a/docs/security.md b/docs/security.md index 455935fcffca..81b6bfc1adfe 100644 --- a/docs/security.md +++ b/docs/security.md @@ -207,6 +207,15 @@ The following table describes the different options available for configuring th 2.2.0 + + spark.network.crypto.cipher + AES/CTR/NoPadding + + Cipher mode to use. Defaults "AES/CTR/NoPadding" for backward compatibility, which is not authenticated. + Recommended to use "AES/GCM/NoPadding", which is an authenticated encryption mode. + + 4.0.0 + spark.network.crypto.authEngineVersion 1