Skip to content

Commit 2a3963f

Browse files
Merge remote-tracking branch 'upstream/master' into tfjs-typescript-4
2 parents 13faefd + 42dee16 commit 2a3963f

32 files changed

+1480
-362
lines changed

e2e/integration_tests/constants.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ export const CONVERT_PREDICT_MODELS = {
3737
'saved_model_v1', 'saved_model_v2', 'saved_model_v2_with_control_flow',
3838
'saved_model_with_conv2d', 'saved_model_with_prelu',
3939
'saved_model_v2_complex64', 'saved_model_v2_with_control_flow_v2',
40-
'saved_model_v2_with_tensorlist_ops', 'saved_model_v1_with_hashtable'
40+
'saved_model_v2_with_tensorlist_ops', 'saved_model_v1_with_hashtable',
41+
'saved_model_v2_with_hashtable'
4142
],
4243
layers_model: ['mobilenet']
4344
};

e2e/integration_tests/convert_predict.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,47 @@ def _create_saved_model_v1_with_hashtable(save_dir):
427427
}
428428
}
429429

430+
def _create_saved_model_v2_with_hashtable(save_dir):
431+
"""Test a TF V2 model with HashTable Ops.
432+
433+
Args:
434+
save_dir: directory name of where the saved model will be stored.
435+
"""
436+
class Table(tf.Module):
437+
def __init__(self):
438+
super(Table, self).__init__()
439+
keys = tf.constant(['a', 'b'])
440+
vals= tf.constant([0, 1])
441+
init = tf.lookup.KeyValueTensorInitializer(keys, vals)
442+
self.table = tf.lookup.StaticHashTable(init, -1)
443+
444+
def initializeTable(self):
445+
@tf.function
446+
def lookup(input):
447+
return self.table.lookup(input)
448+
449+
return lookup
450+
451+
model = Table()
452+
concrete_fn = model.initializeTable().get_concrete_function(
453+
input=tf.TensorSpec([None], tf.string))
454+
455+
tf.saved_model.save(model, save_dir, signatures={"serving_default": concrete_fn})
456+
457+
return {
458+
"async": False,
459+
"inputs": {
460+
"Placeholder:0": {
461+
"value": ["a", "b", "c"], "shape": [3], "dtype": "string"
462+
}
463+
},
464+
"outputs": {
465+
"StatefulPartitionedCall/None_Lookup/LookupTableFindV2:0": {
466+
"value": [0, 1, -1], "shape": [3], "dtype": "int32"
467+
}
468+
}
469+
}
470+
430471
def _layers_mobilenet():
431472
model = tf.keras.applications.MobileNetV2()
432473
model_path = 'mobilenet'
@@ -471,6 +512,8 @@ def main():
471512
'saved_model_v2_with_tensorlist_ops', control_flow_v2=True)
472513
_save_and_convert_model(_create_saved_model_v1_with_hashtable,
473514
'saved_model_v1_with_hashtable')
515+
_save_and_convert_model(_create_saved_model_v2_with_hashtable,
516+
'saved_model_v2_with_hashtable')
474517

475518
_layers_mobilenet()
476519
if __name__ == '__main__':

e2e/yarn.lock

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,11 +1012,6 @@
10121012
dependencies:
10131013
detect-browser "*"
10141014

1015-
"@types/emscripten@~0.0.34":
1016-
version "0.0.34"
1017-
resolved "https://registry.yarnpkg.com/@types/emscripten/-/emscripten-0.0.34.tgz#12b4a344274fb102ff2f6c877b37587bc3e46008"
1018-
integrity sha512-QSb9ojDincskc+uKMI0KXp8e1NALFINCrMlp8VGKGcTSxeEyRTTKyjWw75NYrCZHUsVEEEpr1tYHpbtaC++/sQ==
1019-
10201015
"@types/jasmine@~3.0.0":
10211016
version "3.0.0"
10221017
resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-3.0.0.tgz#9a6b6755a02fcd6baa088a767557709c79728f98"

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
"lint": "tslint -p tsconfig_tslint.json",
7575
"test": "bazel test //:tests",
7676
"test-packages-ci": "yarn generate-cloudbuild-for-packages && ./scripts/run-build.sh",
77+
"nightly-cloudbuild": "NIGHTLY=true yarn generate-cloudbuild-for-packages && gcloud builds submit . --config=cloudbuild_generated.yml --substitutions=_NIGHTLY=true",
7778
"generate-cloudbuild-for-packages": "ts-node -s ./scripts/generate_cloudbuild_for_packages.ts",
7879
"test-generate-cloudbuild": "cd scripts && node --require ts-node/register ../node_modules/jasmine/bin/jasmine.js run generate_cloudbuild_test.ts",
7980
"test-run-flaky": "jasmine run scripts/run_flaky_test.js",
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/**
2+
* @license
3+
* Copyright 2022 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
// TODO: Remove it once webgpu/types is successfully upgraded.
19+
// https://github.com/tensorflow/tfjs/issues/6869
20+
export interface GPUAdapterInfo {
21+
vendor: string;
22+
architecture: string;
23+
}
24+
25+
export class AdapterInfo {
26+
private vendor: string;
27+
28+
constructor(adapterInfo: GPUAdapterInfo) {
29+
if (adapterInfo) {
30+
this.vendor = adapterInfo.vendor;
31+
}
32+
}
33+
34+
isIntel(): boolean {
35+
return this.vendor === 'intel';
36+
}
37+
}

tfjs-backend-webgpu/src/backend_webgpu.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import './flags_webgpu';
1919

2020
import {backend_util, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util} from '@tensorflow/tfjs-core';
2121

22+
import {AdapterInfo, GPUAdapterInfo} from './adapter_info';
2223
import {BufferManager} from './buffer_manager';
2324
import {TextureManager} from './texture_manager';
2425
import * as webgpu_program from './webgpu_program';
@@ -107,6 +108,7 @@ const reshapeDispatch =
107108

108109
export class WebGPUBackend extends KernelBackend {
109110
bufferManager: BufferManager;
111+
adapterInfo: AdapterInfo;
110112
device: GPUDevice;
111113
queue: GPUQueue;
112114
tensorMap: DataStorage<TensorData>;
@@ -135,7 +137,7 @@ export class WebGPUBackend extends KernelBackend {
135137
return WebGPUBackend.nextDataId++;
136138
}
137139

138-
constructor(device: GPUDevice) {
140+
constructor(device: GPUDevice, adapterInfo?: GPUAdapterInfo) {
139141
super();
140142
if (!webgpu_util.isWebGPUSupported()) {
141143
throw new Error('WebGPU is not supported on this device');
@@ -146,6 +148,7 @@ export class WebGPUBackend extends KernelBackend {
146148
this.currentCommandEncoder = null;
147149
this.currentComputePass = null;
148150
this.supportTimeQuery = device.features.has('timestamp-query');
151+
this.adapterInfo = new AdapterInfo(adapterInfo);
149152

150153
this.bufferManager = new BufferManager(this.device);
151154
this.textureManager = new TextureManager(this.device);

tfjs-backend-webgpu/src/base.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ if (isWebGPUSupported()) {
5050
deviceDescriptor.requiredFeatures = ['timestamp-query'];
5151
}
5252
const device: GPUDevice = await adapter.requestDevice(deviceDescriptor);
53-
return new WebGPUBackend(device);
53+
// tslint:disable-next-line:no-any
54+
const adapterInfo = await (adapter as any).requestAdapterInfo();
55+
return new WebGPUBackend(device, adapterInfo);
5456
}, 3 /*priority*/);
5557
}
5658

tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,13 @@ export class Conv2DMMProgram implements WebGPUProgram {
176176
tileInner: number;
177177
innerElementSize: number;
178178
isVec4?: boolean;
179+
private sequentialAccessByThreads: boolean;
179180

180181
constructor(
181182
convInfo: backend_util.Conv2DInfo, dimAOuter: number, dimBOuter: number,
182183
dimInner: number, addBias = false,
183184
activation: backend_util.Activation = null,
184-
hasPreluActivationWeights = false) {
185+
hasPreluActivationWeights = false, sequentialAccessByThreads = false) {
185186
this.outputShape = convInfo.outShape;
186187
this.isChannelsLast = convInfo.dataFormat === 'channelsLast';
187188
this.isVec4 =
@@ -229,6 +230,7 @@ export class Conv2DMMProgram implements WebGPUProgram {
229230
}
230231
}
231232

233+
this.sequentialAccessByThreads = sequentialAccessByThreads;
232234
this.addBias = addBias;
233235
this.activation = activation;
234236
this.hasPreluActivationWeights = hasPreluActivationWeights;
@@ -244,7 +246,8 @@ export class Conv2DMMProgram implements WebGPUProgram {
244246

245247
this.shaderKey = `conv2DMM_${this.elementsPerThread}_${this.activation}}_${
246248
this.fitAOuter}_${this.fitBOuter}_${this.fitInner}_${this.isVec4}_${
247-
this.innerElementSize}_${this.isChannelsLast}`;
249+
this.innerElementSize}_${this.isChannelsLast}_${
250+
this.sequentialAccessByThreads}`;
248251
}
249252

250253
getUserCode(): string {
@@ -254,7 +257,7 @@ export class Conv2DMMProgram implements WebGPUProgram {
254257
this.tileInner) :
255258
makeMatMulPackedSource(
256259
this.elementsPerThread, this.workGroupSize, !this.isChannelsLast,
257-
this.tileInner);
260+
this.tileInner, false, null, this.sequentialAccessByThreads);
258261
const elementsSize =
259262
this.isVec4 ? [this.innerElementSize, 4, 4] : [1, 1, 1];
260263
const userCode = `

tfjs-backend-webgpu/src/kernels/BatchMatMul_impl.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,13 @@ export function batchMatMulImpl({
184184
activation, preluActivationWeights);
185185
break;
186186
case MatMulProgramType.MatMulPackedProgram:
187+
// Experiments show that sequential access is more friendly for Intel
188+
// GPUs.
189+
const sequentialAccessByThreads = backend.adapterInfo.isIntel();
187190
program = new MatMulPackedProgram(
188191
a3dShape, outputShape, batchAEqualOne, batchBEqualOne, transposeA,
189-
transposeB, bias, activation, preluActivationWeights);
192+
transposeB, bias, activation, preluActivationWeights,
193+
sequentialAccessByThreads);
190194
break;
191195
default:
192196
throw new Error(`Unsupported MatMulProgramType ${matmulProgramType}.`);

tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,11 @@ export function conv2DImpl({
229229
{type: 'int32', data: [dimAOuter]}, {type: 'int32', data: [dimBOuter]},
230230
{type: 'int32', data: [dimInner]});
231231

232+
// Experiments show that sequential access is more friendly for Intel GPUs.
233+
const sequentialAccessByThreads = backend.adapterInfo.isIntel();
232234
program = new Conv2DMMProgram(
233235
convInfo, dimAOuter, dimBOuter, dimInner, hasBias, activation,
234-
hasPreluActivationWeights);
236+
hasPreluActivationWeights, sequentialAccessByThreads);
235237
}
236238

237239
const intermediates: TensorInfo[] = [];

0 commit comments

Comments
 (0)