diff --git a/tfjs-core/src/gradients/FusedBatchNorm_grad.ts b/tfjs-core/src/gradients/FusedBatchNorm_grad.ts index a58f535a4c5..b471fec4d5d 100644 --- a/tfjs-core/src/gradients/FusedBatchNorm_grad.ts +++ b/tfjs-core/src/gradients/FusedBatchNorm_grad.ts @@ -58,7 +58,8 @@ export const fusedBatchNormGradConfig: GradConfig = { return reshape( mul(mul(dy, tile( - reshape(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), + reshape(oneOverSqrtVariance, + [1, 1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape); diff --git a/tfjs-core/src/ops/batchnorm.ts b/tfjs-core/src/ops/batchnorm.ts index 5b3f1db27d3..34773c2caf6 100644 --- a/tfjs-core/src/ops/batchnorm.ts +++ b/tfjs-core/src/ops/batchnorm.ts @@ -18,13 +18,13 @@ import {ENGINE} from '../engine'; import {FusedBatchNorm, FusedBatchNormAttrs, FusedBatchNormInputs} from '../kernel_names'; import {NamedAttrMap} from '../kernel_registry'; -import {Tensor, Tensor1D, Tensor4D} from '../tensor'; +import {Tensor, Tensor1D, Tensor5D} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {Rank, TensorLike} from '../types'; import * as util from '../util'; -import {xAs4D} from './batchnorm_util'; +import {xAs5D} from './batchnorm_util'; import {op} from './operation'; import {reshape} from './reshape'; @@ -88,10 +88,10 @@ function batchNorm_( () => 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.'); - const x4D: Tensor4D = xAs4D($x); + const x5D: Tensor5D = xAs5D($x); const inputs: FusedBatchNormInputs = { - x: x4D, + x: x5D, scale: $scale, offset: $offset, mean: $mean, diff --git a/tfjs-core/src/ops/batchnorm5d.ts b/tfjs-core/src/ops/batchnorm5d.ts new file mode 100644 index 00000000000..26ab1a0005f --- /dev/null +++ b/tfjs-core/src/ops/batchnorm5d.ts @@ -0,0 +1,79 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Tensor1D, Tensor5D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {batchNorm} from './batchnorm'; +import {op} from './operation'; + +/** + * Batch normalization, strictly for 5D. For the more relaxed version, see + * `tf.batchNorm`. + * + * @param x The input Tensor. + * @param mean A mean Tensor. + * @param variance A variance Tensor. + * @param offset An offset Tensor. + * @param scale A scale Tensor. + * @param varianceEpsilon A small float number to avoid dividing by 0. + */ +function batchNorm5d_( + x: Tensor5D|TensorLike, mean: Tensor5D|Tensor1D|TensorLike, + variance: Tensor5D|Tensor1D|TensorLike, + offset?: Tensor5D|Tensor1D|TensorLike, scale?: Tensor5D|Tensor1D|TensorLike, + varianceEpsilon?: number): Tensor5D { + const $x = convertToTensor(x, 'x', 'batchNorm'); + const $mean = convertToTensor(mean, 'mean', 'batchNorm'); + const $variance = convertToTensor(variance, 'variance', 'batchNorm'); + let $scale: Tensor5D|Tensor1D; + if (scale != null) { + $scale = convertToTensor(scale, 'scale', 'batchNorm'); + } + let $offset: Tensor5D|Tensor1D; + if (offset != null) { + $offset = convertToTensor(offset, 'offset', 'batchNorm'); + } + util.assert( + $x.rank === 5, + () => `Error in batchNorm5D: x must be rank 5 but got rank ` + + `${$x.rank}.`); + util.assert( + $mean.rank === 5 || $mean.rank === 1, + () => `Error in batchNorm5D: mean must be rank 5 or rank 1 but ` + + `got rank ${$mean.rank}.`); + util.assert( + $variance.rank === 5 || $variance.rank === 1, + () => `Error in batchNorm5D: variance must be rank 5 or rank 1 ` + + `but got rank ${$variance.rank}.`); + if ($scale != null) { + util.assert( + $scale.rank === 5 || $scale.rank === 1, + () => `Error in batchNorm5D: scale must be rank 5 or rank 1 ` + + `but got rank ${$scale.rank}.`); + } + if ($offset != null) { + util.assert( + $offset.rank === 5 || $offset.rank === 1, + () => `Error in batchNorm5D: offset must be rank 5 or rank 1 ` + + `but got rank ${$offset.rank}.`); + } + return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon); +} + +export const batchNorm5d = op({batchNorm5d_}); diff --git a/tfjs-core/src/ops/batchnorm_test.ts b/tfjs-core/src/ops/batchnorm_test.ts index 97da7f4d538..1b9b67bc826 100644 --- a/tfjs-core/src/ops/batchnorm_test.ts +++ b/tfjs-core/src/ops/batchnorm_test.ts @@ -19,6 +19,224 @@ import * as tf from '../index'; import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose} from '../test_util'; +describe('batchNorm5D', () => { + it('simple batchnorm5D, no offset or scale, 2x1x1x1x2', async () => { + const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]); + const meanT = tf.tensor1d([1, 2]); + const varianceT = tf.tensor1d([2, 3]); + const varianceEpsilon = .001; + + const result = tf.batchNorm5d( + xT, meanT, varianceT, undefined, undefined, varianceEpsilon); + + const x = await xT.array(); + const mean = await meanT.array(); + const variance = await varianceT.array(); + expectArraysClose(await result.data(), [ + (x[0][0][0][0][0] - mean[0]) * 1 / + Math.sqrt(variance[0] + varianceEpsilon), + (x[0][0][0][0][1] - mean[1]) * 1 / + Math.sqrt(variance[1] + varianceEpsilon), + (x[1][0][0][0][0] - mean[0]) * 1 / + Math.sqrt(variance[0] + varianceEpsilon), + (x[1][0][0][0][1] - mean[1]) * 1 / + Math.sqrt(variance[1] + varianceEpsilon) + ]); + }); + + it('simple batchnorm5D, no offset, 2x1x1x1x2', async () => { + const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]); + const meanT = tf.tensor1d([1, 2]); + const varianceT = tf.tensor1d([2, 3]); + const scaleT = tf.tensor1d([4, 5]); + const varianceEpsilon = .001; + + const result = tf.batchNorm5d( + xT, meanT, varianceT, undefined, scaleT, varianceEpsilon); + const x = await xT.buffer(); + const mean = await meanT.buffer(); + const variance = await varianceT.buffer(); + const scale = await scaleT.buffer(); + + expectArraysClose(await result.data(), [ + (x.get(0, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) / + Math.sqrt(variance.get(0) + varianceEpsilon), + (x.get(0, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) / + Math.sqrt(variance.get(1) + varianceEpsilon), + (x.get(1, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) / + Math.sqrt(variance.get(0) + varianceEpsilon), + (x.get(1, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) / + Math.sqrt(variance.get(1) + varianceEpsilon) + ]); + }); + + it('simple batchnorm5D, no scale, 2x1x1x1x2', async () => { + const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]); + const meanT = tf.tensor1d([1, 2]); + const varianceT = tf.tensor1d([2, 3]); + const offsetT = tf.tensor1d([4, 5]); + + const varianceEpsilon = .001; + + const result = tf.batchNorm5d( + xT, meanT, varianceT, offsetT, undefined, varianceEpsilon); + const x = await xT.buffer(); + const mean = await meanT.buffer(); + const variance = await varianceT.buffer(); + const offset = await offsetT.buffer(); + + expectArraysClose(await result.data(), [ + offset.get(0) + + (x.get(0, 0, 0, 0, 0) - mean.get(0)) * 1 / + Math.sqrt(variance.get(0) + varianceEpsilon), + offset.get(1) + + (x.get(0, 0, 0, 0, 1) - mean.get(1)) * 1 / + Math.sqrt(variance.get(1) + varianceEpsilon), + offset.get(0) + + (x.get(1, 0, 0, 0, 0) - mean.get(0)) * 1 / + Math.sqrt(variance.get(0) + varianceEpsilon), + offset.get(1) + + (x.get(1, 0, 0, 0, 1) - mean.get(1)) * 1 / + Math.sqrt(variance.get(1) + varianceEpsilon) + ]); + }); + + it('simple batchnorm5D, 2x1x1x1x2', async () => { + const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]); + const meanT = tf.tensor1d([1, 2]); + const varianceT = tf.tensor1d([2, 3]); + const offsetT = tf.tensor1d([3, 4]); + const scaleT = tf.tensor1d([4, 5]); + + const varianceEpsilon = .001; + + const result = + tf.batchNorm5d(xT, meanT, varianceT, offsetT, scaleT, varianceEpsilon); + const x = await xT.buffer(); + const mean = await meanT.buffer(); + const variance = await varianceT.buffer(); + const scale = await scaleT.buffer(); + const offset = await offsetT.buffer(); + + expectArraysClose(await result.data(), [ + offset.get(0) + + (x.get(0, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) / + Math.sqrt(variance.get(0) + varianceEpsilon), + offset.get(1) + + (x.get(0, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) / + Math.sqrt(variance.get(1) + varianceEpsilon), + offset.get(0) + + (x.get(1, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) / + Math.sqrt(variance.get(0) + varianceEpsilon), + offset.get(1) + + (x.get(1, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) / + Math.sqrt(variance.get(1) + varianceEpsilon) + ]); + }); + + it('accepts a tensor-like object', async () => { + const x = [[[[[2, 4]]]], [[[[9, 23]]]]]; // 2x1x1x1x2 + const mean = [1, 2]; + const variance = [2, 3]; + const offset = [3, 4]; + const scale = [4, 5]; + + const varianceEpsilon = .001; + + const result = + tf.batchNorm5d(x, mean, variance, offset, scale, varianceEpsilon); + + expectArraysClose(await result.data(), [ + offset[0] + + (x[0][0][0][0][0] - mean[0]) * scale[0] / + Math.sqrt(variance[0] + varianceEpsilon), + offset[1] + + (x[0][0][0][0][1] - mean[1]) * scale[1] / + Math.sqrt(variance[1] + varianceEpsilon), + offset[0] + + (x[1][0][0][0][0] - mean[0]) * scale[0] / + Math.sqrt(variance[0] + varianceEpsilon), + offset[1] + + (x[1][0][0][0][1] - mean[1]) * scale[1] / + Math.sqrt(variance[1] + varianceEpsilon) + ]); + }); + + it('simple batchnorm5D gradients, 2x1x1x1x2', async () => { + const x = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]); + const mean = tf.tensor1d([1, 2]); + const variance = tf.tensor1d([2, 3]); + const offset = tf.tensor1d([3, 4]); + const scale = tf.tensor1d([2, 5]); + + const varianceEpsilon = .001; + + const dy = tf.tensor5d([-1, -1, -1, -1], [2, 1, 1, 1, 2]); + const gradX = tf.grad( + (x: tf.Tensor5D) => tf.batchNorm5d( + x, mean, variance, offset, scale, varianceEpsilon))(x, dy); + expectArraysClose(await gradX.data(), [-1.414, -2.887, -1.414, -2.887]); + expect(gradX.shape).toEqual([2, 1, 1, 1, 2]); + const gradMean = tf.grad( + (mean: tf.Tensor1D) => tf.batchNorm5d( + x, mean, variance, offset, scale, varianceEpsilon))(mean, dy); + expectArraysClose(await gradMean.data(), [2.828, 5.773]); + expect(gradMean.shape).toEqual([2]); + const gradVariance = tf.grad( + (variance: tf.Tensor1D) => tf.batchNorm5d( + x, mean, variance, offset, scale, varianceEpsilon))(variance, dy); + expectArraysClose(await gradVariance.data(), [3.180, 11.060]); + expect(gradVariance.shape).toEqual([2]); + const gradOffset = tf.grad( + (offset: tf.Tensor1D) => tf.batchNorm5d( + x, mean, variance, offset, scale, varianceEpsilon))(offset, dy); + expectArraysClose(await gradOffset.data(), await dy.sum([0, 1, 2]).data()); + expect(gradOffset.shape).toEqual([2]); + const gradScale = tf.grad( + (scale: tf.Tensor1D) => tf.batchNorm5d( + x, mean, variance, offset, scale, varianceEpsilon))(scale, dy); + expectArraysClose(await gradScale.data(), [-6.362, -13.277]); + expect(gradScale.shape).toEqual([2]); + }); + + it('batchnorm5D gradients, same shapes in x, mean and variance', async () => { + const x = tf.tensor5d([10, 20, 30, 40], [2, 1, 1, 1, 2]); + const mean = tf.tensor5d([0, 5, 10, 15], [2, 1, 1, 1, 2]); + const variance = tf.tensor5d([2, 4, 6, 8], [2, 1, 1, 1, 2]); + const scale = tf.tensor5d([2, 5, 2, 5], [2, 1, 1, 1, 2]); + const offset = tf.tensor5d([0, 0, 0, 0], [2, 1, 1, 1, 2]); + + const varianceEpsilon = .001; + + const dy = tf.tensor5d([-1, -1, -1, -1], [2, 1, 1, 1, 2]); + const gradX = tf.grad( + (x: tf.Tensor5D) => tf.batchNorm5d( + x, mean, variance, offset, scale, varianceEpsilon))(x, dy); + expectArraysClose(await gradX.data(), [-1.414, -2.500, -0.816, -1.768]); + expect(gradX.shape).toEqual([2, 1, 1, 1, 2]); + const gradMean = tf.grad( + (mean: tf.Tensor5D) => tf.batchNorm5d( + x, mean, variance, offset, scale, varianceEpsilon))(mean, dy); + expectArraysClose(await gradMean.data(), [1.414, 2.500, 0.816, 1.768]); + expect(gradMean.shape).toEqual([2, 1, 1, 1, 2]); + const gradVariance = tf.grad( + (variance: tf.Tensor5D) => tf.batchNorm5d( + x, mean, variance, offset, scale, varianceEpsilon))(variance, dy); + expectArraysClose(await gradVariance.data(), [3.533, 4.686, 1.360, 2.762]); + expect(gradVariance.shape).toEqual([2, 1, 1, 1, 2]); + const gradOffset = tf.grad( + (offset: tf.Tensor5D) => tf.batchNorm5d( + x, mean, variance, offset, scale, varianceEpsilon))(offset, dy); + expectArraysClose(await gradOffset.data(), await dy.data()); + expect(gradOffset.shape).toEqual([2, 1, 1, 1, 2]); + const gradScale = tf.grad( + (scale: tf.Tensor5D) => tf.batchNorm5d( + x, mean, variance, offset, scale, varianceEpsilon))(scale, dy); + expectArraysClose(await gradScale.data(), [-7.069, -7.499, -8.164, -8.838]); + expect(gradScale.shape).toEqual([2, 1, 1, 1, 2]); + }); +}); + describeWithFlags('batchNorm4D', ALL_ENVS, () => { it('simple batchnorm4D, no offset or scale, 2x1x1x2', async () => { const xT = tf.tensor4d([2, 4, 9, 23], [2, 1, 1, 2]); diff --git a/tfjs-core/src/ops/batchnorm_util.ts b/tfjs-core/src/ops/batchnorm_util.ts index f85b4c0975d..04c76f56a9b 100644 --- a/tfjs-core/src/ops/batchnorm_util.ts +++ b/tfjs-core/src/ops/batchnorm_util.ts @@ -14,21 +14,23 @@ * limitations under the License. * ============================================================================= */ -import {Tensor, Tensor4D} from '../tensor'; +import {Tensor, Tensor5D} from '../tensor'; import {Rank} from '../types'; import {reshape} from './reshape'; -export function xAs4D(x: Tensor) { - let x4D: Tensor4D; +export function xAs5D(x: Tensor) { + let x5D: Tensor5D; if (x.rank === 0 || x.rank === 1) { - x4D = reshape(x, [1, 1, 1, x.size]); + x5D = reshape(x, [1, 1, 1, 1, x.size]); } else if (x.rank === 2) { - x4D = reshape(x, [1, 1, x.shape[0], x.shape[1]]); + x5D = reshape(x, [1, 1, 1, x.shape[0], x.shape[1]]); } else if (x.rank === 3) { - x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]); + x5D = reshape(x, [1, 1, x.shape[0], x.shape[1], x.shape[2]]); + } else if (x.rank === 4) { + x5D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]); } else { - x4D = x as Tensor4D; + x5D = x as Tensor5D; } - return x4D; + return x5D; } diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 0fbf5dcee1c..24c6eb1de7e 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -38,6 +38,7 @@ export {batchNorm} from './batchnorm'; export {batchNorm2d} from './batchnorm2d'; export {batchNorm3d} from './batchnorm3d'; export {batchNorm4d} from './batchnorm4d'; +export {batchNorm5d} from './batchnorm5d'; export {bincount} from './bincount'; export {bitwiseAnd} from './bitwise_and'; export {broadcastArgs} from './broadcast_args';