Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,8 @@ export class WebGPUBackend extends KernelBackend {
this.uniformDisposalQueue.push(uniformInfo);
}

if (env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED')) {
if (env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE') as
number <= this.commandQueue.length) {
this.submitQueue();
}

Expand Down
28 changes: 12 additions & 16 deletions tfjs-backend-webgpu/src/backend_webgpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ describeWebGPU('backend webgpu cpu forwarding turned on', () => {
});

it('should not allocate GPU memory when CPU forwarding', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true);

const a = tf.tensor2d([2, 4, 6, 8], [2, 2]);
const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]);

Expand Down Expand Up @@ -65,14 +62,13 @@ describeWebGPU('backend webgpu cpu forwarding turned on', () => {

tf.test_util.expectArraysClose(
dData, new Float32Array([9, 12, 15, 19, 26, 33]));
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
});
});

describeWebGPU('backend webgpu', () => {
it('should not leak memory in delayed mode', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', false);
const savedFlag = tf.env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE');
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', 15);
const a = tf.tensor2d([2, 4, 6, 8], [2, 2]);
const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]);

Expand All @@ -96,12 +92,12 @@ describeWebGPU('backend webgpu', () => {

tf.test_util.expectArraysClose(
dData, new Float32Array([9, 12, 15, 19, 26, 33]));
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', savedFlag);
});

it('should not leak memory in immediate mode', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true);
const savedFlag = tf.env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE');
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', 1);
const a = tf.tensor2d([2, 4, 6, 8], [2, 2]);
const b = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [2, 2]);

Expand All @@ -125,12 +121,12 @@ describeWebGPU('backend webgpu', () => {

tf.test_util.expectArraysClose(
dData, new Float32Array([9, 12, 15, 19, 26, 33]));
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', savedFlag);
});

it('should recycle buffers in immediate mode', () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true);
const savedFlag = tf.env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE');
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', 1);
const backend = tf.backend() as WebGPUBackend;
const bufferManager = backend.getBufferManager();
bufferManager.reset();
Expand Down Expand Up @@ -164,12 +160,12 @@ describeWebGPU('backend webgpu', () => {
const usedBuffersAfterSecondMatMul = bufferManager.getNumUsedBuffers();
expect(freeBuffersAfterSecondMatMul - freeBuffersAfterSecondMul).toEqual(0);
expect(usedBuffersAfterSecondMatMul - usedBuffersAfterSecondMul).toEqual(2);
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', savedFlag);
});

it('should not recycle buffers in delayed mode', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', false);
const savedFlag = tf.env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE');
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', 15);
const backend = tf.backend() as WebGPUBackend;
const bufferManager = backend.getBufferManager();
bufferManager.reset();
Expand Down Expand Up @@ -207,7 +203,7 @@ describeWebGPU('backend webgpu', () => {
// Tests happen within a tidy so we need to read a tensor at the end of a
// test in delayed mode in order to force flush the disposal queue.
await c3.data();
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', savedFlag);
});

it('readSync should throw if tensors are on the GPU', async () => {
Expand Down
4 changes: 2 additions & 2 deletions tfjs-backend-webgpu/src/flags_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import {env} from '@tensorflow/tfjs-core';

const ENV = env();

/** Whether we submit commands to the device queue immediately. */
ENV.registerFlag('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', () => true);
/** The batched command encoders size in the device queue. */
ENV.registerFlag('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', () => 15);

/**
* Whether we forward execution to the CPU backend if tensors are small and
Expand Down
12 changes: 6 additions & 6 deletions tfjs-backend-webgpu/src/matmul_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import {describeWebGPU} from './test_util';

describeWebGPU('matmul', () => {
it('it works in delayed mode.', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', false);
const savedFlag = tf.env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE');
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', 15);
const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);

Expand All @@ -34,12 +34,12 @@ describeWebGPU('matmul', () => {
const dData = await d.data();
test_util.expectArraysClose(
dData, new Float32Array([0, 12, 7.5, 0, 6.5, 66]));
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', savedFlag);
});

it('it works in immediate mode.', async () => {
const savedFlag = tf.env().get('WEBGPU_IMMEDIATE_EXECUTION_ENABLED');
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', true);
const savedFlag = tf.env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE');
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', 1);
const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);

Expand All @@ -51,7 +51,7 @@ describeWebGPU('matmul', () => {
const dData = await d.data();
test_util.expectArraysClose(
dData, new Float32Array([0, 12, 7.5, 0, 6.5, 66]));
tf.env().set('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', savedFlag);
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', savedFlag);
});

// tslint:disable-next-line:max-line-length
Expand Down