Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added triangle method
  • Loading branch information
BadMachine authored May 30, 2021
commit 5434cc009cfa6439820cf8a24aad04e741c4540a
93 changes: 91 additions & 2 deletions tfjs-core/src/ops/image/threshold.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,24 @@ import { add } from '../add';
import { mul } from '../mul';
import { div } from '../div';
import { sub } from '../sub';
import {atan} from '../atan';
import {tan} from '../tan';
import {slice} from '../slice';
import {argMax} from '../arg_max';
import {max} from '../max';
import {argMin} from '../arg_min';
import {less} from '../less';
import { round } from '../round';
import { where } from '../where';
import {reshape} from '../reshape';
import {sqrt} from '../sqrt';
import {pow} from '../pow';
import {abs} from '../abs';
import {gatherND} from '../gather_nd';
import {clone} from '../clone';
import {reverse} from '../reverse';
import {zerosLike} from '../zeros_like';
import { fill } from '../fill';
import {slice} from '../slice';
import { range } from '../range';
import { tensor } from '../tensor';
import * as util from '../../util';
Expand All @@ -44,7 +58,7 @@ import { convertToTensor } from '../../tensor_util_env';
* @param image 3d tensor of shape [imageHeight,imageWidth, depth],
* where imageHeight and imageWidth must be positive.The image color
* range should be [0, 255].
* @param method Optional string from `'binary' | 'otsu'`
* @param method Optional string from `'binary' | 'otsu' | 'triangle' `
* which specifies the method for thresholding. Defaults to 'binary'.
* @param inverted Optional boolean whichspecifies
* if colours should be inverted. Defaults to false.
Expand Down Expand Up @@ -109,6 +123,12 @@ function threshold_(
256);
$threshold = otsu($histogram, totalPixelsInImage);
}
else if(method === 'triangle'){
const $histogram = bincount(cast(round(grayscale), 'int32') as Tensor1D,
tensor([]),
256);
$threshold = triangle($histogram);
}

const invCondition = inverted ?
lessEqual(grayscale, $threshold) : greater(grayscale, $threshold);
Expand Down Expand Up @@ -160,4 +180,73 @@ function otsu(histogram: Tensor1D, total: number):Tensor1D {
return bestThresh;
}

function triangle (histogram: Tensor1D){

const histogramTrimmed = trimZeros(histogram);

const maxIdx = + argMax(histogramTrimmed).toString().replace(/[^0-9]/g, '');

const increasing = lessEqual(histogramTrimmed.shape[0]/2, maxIdx)
.toString().includes('true');

const sliced = increasing ? slice(histogramTrimmed, 0, maxIdx+1) :
slice(histogramTrimmed, maxIdx);

const cathetusB = max(sliced);

const cathetusA = sliced.shape;

const aTan = atan(div(cathetusB, cathetusA));

let derivativeTriangle = increasing ? range(1, cathetusA[0]+1, 1 ,'float32')
:
range(cathetusA[0], 0, -1 ,'float32') ;

derivativeTriangle = mul(derivativeTriangle, tan(aTan));

let cathetusDerivativeB = sub(derivativeTriangle, sliced);

cathetusDerivativeB = where(less(cathetusDerivativeB,0),
zerosLike(cathetusDerivativeB), cathetusDerivativeB);

const cathetusDerivativeA = div(cathetusDerivativeB, tan(aTan));

const hiposDerivate = sqrt( add( pow(cathetusDerivativeA,2 ),
pow(cathetusDerivativeB,2) ) );

const heights = div(mul(cathetusDerivativeB, cathetusDerivativeA)
, hiposDerivate);

const maxheightIdx = argMax(heights);

const valueInHistoSliced = gatherND(sliced, reshape(maxheightIdx, [1]));

const bestThresh = argMin( abs(sub(histogram, valueInHistoSliced)) );

return bestThresh;

}

function trimZeros(histogram: Tensor1D){

const histogramCopy = clone(histogram);

const divideToZero = div(histogramCopy,0);

const leftPointer = + argMax(divideToZero)
.toString().replace(/[^0-9]/g, '');

const leftSideClean = slice(histogramCopy, leftPointer);

const divideToMZero = div(reverse(leftSideClean), -0);

const rightPointer = + argMin(divideToMZero)
.toString().replace(/[^0-9]/g, '');

const clean = slice(leftSideClean,
0, leftSideClean.shape[0] - rightPointer );

return clean;
}

export const threshold = op({ threshold_ });