Skip to content

Commit 4d236f8

Browse files
committed
[webgpu] Add unpack
FEATURE
1 parent f21938f commit 4d236f8

File tree

4 files changed

+88
-2
lines changed

4 files changed

+88
-2
lines changed

tfjs-backend-webgpu/src/kernels/Pack.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ export function pack(
5454

5555
const result = concat({inputs: expandedTensors, backend, attrs: {axis}});
5656

57-
intermediateTensorInfos.forEach(t => backend.disposeData(t));
57+
intermediateTensorInfos.forEach(t => backend.disposeData(t.dataId));
5858

5959
return result;
6060
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {KernelConfig, KernelFunc, TensorInfo, Unpack, UnpackAttrs, UnpackInputs} from '@tensorflow/tfjs-core';
19+
20+
import {WebGPUBackend} from '../backend_webgpu';
21+
22+
import {reshape} from './Reshape';
23+
import {slice} from './Slice';
24+
25+
export function unpack(
26+
args:
27+
{inputs: UnpackInputs, backend: WebGPUBackend, attrs: UnpackAttrs}):
28+
TensorInfo[] {
29+
const {inputs, backend, attrs} = args;
30+
const {value} = inputs;
31+
let {axis} = attrs;
32+
33+
if (axis < 0) {
34+
axis += value.shape.length;
35+
}
36+
37+
const x = value;
38+
const xRank = x.shape.length;
39+
40+
const num = value.shape[axis];
41+
const outShape: number[] = new Array(xRank - 1);
42+
let outIndex = 0;
43+
for (let i = 0; i < xRank; i++) {
44+
if (i !== axis) {
45+
outShape[outIndex++] = x.shape[i];
46+
}
47+
}
48+
49+
const toDispose = [];
50+
51+
const begin = new Array(xRank).fill(0);
52+
const size = x.shape.slice();
53+
size[axis] = 1;
54+
const res: TensorInfo[] = new Array(num);
55+
for (let i = 0; i < res.length; i++) {
56+
begin[axis] = i;
57+
const sliced = slice({inputs: {x}, backend, attrs: {begin, size}});
58+
const reshaped =
59+
reshape({inputs: {x: sliced}, backend, attrs: {shape: outShape}});
60+
res[i] = reshaped;
61+
62+
toDispose.push(sliced);
63+
}
64+
65+
toDispose.forEach(t => backend.disposeData(t.dataId));
66+
return res;
67+
}
68+
69+
export const unpackConfig: KernelConfig = {
70+
kernelName: Unpack,
71+
backendName: 'webgpu',
72+
kernelFunc: unpack as {} as KernelFunc
73+
};

tfjs-backend-webgpu/src/register_all_kernels.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ import {subConfig} from './kernels/Sub';
8484
import {sumConfig} from './kernels/Sum';
8585
import {tanhConfig} from './kernels/Tanh';
8686
import {transposeConfig} from './kernels/Transpose';
87+
import {unpackConfig} from './kernels/Unpack';
8788
import {zerosLikeConfig} from './kernels/ZerosLike';
8889

8990
// List all kernel configs here
@@ -156,6 +157,7 @@ const kernelConfigs: KernelConfig[] = [
156157
sumConfig,
157158
tanhConfig,
158159
transposeConfig,
160+
unpackConfig,
159161
zerosLikeConfig
160162
];
161163

tfjs-backend-webgpu/src/setup_test.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,18 @@ const TEST_FILTERS: TestFilter[] = [
523523
include: 'stack',
524524
excludes: [
525525
'accepts string',
526-
'unstack',
526+
'grad of unstack axis=0', // Remove this when grad is fixed in unstack.
527+
'gradient with clones', // Remove this when grad is fixed in unstack.
528+
'grad of unstack axis=1', // Remove this when grad is fixed in unstack.
529+
]
530+
},
531+
{
532+
include: 'unstack',
533+
excludes: [
534+
'accepts string',
535+
'grad of unstack axis=0',
536+
'gradient with clones',
537+
'grad of unstack axis=1',
527538
]
528539
},
529540
{

0 commit comments

Comments
 (0)