Skip to content
3 changes: 2 additions & 1 deletion tfjs-core/src/gradients/FusedBatchNorm_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ export const fusedBatchNormGradConfig: GradConfig = {
return reshape(
mul(mul(dy,
tile(
reshape(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]),
reshape(oneOverSqrtVariance,
[1, 1, 1, 1, mean.shape[0]]),
tileShape)),
scaleValue),
x.shape);
Expand Down
8 changes: 4 additions & 4 deletions tfjs-core/src/ops/batchnorm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import {ENGINE} from '../engine';
import {FusedBatchNorm, FusedBatchNormAttrs, FusedBatchNormInputs} from '../kernel_names';
import {NamedAttrMap} from '../kernel_registry';
import {Tensor, Tensor1D, Tensor4D} from '../tensor';
import {Tensor, Tensor1D, Tensor5D} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {Rank, TensorLike} from '../types';
import * as util from '../util';

import {xAs4D} from './batchnorm_util';
import {xAs5D} from './batchnorm_util';
import {op} from './operation';
import {reshape} from './reshape';

Expand Down Expand Up @@ -88,10 +88,10 @@ function batchNorm_<R extends Rank>(
() => 'Batch normalization gradient requires mean and scale to have ' +
'equal ranks.');

const x4D: Tensor4D = xAs4D($x);
const x5D: Tensor5D = xAs5D($x);

