diff --git a/api/src/main/java/io/grpc/ByteBufferBacked.java b/api/src/main/java/io/grpc/ByteBufferBacked.java new file mode 100644 index 00000000000..78925c6e52f --- /dev/null +++ b/api/src/main/java/io/grpc/ByteBufferBacked.java @@ -0,0 +1,45 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.nio.ByteBuffer; + +/** + * Extension to an {@link java.io.OutputStream} or alike by adding methods that + * allow writing directly to an underlying {@link ByteBuffer} + */ +public interface ByteBufferBacked { + + /** + * If available, returns a {@link ByteBuffer} backing this writable + * object whose position corresponds to this object's current + * writing position and with at least {@code size} remaining bytes. + * + * @param size minimum required size + * @return null if not supported or writable buffer of insufficient size + */ + ByteBuffer getWritableBuffer(int size); + + /** + * This must be called to notify that data has been written to the + * buffer previously returned from {@link #getWritableBuffer(int)}, + * prior to calling any other methods. + * + * @param written number of bytes written + */ + void bufferBytesWritten(int written); +} diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index ad86d450faa..31a9172f5aa 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -22,6 +22,7 @@ import static java.lang.Math.min; import com.google.common.io.ByteStreams; +import io.grpc.ByteBufferBacked; import io.grpc.Codec; import io.grpc.Compressor; import io.grpc.Drainable; @@ -277,17 +278,12 @@ private static int writeToOutputStream(InputStream message, OutputStream outputS } } - private void writeRaw(byte[] b, int off, int len) { + // package-private to avoid synthetic access from OutputStreamAdapter + void writeRaw(byte[] b, int off, int len) { while (len > 0) { - if (buffer != null && buffer.writableBytes() == 0) { - commitToSink(false, false); - } - if (buffer == null) { - // Request a buffer allocation using the message length as a hint. - buffer = bufferAllocator.allocate(len); - } - int toWrite = min(len, buffer.writableBytes()); - buffer.write(b, off, toWrite); + WritableBuffer buf = getWritableBuffer(len); + int toWrite = min(len, buf.writableBytes()); + buf.write(b, off, toWrite); off += toWrite; len -= toWrite; } @@ -303,6 +299,17 @@ public void flush() { } } + WritableBuffer getWritableBuffer(int len) { + if (buffer != null && buffer.writableBytes() == 0) { + commitToSink(false, false); + } + if (buffer == null) { + // Request a buffer allocation using the message length as a hint. + buffer = bufferAllocator.allocate(len); + } + return buffer; + } + /** * Indicates whether or not this framer has been closed via a call to either * {@link #close()} or {@link #dispose()}. @@ -360,7 +367,7 @@ private void verifyNotClosed() { } /** OutputStream whose write()s are passed to the framer. */ - private class OutputStreamAdapter extends OutputStream { + private class OutputStreamAdapter extends OutputStream implements ByteBufferBacked { /** * This is slow, don't call it. If you care about write overhead, use a BufferedOutputStream. * Better yet, you can use your own single byte buffer and call @@ -376,13 +383,29 @@ public void write(int b) { public void write(byte[] b, int off, int len) { writeRaw(b, off, len); } + + @Override + public ByteBuffer getWritableBuffer(int size) { + if (size > 0) { + WritableBuffer buf = MessageFramer.this.getWritableBuffer(size); + if (buf instanceof ByteBufferBacked) { + return ((ByteBufferBacked) buf).getWritableBuffer(size); + } + } + return null; + } + + @Override + public void bufferBytesWritten(int size) { + MessageFramer.bufferBytesWritten(buffer, size); + } } /** * Produce a collection of {@link WritableBuffer} instances from the data written to an * {@link OutputStream}. */ - private final class BufferChainOutputStream extends OutputStream { + private final class BufferChainOutputStream extends OutputStream implements ByteBufferBacked { private final List bufferList = new ArrayList<>(); private WritableBuffer current; @@ -403,7 +426,7 @@ public void write(int b) throws IOException { @Override public void write(byte[] b, int off, int len) { - if (current == null) { + if (current == null && len > 0) { // Request len bytes initially from the allocator, it may give us more. current = bufferAllocator.allocate(len); bufferList.add(current); @@ -431,5 +454,35 @@ private int readableBytes() { } return readable; } + + @Override + public ByteBuffer getWritableBuffer(int size) { + if (size > 0) { + if (current == null || current.writableBytes() == 0) { + bufferList.add(current = bufferAllocator.allocate(size)); + } + if (current instanceof ByteBufferBacked) { + return ((ByteBufferBacked) current).getWritableBuffer(size); + } + } + return null; + } + + @Override + public void bufferBytesWritten(int size) { + MessageFramer.bufferBytesWritten(current, size); + } + } + + static void bufferBytesWritten(WritableBuffer buffer, int size) { + try { + ((ByteBufferBacked) buffer).bufferBytesWritten(size); + return; + } catch (ClassCastException cce) { + // fall-through + } catch (NullPointerException npe) { + // fall-through + } + throw new IllegalStateException(); } } diff --git a/netty/src/main/java/io/grpc/netty/NettyWritableBuffer.java b/netty/src/main/java/io/grpc/netty/NettyWritableBuffer.java index b274057d648..db67a7b5b02 100644 --- a/netty/src/main/java/io/grpc/netty/NettyWritableBuffer.java +++ b/netty/src/main/java/io/grpc/netty/NettyWritableBuffer.java @@ -16,13 +16,16 @@ package io.grpc.netty; +import io.grpc.ByteBufferBacked; import io.grpc.internal.WritableBuffer; import io.netty.buffer.ByteBuf; +import java.nio.ByteBuffer; + /** * The {@link WritableBuffer} used by the Netty transport. */ -class NettyWritableBuffer implements WritableBuffer { +class NettyWritableBuffer implements WritableBuffer, ByteBufferBacked { private final ByteBuf bytebuf; @@ -50,6 +53,26 @@ public int readableBytes() { return bytebuf.readableBytes(); } + @Override + public ByteBuffer getWritableBuffer(int size) { + if (bytebuf.writableBytes() >= size && size > 0) { + try { + return bytebuf.internalNioBuffer(bytebuf.writerIndex(), size); + } catch (UnsupportedOperationException uoe) { + // fall-through + } + } + return null; + } + + @Override + public void bufferBytesWritten(int size) { + if (size < 0 || size > bytebuf.writableBytes()) { + throw new IllegalStateException(); + } + bytebuf.writerIndex(bytebuf.writerIndex() + size); + } + @Override public void release() { bytebuf.release(); diff --git a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoInputStream.java b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoInputStream.java index f885a1ece33..7be2a99be1d 100644 --- a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoInputStream.java +++ b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoInputStream.java @@ -19,12 +19,14 @@ import com.google.protobuf.CodedOutputStream; import com.google.protobuf.MessageLite; import com.google.protobuf.Parser; +import io.grpc.ByteBufferBacked; import io.grpc.Drainable; import io.grpc.KnownLength; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.nio.ByteBuffer; import javax.annotation.Nullable; /** @@ -49,7 +51,16 @@ public int drainTo(OutputStream target) throws IOException { int written; if (message != null) { written = message.getSerializedSize(); - message.writeTo(target); + ByteBuffer buffer; + if (target instanceof ByteBufferBacked + && (buffer = ((ByteBufferBacked) target).getWritableBuffer(written)) != null) { + CodedOutputStream coded = CodedOutputStream.newInstance(buffer); + message.writeTo(coded); + coded.flush(); + ((ByteBufferBacked) target).bufferBytesWritten(written); + } else { + message.writeTo(target); + } message = null; } else if (partial != null) { written = (int) ProtoLiteUtils.copy(partial, target); diff --git a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java index ddba5b8d5b1..a2a33e466b8 100644 --- a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java +++ b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java @@ -160,7 +160,7 @@ public T parse(InputStream stream) { if (protoStream.parser() == parser) { try { @SuppressWarnings("unchecked") - T message = (T) ((ProtoInputStream) stream).message(); + T message = (T) protoStream.message(); return message; } catch (IllegalStateException ignored) { // Stream must have been read from, which is a strange state. Since the point of this