1919
2020import com .google .common .annotations .VisibleForTesting ;
2121import com .google .common .base .Preconditions ;
22+ import com .google .common .primitives .Longs ;
2223import com .google .crypto .tink .subtle .AesGcmHkdfStreaming ;
2324import com .google .crypto .tink .subtle .StreamSegmentDecrypter ;
2425import com .google .crypto .tink .subtle .StreamSegmentEncrypter ;
26+ import io .netty .buffer .ByteBuf ;
2527import io .netty .buffer .Unpooled ;
2628import io .netty .channel .*;
2729import io .netty .util .ReferenceCounted ;
2830import org .apache .spark .network .util .AbstractFileRegion ;
29- import io .netty .buffer .ByteBuf ;
3031
3132import javax .crypto .spec .SecretKeySpec ;
3233import java .io .IOException ;
3334import java .nio .ByteBuffer ;
3435import java .nio .channels .ClosedChannelException ;
3536import java .nio .channels .WritableByteChannel ;
3637import java .security .GeneralSecurityException ;
38+ import java .security .InvalidAlgorithmParameterException ;
3739
3840public class GcmTransportCipher implements TransportCipher {
39- private static final byte [] DEFAULT_AAD = new byte [ 0 ] ;
41+ private static final String HKDF_ALG = "HmacSha256" ;
4042 private static final int LENGTH_HEADER_BYTES = 8 ;
4143 @ VisibleForTesting
42- static final int CIPHERTEXT_BUFFER_SIZE = 1024 ;
44+ static final int CIPHERTEXT_BUFFER_SIZE = 32 * 1024 ; // 32KB
4345 private final SecretKeySpec aesKey ;
4446
4547 public GcmTransportCipher (SecretKeySpec aesKey ) {
4648 this .aesKey = aesKey ;
4749 }
4850
51+ AesGcmHkdfStreaming getAesGcmHkdfStreaming () throws InvalidAlgorithmParameterException {
52+ return new AesGcmHkdfStreaming (
53+ aesKey .getEncoded (),
54+ HKDF_ALG ,
55+ aesKey .getEncoded ().length ,
56+ CIPHERTEXT_BUFFER_SIZE ,
57+ 0 );
58+ }
59+
4960 @ VisibleForTesting
5061 EncryptionHandler getEncryptionHandler () throws GeneralSecurityException {
5162 return new EncryptionHandler ();
@@ -68,13 +79,8 @@ class EncryptionHandler extends ChannelOutboundHandlerAdapter {
6879 private final ByteBuffer ciphertextBuffer ;
6980 private final AesGcmHkdfStreaming aesGcmHkdfStreaming ;
7081
71- EncryptionHandler () throws GeneralSecurityException {
72- aesGcmHkdfStreaming = new AesGcmHkdfStreaming (
73- aesKey .getEncoded (),
74- "HmacSha256" ,
75- aesKey .getEncoded ().length ,
76- CIPHERTEXT_BUFFER_SIZE ,
77- 0 );
82+ EncryptionHandler () throws InvalidAlgorithmParameterException {
83+ aesGcmHkdfStreaming = getAesGcmHkdfStreaming ();
7884 plaintextBuffer = ByteBuffer .allocate (aesGcmHkdfStreaming .getPlaintextSegmentSize ());
7985 ciphertextBuffer = ByteBuffer .allocate (aesGcmHkdfStreaming .getCiphertextSegmentSize ());
8086 }
@@ -95,10 +101,10 @@ static class GcmEncryptedMessage extends AbstractFileRegion {
95101 private final Object plaintextMessage ;
96102 private final ByteBuffer plaintextBuffer ;
97103 private final ByteBuffer ciphertextBuffer ;
104+ private final ByteBuffer headerByteBuffer ;
98105 private final long bytesToRead ;
99106 private long bytesRead = 0 ;
100107 private final StreamSegmentEncrypter encrypter ;
101- private boolean headerWritten = false ;
102108 private long transferred = 0 ;
103109 private final long encryptedCount ;
104110
@@ -110,14 +116,30 @@ static class GcmEncryptedMessage extends AbstractFileRegion {
110116 plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion ,
111117 "Unrecognized message type: %s" , plaintextMessage .getClass ().getName ());
112118 this .plaintextMessage = plaintextMessage ;
113- this .bytesToRead = getReadableBytes ();
114119 this .plaintextBuffer = plaintextBuffer ;
115- this .plaintextBuffer .clear ();
116120 this .ciphertextBuffer = ciphertextBuffer ;
117- this .ciphertextBuffer .clear ();
118- this .encrypter = aesGcmHkdfStreaming .newStreamSegmentEncrypter (DEFAULT_AAD );
121+ // If the ciphertext buffer cannot be fully written the target, transferTo may
122+ // return with it containing some unwritten data. The initial call we'll explicitly
123+ // set its limit to 0 to indicate the first call to transferTo.
124+ this .ciphertextBuffer .limit (0 );
125+
126+ this .bytesToRead = getReadableBytes ();
119127 this .encryptedCount =
120128 LENGTH_HEADER_BYTES + aesGcmHkdfStreaming .expectedCiphertextSize (bytesToRead );
129+ byte [] lengthAad = Longs .toByteArray (encryptedCount );
130+ this .encrypter = aesGcmHkdfStreaming .newStreamSegmentEncrypter (lengthAad );
131+ this .headerByteBuffer = createHeaderByteBuffer ();
132+ }
133+
134+ // The format of the output is:
135+ // [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
136+ private ByteBuffer createHeaderByteBuffer () {
137+ ByteBuffer encrypterHeader = encrypter .getHeader ();
138+ return ByteBuffer
139+ .allocate (encrypterHeader .remaining () + LENGTH_HEADER_BYTES )
140+ .putLong (encryptedCount )
141+ .put (encrypterHeader )
142+ .flip ();
121143 }
122144
123145 @ Override
@@ -135,25 +157,61 @@ public long count() {
135157 return encryptedCount ;
136158 }
137159
160+ @ Override
161+ public GcmEncryptedMessage touch (Object o ) {
162+ super .touch (o );
163+ if (plaintextMessage instanceof ByteBuf byteBuf ) {
164+ byteBuf .touch (o );
165+ } else if (plaintextMessage instanceof AbstractFileRegion fileRegion ) {
166+ fileRegion .touch (o );
167+ }
168+ return this ;
169+ }
170+
171+ @ Override
172+ public GcmEncryptedMessage retain (int increment ) {
173+ super .retain (increment );
174+ if (plaintextMessage instanceof ByteBuf byteBuf ) {
175+ byteBuf .retain (increment );
176+ } else if (plaintextMessage instanceof AbstractFileRegion fileRegion ) {
177+ fileRegion .retain (increment );
178+ }
179+ return this ;
180+ }
181+
182+ @ Override
183+ public boolean release (int decrement ) {
184+ if (plaintextMessage instanceof ByteBuf byteBuf ) {
185+ byteBuf .release (decrement );
186+ } else if (plaintextMessage instanceof AbstractFileRegion fileRegion ) {
187+ fileRegion .release (decrement );
188+ }
189+ return super .release (decrement );
190+ }
191+
138192 @ Override
139193 public long transferTo (WritableByteChannel target , long position ) throws IOException {
140194 Preconditions .checkArgument (position == transferred (),
141195 "Invalid position." );
142196 int transferredThisCall = 0 ;
143- // The format of the output is:
144- // [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
145- if (!headerWritten ) {
146- ByteBuffer expectedLength = ByteBuffer
147- .allocate (LENGTH_HEADER_BYTES )
148- .putLong (encryptedCount )
149- .flip ();
150- target .write (expectedLength );
151- int headerWritten = LENGTH_HEADER_BYTES + target .write (encrypter .getHeader ());
152- transferredThisCall += headerWritten ;
153- this .transferred += headerWritten ;
154- this .headerWritten = true ;
197+ // If the header has is not empty, try to write it out to the target.
198+ if (headerByteBuffer .hasRemaining ()) {
199+ int written = target .write (headerByteBuffer );
200+ transferredThisCall += written ;
201+ this .transferred += written ;
202+ if (headerByteBuffer .hasRemaining ()) {
203+ return written ;
204+ }
205+ }
206+ // If the ciphertext buffer is not empty, try to write it to the target.
207+ if (ciphertextBuffer .hasRemaining ()) {
208+ int written = target .write (ciphertextBuffer );
209+ transferredThisCall += written ;
210+ this .transferred += written ;
211+ if (ciphertextBuffer .hasRemaining ()) {
212+ return transferredThisCall ;
213+ }
155214 }
156-
157215 while (bytesRead < bytesToRead ) {
158216 long readableBytes = getReadableBytes ();
159217 boolean lastSegment = readableBytes <= plaintextBuffer .capacity ();
@@ -186,12 +244,14 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep
186244 throw new RuntimeException (e );
187245 }
188246 ciphertextBuffer .flip ();
189- int outputRemaining = ciphertextBuffer .remaining ();
190- while (ciphertextBuffer .hasRemaining ()) {
191- target .write (ciphertextBuffer );
247+ int written = target .write (ciphertextBuffer );
248+ transferredThisCall += written ;
249+ this .transferred += written ;
250+ if (ciphertextBuffer .hasRemaining ()) {
251+ // In this case, upon calling transferTo again, it will try to write the
252+ // remaining ciphertext buffer in the conditional before this loop.
253+ return transferredThisCall ;
192254 }
193- transferredThisCall += outputRemaining ;
194- transferred += outputRemaining ;
195255 }
196256 return transferredThisCall ;
197257 }
@@ -229,12 +289,7 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
229289 private long ciphertextRead = 0 ;
230290
231291 DecryptionHandler () throws GeneralSecurityException {
232- aesGcmHkdfStreaming = new AesGcmHkdfStreaming (
233- aesKey .getEncoded (),
234- "HmacSha256" ,
235- aesKey .getEncoded ().length ,
236- CIPHERTEXT_BUFFER_SIZE ,
237- 0 );
292+ aesGcmHkdfStreaming = getAesGcmHkdfStreaming ();
238293 plaintextBuffer =
239294 ByteBuffer .allocate (aesGcmHkdfStreaming .getPlaintextSegmentSize ());
240295 ciphertextBuffer =
@@ -270,7 +325,8 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
270325 ByteBuffer headerBuffer = ByteBuffer .allocate (headerLength );
271326 ciphertextNettyBuf .readBytes (headerBuffer );
272327 headerBuffer .flip ();
273- decrypter .init (headerBuffer , DEFAULT_AAD );
328+ byte [] lengthAad = Longs .toByteArray (expectedLength );
329+ decrypter .init (headerBuffer , lengthAad );
274330 decrypterInit = true ;
275331 ciphertextRead += headerLength ;
276332 }
@@ -290,10 +346,12 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
290346 if (readableBytes == 0 ) {
291347 return ;
292348 }
349+ int expectedRemaining = (int ) (expectedLength - ciphertextRead );
350+ int bytesToRead = Integer .min (readableBytes , expectedRemaining );
293351 // The smallest ciphertext size is 16 bytes for the auth tag
294- ciphertextBuffer .limit (readableBytes );
352+ ciphertextBuffer .limit (bytesToRead );
295353 ciphertextNettyBuf .readBytes (ciphertextBuffer );
296- ciphertextRead += readableBytes ;
354+ ciphertextRead += bytesToRead ;
297355 // Check if this is the last segment
298356 boolean lastSegment = false ;
299357 if (ciphertextRead == expectedLength ) {
0 commit comments