Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 53 additions & 19 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export class WebGPUBackend extends KernelBackend {
tensorMap: DataStorage<TensorBufferInfo>;
fromPixelProgram: FromPixelsProgram;
supportTimeQuery: boolean;
computePass: GPUComputePassEncoder;

private static nextDataId = 0;
private nextDataId(): number {
Expand All @@ -95,6 +96,7 @@ export class WebGPUBackend extends KernelBackend {
private activeTimers: TimerNode[];
private uploadWaitMs = 0;
private downloadWaitMs = 0;
private dispatchNumberInEncoder = 0;
private cpuBackend: KernelBackend;
private querySet: GPUQuerySet;

Expand Down Expand Up @@ -259,6 +261,11 @@ export class WebGPUBackend extends KernelBackend {
// Data is on the CPU.
return info.values;
}
if (this.dispatchNumberInEncoder !== 0) {
this.dispatchNumberInEncoder = 0;
this.computePass.endPass();
this.submitQueue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this submitQueue() be removed? So we could get the same behavior as old delayed mode when WEBGPU_BATCH_DISPATCHING_CALLS set to negative value.

}
const staging = this.acquireBuffer(
info.bufferInfo.byteSize,
GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ);
Expand Down Expand Up @@ -474,26 +481,42 @@ export class WebGPUBackend extends KernelBackend {
const bg = webgpu_program.makeBindGroup(
this.device, bindGroupLayout, inputs.map(t => this.tensorToBinding(t)),
this.tensorToBinding(output), uniforms);

const encoder = this.device.createCommandEncoder();
const pass = encoder.beginComputePass();
if (shouldTimeProgram) {
if (this.supportTimeQuery) {
pass.writeTimestamp(this.querySet, 0);
if (this.dispatchNumberInEncoder === 0) {
const encoder = this.device.createCommandEncoder();
this.computePass = encoder.beginComputePass();
if (shouldTimeProgram) {
if (this.supportTimeQuery) {
this.computePass.writeTimestamp(this.querySet, 0);
}
}
}
pass.setPipeline(pipeline);
pass.setBindGroup(0, bg);
pass.dispatch(
program.dispatch[0], program.dispatch[1], program.dispatch[2]);
if (shouldTimeProgram) {
if (this.supportTimeQuery) {
pass.writeTimestamp(this.querySet, 1);
this.computePass.setPipeline(pipeline);
this.computePass.setBindGroup(0, bg);
this.computePass.dispatch(
program.dispatch[0], program.dispatch[1], program.dispatch[2]);
if (shouldTimeProgram) {
if (this.supportTimeQuery) {
this.computePass.writeTimestamp(this.querySet, 1);
}
}
this.dispatchNumberInEncoder++;
this.commandQueue.push(encoder);
} else {
if (shouldTimeProgram) {
if (this.supportTimeQuery) {
this.computePass.writeTimestamp(this.querySet, 0);
}
}
this.computePass.setPipeline(pipeline);
this.computePass.setBindGroup(0, bg);
this.computePass.dispatch(
program.dispatch[0], program.dispatch[1], program.dispatch[2]);
if (shouldTimeProgram) {
if (this.supportTimeQuery) {
this.computePass.writeTimestamp(this.querySet, 1);
}
}
this.dispatchNumberInEncoder++;
}
pass.endPass();

this.commandQueue.push(encoder);

inputs.forEach(input => {
this.commandQueueOwnedIds.add(input.dataId);
Expand All @@ -509,7 +532,10 @@ export class WebGPUBackend extends KernelBackend {
this.uniformDisposalQueue.push(uniformInfo);
}

if (env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED')) {
if (env().get('WEBGPU_BATCH_DISPATCHING_CALLS') as number ===
this.dispatchNumberInEncoder) {
this.dispatchNumberInEncoder = 0;
this.computePass.endPass();
this.submitQueue();
}

Expand All @@ -525,6 +551,12 @@ export class WebGPUBackend extends KernelBackend {
}

async getTimeFromQuerySet(querySet: GPUQuerySet) {
if (this.dispatchNumberInEncoder !== 0) {
this.dispatchNumberInEncoder = 0;
this.computePass.endPass();
this.submitQueue();
}

const queryBuffer = this.acquireBuffer(
16, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE);
const dst = this.acquireBuffer(
Expand Down Expand Up @@ -597,7 +629,9 @@ export class WebGPUBackend extends KernelBackend {
}

numDataIds() {
return this.tensorMap.numDataIds();
return this.tensorMap.numDataIds() +
(this.cpuBackend ? this.cpuBackend.numDataIds() : 0) -
this.tensorDisposalQueue.length;
}

dispose() {
Expand Down
30 changes: 15 additions & 15 deletions tfjs-backend-webgpu/src/backend_webgpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ describeWebGPU('backend webgpu cpu forwarding turned on', () => {
});

it('should not allocate GPU memory when CPU forwarding', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true);
const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS');
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 1);

const a = tf.tensor2d([2, 4, 6, 8], [2, 2]);
const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]);
Expand Down Expand Up @@ -65,14 +65,14 @@ describeWebGPU('backend webgpu cpu forwarding turned on', () => {

tf.test_util.expectArraysClose(
dData, new Float32Array([9, 12, 15, 19, 26, 33]));
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag);
});
});

describeWebGPU('backend webgpu', () => {
it('should not leak memory in delayed mode', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', false);
const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS');
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 4);
const a = tf.tensor2d([2, 4, 6, 8], [2, 2]);
const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]);

Expand All @@ -96,12 +96,12 @@ describeWebGPU('backend webgpu', () => {

tf.test_util.expectArraysClose(
dData, new Float32Array([9, 12, 15, 19, 26, 33]));
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag);
});

it('should not leak memory in immediate mode', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true);
const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS');
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 1);
const a = tf.tensor2d([2, 4, 6, 8], [2, 2]);
const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]);

Expand All @@ -125,12 +125,12 @@ describeWebGPU('backend webgpu', () => {

tf.test_util.expectArraysClose(
dData, new Float32Array([9, 12, 15, 19, 26, 33]));
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag);
});

it('should recycle buffers in immediate mode', () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true);
const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS');
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 1);
const backend = tf.backend() as WebGPUBackend;
const bufferManager = backend.getBufferManager();
bufferManager.reset();
Expand Down Expand Up @@ -164,12 +164,12 @@ describeWebGPU('backend webgpu', () => {
const usedBuffersAfterSecondMatMul = bufferManager.getNumUsedBuffers();
expect(freeBuffersAfterSecondMatMul - freeBuffersAfterSecondMul).toEqual(0);
expect(usedBuffersAfterSecondMatMul - usedBuffersAfterSecondMul).toEqual(2);
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag);
});

