Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions api/src/main/java/io/grpc/ByteBufferBacked.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc;

import java.nio.ByteBuffer;

/**
* Extension to an {@link java.io.OutputStream} or alike by adding methods that
* allow writing directly to an underlying {@link ByteBuffer}
*/
public interface ByteBufferBacked {

/**
* If available, returns a {@link ByteBuffer} backing this writable
* object whose position corresponds to this object's current
* writing position and with at least {@code size} remaining bytes.
*
* @param size minimum required size
* @return null if not supported or writable buffer of insufficient size
*/
ByteBuffer getWritableBuffer(int size);

/**
* This must be called to notify that data has been written to the
* buffer previously returned from {@link #getWritableBuffer(int)},
* prior to calling any other methods.
*
* @param written number of bytes written
*/
void bufferBytesWritten(int written);
}
79 changes: 66 additions & 13 deletions core/src/main/java/io/grpc/internal/MessageFramer.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static java.lang.Math.min;

import com.google.common.io.ByteStreams;
import io.grpc.ByteBufferBacked;
import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.Drainable;
Expand Down Expand Up @@ -277,17 +278,12 @@ private static int writeToOutputStream(InputStream message, OutputStream outputS
}
}

private void writeRaw(byte[] b, int off, int len) {
// package-private to avoid synthetic access from OutputStreamAdapter
void writeRaw(byte[] b, int off, int len) {
while (len > 0) {
if (buffer != null && buffer.writableBytes() == 0) {
commitToSink(false, false);
}
if (buffer == null) {
// Request a buffer allocation using the message length as a hint.
buffer = bufferAllocator.allocate(len);
}
int toWrite = min(len, buffer.writableBytes());
buffer.write(b, off, toWrite);
WritableBuffer buf = getWritableBuffer(len);
int toWrite = min(len, buf.writableBytes());
buf.write(b, off, toWrite);
off += toWrite;
len -= toWrite;
}
Expand All @@ -303,6 +299,17 @@ public void flush() {
}
}

WritableBuffer getWritableBuffer(int len) {
if (buffer != null && buffer.writableBytes() == 0) {
commitToSink(false, false);
}
if (buffer == null) {
// Request a buffer allocation using the message length as a hint.
buffer = bufferAllocator.allocate(len);
}
return buffer;
}

/**
* Indicates whether or not this framer has been closed via a call to either
* {@link #close()} or {@link #dispose()}.
Expand Down Expand Up @@ -360,7 +367,7 @@ private void verifyNotClosed() {
}

/** OutputStream whose write()s are passed to the framer. */
private class OutputStreamAdapter extends OutputStream {
private class OutputStreamAdapter extends OutputStream implements ByteBufferBacked {
/**
* This is slow, don't call it. If you care about write overhead, use a BufferedOutputStream.
* Better yet, you can use your own single byte buffer and call
Expand All @@ -376,13 +383,29 @@ public void write(int b) {
public void write(byte[] b, int off, int len) {
writeRaw(b, off, len);
}

@Override
public ByteBuffer getWritableBuffer(int size) {
if (size > 0) {
WritableBuffer buf = MessageFramer.this.getWritableBuffer(size);
if (buf instanceof ByteBufferBacked) {
return ((ByteBufferBacked) buf).getWritableBuffer(size);
}
}
return null;
}

@Override
public void bufferBytesWritten(int size) {
MessageFramer.bufferBytesWritten(buffer, size);
}
}

/**
* Produce a collection of {@link WritableBuffer} instances from the data written to an
* {@link OutputStream}.
*/
private final class BufferChainOutputStream extends OutputStream {
private final class BufferChainOutputStream extends OutputStream implements ByteBufferBacked {
private final List<WritableBuffer> bufferList = new ArrayList<>();
private WritableBuffer current;

Expand All @@ -403,7 +426,7 @@ public void write(int b) throws IOException {

@Override
public void write(byte[] b, int off, int len) {
if (current == null) {
if (current == null && len > 0) {
// Request len bytes initially from the allocator, it may give us more.
current = bufferAllocator.allocate(len);
bufferList.add(current);
Expand Down Expand Up @@ -431,5 +454,35 @@ private int readableBytes() {
}
return readable;
}

@Override
public ByteBuffer getWritableBuffer(int size) {
if (size > 0) {
if (current == null || current.writableBytes() == 0) {
bufferList.add(current = bufferAllocator.allocate(size));
}
if (current instanceof ByteBufferBacked) {
return ((ByteBufferBacked) current).getWritableBuffer(size);
}
}
return null;
}

@Override
public void bufferBytesWritten(int size) {
MessageFramer.bufferBytesWritten(current, size);
}
}

static void bufferBytesWritten(WritableBuffer buffer, int size) {
try {
((ByteBufferBacked) buffer).bufferBytesWritten(size);
return;
} catch (ClassCastException cce) {
// fall-through
} catch (NullPointerException npe) {
// fall-through
}
throw new IllegalStateException();
}
}
25 changes: 24 additions & 1 deletion netty/src/main/java/io/grpc/netty/NettyWritableBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

package io.grpc.netty;

import io.grpc.ByteBufferBacked;
import io.grpc.internal.WritableBuffer;
import io.netty.buffer.ByteBuf;

import java.nio.ByteBuffer;

/**
* The {@link WritableBuffer} used by the Netty transport.
*/
class NettyWritableBuffer implements WritableBuffer {
class NettyWritableBuffer implements WritableBuffer, ByteBufferBacked {

private final ByteBuf bytebuf;

Expand Down Expand Up @@ -50,6 +53,26 @@ public int readableBytes() {
return bytebuf.readableBytes();
}

@Override
public ByteBuffer getWritableBuffer(int size) {
if (bytebuf.writableBytes() >= size && size > 0) {
try {
return bytebuf.internalNioBuffer(bytebuf.writerIndex(), size);
} catch (UnsupportedOperationException uoe) {
// fall-through
}
}
return null;
}

@Override
public void bufferBytesWritten(int size) {
if (size < 0 || size > bytebuf.writableBytes()) {
throw new IllegalStateException();
}
bytebuf.writerIndex(bytebuf.writerIndex() + size);
}

@Override
public void release() {
bytebuf.release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.MessageLite;
import com.google.protobuf.Parser;
import io.grpc.ByteBufferBacked;
import io.grpc.Drainable;
import io.grpc.KnownLength;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import javax.annotation.Nullable;

/**
Expand All @@ -49,7 +51,16 @@ public int drainTo(OutputStream target) throws IOException {
int written;
if (message != null) {
written = message.getSerializedSize();
message.writeTo(target);
ByteBuffer buffer;
if (target instanceof ByteBufferBacked
&& (buffer = ((ByteBufferBacked) target).getWritableBuffer(written)) != null) {
CodedOutputStream coded = CodedOutputStream.newInstance(buffer);
message.writeTo(coded);
coded.flush();
((ByteBufferBacked) target).bufferBytesWritten(written);
} else {
message.writeTo(target);
}
message = null;
} else if (partial != null) {
written = (int) ProtoLiteUtils.copy(partial, target);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public T parse(InputStream stream) {
if (protoStream.parser() == parser) {
try {
@SuppressWarnings("unchecked")
T message = (T) ((ProtoInputStream) stream).message();
T message = (T) protoStream.message();
return message;
} catch (IllegalStateException ignored) {
// Stream must have been read from, which is a strange state. Since the point of this
Expand Down