Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix lint.
  • Loading branch information
lina128 committed Mar 8, 2021
commit 6ef07e98adeab546d66852e1cc891978738703ad
7 changes: 5 additions & 2 deletions tfjs-backend-cpu/src/kernels/Transform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/

import {KernelConfig, KernelFunc, NumericDataType, TensorInfo, Transform, TransformAttrs, TransformInputs, TypedArray, util} from '@tensorflow/tfjs-core';
import {clamp} from '@tensorflow/tfjs-core/dist/util';

import {MathBackendCPU} from '../backend_cpu';

Expand Down Expand Up @@ -87,6 +86,10 @@ export function transform(args: {
imageVals, imageHeight, imageWidth, batchStride, rowStride,
colStride, b, y, x, channel, fillValue);
break;
default:
throw new Error(
`Error in Transform: Expect 'nearest' or ` +
`'bilinear', but got ${interpolation}`);
}

const ind =
Expand Down Expand Up @@ -183,7 +186,7 @@ function mapCoordConstant(outCoord: number, len: number): number {
}

function mapCoordNearest(outCoord: number, len: number): number {
return clamp(0, outCoord, len - 1);
return util.clamp(0, outCoord, len - 1);
}

function readWithFillValue(
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgl/src/kernels/Transform.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand Down
6 changes: 5 additions & 1 deletion tfjs-backend-webgl/src/transform_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ export class TransformProgram implements GPGPUProgram {
case 'nearest':
fillModeId = 4;
break;
default:
fillModeId = 1;
break;
}
this.userCode = `
float mapCoord(float outCoord, float len) {
Expand All @@ -54,7 +57,8 @@ export class TransformProgram implements GPGPUProgram {
} else {
float sz2 = 2.0 * len;
if (inCoord < sz2) {
inCoord = sz2 * float(int(float(-inCoord / sz2))) + inCoord;
inCoord = sz2 * float(int(float(-inCoord / sz2))) +
inCoord;
}
inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0;
}
Expand Down
37 changes: 25 additions & 12 deletions tfjs-core/src/ops/image/transform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,32 @@ import * as util from '../../util';
import {op} from '../operation';

/**
* Rotates the input image tensor counter-clockwise with an optional offset
* center of rotation. Currently available in the CPU, WebGL, and WASM backends.
* Applies the given transform(s) to the image(s).
*
* @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
* @param radians The amount of rotation.
* @param fillValue The value to fill in the empty space leftover
* after rotation. Can be either a single grayscale value (0-255), or an
* array of three numbers `[red, green, blue]` specifying the red, green,
* and blue channels. Defaults to `0` (black).
* @param center The center of rotation. Can be either a single value (0-1), or
* an array of two numbers `[centerX, centerY]`. Defaults to `0.5` (rotates
* the image around its center).
* @param transforms Projective transform matrix/matrices. A vector of length
* 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0
* b1, b2, c0, c1], then it maps the output point (x, y) to a transformed
* input point (x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k),
* where k = c0 x + c1 y + 1. The transforms are inverted compared to the
* transform mapping input points to output points.
* @param interpolation Interpolation mode.
* Supported values: 'nearest', 'bilinear'. Default to 'nearest'.
* @param fillMode Points outside the boundaries of the input are filled
* according to the given mode, one of 'constant', 'reflect', 'wrap',
* 'nearest'. Default to 'constant'.
* 'reflect': (d c b a | a b c d | d c b a ) The input is extended by
* reflecting about the edge of the last pixel.
* 'constant': (k k k k | a b c d | k k k k) The input is extended by
* filling all values beyond the edge with the same constant value k.
* 'wrap': (a b c d | a b c d | a b c d) The input is extended by
* wrapping around to the opposite edge.
* 'nearest': (a a a a | a b c d | d d d d) The input is extended by
* the nearest pixel.
* @param fillValue A float represents the value to be filled outside the
* boundaries when fillMode is 'constant'.
* @param Output dimension after the transform, [height, width]. If undefined,
* output is the same size as input image.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
Expand Down Expand Up @@ -73,8 +87,7 @@ function transform_(
TransformAttrs = {interpolation, fillMode, fillValue, outputShape};

return ENGINE.runKernel(
Transform, inputs as {} as NamedTensorMap,
attrs as {} as NamedAttrMap) as Tensor4D;
Transform, inputs as {} as NamedTensorMap, attrs as {} as NamedAttrMap);
}

export const transform = op({transform_});
10 changes: 5 additions & 5 deletions tfjs-core/src/ops/image/transform_test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand All @@ -18,7 +18,7 @@ import * as tf from '../../index';
import {BROWSER_ENVS, describeWithFlags} from '../../jasmine_util';
import {expectArraysClose} from '../../test_util';

describeWithFlags('transform', BROWSER_ENVS, () => {
describeWithFlags('image.transform', BROWSER_ENVS, () => {
it('extreme projective transform.', async () => {
const images = tf.tensor4d(
[1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1], [1, 4, 4, 1]);
Expand All @@ -33,9 +33,9 @@ describeWithFlags('transform', BROWSER_ENVS, () => {

it('static output shape.', async () => {
const images = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
const transform = tf.randomUniform([1, 8], -1, 1) as tf.Tensor2D;
const transformedImages =
tf.image.transform(images, transform, 'nearest', 'constant', 0, [3, 5]);
const transform = tf.randomUniform([1, 8], -1, 1);
const transformedImages = tf.image.transform(
images, transform as tf.Tensor2D, 'nearest', 'constant', 0, [3, 5]);

expectArraysClose(transformedImages.shape, [1, 3, 5, 1]);
});
Expand Down
4 changes: 3 additions & 1 deletion tfjs-node/src/run_tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ const IGNORE_LIST: string[] = [
// tslint:disable-next-line:max-line-length
'pool test-tensorflow {} avg x=[2,2,3] f=[1,1] s=2 p=1 fractional outputs default rounding',
// not available in tf yet.
'denseBincount'
'denseBincount',
// only available in tf addon.
'image.transform'
];

if (process.platform === 'win32') {
Expand Down