it('should not recycle buffers in delayed mode', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', false);
const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS');
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 4);
const backend = tf.backend() as WebGPUBackend;
const bufferManager = backend.getBufferManager();
bufferManager.reset();
Expand Down Expand Up @@ -207,7 +207,7 @@ describeWebGPU('backend webgpu', () => {
// Tests happen within a tidy so we need to read a tensor at the end of a
// test in delayed mode in order to force flush the disposal queue.
await c3.data();
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag);
});

it('readSync should throw if tensors are on the GPU', async () => {
Expand Down
8 changes: 6 additions & 2 deletions tfjs-backend-webgpu/src/flags_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ import {env} from '@tensorflow/tfjs-core';

const ENV = env();

/** Whether we submit commands to the device queue immediately. */
ENV.registerFlag('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', () => true);
/**
* Batch several dispaching calls into one encoder. The negative value means
* batching all dispatching calls into one command encoder. Other value means
* the dispatching number in one command encoder.
*/
ENV.registerFlag('WEBGPU_BATCH_DISPATCHING_CALLS', () => 1);

/**
* Whether we forward execution to the CPU backend if tensors are small and
Expand Down
12 changes: 6 additions & 6 deletions tfjs-backend-webgpu/src/matmul_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import {describeWebGPU} from './test_util';

describeWebGPU('matmul', () => {
it('it works in delayed mode.', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', false);
const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS');
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 4);
const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);

Expand All @@ -34,12 +34,12 @@ describeWebGPU('matmul', () => {
const dData = await d.data();
test_util.expectArraysClose(
dData, new Float32Array([0, 12, 7.5, 0, 6.5, 66]));
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag);
});

it('it works in immediate mode.', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true);
const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS');
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 1);
const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);

Expand All @@ -51,7 +51,7 @@ describeWebGPU('matmul', () => {
const dData = await d.data();
test_util.expectArraysClose(
dData, new Float32Array([0, 12, 7.5, 0, 6.5, 66]));
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag);
});

// tslint:disable-next-line:max-line-length
Expand Down