diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 336edd9f3a6..f62783371b5 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -77,6 +77,7 @@ export class WebGPUBackend extends KernelBackend { tensorMap: DataStorage; fromPixelProgram: FromPixelsProgram; supportTimeQuery: boolean; + computePass: GPUComputePassEncoder; private static nextDataId = 0; private nextDataId(): number { @@ -95,6 +96,7 @@ export class WebGPUBackend extends KernelBackend { private activeTimers: TimerNode[]; private uploadWaitMs = 0; private downloadWaitMs = 0; + private dispatchNumberInEncoder = 0; private cpuBackend: KernelBackend; private querySet: GPUQuerySet; @@ -259,6 +261,11 @@ export class WebGPUBackend extends KernelBackend { // Data is on the CPU. return info.values; } + if (this.dispatchNumberInEncoder !== 0) { + this.dispatchNumberInEncoder = 0; + this.computePass.endPass(); + this.submitQueue(); + } const staging = this.acquireBuffer( info.bufferInfo.byteSize, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); @@ -474,26 +481,42 @@ export class WebGPUBackend extends KernelBackend { const bg = webgpu_program.makeBindGroup( this.device, bindGroupLayout, inputs.map(t => this.tensorToBinding(t)), this.tensorToBinding(output), uniforms); - - const encoder = this.device.createCommandEncoder(); - const pass = encoder.beginComputePass(); - if (shouldTimeProgram) { - if (this.supportTimeQuery) { - pass.writeTimestamp(this.querySet, 0); + if (this.dispatchNumberInEncoder === 0) { + const encoder = this.device.createCommandEncoder(); + this.computePass = encoder.beginComputePass(); + if (shouldTimeProgram) { + if (this.supportTimeQuery) { + this.computePass.writeTimestamp(this.querySet, 0); + } } - } - pass.setPipeline(pipeline); - pass.setBindGroup(0, bg); - pass.dispatch( - program.dispatch[0], program.dispatch[1], program.dispatch[2]); - if (shouldTimeProgram) { - if (this.supportTimeQuery) { - pass.writeTimestamp(this.querySet, 1); + this.computePass.setPipeline(pipeline); + this.computePass.setBindGroup(0, bg); + this.computePass.dispatch( + program.dispatch[0], program.dispatch[1], program.dispatch[2]); + if (shouldTimeProgram) { + if (this.supportTimeQuery) { + this.computePass.writeTimestamp(this.querySet, 1); + } } + this.dispatchNumberInEncoder++; + this.commandQueue.push(encoder); + } else { + if (shouldTimeProgram) { + if (this.supportTimeQuery) { + this.computePass.writeTimestamp(this.querySet, 0); + } + } + this.computePass.setPipeline(pipeline); + this.computePass.setBindGroup(0, bg); + this.computePass.dispatch( + program.dispatch[0], program.dispatch[1], program.dispatch[2]); + if (shouldTimeProgram) { + if (this.supportTimeQuery) { + this.computePass.writeTimestamp(this.querySet, 1); + } + } + this.dispatchNumberInEncoder++; } - pass.endPass(); - - this.commandQueue.push(encoder); inputs.forEach(input => { this.commandQueueOwnedIds.add(input.dataId); @@ -509,7 +532,10 @@ export class WebGPUBackend extends KernelBackend { this.uniformDisposalQueue.push(uniformInfo); } - if (env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED')) { + if (env().get('WEBGPU_BATCH_DISPATCHING_CALLS') as number === + this.dispatchNumberInEncoder) { + this.dispatchNumberInEncoder = 0; + this.computePass.endPass(); this.submitQueue(); } @@ -525,6 +551,12 @@ export class WebGPUBackend extends KernelBackend { } async getTimeFromQuerySet(querySet: GPUQuerySet) { + if (this.dispatchNumberInEncoder !== 0) { + this.dispatchNumberInEncoder = 0; + this.computePass.endPass(); + this.submitQueue(); + } + const queryBuffer = this.acquireBuffer( 16, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); const dst = this.acquireBuffer( @@ -597,7 +629,9 @@ export class WebGPUBackend extends KernelBackend { } numDataIds() { - return this.tensorMap.numDataIds(); + return this.tensorMap.numDataIds() + + (this.cpuBackend ? this.cpuBackend.numDataIds() : 0) - + this.tensorDisposalQueue.length; } dispose() { diff --git a/tfjs-backend-webgpu/src/backend_webgpu_test.ts b/tfjs-backend-webgpu/src/backend_webgpu_test.ts index ece7f772a04..433a111ce49 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu_test.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu_test.ts @@ -35,8 +35,8 @@ describeWebGPU('backend webgpu cpu forwarding turned on', () => { }); it('should not allocate GPU memory when CPU forwarding', async () => { - const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED'); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true); + const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS'); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 1); const a = tf.tensor2d([2, 4, 6, 8], [2, 2]); const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]); @@ -65,14 +65,14 @@ describeWebGPU('backend webgpu cpu forwarding turned on', () => { tf.test_util.expectArraysClose( dData, new Float32Array([9, 12, 15, 19, 26, 33])); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag); }); }); describeWebGPU('backend webgpu', () => { it('should not leak memory in delayed mode', async () => { - const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED'); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', false); + const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS'); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 4); const a = tf.tensor2d([2, 4, 6, 8], [2, 2]); const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]); @@ -96,12 +96,12 @@ describeWebGPU('backend webgpu', () => { tf.test_util.expectArraysClose( dData, new Float32Array([9, 12, 15, 19, 26, 33])); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag); }); it('should not leak memory in immediate mode', async () => { - const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED'); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true); + const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS'); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 1); const a = tf.tensor2d([2, 4, 6, 8], [2, 2]); const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]); @@ -125,12 +125,12 @@ describeWebGPU('backend webgpu', () => { tf.test_util.expectArraysClose( dData, new Float32Array([9, 12, 15, 19, 26, 33])); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag); }); it('should recycle buffers in immediate mode', () => { - const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED'); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true); + const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS'); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 1); const backend = tf.backend() as WebGPUBackend; const bufferManager = backend.getBufferManager(); bufferManager.reset(); @@ -164,12 +164,12 @@ describeWebGPU('backend webgpu', () => { const usedBuffersAfterSecondMatMul = bufferManager.getNumUsedBuffers(); expect(freeBuffersAfterSecondMatMul - freeBuffersAfterSecondMul).toEqual(0); expect(usedBuffersAfterSecondMatMul - usedBuffersAfterSecondMul).toEqual(2); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag); }); it('should not recycle buffers in delayed mode', async () => { - const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED'); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', false); + const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS'); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 4); const backend = tf.backend() as WebGPUBackend; const bufferManager = backend.getBufferManager(); bufferManager.reset(); @@ -207,7 +207,7 @@ describeWebGPU('backend webgpu', () => { // Tests happen within a tidy so we need to read a tensor at the end of a // test in delayed mode in order to force flush the disposal queue. await c3.data(); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag); }); it('readSync should throw if tensors are on the GPU', async () => { diff --git a/tfjs-backend-webgpu/src/flags_webgpu.ts b/tfjs-backend-webgpu/src/flags_webgpu.ts index 697877d44ff..a9da5a3b7ee 100644 --- a/tfjs-backend-webgpu/src/flags_webgpu.ts +++ b/tfjs-backend-webgpu/src/flags_webgpu.ts @@ -19,8 +19,12 @@ import {env} from '@tensorflow/tfjs-core'; const ENV = env(); -/** Whether we submit commands to the device queue immediately. */ -ENV.registerFlag('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', () => true); +/** + * Batch several dispaching calls into one encoder. The negative value means + * batching all dispatching calls into one command encoder. Other value means + * the dispatching number in one command encoder. + */ +ENV.registerFlag('WEBGPU_BATCH_DISPATCHING_CALLS', () => 1); /** * Whether we forward execution to the CPU backend if tensors are small and diff --git a/tfjs-backend-webgpu/src/matmul_test.ts b/tfjs-backend-webgpu/src/matmul_test.ts index 129d4d20540..e9527638af4 100644 --- a/tfjs-backend-webgpu/src/matmul_test.ts +++ b/tfjs-backend-webgpu/src/matmul_test.ts @@ -21,8 +21,8 @@ import {describeWebGPU} from './test_util'; describeWebGPU('matmul', () => { it('it works in delayed mode.', async () => { - const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED'); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', false); + const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS'); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 4); const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); @@ -34,12 +34,12 @@ describeWebGPU('matmul', () => { const dData = await d.data(); test_util.expectArraysClose( dData, new Float32Array([0, 12, 7.5, 0, 6.5, 66])); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag); }); it('it works in immediate mode.', async () => { - const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED'); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true); + const savedFlag = tf.env().get('WEBGPU_BATCH_DISPATCHING_CALLS'); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', 1); const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); @@ -51,7 +51,7 @@ describeWebGPU('matmul', () => { const dData = await d.data(); test_util.expectArraysClose( dData, new Float32Array([0, 12, 7.5, 0, 6.5, 66])); - tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag); + tf.env().set('WEBGPU_BATCH_DISPATCHING_CALLS', savedFlag); }); // tslint:disable-next-line:max-line-length