Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix webgl.
  • Loading branch information
lina128 committed Mar 8, 2021
commit 2c9e30e966260f596e457b8752cac4c2b278d488
59 changes: 31 additions & 28 deletions tfjs-backend-webgl/src/transform_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export class TransformProgram implements GPGPUProgram {
} else {
float sz2 = 2.0 * len;
if (inCoord < sz2) {
inCoord = sz2 * trunc(-inCoord / sz2) + inCoord;
inCoord = sz2 * float(int(float(-inCoord / sz2))) + inCoord;
}
inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0;
}
Expand All @@ -63,7 +63,7 @@ export class TransformProgram implements GPGPUProgram {
inCoord = 0.0;
} else {
float sz2 = 2.0 * len;
inCoord -= sz2 * trunc(inCoord / sz2);
inCoord -= sz2 * float(int(float(inCoord / sz2)));
if (inCoord >= len) {
inCoord = sz2 - inCoord - 1.0;
}
Expand All @@ -76,14 +76,14 @@ export class TransformProgram implements GPGPUProgram {
inCoord = 0.0;
} else {
float sz = len - 1.0;
inCoord += len * (trunc(-inCoord / sz) + 1.0);
inCoord += len * (float(int(float(-inCoord / sz))) + 1.0);
}
} else if (inCoord > len - 1.0) {
if (len <= 1.0) {
inCoord = 0.0;
} else {
float sz = len - 1.0;
inCoord -= len * trunc(inCoord / sz);
inCoord -= len * float(int(float(inCoord / sz)));
}
}
return clamp(inCoord, 0.0, len - 1.0);
Expand All @@ -109,10 +109,12 @@ export class TransformProgram implements GPGPUProgram {
void main() {
ivec4 coords = getOutputCoords();
float outputValue;
float batch = float(coords[0]);
float x = float(coords[2]);
float y = float(coords[1]);
float channel = float(coords[3]);
int batch = coords[0];
int x = coords[2];
int y = coords[1];
int channel = coords[3];
float xf = float(x);
float yf = float(y);
float a1 = getTransforms(batch, 0);
float a2 = getTransforms(batch, 1);
float a3 = getTransforms(batch, 2);
Expand All @@ -121,35 +123,36 @@ export class TransformProgram implements GPGPUProgram {
float b3 = getTransforms(batch, 5);
float c1 = getTransforms(batch, 6);
float c2 = getTransforms(batch, 7);
float projection = c1 * x + c2 * y + 1.0;
float projection = c1 * xf + c2 * yf + 1.0;
if (projection == 0.0) {
outputValue = float(${fillValue});
} else {
float inX = (a1 * x + a2 * y + a3) / projection;
float inY = (b1 * x + b2 * y + b3) / projection;
float mapX = mapCoord(inX, ${imageWidth});
float mapY = mapCoord(inY, ${imageHeight});
float inX = (a1 * xf + a2 * yf + a3) / projection;
float inY = (b1 * xf + b2 * yf + b3) / projection;
float mapX = mapCoord(inX, float(${imageWidth}));
float mapY = mapCoord(inY, float(${imageHeight}));

if (${interpolationModeId} == 1) {
int coordY = round(mapY);
int coordX = round(mapX);
int coordY = int(round(mapY));
int coordX = int(round(mapX));
outputValue = readWithFillValue(batch, coordY, coordX,
coords[3]);
channel);
} else {
int yFloor = floor(mapY);
int xFloor = floor(mapX);
int yFloor = int(floor(mapY));
int xFloor = int(floor(mapX));
int yCeil = yFloor + 1;
int xCeil = xFloor + 1;
float valueYFloor = (xCeil - x) / (xCeil - xFloor) *
readWithFillValue(batch, yFloor, xFloor, coords[3]) +
(x - xFloor) *
readWithFillValue(batch, yFloor, xCeil, coords[3]);
float valueYCeil = (xCeil - x) *
readWithFillValue(batch, yCeil, xFloor, coords[3]) +
(x - xFloor) *
readWithFillValue(batch, yCeil, xCeil, coords[3]);
outputValue = (yCeil - y) * valueYFloor +
(y - yFloor) * valueYCeil;
float valueYFloor = float((xCeil - x)) /
float((xCeil - xFloor)) *
readWithFillValue(batch, yFloor, xFloor, channel) +
float((x - xFloor)) *
readWithFillValue(batch, yFloor, xCeil, channel);
float valueYCeil = float((xCeil - x)) *
readWithFillValue(batch, yCeil, xFloor, channel) +
float((x - xFloor)) *
readWithFillValue(batch, yCeil, xCeil, channel);
outputValue = float((yCeil - y)) * valueYFloor +
float((y - yFloor)) * valueYCeil;
}
}
setOutput(outputValue);
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/ops/image/transform_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {BROWSER_ENVS, describeWithFlags} from '../../jasmine_util';
import {expectArraysClose} from '../../test_util';

describeWithFlags('transform', BROWSER_ENVS, () => {
fit('extreme projective transform.', async () => {
it('extreme projective transform.', async () => {
const images = tf.tensor4d(
[1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1], [1, 4, 4, 1]);
const transform = tf.tensor2d([1, 0, 0, 0, 1, 0, -1, 0], [1, 8]);
Expand Down