const inputs: FusedBatchNormInputs = {
x: x4D,
x: x5D,
scale: $scale,
offset: $offset,
mean: $mean,
Expand Down
79 changes: 79 additions & 0 deletions tfjs-core/src/ops/batchnorm5d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {Tensor1D, Tensor5D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {batchNorm} from './batchnorm';
import {op} from './operation';

/**
* Batch normalization, strictly for 5D. For the more relaxed version, see
* `tf.batchNorm`.
*
* @param x The input Tensor.
* @param mean A mean Tensor.
* @param variance A variance Tensor.
* @param offset An offset Tensor.
* @param scale A scale Tensor.
* @param varianceEpsilon A small float number to avoid dividing by 0.
*/
function batchNorm5d_(
x: Tensor5D|TensorLike, mean: Tensor5D|Tensor1D|TensorLike,
variance: Tensor5D|Tensor1D|TensorLike,
offset?: Tensor5D|Tensor1D|TensorLike, scale?: Tensor5D|Tensor1D|TensorLike,
varianceEpsilon?: number): Tensor5D {
const $x = convertToTensor(x, 'x', 'batchNorm');
const $mean = convertToTensor(mean, 'mean', 'batchNorm');
const $variance = convertToTensor(variance, 'variance', 'batchNorm');
let $scale: Tensor5D|Tensor1D;
if (scale != null) {
$scale = convertToTensor(scale, 'scale', 'batchNorm');
}
let $offset: Tensor5D|Tensor1D;
if (offset != null) {
$offset = convertToTensor(offset, 'offset', 'batchNorm');
}
util.assert(
$x.rank === 5,
() => `Error in batchNorm5D: x must be rank 5 but got rank ` +
`${$x.rank}.`);
util.assert(
$mean.rank === 5 || $mean.rank === 1,
() => `Error in batchNorm5D: mean must be rank 5 or rank 1 but ` +
`got rank ${$mean.rank}.`);
util.assert(
$variance.rank === 5 || $variance.rank === 1,
() => `Error in batchNorm5D: variance must be rank 5 or rank 1 ` +
`but got rank ${$variance.rank}.`);
if ($scale != null) {
util.assert(
$scale.rank === 5 || $scale.rank === 1,
() => `Error in batchNorm5D: scale must be rank 5 or rank 1 ` +
`but got rank ${$scale.rank}.`);
}
if ($offset != null) {
util.assert(
$offset.rank === 5 || $offset.rank === 1,
() => `Error in batchNorm5D: offset must be rank 5 or rank 1 ` +
`but got rank ${$offset.rank}.`);
}
return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
}

export const batchNorm5d = op({batchNorm5d_});
218 changes: 218 additions & 0 deletions tfjs-core/src/ops/batchnorm_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,224 @@ import * as tf from '../index';
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {expectArraysClose} from '../test_util';

describe('batchNorm5D', () => {
it('simple batchnorm5D, no offset or scale, 2x1x1x1x2', async () => {
const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]);
const meanT = tf.tensor1d([1, 2]);
const varianceT = tf.tensor1d([2, 3]);
const varianceEpsilon = .001;

const result = tf.batchNorm5d(
xT, meanT, varianceT, undefined, undefined, varianceEpsilon);

const x = await xT.array();
const mean = await meanT.array();
const variance = await varianceT.array();
expectArraysClose(await result.data(), [
(x[0][0][0][0][0] - mean[0]) * 1 /
Math.sqrt(variance[0] + varianceEpsilon),
(x[0][0][0][0][1] - mean[1]) * 1 /
Math.sqrt(variance[1] + varianceEpsilon),
(x[1][0][0][0][0] - mean[0]) * 1 /
Math.sqrt(variance[0] + varianceEpsilon),
(x[1][0][0][0][1] - mean[1]) * 1 /
Math.sqrt(variance[1] + varianceEpsilon)
]);
});

it('simple batchnorm5D, no offset, 2x1x1x1x2', async () => {
const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]);
const meanT = tf.tensor1d([1, 2]);
const varianceT = tf.tensor1d([2, 3]);
const scaleT = tf.tensor1d([4, 5]);
const varianceEpsilon = .001;

const result = tf.batchNorm5d(
xT, meanT, varianceT, undefined, scaleT, varianceEpsilon);
const x = await xT.buffer();
const mean = await meanT.buffer();
const variance = await varianceT.buffer();
const scale = await scaleT.buffer();

expectArraysClose(await result.data(), [
(x.get(0, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) /
Math.sqrt(variance.get(0) + varianceEpsilon),
(x.get(0, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) /
Math.sqrt(variance.get(1) + varianceEpsilon),
(x.get(1, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) /
Math.sqrt(variance.get(0) + varianceEpsilon),
(x.get(1, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) /
Math.sqrt(variance.get(1) + varianceEpsilon)
]);
});

it('simple batchnorm5D, no scale, 2x1x1x1x2', async () => {
const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]);
const meanT = tf.tensor1d([1, 2]);
const varianceT = tf.tensor1d([2, 3]);
const offsetT = tf.tensor1d([4, 5]);

const varianceEpsilon = .001;

const result = tf.batchNorm5d(
xT, meanT, varianceT, offsetT, undefined, varianceEpsilon);
const x = await xT.buffer();
const mean = await meanT.buffer();
const variance = await varianceT.buffer();
const offset = await offsetT.buffer();

expectArraysClose(await result.data(), [
offset.get(0) +
(x.get(0, 0, 0, 0, 0) - mean.get(0)) * 1 /
Math.sqrt(variance.get(0) + varianceEpsilon),
offset.get(1) +
(x.get(0, 0, 0, 0, 1) - mean.get(1)) * 1 /
Math.sqrt(variance.get(1) + varianceEpsilon),
offset.get(0) +
(x.get(1, 0, 0, 0, 0) - mean.get(0)) * 1 /
Math.sqrt(variance.get(0) + varianceEpsilon),
offset.get(1) +
(x.get(1, 0, 0, 0, 1) - mean.get(1)) * 1 /
Math.sqrt(variance.get(1) + varianceEpsilon)
]);
});

it('simple batchnorm5D, 2x1x1x1x2', async () => {
const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]);
const meanT = tf.tensor1d([1, 2]);
const varianceT = tf.tensor1d([2, 3]);
const offsetT = tf.tensor1d([3, 4]);
const scaleT = tf.tensor1d([4, 5]);

const varianceEpsilon = .001;

const result =
tf.batchNorm5d(xT, meanT, varianceT, offsetT, scaleT, varianceEpsilon);
const x = await xT.buffer();
const mean = await meanT.buffer();
const variance = await varianceT.buffer();
const scale = await scaleT.buffer();
const offset = await offsetT.buffer();

expectArraysClose(await result.data(), [
offset.get(0) +
(x.get(0, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) /
Math.sqrt(variance.get(0) + varianceEpsilon),
offset.get(1) +
(x.get(0, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) /
Math.sqrt(variance.get(1) + varianceEpsilon),
offset.get(0) +
(x.get(1, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) /
Math.sqrt(variance.get(0) + varianceEpsilon),
offset.get(1) +
(x.get(1, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) /
Math.sqrt(variance.get(1) + varianceEpsilon)
]);
});

it('accepts a tensor-like object', async () => {
const x = [[[[[2, 4]]]], [[[[9, 23]]]]]; // 2x1x1x1x2
const mean = [1, 2];
const variance = [2, 3];
const offset = [3, 4];
const scale = [4, 5];

const varianceEpsilon = .001;

const result =
tf.batchNorm5d(x, mean, variance, offset, scale, varianceEpsilon);

expectArraysClose(await result.data(), [
offset[0] +
(x[0][0][0][0][0] - mean[0]) * scale[0] /
Math.sqrt(variance[0] + varianceEpsilon),
offset[1] +
(x[0][0][0][0][1] - mean[1]) * scale[1] /
Math.sqrt(variance[1] + varianceEpsilon),
offset[0] +
(x[1][0][0][0][0] - mean[0]) * scale[0] /
Math.sqrt(variance[0] + varianceEpsilon),
offset[1] +
(x[1][0][0][0][1] - mean[1]) * scale[1] /
Math.sqrt(variance[1] + varianceEpsilon)
]);
});

it('simple batchnorm5D gradients, 2x1x1x1x2', async () => {
const x = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]);
const mean = tf.tensor1d([1, 2]);
const variance = tf.tensor1d([2, 3]);
const offset = tf.tensor1d([3, 4]);
const scale = tf.tensor1d([2, 5]);

const varianceEpsilon = .001;

const dy = tf.tensor5d([-1, -1, -1, -1], [2, 1, 1, 1, 2]);
const gradX = tf.grad(
(x: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(x, dy);
expectArraysClose(await gradX.data(), [-1.414, -2.887, -1.414, -2.887]);
expect(gradX.shape).toEqual([2, 1, 1, 1, 2]);
const gradMean = tf.grad(
(mean: tf.Tensor1D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(mean, dy);
expectArraysClose(await gradMean.data(), [2.828, 5.773]);
expect(gradMean.shape).toEqual([2]);
const gradVariance = tf.grad(
(variance: tf.Tensor1D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(variance, dy);
expectArraysClose(await gradVariance.data(), [3.180, 11.060]);
expect(gradVariance.shape).toEqual([2]);
const gradOffset = tf.grad(
(offset: tf.Tensor1D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(offset, dy);
expectArraysClose(await gradOffset.data(), await dy.sum([0, 1, 2]).data());
expect(gradOffset.shape).toEqual([2]);
const gradScale = tf.grad(
(scale: tf.Tensor1D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(scale, dy);
expectArraysClose(await gradScale.data(), [-6.362, -13.277]);
expect(gradScale.shape).toEqual([2]);
});

it('batchnorm5D gradients, same shapes in x, mean and variance', async () => {
const x = tf.tensor5d([10, 20, 30, 40], [2, 1, 1, 1, 2]);
const mean = tf.tensor5d([0, 5, 10, 15], [2, 1, 1, 1, 2]);
const variance = tf.tensor5d([2, 4, 6, 8], [2, 1, 1, 1, 2]);
const scale = tf.tensor5d([2, 5, 2, 5], [2, 1, 1, 1, 2]);
const offset = tf.tensor5d([0, 0, 0, 0], [2, 1, 1, 1, 2]);

const varianceEpsilon = .001;

const dy = tf.tensor5d([-1, -1, -1, -1], [2, 1, 1, 1, 2]);
const gradX = tf.grad(
(x: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(x, dy);
expectArraysClose(await gradX.data(), [-1.414, -2.500, -0.816, -1.768]);
expect(gradX.shape).toEqual([2, 1, 1, 1, 2]);
const gradMean = tf.grad(
(mean: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(mean, dy);
expectArraysClose(await gradMean.data(), [1.414, 2.500, 0.816, 1.768]);
expect(gradMean.shape).toEqual([2, 1, 1, 1, 2]);
const gradVariance = tf.grad(
(variance: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(variance, dy);
expectArraysClose(await gradVariance.data(), [3.533, 4.686, 1.360, 2.762]);
expect(gradVariance.shape).toEqual([2, 1, 1, 1, 2]);
const gradOffset = tf.grad(
(offset: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(offset, dy);
expectArraysClose(await gradOffset.data(), await dy.data());
expect(gradOffset.shape).toEqual([2, 1, 1, 1, 2]);
const gradScale = tf.grad(
(scale: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(scale, dy);
expectArraysClose(await gradScale.data(), [-7.069, -7.499, -8.164, -8.838]);
expect(gradScale.shape).toEqual([2, 1, 1, 1, 2]);
});
});

describeWithFlags('batchNorm4D', ALL_ENVS, () => {
it('simple batchnorm4D, no offset or scale, 2x1x1x2', async () => {
const xT = tf.tensor4d([2, 4, 9, 23], [2, 1, 1, 2]);
Expand Down
18 changes: 10 additions & 8 deletions tfjs-core/src/ops/batchnorm_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@
* limitations under the License.
* =============================================================================
*/
import {Tensor, Tensor4D} from '../tensor';
import {Tensor, Tensor5D} from '../tensor';
import {Rank} from '../types';
import {reshape} from './reshape';

export function xAs4D<R extends Rank>(x: Tensor<R>) {
let x4D: Tensor4D;
export function xAs5D<R extends Rank>(x: Tensor<R>) {
let x5D: Tensor5D;
if (x.rank === 0 || x.rank === 1) {
x4D = reshape(x, [1, 1, 1, x.size]);
x5D = reshape(x, [1, 1, 1, 1, x.size]);
} else if (x.rank === 2) {
x4D = reshape(x, [1, 1, x.shape[0], x.shape[1]]);
x5D = reshape(x, [1, 1, 1, x.shape[0], x.shape[1]]);
} else if (x.rank === 3) {
x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
x5D = reshape(x, [1, 1, x.shape[0], x.shape[1], x.shape[2]]);
} else if (x.rank === 4) {
x5D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]);
} else {
x4D = x as Tensor4D;
x5D = x as Tensor5D;
}

return x4D;
return x5D;
}
1 change: 1 addition & 0 deletions tfjs-core/src/ops/ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export {batchNorm} from './batchnorm';
export {batchNorm2d} from './batchnorm2d';
export {batchNorm3d} from './batchnorm3d';
export {batchNorm4d} from './batchnorm4d';
export {batchNorm5d} from './batchnorm5d';
export {bincount} from './bincount';
export {bitwiseAnd} from './bitwise_and';
export {broadcastArgs} from './broadcast_args';
Expand Down