diff --git a/tfjs-converter/package.json b/tfjs-converter/package.json index 70a79fcf99f..26230e2a477 100644 --- a/tfjs-converter/package.json +++ b/tfjs-converter/package.json @@ -31,7 +31,8 @@ "opn": "~5.1.0", "protobufjs": "~6.11.3", "ts-node": "~8.8.2", - "typescript": "3.5.3" + "typescript": "3.5.3", + "yalc": "~1.0.0-pre.50" }, "scripts": { "build": "bazel build :tfjs-converter_pkg", diff --git a/tfjs-converter/scripts/kernels_to_ops.ts b/tfjs-converter/scripts/kernels_to_ops.ts index bdbbce3a26f..ca689b6a535 100644 --- a/tfjs-converter/scripts/kernels_to_ops.ts +++ b/tfjs-converter/scripts/kernels_to_ops.ts @@ -98,16 +98,16 @@ function getKernelMappingForFile(source: SourceFile) { const callExprs = clausePart.getDescendantsOfKind(SyntaxKind.CallExpression); const tfOpsCallExprs = - callExprs.filter(expr => expr.getText().match(/tfOps/)); + callExprs.filter(expr => expr.getText().match(/ops/)); const tfSymbols: Set = new Set(); for (const tfOpsCall of tfOpsCallExprs) { const tfOpsCallStr = tfOpsCall.getText(); - const functionCallMatcher = /(tfOps\.([\w\.]*)\()/g; + const functionCallMatcher = /(ops\.([\w\.]*)\()/g; const matches = tfOpsCallStr.match(functionCallMatcher); if (matches != null && matches.length > 0) { for (const match of matches) { // extract the method name (and any namespaces used to call it) - const symbolMatcher = /(tfOps\.([\w\.]*)\()/; + const symbolMatcher = /(ops\.([\w\.]*)\()/; const symbol = match.match(symbolMatcher)[2]; tfSymbols.add(symbol); } diff --git a/tfjs-converter/src/BUILD.bazel b/tfjs-converter/src/BUILD.bazel index fd258aeaa4e..d6ced5f006f 100644 --- a/tfjs-converter/src/BUILD.bazel +++ b/tfjs-converter/src/BUILD.bazel @@ -21,6 +21,7 @@ package(default_visibility = ["//visibility:public"]) TEST_SRCS = [ "**/*_test.ts", "run_tests.ts", + "operations/executors/spy_ops.ts", ] # Used for test-snippets diff --git a/tfjs-converter/src/executor/graph_model.ts b/tfjs-converter/src/executor/graph_model.ts index 12bb32bf17f..be4a91b97df 100644 --- a/tfjs-converter/src/executor/graph_model.ts +++ b/tfjs-converter/src/executor/graph_model.ts @@ -48,6 +48,7 @@ export class GraphModel private initializer: GraphExecutor; private resourceManager: ResourceManager; private signature: tensorflow.ISignatureDef; + private readonly io: typeof io; // Returns the version information for the tensorflow model GraphDef. get modelVersion(): string { @@ -93,7 +94,8 @@ export class GraphModel */ constructor( private modelUrl: ModelURL, - private loadOptions: io.LoadOptions = {}) { + private loadOptions: io.LoadOptions = {}, tfio = io) { + this.io = tfio; if (loadOptions == null) { this.loadOptions = {}; } @@ -107,14 +109,16 @@ export class GraphModel // Path is an IO Handler. this.handler = path as IOHandler; } else if (this.loadOptions.requestInit != null) { - this.handler = io.browserHTTPRequest(path as string, this.loadOptions) as - IOHandler; + this.handler = this.io + .browserHTTPRequest(path as string, this.loadOptions) as IOHandler; } else { - const handlers = io.getLoadHandlers(path as string, this.loadOptions); + const handlers = + this.io.getLoadHandlers(path as string, this.loadOptions); if (handlers.length === 0) { // For backward compatibility: if no load handler can be found, // assume it is a relative http path. - handlers.push(io.browserHTTPRequest(path as string, this.loadOptions)); + handlers.push( + this.io.browserHTTPRequest(path as string, this.loadOptions)); } else if (handlers.length > 1) { throw new Error( `Found more than one (${handlers.length}) load handlers for ` + @@ -171,8 +175,8 @@ export class GraphModel this.signature = signature; this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`; - const weightMap = - io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs); + const weightMap = this.io.decodeWeights( + this.artifacts.weightData, this.artifacts.weightSpecs); this.executor = new GraphExecutor( OperationMapper.Instance.transformGraph(graph, this.signature)); this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap); @@ -243,7 +247,7 @@ export class GraphModel async save(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig): Promise { if (typeof handlerOrURL === 'string') { - const handlers = io.getSaveHandlers(handlerOrURL); + const handlers = this.io.getSaveHandlers(handlerOrURL); if (handlers.length === 0) { throw new Error( `Cannot find any save handlers for URL '${handlerOrURL}'`); @@ -452,8 +456,8 @@ export class GraphModel * @doc {heading: 'Models', subheading: 'Loading'} */ export async function loadGraphModel( - modelUrl: string|io.IOHandler, - options: io.LoadOptions = {}): Promise { + modelUrl: string|io.IOHandler, options: io.LoadOptions = {}, + tfio = io): Promise { if (modelUrl == null) { throw new Error( 'modelUrl in loadGraphModel() cannot be null. Please provide a url ' + @@ -466,7 +470,7 @@ export async function loadGraphModel( if (options.fromTFHub && typeof modelUrl === 'string') { modelUrl = getTFHubUrl(modelUrl); } - const model = new GraphModel(modelUrl, options); + const model = new GraphModel(modelUrl, options, tfio); await model.load(); return model; } diff --git a/tfjs-converter/src/executor/graph_model_test.ts b/tfjs-converter/src/executor/graph_model_test.ts index 1fbb095c595..aa1d3966b39 100644 --- a/tfjs-converter/src/executor/graph_model_test.ts +++ b/tfjs-converter/src/executor/graph_model_test.ts @@ -23,6 +23,7 @@ import {deregisterOp, registerOp} from '../operations/custom_op/register'; import {GraphNode} from '../operations/types'; import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model'; +import {RecursiveSpy, spyOnAllFunctions} from '../operations/executors/spy_ops'; const HOST = 'http://example.org'; const MODEL_URL = `${HOST}/model.json`; @@ -368,6 +369,12 @@ describe('loadSync', () => { }); describe('loadGraphModel', () => { + let spyIo: RecursiveSpy; + + beforeEach(() => { + spyIo = spyOnAllFunctions(io); + }); + it('Pass a custom io handler', async () => { const customLoader: tfc.io.IOHandler = { load: async () => { @@ -397,11 +404,11 @@ describe('loadGraphModel', () => { it('Pass a fetchFunc', async () => { const fetchFunc = () => {}; - spyOn(tfc.io, 'getLoadHandlers').and.returnValue([ + spyIo.getLoadHandlers.and.returnValue([ CUSTOM_HTTP_MODEL_LOADER ]); - await loadGraphModel(MODEL_URL, {fetchFunc}); - expect(tfc.io.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc}); + await loadGraphModel(MODEL_URL, {fetchFunc}, spyIo); + expect(spyIo.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc}); }); }); @@ -436,13 +443,16 @@ describe('loadGraphModelSync', () => { }); describe('Model', () => { + let spyIo: RecursiveSpy; + beforeEach(() => { - model = new GraphModel(MODEL_URL); + spyIo = spyOnAllFunctions(io); + model = new GraphModel(MODEL_URL, undefined, spyIo); }); describe('custom model', () => { beforeEach(() => { - spyOn(tfc.io, 'getLoadHandlers').and.returnValue([ + spyIo.getLoadHandlers.and.returnValue([ CUSTOM_HTTP_MODEL_LOADER ]); registerOp('CustomOp', (nodeValue: GraphNode) => { @@ -484,11 +494,10 @@ describe('Model', () => { describe('simple model', () => { beforeEach(() => { - spyOn(tfc.io, 'getLoadHandlers').and.returnValue([ + spyIo.getLoadHandlers.and.returnValue([ SIMPLE_HTTP_MODEL_LOADER ]); - spyOn(tfc.io, 'browserHTTPRequest') - .and.returnValue(SIMPLE_HTTP_MODEL_LOADER); + spyIo.browserHTTPRequest.and.returnValue(SIMPLE_HTTP_MODEL_LOADER); }); it('load', async () => { const loaded = await model.load(); @@ -621,7 +630,7 @@ describe('Model', () => { describe('dispose', () => { it('should dispose the weights', async () => { const numOfTensors = tfc.memory().numTensors; - model = new GraphModel(MODEL_URL); + model = new GraphModel(MODEL_URL, undefined, spyIo); await model.load(); model.dispose(); @@ -639,7 +648,7 @@ describe('Model', () => { describe('relative path', () => { beforeEach(() => { - model = new GraphModel(RELATIVE_MODEL_URL); + model = new GraphModel(RELATIVE_MODEL_URL, undefined, spyIo); }); it('load', async () => { @@ -649,14 +658,14 @@ describe('Model', () => { }); it('should loadGraphModel', async () => { - const model = await loadGraphModel(MODEL_URL); + const model = await loadGraphModel(MODEL_URL, undefined, spyIo); expect(model).not.toBeUndefined(); }); it('should loadGraphModel with request options', async () => { const model = await loadGraphModel( - MODEL_URL, {requestInit: {credentials: 'include'}}); - expect(tfc.io.browserHTTPRequest).toHaveBeenCalledWith(MODEL_URL, { + MODEL_URL, {requestInit: {credentials: 'include'}}, spyIo); + expect(spyIo.browserHTTPRequest).toHaveBeenCalledWith(MODEL_URL, { requestInit: {credentials: 'include'} }); expect(model).not.toBeUndefined(); @@ -664,7 +673,7 @@ describe('Model', () => { it('should call loadGraphModel for TfHub Module', async () => { const url = `${HOST}/model/1`; - const model = await loadGraphModel(url, {fromTFHub: true}); + const model = await loadGraphModel(url, {fromTFHub: true}, spyIo); expect(model).toBeDefined(); }); @@ -686,11 +695,10 @@ describe('Model', () => { describe('control flow model', () => { beforeEach(() => { - spyOn(tfc.io, 'getLoadHandlers').and.returnValue([ + spyIo.getLoadHandlers.and.returnValue([ CONTROL_FLOW_HTTP_MODEL_LOADER ]); - spyOn(tfc.io, 'browserHTTPRequest') - .and.returnValue(CONTROL_FLOW_HTTP_MODEL_LOADER); + spyIo.browserHTTPRequest.and.returnValue(CONTROL_FLOW_HTTP_MODEL_LOADER); }); describe('save', () => { @@ -777,11 +785,10 @@ describe('Model', () => { }; describe('dynamic shape model', () => { beforeEach(() => { - spyOn(tfc.io, 'getLoadHandlers').and.returnValue([ + spyIo.getLoadHandlers.and.returnValue([ DYNAMIC_HTTP_MODEL_LOADER ]); - spyOn(tfc.io, 'browserHTTPRequest') - .and.returnValue(DYNAMIC_HTTP_MODEL_LOADER); + spyIo.browserHTTPRequest.and.returnValue(DYNAMIC_HTTP_MODEL_LOADER); }); it('should throw error if call predict directly', async () => { @@ -822,11 +829,10 @@ describe('Model', () => { }); describe('dynamic shape model with metadata', () => { beforeEach(() => { - spyOn(tfc.io, 'getLoadHandlers').and.returnValue([ + spyIo.getLoadHandlers.and.returnValue([ DYNAMIC_HTTP_MODEL_NEW_LOADER ]); - spyOn(tfc.io, 'browserHTTPRequest') - .and.returnValue(DYNAMIC_HTTP_MODEL_NEW_LOADER); + spyIo.browserHTTPRequest.and.returnValue(DYNAMIC_HTTP_MODEL_NEW_LOADER); }); it('should be success if call executeAsync with signature key', @@ -848,11 +854,10 @@ describe('Model', () => { describe('Hashtable model', () => { beforeEach(() => { - spyOn(tfc.io, 'getLoadHandlers').and.returnValue([ + spyIo.getLoadHandlers.and.returnValue([ HASHTABLE_HTTP_MODEL_LOADER ]); - spyOn(tfc.io, 'browserHTTPRequest') - .and.returnValue(HASHTABLE_HTTP_MODEL_LOADER); + spyIo.browserHTTPRequest.and.returnValue(HASHTABLE_HTTP_MODEL_LOADER); }); it('should be successful if call executeAsync', async () => { await model.load(); diff --git a/tfjs-converter/src/operations/executors/arithmetic_executor.ts b/tfjs-converter/src/operations/executors/arithmetic_executor.ts index e69efeca892..35e8f590bfd 100644 --- a/tfjs-converter/src/operations/executors/arithmetic_executor.ts +++ b/tfjs-converter/src/operations/executors/arithmetic_executor.ts @@ -27,66 +27,66 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'BiasAdd': case 'AddV2': case 'Add': { - return [tfOps.add( + return [ops.add( (getParamValue('a', node, tensorMap, context) as Tensor), getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'AddN': { - return [tfOps.addN(( + return [ops.addN(( getParamValue('tensors', node, tensorMap, context) as Tensor[]))]; } case 'FloorMod': case 'Mod': - return [tfOps.mod( + return [ops.mod( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; case 'Mul': - return [tfOps.mul( + return [ops.mul( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; case 'RealDiv': case 'Div': { - return [tfOps.div( + return [ops.div( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'DivNoNan': { - return [tfOps.divNoNan( + return [ops.divNoNan( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'FloorDiv': { - return [tfOps.floorDiv( + return [ops.floorDiv( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'Sub': { - return [tfOps.sub( + return [ops.sub( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'Minimum': { - return [tfOps.minimum( + return [ops.minimum( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'Maximum': { - return [tfOps.maximum( + return [ops.maximum( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'Pow': { - return [tfOps.pow( + return [ops.pow( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'SquaredDifference': { - return [tfOps.squaredDifference( + return [ops.squaredDifference( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } diff --git a/tfjs-converter/src/operations/executors/arithmetic_executor_test.ts b/tfjs-converter/src/operations/executors/arithmetic_executor_test.ts index eb23a1d2217..1ae5156a687 100644 --- a/tfjs-converter/src/operations/executors/arithmetic_executor_test.ts +++ b/tfjs-converter/src/operations/executors/arithmetic_executor_test.ts @@ -23,7 +23,8 @@ import {ExecutionContext} from '../../executor/execution_context'; import {Node} from '../types'; import {executeOp} from './arithmetic_executor'; -import {createTensorAttr, createTensorsAttr} from './test_helper'; +import {createTensorAttr, createTensorsAttr, uncapitalize} from './test_helper'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; describe('arithmetic', () => { let node: Node; @@ -46,37 +47,47 @@ describe('arithmetic', () => { }); describe('executeOp', () => { - ['Add', 'Mul', 'Div', 'Sub', 'Maximum', 'Minimum', 'Pow', - 'SquaredDifference', 'Mod', 'FloorDiv', 'DivNoNan'] - .forEach((op => { - it('should call tfOps.' + op, () => { - const spy = - spyOn(tfOps, op.charAt(0).toLowerCase() + op.slice(1) as 'add'); - node.op = op; - executeOp(node, {input1, input2}, context); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; - expect(spy).toHaveBeenCalledWith(input1[0], input2[0]); - }); - })); + beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; + }); + + (['Add', 'Mul', 'Div', 'Sub', 'Maximum', 'Minimum', 'Pow', + 'SquaredDifference', 'Mod', 'FloorDiv', 'DivNoNan'] as const) + .forEach((op => { + it('should call tfOps.' + op, () => { + node.op = op; + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); + + // TODO(mattsoulanille): Remove type assertion after TS4 + expect(spyOps[uncapitalize(op) as keyof typeof spyOps]) + .toHaveBeenCalledWith(input1[0], input2[0]); + }); + })); it('AddV2', async () => { - const spy = spyOn(tfOps, 'add').and.callThrough(); + node.op = 'AddV2'; - const res = executeOp(node, {input1, input2}, context) as Tensor[]; - expect(spy).toHaveBeenCalledWith(input1[0], input2[0]); + const res = executeOp(node, {input1, input2}, context, + spyOpsAsTfOps) as Tensor[]; + expect(spyOps.add).toHaveBeenCalledWith(input1[0], input2[0]); expect(res[0].dtype).toBe('float32'); expect(res[0].shape).toEqual([]); test_util.expectArraysClose(await res[0].data(), 2); }); it('AddN', async () => { - const spy = spyOn(tfOps, 'addN').and.callThrough(); node.op = 'AddN'; node.inputParams = {tensors: createTensorsAttr(0, 0)}; node.inputNames = ['input1', 'input2', 'input3']; const res = - executeOp(node, {input1, input2, input3}, context) as Tensor[]; - expect(spy).toHaveBeenCalledWith([input1[0], input2[0], input3[0]]); + executeOp(node, {input1, input2, input3}, context, + spyOpsAsTfOps) as Tensor[]; + expect(spyOps.addN) + .toHaveBeenCalledWith([input1[0], input2[0], input3[0]]); expect(res[0].dtype).toBe('float32'); expect(res[0].shape).toEqual([]); test_util.expectArraysClose(await res[0].data(), [6]); diff --git a/tfjs-converter/src/operations/executors/basic_math_executor.ts b/tfjs-converter/src/operations/executors/basic_math_executor.ts index 0751d95a2aa..6b0db9aa173 100644 --- a/tfjs-converter/src/operations/executors/basic_math_executor.ts +++ b/tfjs-converter/src/operations/executors/basic_math_executor.ts @@ -27,153 +27,153 @@ import {getParamValue, getTensor} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'Abs': case 'ComplexAbs': - return [tfOps.abs( + return [ops.abs( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Acos': - return [tfOps.acos( + return [ops.acos( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Acosh': - return [tfOps.acosh( + return [ops.acosh( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Asin': - return [tfOps.asin( + return [ops.asin( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Asinh': - return [tfOps.asinh( + return [ops.asinh( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Atan': - return [tfOps.atan( + return [ops.atan( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Atan2': - return [tfOps.atan2( + return [ops.atan2( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('y', node, tensorMap, context) as Tensor)]; case 'Atanh': - return [tfOps.atanh( + return [ops.atanh( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Ceil': - return [tfOps.ceil( + return [ops.ceil( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Complex': - return [tfOps.complex( + return [ops.complex( getParamValue('real', node, tensorMap, context) as Tensor, getParamValue('imag', node, tensorMap, context) as Tensor)]; case 'Cos': - return [tfOps.cos( + return [ops.cos( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Cosh': - return [tfOps.cosh( + return [ops.cosh( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Elu': - return [tfOps.elu( + return [ops.elu( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Erf': - return [tfOps.erf( + return [ops.erf( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Exp': - return [tfOps.exp( + return [ops.exp( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Expm1': { - return [tfOps.expm1( + return [ops.expm1( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'Floor': - return [tfOps.floor( + return [ops.floor( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Log': - return [tfOps.log( + return [ops.log( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Log1p': { - return [tfOps.log1p( + return [ops.log1p( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'Imag': - return [tfOps.imag( + return [ops.imag( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Neg': - return [tfOps.neg( + return [ops.neg( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Reciprocal': { - return [tfOps.reciprocal( + return [ops.reciprocal( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'Real': - return [tfOps.real( + return [ops.real( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Relu': - return [tfOps.relu( + return [ops.relu( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Round': { - return [tfOps.round( + return [ops.round( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'Selu': - return [tfOps.selu( + return [ops.selu( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Sigmoid': - return [tfOps.sigmoid( + return [ops.sigmoid( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Sin': - return [tfOps.sin( + return [ops.sin( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Sign': { - return [tfOps.sign( + return [ops.sign( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'Sinh': { - return [tfOps.sinh( + return [ops.sinh( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'Softplus': { - return [tfOps.softplus( + return [ops.softplus( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'Sqrt': { - return [tfOps.sqrt( + return [ops.sqrt( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'Square': { - return [tfOps.square( + return [ops.square( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'Tanh': { - return [tfOps.tanh( + return [ops.tanh( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'Tan': - return [tfOps.tan( + return [ops.tan( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'ClipByValue': - return [tfOps.clipByValue( + return [ops.clipByValue( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('clipValueMin', node, tensorMap, context) as number, getParamValue('clipValueMax', node, tensorMap, context) as number)]; case 'Relu6': - return [tfOps.relu6( + return [ops.relu6( getParamValue('x', node, tensorMap, context) as Tensor)]; case 'Rsqrt': - return [tfOps.rsqrt( + return [ops.rsqrt( getTensor(node.inputNames[0], tensorMap, context))]; case 'Prod': - return [tfOps.prod( + return [ops.prod( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('axes', node, tensorMap, context) as number[])]; case 'LeakyRelu': - return [tfOps.leakyRelu( + return [ops.leakyRelu( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('alpha', node, tensorMap, context) as number)]; case 'Prelu': - return [tfOps.prelu( + return [ops.prelu( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('alpha', node, tensorMap, context) as Tensor)]; case 'IsNan': - return [tfOps.isNaN( + return [ops.isNaN( getTensor(node.inputNames[0], tensorMap, context))]; default: throw TypeError(`Node type ${node.op} is not implemented`); diff --git a/tfjs-converter/src/operations/executors/basic_math_executor_test.ts b/tfjs-converter/src/operations/executors/basic_math_executor_test.ts index 124b804e8a0..4ce2e960247 100644 --- a/tfjs-converter/src/operations/executors/basic_math_executor_test.ts +++ b/tfjs-converter/src/operations/executors/basic_math_executor_test.ts @@ -23,7 +23,8 @@ import * as basic_math from '../op_list/basic_math'; import {Node} from '../types'; import {executeOp} from './basic_math_executor'; -import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, validateParam} from './test_helper'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; +import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, uncapitalize, validateParam} from './test_helper'; describe('basic math', () => { let node: Node; @@ -44,18 +45,33 @@ describe('basic math', () => { }); describe('executeOp', () => { - ['Abs', 'Acos', 'Asin', 'Atan', 'Ceil', 'Cos', 'Cosh', 'Elu', 'Exp', - 'Floor', 'Log', 'Imag', 'Neg', 'Real', 'Relu', 'Selu', 'Sigmoid', 'Sin', - 'Sinh', 'Sqrt', 'Square', 'Tanh', 'Tan', 'Sign', 'Round', 'Expm1', 'Log1p', - 'Reciprocal', 'Softplus', 'Asinh', 'Acosh', 'Atanh', 'Erf'] + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; + + beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; + }); + + ([ + 'Abs', 'Acos', 'Asin', 'Atan', 'Ceil', 'Cos', 'Cosh', + 'Elu', 'Exp', 'Floor', 'Log', 'Imag', 'Neg', 'Real', + 'Relu', 'Selu', 'Sigmoid', 'Sin', 'Sinh', 'Sqrt', 'Square', + 'Tanh', 'Tan', 'Sign', 'Round', 'Expm1', 'Log1p', 'Reciprocal', + 'Softplus', 'Asinh', 'Acosh', 'Atanh', 'Erf' + ] as const ) .forEach(op => { it('should call tfOps.' + op, () => { - const spy = - spyOn(tfOps, op.charAt(0).toLowerCase() + op.slice(1) as 'abs'); node.op = op; - executeOp(node, {input1}, context); - - expect(spy).toHaveBeenCalledWith(input1[0]); + // TODO(mattsoulanille): Remove type assertion after TS4 + // tslint:disable-next-line no-any + (spyOps[uncapitalize(op) as keyof typeof spyOps] as any) + .and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); + + // TODO(mattsoulanille): Remove type assertion after TS4 + expect(spyOps[uncapitalize(op) as keyof typeof spyOps]) + .toHaveBeenCalledWith(input1[0]); }); it('should match op def', () => { node.op = op; @@ -65,12 +81,11 @@ describe('basic math', () => { }); describe('Relu6', () => { it('should call tfOps.relu6', () => { - spyOn(tfOps, 'relu6'); node.op = 'Relu6'; - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.relu6).toHaveBeenCalledWith(input1[0]); + expect(spyOps.relu6).toHaveBeenCalledWith(input1[0]); }); it('should match op def', () => { node.op = 'Relu6'; @@ -80,16 +95,15 @@ describe('basic math', () => { }); describe('ClipByValue', () => { it('should call tfOps.clipByValue', () => { - spyOn(tfOps, 'clipByValue'); node.op = 'ClipByValue'; node.inputNames = ['input1', 'input2', 'input3']; node.inputParams['clipValueMin'] = createNumberAttrFromIndex(1); node.inputParams['clipValueMax'] = createNumberAttrFromIndex(2); const input2 = [tfOps.scalar(2)]; const input3 = [tfOps.scalar(3)]; - executeOp(node, {input1, input2, input3}, context); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.clipByValue).toHaveBeenCalledWith(input1[0], 2, 3); + expect(spyOps.clipByValue).toHaveBeenCalledWith(input1[0], 2, 3); }); it('should match op def', () => { node.op = 'ClipByValue'; @@ -101,14 +115,14 @@ describe('basic math', () => { }); describe('Prod', () => { it('should call tfOps.prod', () => { - spyOn(tfOps, 'prod'); node.op = 'Prod'; node.inputParams['axes'] = createNumericArrayAttrFromIndex(1); node.inputNames = ['input1', 'input2']; const input2 = [tfOps.tensor1d([2])]; - executeOp(node, {input1, input2}, context); + spyOps.prod.and.returnValue({}); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.prod).toHaveBeenCalledWith(input1[0], [2]); + expect(spyOps.prod).toHaveBeenCalledWith(input1[0], [2]); }); it('should match op def', () => { node.op = 'Prod'; @@ -121,10 +135,9 @@ describe('basic math', () => { it('should call tfOps.rsqrt', () => { const input1 = [tfOps.scalar(1)]; node.op = 'Rsqrt'; - spyOn(tfOps, 'rsqrt').and.returnValue(input1); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.rsqrt).toHaveBeenCalledWith(input1[0]); + expect(spyOps.rsqrt).toHaveBeenCalledWith(input1[0]); }); it('should match op def', () => { node.op = 'Rsqrt'; @@ -134,13 +147,12 @@ describe('basic math', () => { }); describe('LeakyRelu', () => { it('should call tfOps.leakyRelu', () => { - spyOn(tfOps, 'leakyRelu'); node.op = 'LeakyRelu'; node.attrParams['alpha'] = createNumberAttr(1); node.inputNames = ['input1']; - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.leakyRelu).toHaveBeenCalledWith(input1[0], 1); + expect(spyOps.leakyRelu).toHaveBeenCalledWith(input1[0], 1); }); it('should match op def', () => { node.op = 'LeakyRelu'; @@ -150,15 +162,14 @@ describe('basic math', () => { }); describe('Prelu', () => { it('should call tfOps.Prelu', () => { - spyOn(tfOps, 'prelu'); node.op = 'Prelu'; node.inputParams['x'] = createTensorAttr(0); node.inputParams['alpha'] = createTensorAttr(1); node.inputNames = ['input1', 'input2']; const input2 = [tfOps.scalar(1)]; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.prelu).toHaveBeenCalledWith(input1[0], input2[0]); + expect(spyOps.prelu).toHaveBeenCalledWith(input1[0], input2[0]); }); it('should match op def', () => { node.op = 'Prelu'; @@ -169,14 +180,13 @@ describe('basic math', () => { }); describe('Atan2', () => { it('should call tfOps.atan2', () => { - spyOn(tfOps, 'atan2'); node.op = 'Atan2'; node.inputParams['y'] = createTensorAttr(1); node.inputNames = ['input1', 'input2']; const input2 = [tfOps.scalar(2)]; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.atan2).toHaveBeenCalledWith(input1[0], input2[0]); + expect(spyOps.atan2).toHaveBeenCalledWith(input1[0], input2[0]); }); it('should match op def', () => { node.op = 'Atan2'; @@ -187,12 +197,11 @@ describe('basic math', () => { }); describe('ComplexAbs', () => { it('should call tfOps.abs', () => { - spyOn(tfOps, 'abs'); node.op = 'ComplexAbs'; node.inputNames = ['input1']; - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.abs).toHaveBeenCalledWith(input1[0]); + expect(spyOps.abs).toHaveBeenCalledWith(input1[0]); }); it('should match op def', () => { node.op = 'ComplexAbs'; @@ -202,7 +211,6 @@ describe('basic math', () => { }); describe('Complex', () => { it('should call tfOps.complex', () => { - spyOn(tfOps, 'complex'); node.op = 'Complex'; node.inputParams = { real: createTensorAttr(0), @@ -210,9 +218,9 @@ describe('basic math', () => { }; const input2 = [tfOps.scalar(2)]; node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.complex).toHaveBeenCalledWith(input1[0], input2[0]); + expect(spyOps.complex).toHaveBeenCalledWith(input1[0], input2[0]); }); it('should match op def', () => { node.op = 'Complex'; @@ -226,12 +234,11 @@ describe('basic math', () => { }); describe('IsNan', () => { it('should call tfOps.isNaN', () => { - spyOn(tfOps, 'isNaN'); node.op = 'IsNan'; - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.isNaN).toHaveBeenCalledWith(input1[0]); + expect(spyOps.isNaN).toHaveBeenCalledWith(input1[0]); }); it('should match op def', () => { node.op = 'IsNan'; diff --git a/tfjs-converter/src/operations/executors/convolution_executor.ts b/tfjs-converter/src/operations/executors/convolution_executor.ts index 2d33ff14036..18658dc0636 100644 --- a/tfjs-converter/src/operations/executors/convolution_executor.ts +++ b/tfjs-converter/src/operations/executors/convolution_executor.ts @@ -83,7 +83,7 @@ function fusedConvAndDepthWiseParams( export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'Conv1D': { const stride = @@ -94,7 +94,7 @@ export const executeOp: InternalOpExecutor = .toUpperCase(); const dilation = getParamValue('dilation', node, tensorMap, context) as number; - return [tfOps.conv1d( + return [ops.conv1d( getParamValue('x', node, tensorMap, context) as Tensor3D, getParamValue('filter', node, tensorMap, context) as Tensor3D, stride, pad as 'valid' | 'same', dataFormat as 'NWC' | 'NCW', @@ -109,7 +109,7 @@ export const executeOp: InternalOpExecutor = .toUpperCase(); const dilations = getParamValue('dilations', node, tensorMap, context) as number[]; - return [tfOps.conv2d( + return [ops.conv2d( getParamValue('x', node, tensorMap, context) as Tensor3D | Tensor4D, getParamValue('filter', node, tensorMap, context) as Tensor4D, @@ -128,7 +128,7 @@ export const executeOp: InternalOpExecutor = leakyreluAlpha } = fusedConvAndDepthWiseParams(node, tensorMap, context); - return [tfOps.fused.conv2d({ + return [ops.fused.conv2d({ x: getParamValue('x', node, tensorMap, context) as Tensor3D | Tensor4D, filter: getParamValue('filter', node, tensorMap, context) as @@ -156,7 +156,7 @@ export const executeOp: InternalOpExecutor = leakyreluAlpha, } = fusedConvAndDepthWiseParams(node, tensorMap, context); - return [tfOps.fused.depthwiseConv2d({ + return [ops.fused.depthwiseConv2d({ x: getParamValue('x', node, tensorMap, context) as Tensor3D | Tensor4D, filter: getParamValue('filter', node, tensorMap, context) as @@ -180,7 +180,7 @@ export const executeOp: InternalOpExecutor = const stride = getParamValue('strides', node, tensorMap, context) as number[]; const pad = getPadding(node, tensorMap, context); - return [tfOps.conv2dTranspose( + return [ops.conv2dTranspose( getParamValue('x', node, tensorMap, context) as Tensor3D | Tensor4D, getParamValue('filter', node, tensorMap, context) as Tensor4D, @@ -197,7 +197,7 @@ export const executeOp: InternalOpExecutor = (getParamValue('dataFormat', node, tensorMap, context) as string) .toUpperCase(); - return [tfOps.depthwiseConv2d( + return [ops.depthwiseConv2d( getParamValue('input', node, tensorMap, context) as Tensor3D | Tensor4D, getParamValue('filter', node, tensorMap, context) as Tensor4D, @@ -213,7 +213,7 @@ export const executeOp: InternalOpExecutor = .toUpperCase(); const dilations = getParamValue('dilations', node, tensorMap, context) as number[]; - return [tfOps.conv3d( + return [ops.conv3d( getParamValue('x', node, tensorMap, context) as Tensor4D | Tensor, getParamValue('filter', node, tensorMap, context) as @@ -229,7 +229,7 @@ export const executeOp: InternalOpExecutor = const kernelSize = getParamValue('kernelSize', node, tensorMap, context) as number[]; - return [tfOps.avgPool( + return [ops.avgPool( getParamValue('x', node, tensorMap, context) as Tensor3D | Tensor4D, [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], @@ -242,7 +242,7 @@ export const executeOp: InternalOpExecutor = const kernelSize = getParamValue('kernelSize', node, tensorMap, context) as number[]; - return [tfOps.maxPool( + return [ops.maxPool( getParamValue('x', node, tensorMap, context) as Tensor3D | Tensor4D, [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], @@ -257,7 +257,7 @@ export const executeOp: InternalOpExecutor = const includeBatchInIndex = getParamValue('includeBatchInIndex', node, tensorMap, context) as boolean; - const {result, indexes} = tfOps.maxPoolWithArgmax( + const {result, indexes} = ops.maxPoolWithArgmax( getParamValue('x', node, tensorMap, context) as Tensor4D, [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad as 'valid' | 'same', includeBatchInIndex); @@ -270,7 +270,7 @@ export const executeOp: InternalOpExecutor = const kernelSize = getParamValue('kernelSize', node, tensorMap, context) as number[]; - return [tfOps.avgPool3d( + return [ops.avgPool3d( getParamValue('x', node, tensorMap, context) as Tensor5D, [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad as 'valid' | 'same')]; @@ -283,7 +283,7 @@ export const executeOp: InternalOpExecutor = const kernelSize = getParamValue('kernelSize', node, tensorMap, context) as number[]; - return [tfOps.maxPool3d( + return [ops.maxPool3d( getParamValue('x', node, tensorMap, context) as Tensor5D, [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad as 'valid' | 'same')]; @@ -304,7 +304,7 @@ export const executeOp: InternalOpExecutor = const dilationHeight = dilations[1]; const dilationWidth = dilations[2]; - return [tfOps.dilation2d( + return [ops.dilation2d( getParamValue('x', node, tensorMap, context) as Tensor3D | Tensor4D, getParamValue('filter', node, tensorMap, context) as Tensor3D, diff --git a/tfjs-converter/src/operations/executors/convolution_executor_test.ts b/tfjs-converter/src/operations/executors/convolution_executor_test.ts index 65553cd4542..ef3c3c8e3e6 100644 --- a/tfjs-converter/src/operations/executors/convolution_executor_test.ts +++ b/tfjs-converter/src/operations/executors/convolution_executor_test.ts @@ -21,6 +21,7 @@ import {ExecutionContext} from '../../executor/execution_context'; import {Node} from '../types'; import {executeOp} from './convolution_executor'; +import {RecursiveSpy} from './spy_ops'; import {createNumberAttr, createNumericArrayAttr, createStrArrayAttr, createStrAttr, createTensorAttr, createTensorsAttr} from './test_helper'; import {createBoolAttr} from './test_helper'; @@ -29,6 +30,9 @@ describe('convolution', () => { const input = [tfOps.scalar(1)]; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; + beforeEach(() => { node = { name: 'test', @@ -40,41 +44,53 @@ describe('convolution', () => { attrParams: {}, children: [] }; + spyOps = + Object.fromEntries(Object.keys(tfOps).map((op: keyof typeof tfOps) => { + if (op === 'fused') { + return [ + op, { + conv2d: jasmine.createSpy(op), + depthwiseConv2d: jasmine.createSpy(op), + matMul: jasmine.createSpy(op), + } + ]; + } + const spy = jasmine.createSpy(op); + return [op, spy] as const ; + })) as unknown as typeof spyOps; + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; }); describe('executeOp', () => { describe('AvgPool', () => { it('should call tfOps.avgPool', () => { - spyOn(tfOps, 'avgPool'); node.op = 'AvgPool'; node.attrParams['strides'] = createNumericArrayAttr([1, 2, 2, 1]); node.attrParams['pad'] = createStrAttr('same'); node.attrParams['kernelSize'] = createNumericArrayAttr([1, 2, 2, 1]); - executeOp(node, {input}, context); + executeOp(node, {input}, context, spyOpsAsTfOps); - expect(tfOps.avgPool) + expect(spyOps.avgPool) .toHaveBeenCalledWith(input[0], [2, 2], [2, 2], 'same'); }); }); describe('maxPool', () => { it('should call tfOps.maxPool', () => { - spyOn(tfOps, 'maxPool'); node.op = 'MaxPool'; node.attrParams['strides'] = createNumericArrayAttr([1, 2, 2, 1]); node.attrParams['pad'] = createStrAttr('same'); node.attrParams['kernelSize'] = createNumericArrayAttr([1, 2, 2, 1]); - executeOp(node, {input}, context); + executeOp(node, {input}, context, spyOpsAsTfOps); - expect(tfOps.maxPool) + expect(spyOps.maxPool) .toHaveBeenCalledWith(input[0], [2, 2], [2, 2], 'same'); }); }); describe('Conv2d', () => { it('should call tfOps.conv2d', () => { - spyOn(tfOps, 'conv2d'); node.op = 'Conv2D'; node.inputParams['filter'] = createTensorAttr(1); node.attrParams['strides'] = createNumericArrayAttr([1, 2, 2, 1]); @@ -86,14 +102,13 @@ describe('convolution', () => { const input2 = [tfOps.scalar(1.0)]; node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.conv2d) + expect(spyOps.conv2d) .toHaveBeenCalledWith( input1[0], input2[0], [2, 2], 'same', 'NHWC', [2, 2]); }); it('should support explicit padding', () => { - spyOn(tfOps, 'conv2d'); node.op = 'Conv2D'; node.inputParams['filter'] = createTensorAttr(1); node.attrParams['strides'] = createNumericArrayAttr([1, 2, 2, 1]); @@ -107,9 +122,9 @@ describe('convolution', () => { const input2 = [tfOps.scalar(1.0)]; node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.conv2d) + expect(spyOps.conv2d) .toHaveBeenCalledWith( input1[0], input2[0], [2, 2], [[0, 0], [1, 1], [2, 2], [0, 0]], 'NHWC', [2, 2]); @@ -117,7 +132,6 @@ describe('convolution', () => { }); describe('Conv2DBackpropInput', () => { it('should call tfOps.conv2dTranspose', () => { - spyOn(tfOps, 'conv2dTranspose'); node.op = 'Conv2DBackpropInput'; node.attrParams['outputShape'] = createNumericArrayAttr([1, 2, 2, 2]); node.inputParams['filter'] = createTensorAttr(1); @@ -128,14 +142,13 @@ describe('convolution', () => { const input2 = [tfOps.scalar(1.0)]; node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.conv2dTranspose) + expect(spyOps.conv2dTranspose) .toHaveBeenCalledWith( input1[0], input2[0], [1, 2, 2, 2], [2, 2], 'same'); }); it('should support explicit padding', () => { - spyOn(tfOps, 'conv2dTranspose'); node.op = 'Conv2DBackpropInput'; node.attrParams['outputShape'] = createNumericArrayAttr([1, 2, 2, 2]); node.inputParams['filter'] = createTensorAttr(1); @@ -148,9 +161,9 @@ describe('convolution', () => { const input2 = [tfOps.scalar(1.0)]; node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.conv2dTranspose) + expect(spyOps.conv2dTranspose) .toHaveBeenCalledWith( input1[0], input2[0], @@ -162,7 +175,6 @@ describe('convolution', () => { }); describe('Conv1D', () => { it('should call tfOps.conv1d', () => { - spyOn(tfOps, 'conv1d'); node.op = 'Conv1D'; node.category = 'convolution'; node.inputParams['filter'] = createTensorAttr(1); @@ -175,16 +187,15 @@ describe('convolution', () => { const input2 = [tfOps.scalar(1.0)]; node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.conv1d) + expect(spyOps.conv1d) .toHaveBeenCalledWith(input1[0], input2[0], 1, 'same', 'NWC', 1); }); }); describe('DepthwiseConv2d', () => { it('should call tfOps.depthwiseConv2d', () => { - spyOn(tfOps, 'depthwiseConv2d'); node.op = 'DepthwiseConv2d'; node.category = 'convolution'; node.inputParams['input'] = createTensorAttr(0); @@ -197,14 +208,13 @@ describe('convolution', () => { const input2 = [tfOps.scalar(1.0)]; node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.depthwiseConv2d) + expect(spyOps.depthwiseConv2d) .toHaveBeenCalledWith( input1[0], input2[0], [2, 2], 'same', 'NHWC', [2, 2]); }); it('support explicit padding', () => { - spyOn(tfOps, 'depthwiseConv2d'); node.op = 'DepthwiseConv2d'; node.category = 'convolution'; node.inputParams['input'] = createTensorAttr(0); @@ -219,9 +229,9 @@ describe('convolution', () => { const input2 = [tfOps.scalar(1.0)]; node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.depthwiseConv2d) + expect(spyOps.depthwiseConv2d) .toHaveBeenCalledWith( input1[0], input2[0], [2, 2], [[0, 0], [1, 1], [2, 2], [0, 0]], 'NHWC', [2, 2]); @@ -230,7 +240,6 @@ describe('convolution', () => { describe('Conv3d', () => { it('should call tfOps.conv3d', () => { - spyOn(tfOps, 'conv3d'); node.op = 'Conv3D'; node.category = 'convolution'; node.inputParams['filter'] = createTensorAttr(1); @@ -243,9 +252,9 @@ describe('convolution', () => { const input2 = [tfOps.scalar(1.0)]; node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.conv3d) + expect(spyOps.conv3d) .toHaveBeenCalledWith( input1[0], input2[0], [2, 2, 2], 'same', 'NHWC', [2, 2, 2]); }); @@ -253,53 +262,52 @@ describe('convolution', () => { describe('AvgPool3D', () => { it('should call tfOps.avgPool3d', () => { - spyOn(tfOps, 'avgPool3d'); node.op = 'AvgPool3D'; node.attrParams['strides'] = createNumericArrayAttr([1, 2, 2, 2, 1]); node.attrParams['pad'] = createStrAttr('same'); node.attrParams['kernelSize'] = createNumericArrayAttr([1, 2, 2, 2, 1]); - executeOp(node, {input}, context); + executeOp(node, {input}, context, spyOpsAsTfOps); - expect(tfOps.avgPool3d) + expect(spyOps.avgPool3d) .toHaveBeenCalledWith(input[0], [2, 2, 2], [2, 2, 2], 'same'); }); }); describe('MaxPool3D', () => { it('should call tfOps.maxPool3d', () => { - spyOn(tfOps, 'maxPool3d'); node.op = 'MaxPool3D'; node.attrParams['strides'] = createNumericArrayAttr([1, 2, 2, 2, 1]); node.attrParams['pad'] = createStrAttr('same'); node.attrParams['kernelSize'] = createNumericArrayAttr([1, 2, 2, 2, 1]); - executeOp(node, {input}, context); + executeOp(node, {input}, context, spyOpsAsTfOps); - expect(tfOps.maxPool3d) + expect(spyOps.maxPool3d) .toHaveBeenCalledWith(input[0], [2, 2, 2], [2, 2, 2], 'same'); }); }); describe('MaxPoolWithArgmax', () => { it('should call tfOps.maxPoolWithArgmax', () => { - spyOn(tfOps, 'maxPoolWithArgmax').and.returnValue({}); node.op = 'MaxPoolWithArgmax'; node.attrParams['strides'] = createNumericArrayAttr([1, 2, 2, 1]); node.attrParams['pad'] = createStrAttr('same'); node.attrParams['kernelSize'] = createNumericArrayAttr([1, 2, 2, 1]); node.attrParams['dataFormat'] = createStrAttr('NDHWC'); node.attrParams['includeBatchInIndex'] = createBoolAttr(true); - executeOp(node, {input}, context); + spyOps.maxPoolWithArgmax.and.returnValue( + {result: 'fake', indexes: 'fake'}); + + executeOp(node, {input}, context, spyOpsAsTfOps); - expect(tfOps.maxPoolWithArgmax) + expect(spyOps.maxPoolWithArgmax) .toHaveBeenCalledWith(input[0], [2, 2], [2, 2], 'same', true); }); }); describe('_FusedConv2d', () => { it('with bias and activation func', () => { - spyOn(tfOps.fused, 'conv2d'); node.op = '_FusedConv2D'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -314,9 +322,9 @@ describe('convolution', () => { const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.conv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.conv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -330,7 +338,6 @@ describe('convolution', () => { }); }); it('should support explicit padding', () => { - spyOn(tfOps.fused, 'conv2d'); node.op = '_FusedConv2D'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -347,9 +354,9 @@ describe('convolution', () => { const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.conv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.conv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -363,7 +370,6 @@ describe('convolution', () => { }); }); it('with bias and prelu activation func', () => { - spyOn(tfOps.fused, 'conv2d'); node.op = '_FusedConv2D'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -378,9 +384,10 @@ describe('convolution', () => { const input3 = [tfOps.scalar(3.0)]; const input4 = [tfOps.scalar(4.0)]; node.inputNames = ['input1', 'input2', 'input3', 'input4']; - executeOp(node, {input1, input2, input3, input4}, context); + executeOp( + node, {input1, input2, input3, input4}, context, spyOpsAsTfOps); - expect(tfOps.fused.conv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.conv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -394,7 +401,6 @@ describe('convolution', () => { }); }); it('with bias and leakyrelu activation func', () => { - spyOn(tfOps.fused, 'conv2d'); node.op = '_FusedConv2D'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -410,9 +416,9 @@ describe('convolution', () => { const input2 = [tfOps.scalar(2.0)]; const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.conv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.conv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -427,7 +433,6 @@ describe('convolution', () => { }); it('bias add', () => { - spyOn(tfOps.fused, 'conv2d'); node.op = '_FusedConv2D'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -442,9 +447,9 @@ describe('convolution', () => { const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.conv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.conv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -458,7 +463,6 @@ describe('convolution', () => { }); }); it('fail with batchnorm', () => { - spyOn(tfOps.fused, 'conv2d'); node.op = '_FusedConv2D'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -473,14 +477,15 @@ describe('convolution', () => { const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - expect(() => executeOp(node, {input1, input2, input3}, context)) + expect( + () => executeOp( + node, {input1, input2, input3}, context, spyOpsAsTfOps)) .toThrow(); }); }); }); describe('FusedDepthwiseConv2d', () => { it('support explicit padding', () => { - spyOn(tfOps.fused, 'depthwiseConv2d'); node.op = 'FusedDepthwiseConv2dNative'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -497,9 +502,9 @@ describe('convolution', () => { const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -513,7 +518,6 @@ describe('convolution', () => { }); }); it('with only activation func', () => { - spyOn(tfOps.fused, 'depthwiseConv2d'); node.op = 'FusedDepthwiseConv2dNative'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -527,9 +531,9 @@ describe('convolution', () => { const input2 = [tfOps.scalar(2.0)]; const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -543,7 +547,6 @@ describe('convolution', () => { }); }); it('with bias and activation func', () => { - spyOn(tfOps.fused, 'depthwiseConv2d'); node.op = 'FusedDepthwiseConv2dNative'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -558,9 +561,9 @@ describe('convolution', () => { const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -574,7 +577,6 @@ describe('convolution', () => { }); }); it('with bias and prelu activation func', () => { - spyOn(tfOps.fused, 'depthwiseConv2d'); node.op = 'FusedDepthwiseConv2dNative'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -589,9 +591,9 @@ describe('convolution', () => { const input3 = [tfOps.scalar(3.0)]; const input4 = [tfOps.scalar(4.0)]; node.inputNames = ['input1', 'input2', 'input3', 'input4']; - executeOp(node, {input1, input2, input3, input4}, context); + executeOp(node, {input1, input2, input3, input4}, context, spyOpsAsTfOps); - expect(tfOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -605,7 +607,6 @@ describe('convolution', () => { }); }); it('with bias and leakyrelu activation func', () => { - spyOn(tfOps.fused, 'depthwiseConv2d'); node.op = 'FusedDepthwiseConv2dNative'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -621,9 +622,9 @@ describe('convolution', () => { const input2 = [tfOps.scalar(2.0)]; const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -638,7 +639,6 @@ describe('convolution', () => { }); it('bias add', () => { - spyOn(tfOps.fused, 'depthwiseConv2d'); node.op = 'FusedDepthwiseConv2dNative'; node.inputParams['filter'] = createTensorAttr(1); node.inputParams['args'] = createTensorsAttr(2, 0); @@ -653,9 +653,9 @@ describe('convolution', () => { const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ + expect(spyOps.fused.depthwiseConv2d).toHaveBeenCalledWith({ x: input1[0], filter: input2[0], strides: [2, 2], @@ -672,7 +672,6 @@ describe('convolution', () => { describe('dilation2d', () => { it('should call tfOps.dilation2d', () => { - spyOn(tfOps, 'dilation2d'); node.op = 'Dilation2D'; node.inputParams['filter'] = createTensorAttr(1); node.attrParams['strides'] = createNumericArrayAttr([1, 1, 1, 1]); @@ -683,9 +682,9 @@ describe('convolution', () => { const input2 = [tfOps.scalar(1.0)]; node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.dilation2d) + expect(spyOps.dilation2d) .toHaveBeenCalledWith( input1[0], input2[0], [1, 1], 'same', [2, 2], 'NHWC'); }); diff --git a/tfjs-converter/src/operations/executors/creation_executor.ts b/tfjs-converter/src/operations/executors/creation_executor.ts index 596384b6f42..f4e83345578 100644 --- a/tfjs-converter/src/operations/executors/creation_executor.ts +++ b/tfjs-converter/src/operations/executors/creation_executor.ts @@ -27,7 +27,7 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'Fill': { const shape = @@ -36,7 +36,7 @@ export const executeOp: InternalOpExecutor = getParamValue('dtype', node, tensorMap, context) as DataType; const value = getParamValue('value', node, tensorMap, context) as number; - return [tfOps.fill(shape, value, dtype)]; + return [ops.fill(shape, value, dtype)]; } case 'LinSpace': { const start = @@ -44,7 +44,7 @@ export const executeOp: InternalOpExecutor = const stop = getParamValue('stop', node, tensorMap, context) as number; const num = getParamValue('num', node, tensorMap, context) as number; - return [tfOps.linspace(start, stop, num)]; + return [ops.linspace(start, stop, num)]; } case 'Multinomial': { const logits = @@ -53,7 +53,7 @@ export const executeOp: InternalOpExecutor = getParamValue('numSamples', node, tensorMap, context) as number; const seed = getParamValue('seed', node, tensorMap, context) as number; - return [tfOps.multinomial(logits, numSamples, seed)]; + return [ops.multinomial(logits, numSamples, seed)]; } case 'OneHot': { const indices = @@ -64,26 +64,26 @@ export const executeOp: InternalOpExecutor = getParamValue('onValue', node, tensorMap, context) as number; const offValue = getParamValue('offValue', node, tensorMap, context) as number; - return [tfOps.oneHot(indices, depth, onValue, offValue)]; + return [ops.oneHot(indices, depth, onValue, offValue)]; } case 'Ones': { - return [tfOps.ones( + return [ops.ones( getParamValue('shape', node, tensorMap, context) as number[], getParamValue('dtype', node, tensorMap, context) as DataType)]; } case 'OnesLike': { - return [tfOps.onesLike( + return [ops.onesLike( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'RandomStandardNormal': { - return [tfOps.randomStandardNormal( + return [ops.randomStandardNormal( getParamValue('shape', node, tensorMap, context) as number[], getParamValue('dtype', node, tensorMap, context) as 'float32' | 'int32', getParamValue('seed', node, tensorMap, context) as number)]; } case 'RandomUniform': { - return [tfOps.randomUniform( + return [ops.randomUniform( // tslint:disable-next-line:no-any getParamValue('shape', node, tensorMap, context) as any, getParamValue('minval', node, tensorMap, context) as number, @@ -97,7 +97,7 @@ export const executeOp: InternalOpExecutor = getParamValue('stop', node, tensorMap, context) as number; const step = getParamValue('step', node, tensorMap, context) as number; - return [tfOps.range( + return [ops.range( start, stop, step, getParamValue('dtype', node, tensorMap, context) as 'float32' | 'int32')]; @@ -111,19 +111,19 @@ export const executeOp: InternalOpExecutor = getParamValue('stdDev', node, tensorMap, context) as number; const seed = getParamValue('seed', node, tensorMap, context) as number; - return [tfOps.truncatedNormal( + return [ops.truncatedNormal( shape, mean, stdDev, getParamValue('dtype', node, tensorMap, context) as 'float32' | 'int32', seed)]; } case 'Zeros': { - return [tfOps.zeros( + return [ops.zeros( getParamValue('shape', node, tensorMap, context) as number[], getParamValue('dtype', node, tensorMap, context) as DataType)]; } case 'ZerosLike': { - return [tfOps.zerosLike( + return [ops.zerosLike( getParamValue('x', node, tensorMap, context) as Tensor)]; } default: diff --git a/tfjs-converter/src/operations/executors/creation_executor_test.ts b/tfjs-converter/src/operations/executors/creation_executor_test.ts index 7e6b3d84667..b9ceefd4708 100644 --- a/tfjs-converter/src/operations/executors/creation_executor_test.ts +++ b/tfjs-converter/src/operations/executors/creation_executor_test.ts @@ -23,14 +23,20 @@ import {Node} from '../types'; import {executeOp} from './creation_executor'; import {createDtypeAttr, createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, validateParam} from './test_helper'; +import {spyOnAllFunctions, RecursiveSpy} from './spy_ops'; describe('creation', () => { let node: Node; const input1 = [tfOps.tensor1d([1, 2, 3])]; const input2 = [tfOps.scalar(1)]; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; + node = { name: 'test', op: '', @@ -46,15 +52,14 @@ describe('creation', () => { describe('executeOp', () => { describe('Fill', () => { it('should call tfOps.fill', () => { - spyOn(tfOps, 'fill'); node.op = 'Fill'; node.inputParams['shape'] = createNumericArrayAttrFromIndex(0); node.inputParams['value'] = createNumberAttrFromIndex(1); node.attrParams['dtype'] = createDtypeAttr('int32'); - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.fill).toHaveBeenCalledWith([1, 2, 3], 1, 'int32'); + expect(spyOps.fill).toHaveBeenCalledWith([1, 2, 3], 1, 'int32'); }); it('should match json def', () => { node.op = 'Fill'; @@ -67,7 +72,6 @@ describe('creation', () => { }); describe('LinSpace', () => { it('should call tfOps.linspace', () => { - spyOn(tfOps, 'linspace'); node.op = 'LinSpace'; node.inputParams['start'] = createNumberAttrFromIndex(0); node.inputParams['stop'] = createNumberAttrFromIndex(1); @@ -75,9 +79,9 @@ describe('creation', () => { node.inputNames = ['input', 'input2', 'input3']; const input = [tfOps.scalar(0)]; const input3 = [tfOps.scalar(2)]; - executeOp(node, {input, input2, input3}, context); + executeOp(node, {input, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.linspace).toHaveBeenCalledWith(0, 1, 2); + expect(spyOps.linspace).toHaveBeenCalledWith(0, 1, 2); }); it('should match json def', () => { node.op = 'LinSpace'; @@ -90,7 +94,6 @@ describe('creation', () => { }); describe('OneHot', () => { it('should call tfOps.oneHot', () => { - spyOn(tfOps, 'oneHot'); node.op = 'OneHot'; node.inputParams['indices'] = createTensorAttr(0); node.inputParams['depth'] = createNumberAttrFromIndex(1); @@ -100,9 +103,11 @@ describe('creation', () => { const input = [tfOps.tensor1d([0])]; const input3 = [tfOps.scalar(2)]; const input4 = [tfOps.scalar(3)]; - executeOp(node, {input, input2, input3, input4}, context); + spyOps.oneHot.and.returnValue({}); + executeOp(node, {input, input2, input3, input4}, context, + spyOpsAsTfOps); - expect(tfOps.oneHot).toHaveBeenCalledWith(input[0], 1, 2, 3); + expect(spyOps.oneHot).toHaveBeenCalledWith(input[0], 1, 2, 3); }); it('should match json def', () => { node.op = 'OneHot'; @@ -116,13 +121,12 @@ describe('creation', () => { }); describe('Ones', () => { it('should call tfOps.ones', () => { - spyOn(tfOps, 'ones'); node.op = 'Ones'; node.inputParams['shape'] = createNumericArrayAttrFromIndex(0); node.attrParams['dtype'] = createDtypeAttr('float32'); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.ones).toHaveBeenCalledWith([1, 2, 3], 'float32'); + expect(spyOps.ones).toHaveBeenCalledWith([1, 2, 3], 'float32'); }); it('should match json def', () => { node.op = 'Ones'; @@ -134,12 +138,11 @@ describe('creation', () => { }); describe('OnesLike', () => { it('should call tfOps.onesLike', () => { - spyOn(tfOps, 'onesLike'); node.op = 'OnesLike'; node.inputParams['x'] = createTensorAttr(0); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.onesLike).toHaveBeenCalledWith(input1[0]); + expect(spyOps.onesLike).toHaveBeenCalledWith(input1[0]); }); it('should match json def', () => { node.op = 'OnesLike'; @@ -150,7 +153,6 @@ describe('creation', () => { }); describe('Range', () => { it('should call tfOps.range', () => { - spyOn(tfOps, 'range'); node.op = 'Range'; node.inputParams['start'] = createNumberAttrFromIndex(0); node.inputParams['stop'] = createNumberAttrFromIndex(1); @@ -159,9 +161,9 @@ describe('creation', () => { node.inputNames = ['input', 'input2', 'input3']; const input = [tfOps.scalar(0)]; const input3 = [tfOps.scalar(2)]; - executeOp(node, {input, input2, input3}, context); + executeOp(node, {input, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.range).toHaveBeenCalledWith(0, 1, 2, 'float32'); + expect(spyOps.range).toHaveBeenCalledWith(0, 1, 2, 'float32'); }); it('should match json def', () => { node.op = 'Range'; @@ -175,16 +177,15 @@ describe('creation', () => { }); describe('RandomStandardNormal', () => { it('should call tfOps.randomStandardNormal', () => { - spyOn(tfOps, 'randomStandardNormal'); node.op = 'RandomStandardNormal'; node.inputParams['shape'] = createNumericArrayAttrFromIndex(0); node.inputNames = ['input1']; node.attrParams['dtype'] = createDtypeAttr('float32'); node.attrParams['seed'] = createNumberAttr(0); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.randomStandardNormal) + expect(spyOps.randomStandardNormal) .toHaveBeenCalledWith([1, 2, 3], 'float32', 0); }); it('should match json def', () => { @@ -199,7 +200,6 @@ describe('creation', () => { }); describe('RandomUniform', () => { it('should call tfOps.randomUniform', () => { - spyOn(tfOps, 'randomUniform'); node.op = 'RandomUniform'; node.inputParams['shape'] = createNumericArrayAttrFromIndex(0); node.inputNames = ['input1']; @@ -208,9 +208,9 @@ describe('creation', () => { node.attrParams['dtype'] = createDtypeAttr('float32'); node.attrParams['seed'] = createNumberAttr(0); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.randomUniform) + expect(spyOps.randomUniform) .toHaveBeenCalledWith([1, 2, 3], 0, 1, 'float32'); }); it('should match json def', () => { @@ -227,7 +227,6 @@ describe('creation', () => { }); describe('TruncatedNormal', () => { it('should call tfOps.truncatedNormal', () => { - spyOn(tfOps, 'truncatedNormal'); node.op = 'TruncatedNormal'; node.inputParams['shape'] = createNumericArrayAttrFromIndex(0); node.inputNames = ['input1']; @@ -236,9 +235,9 @@ describe('creation', () => { node.attrParams['dtype'] = createDtypeAttr('float32'); node.attrParams['seed'] = createNumberAttr(0); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.truncatedNormal) + expect(spyOps.truncatedNormal) .toHaveBeenCalledWith([1, 2, 3], 0, 1, 'float32', 0); }); it('should match json def', () => { @@ -255,13 +254,12 @@ describe('creation', () => { }); describe('Zeros', () => { it('should call tfOps.zeros', () => { - spyOn(tfOps, 'zeros'); node.op = 'Zeros'; node.inputParams['shape'] = createNumericArrayAttrFromIndex(0); node.attrParams['dtype'] = createDtypeAttr('float32'); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.zeros).toHaveBeenCalledWith([1, 2, 3], 'float32'); + expect(spyOps.zeros).toHaveBeenCalledWith([1, 2, 3], 'float32'); }); it('should match json def', () => { node.op = 'Zeros'; @@ -272,12 +270,11 @@ describe('creation', () => { }); describe('ZerosLike', () => { it('should call tfOps.zerosLike', () => { - spyOn(tfOps, 'zerosLike'); node.op = 'ZerosLike'; node.inputParams['x'] = createTensorAttr(0); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.zerosLike).toHaveBeenCalledWith(input1[0]); + expect(spyOps.zerosLike).toHaveBeenCalledWith(input1[0]); }); it('should match json def', () => { node.op = 'ZerosLike'; @@ -287,14 +284,13 @@ describe('creation', () => { }); describe('Multinomial', () => { it('should call tfOps.multinomial', () => { - spyOn(tfOps, 'multinomial'); node.op = 'Multinomial'; node.inputParams['logits'] = createTensorAttr(0); node.inputParams['numSamples'] = createNumberAttrFromIndex(1); node.attrParams['seed'] = createNumberAttr(2); - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.multinomial).toHaveBeenCalledWith(input1[0], 1, 2); + expect(spyOps.multinomial).toHaveBeenCalledWith(input1[0], 1, 2); }); it('should match json def', () => { node.op = 'Multinomial'; diff --git a/tfjs-converter/src/operations/executors/dynamic_executor.ts b/tfjs-converter/src/operations/executors/dynamic_executor.ts index 083071dc47e..801ec2ebf8b 100644 --- a/tfjs-converter/src/operations/executors/dynamic_executor.ts +++ b/tfjs-converter/src/operations/executors/dynamic_executor.ts @@ -21,6 +21,7 @@ import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter'; import {NamedTensorsMap} from '../../data/types'; import {ExecutionContext} from '../../executor/execution_context'; +import { ResourceManager } from '../../executor/resource_manager'; import {InternalOpAsyncExecutor, Node} from '../types'; import {getParamValue} from './utils'; @@ -50,7 +51,8 @@ function nmsParams( export const executeOp: InternalOpAsyncExecutor = async( node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Promise => { + context: ExecutionContext, resourceManager: ResourceManager, + ops = tfOps): Promise => { switch (node.op) { case 'NonMaxSuppressionV5': { const { @@ -62,7 +64,7 @@ export const executeOp: InternalOpAsyncExecutor = async( softNmsSigma } = nmsParams(node, tensorMap, context); - const result = await tfOps.image.nonMaxSuppressionWithScoreAsync( + const result = await ops.image.nonMaxSuppressionWithScoreAsync( boxes as Tensor2D, scores as Tensor1D, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma); @@ -76,7 +78,7 @@ export const executeOp: InternalOpAsyncExecutor = async( getParamValue('padToMaxOutputSize', node, tensorMap, context) as boolean; - const result = await tfOps.image.nonMaxSuppressionPaddedAsync( + const result = await ops.image.nonMaxSuppressionPaddedAsync( boxes as Tensor2D, scores as Tensor1D, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize); @@ -87,20 +89,20 @@ export const executeOp: InternalOpAsyncExecutor = async( const {boxes, scores, maxOutputSize, iouThreshold, scoreThreshold} = nmsParams(node, tensorMap, context); - return [await tfOps.image.nonMaxSuppressionAsync( + return [await ops.image.nonMaxSuppressionAsync( boxes as Tensor2D, scores as Tensor1D, maxOutputSize, iouThreshold, scoreThreshold)]; } case 'Where': { - const condition = tfOps.cast( + const condition = ops.cast( (getParamValue('condition', node, tensorMap, context) as Tensor), 'bool'); - const result = [await tfOps.whereAsync(condition)]; + const result = [await ops.whereAsync(condition)]; condition.dispose(); return result; } case 'ListDiff': { - return tfOps.setdiff1dAsync( + return ops.setdiff1dAsync( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('y', node, tensorMap, context) as Tensor); } diff --git a/tfjs-converter/src/operations/executors/dynamic_executor_test.ts b/tfjs-converter/src/operations/executors/dynamic_executor_test.ts index b0bd30d6e7d..5e052d62943 100644 --- a/tfjs-converter/src/operations/executors/dynamic_executor_test.ts +++ b/tfjs-converter/src/operations/executors/dynamic_executor_test.ts @@ -23,14 +23,19 @@ import * as dynamic from '../op_list/dynamic'; import {Node} from '../types'; import {executeOp} from './dynamic_executor'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; import {createBoolAttr, createNumberAttrFromIndex, createTensorAttr, validateParam} from './test_helper'; describe('dynamic', () => { let node: Node; const input1 = [tfOps.tensor1d([1])]; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; node = { name: 'input1', op: '', @@ -57,10 +62,12 @@ describe('dynamic', () => { const input3 = [tfOps.tensor1d([1])]; const input4 = [tfOps.tensor1d([1])]; const input5 = [tfOps.tensor1d([1])]; - spyOn(tfOps.image, 'nonMaxSuppressionAsync'); - const result = - executeOp(node, {input1, input2, input3, input4, input5}, context); - expect(tfOps.image.nonMaxSuppressionAsync) + spyOps.image.nonMaxSuppressionAsync.and.returnValue({}); + + const result = executeOp( + node, {input1, input2, input3, input4, input5}, context, undefined, + spyOpsAsTfOps); + expect(spyOps.image.nonMaxSuppressionAsync) .toHaveBeenCalledWith(input1[0], input2[0], 1, 1, 1); expect(result instanceof Promise).toBeTruthy(); }); @@ -90,10 +97,12 @@ describe('dynamic', () => { const input3 = [tfOps.tensor1d([1])]; const input4 = [tfOps.tensor1d([1])]; const input5 = [tfOps.tensor1d([1])]; - spyOn(tfOps.image, 'nonMaxSuppressionAsync'); - const result = - executeOp(node, {input1, input2, input3, input4, input5}, context); - expect(tfOps.image.nonMaxSuppressionAsync) + spyOps.image.nonMaxSuppressionAsync.and.returnValue({}); + + const result = executeOp( + node, {input1, input2, input3, input4, input5}, context, undefined, + spyOpsAsTfOps); + expect(spyOps.image.nonMaxSuppressionAsync) .toHaveBeenCalledWith(input1[0], input2[0], 1, 1, 1); expect(result instanceof Promise).toBeTruthy(); }); @@ -125,10 +134,13 @@ describe('dynamic', () => { const input3 = [tfOps.tensor1d([1])]; const input4 = [tfOps.tensor1d([1])]; const input5 = [tfOps.tensor1d([1])]; - spyOn(tfOps.image, 'nonMaxSuppressionPaddedAsync').and.returnValue({}); - const result = - executeOp(node, {input1, input2, input3, input4, input5}, context); - expect(tfOps.image.nonMaxSuppressionPaddedAsync) + + spyOps.image.nonMaxSuppressionPaddedAsync.and.returnValue({}); + + const result = executeOp( + node, {input1, input2, input3, input4, input5}, context, undefined, + spyOpsAsTfOps); + expect(spyOps.image.nonMaxSuppressionPaddedAsync) .toHaveBeenCalledWith(input1[0], input2[0], 1, 1, 1, true); expect(result instanceof Promise).toBeTruthy(); }); @@ -163,11 +175,11 @@ describe('dynamic', () => { const input4 = [tfOps.tensor1d([1])]; const input5 = [tfOps.tensor1d([1])]; const input6 = [tfOps.tensor1d([1])]; - spyOn(tfOps.image, 'nonMaxSuppressionWithScoreAsync') - .and.returnValue({}); + spyOps.image.nonMaxSuppressionWithScoreAsync.and.returnValue({}); const result = executeOp( - node, {input1, input2, input3, input4, input5, input6}, context); - expect(tfOps.image.nonMaxSuppressionWithScoreAsync) + node, {input1, input2, input3, input4, input5, input6}, context, + undefined, spyOpsAsTfOps); + expect(spyOps.image.nonMaxSuppressionWithScoreAsync) .toHaveBeenCalledWith(input1[0], input2[0], 1, 1, 1, 1); expect(result instanceof Promise).toBeTruthy(); }); @@ -192,16 +204,13 @@ describe('dynamic', () => { node.op = 'Where'; node.inputParams = {'condition': createTensorAttr(0)}; const input1 = [tfOps.scalar(1)]; - spyOn(tfOps, 'whereAsync'); + // spyOn(tfOps, 'whereAsync'); - const result = executeOp(node, {input1}, context); - expect( - (tfOps.whereAsync as jasmine.Spy).calls.mostRecent().args[0].dtype) + const result = + executeOp(node, {input1}, context, undefined, spyOpsAsTfOps); + expect(spyOps.whereAsync.calls.mostRecent().args[0].dtype) .toEqual('bool'); - expect((tfOps.whereAsync as jasmine.Spy) - .calls.mostRecent() - .args[0] - .arraySync()) + expect(spyOps.whereAsync.calls.mostRecent().args[0].arraySync()) .toEqual(1); expect(result instanceof Promise).toBeTruthy(); }); @@ -215,7 +224,6 @@ describe('dynamic', () => { node.op = 'Where'; node.inputParams = {'condition': createTensorAttr(0)}; const input1 = [tfOps.scalar(1)]; - spyOn(tfOps, 'whereAsync').and.callThrough(); const prevCount = memory().numTensors; await executeOp(node, {input1}, context); @@ -231,10 +239,12 @@ describe('dynamic', () => { node.inputParams = {'x': createTensorAttr(0), 'y': createTensorAttr(1)}; const input1 = [tfOps.scalar(1)]; const input2 = [tfOps.scalar(1)]; - spyOn(tfOps, 'setdiff1dAsync'); + spyOps.setdiff1dAsync.and.returnValue({}); - const result = executeOp(node, {input1, input2}, context); - expect(tfOps.setdiff1dAsync).toHaveBeenCalledWith(input1[0], input2[0]); + const result = executeOp( + node, {input1, input2}, context, undefined, spyOpsAsTfOps); + expect(spyOps.setdiff1dAsync) + .toHaveBeenCalledWith(input1[0], input2[0]); expect(result instanceof Promise).toBeTruthy(); }); it('should match json def', () => { diff --git a/tfjs-converter/src/operations/executors/evaluation_executor.ts b/tfjs-converter/src/operations/executors/evaluation_executor.ts index 815e73df862..6020624ac29 100644 --- a/tfjs-converter/src/operations/executors/evaluation_executor.ts +++ b/tfjs-converter/src/operations/executors/evaluation_executor.ts @@ -26,7 +26,8 @@ import {InternalOpExecutor, Node} from '../types'; import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = - (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext): + (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext, + ops = tfOps): Tensor[] => { switch (node.op) { case 'LowerBound': { @@ -35,14 +36,14 @@ export const executeOp: InternalOpExecutor = Tensor; const values = getParamValue('values', node, tensorMap, context) as Tensor; - return [tfOps.lowerBound(sortedSequence, values)]; + return [ops.lowerBound(sortedSequence, values)]; } case 'TopKV2': { const x = getParamValue('x', node, tensorMap, context) as Tensor; const k = getParamValue('k', node, tensorMap, context) as number; const sorted = getParamValue('sorted', node, tensorMap, context) as boolean; - const result = tfOps.topk(x, k, sorted); + const result = ops.topk(x, k, sorted); return [result.values, result.indices]; } case 'UpperBound': { @@ -51,18 +52,18 @@ export const executeOp: InternalOpExecutor = Tensor; const values = getParamValue('values', node, tensorMap, context) as Tensor; - return [tfOps.upperBound(sortedSequence, values)]; + return [ops.upperBound(sortedSequence, values)]; } case 'Unique': { const x = getParamValue('x', node, tensorMap, context) as Tensor; - const result = tfOps.unique(x); + const result = ops.unique(x); return [result.values, result.indices]; } case 'UniqueV2': { const x = getParamValue('x', node, tensorMap, context) as Tensor; const axis = getParamValue('axis', node, tensorMap, context) as number; - const result = tfOps.unique(x, axis); + const result = ops.unique(x, axis); return [result.values, result.indices]; } default: diff --git a/tfjs-converter/src/operations/executors/evaluation_executor_test.ts b/tfjs-converter/src/operations/executors/evaluation_executor_test.ts index b9f2fdd6665..919fdf0a093 100644 --- a/tfjs-converter/src/operations/executors/evaluation_executor_test.ts +++ b/tfjs-converter/src/operations/executors/evaluation_executor_test.ts @@ -21,6 +21,7 @@ import {ExecutionContext} from '../../executor/execution_context'; import {Node} from '../types'; import {executeOp} from './evaluation_executor'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; import {createBoolAttr, createNumberAttrFromIndex, createTensorAttr} from './test_helper'; describe('evaluation', () => { @@ -43,6 +44,14 @@ describe('evaluation', () => { }); describe('executeOp', () => { + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; + + beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; + }); + describe('LowerBound', () => { it('should return input', () => { node.op = 'LowerBound'; @@ -50,7 +59,6 @@ describe('evaluation', () => { node.inputParams['values'] = createTensorAttr(1); node.inputNames = ['sortedSequence', 'values']; - spyOn(tfOps, 'lowerBound').and.callThrough(); const sortedSequence = [tfOps.tensor2d( [0., 3., 8., 9., 10., 1., 2., 3., 4., 5.], [2, 5], 'int32')]; const values = [tfOps.tensor2d( @@ -63,8 +71,8 @@ describe('evaluation', () => { 4.5, ], [2, 3], 'float32')]; - executeOp(node, {sortedSequence, values}, context); - expect(tfOps.lowerBound) + executeOp(node, {sortedSequence, values}, context, spyOpsAsTfOps); + expect(spyOps.lowerBound) .toHaveBeenCalledWith(sortedSequence[0], values[0]); }); }); @@ -75,9 +83,8 @@ describe('evaluation', () => { node.inputParams['x'] = createTensorAttr(0); node.inputParams['k'] = createNumberAttrFromIndex(1); node.attrParams['sorted'] = createBoolAttr(true); - spyOn(tfOps, 'topk').and.callThrough(); - executeOp(node, {input1, input2}, context); - expect(tfOps.topk).toHaveBeenCalledWith(input1[0], 1, true); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); + expect(spyOps.topk).toHaveBeenCalledWith(input1[0], 1, true); }); }); @@ -88,7 +95,6 @@ describe('evaluation', () => { node.inputParams['values'] = createTensorAttr(1); node.inputNames = ['sortedSequence', 'values']; - spyOn(tfOps, 'upperBound').and.callThrough(); const sortedSequence = [tfOps.tensor2d( [0., 3., 8., 9., 10., 1., 2., 3., 4., 5.], [2, 5], 'int32')]; const values = [tfOps.tensor2d( @@ -101,8 +107,8 @@ describe('evaluation', () => { 4.5, ], [2, 3], 'float32')]; - executeOp(node, {sortedSequence, values}, context); - expect(tfOps.upperBound) + executeOp(node, {sortedSequence, values}, context, spyOpsAsTfOps); + expect(spyOps.upperBound) .toHaveBeenCalledWith(sortedSequence[0], values[0]); }); }); @@ -111,9 +117,8 @@ describe('evaluation', () => { it('should get called correctly', () => { node.op = 'Unique'; node.inputParams['x'] = createTensorAttr(0); - spyOn(tfOps, 'unique').and.callThrough(); - executeOp(node, {input1}, context); - expect(tfOps.unique).toHaveBeenCalledWith(input1[0]); + executeOp(node, {input1}, context, spyOpsAsTfOps); + expect(spyOps.unique).toHaveBeenCalledWith(input1[0]); }); }); @@ -122,11 +127,12 @@ describe('evaluation', () => { node.op = 'UniqueV2'; node.inputParams['x'] = createTensorAttr(0); node.inputParams['axis'] = createNumberAttrFromIndex(1); - spyOn(tfOps, 'unique').and.callThrough(); const xInput = [tfOps.tensor2d([[1], [2]])]; const axisInput = [tfOps.scalar(1)]; - executeOp(node, {'input1': xInput, 'input2': axisInput}, context); - expect(tfOps.unique).toHaveBeenCalledWith(xInput[0], 1); + executeOp( + node, {'input1': xInput, 'input2': axisInput}, context, + spyOpsAsTfOps); + expect(spyOps.unique).toHaveBeenCalledWith(xInput[0], 1); }); }); }); diff --git a/tfjs-converter/src/operations/executors/graph_executor.ts b/tfjs-converter/src/operations/executors/graph_executor.ts index b2c53c78a5a..f6ba57e768a 100644 --- a/tfjs-converter/src/operations/executors/graph_executor.ts +++ b/tfjs-converter/src/operations/executors/graph_executor.ts @@ -27,7 +27,7 @@ import {cloneTensor, getParamValue, getTensor} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'Const': { return tensorMap[node.name]; @@ -52,22 +52,22 @@ export const executeOp: InternalOpExecutor = (getParamValue('x', node, tensorMap, context) as Tensor); return [cloneTensor(snapshot)]; case 'Shape': - return [tfOps.tensor1d( + return [ops.tensor1d( (getParamValue('x', node, tensorMap, context) as Tensor).shape, 'int32')]; case 'ShapeN': return (getParamValue('x', node, tensorMap, context) as Tensor[]) - .map((t: Tensor) => tfOps.tensor1d(t.shape)); + .map((t: Tensor) => ops.tensor1d(t.shape)); case 'Size': - return [tfOps.scalar( + return [ops.scalar( (getParamValue('x', node, tensorMap, context) as Tensor).size, 'int32')]; case 'Rank': - return [tfOps.scalar( + return [ops.scalar( (getParamValue('x', node, tensorMap, context) as Tensor).rank, 'int32')]; case 'NoOp': - return [tfOps.scalar(1)]; + return [ops.scalar(1)]; case 'Print': const input = getParamValue('x', node, tensorMap, context) as Tensor; const data = diff --git a/tfjs-converter/src/operations/executors/image_executor.ts b/tfjs-converter/src/operations/executors/image_executor.ts index b598dd5f95c..87e29d40d9b 100644 --- a/tfjs-converter/src/operations/executors/image_executor.ts +++ b/tfjs-converter/src/operations/executors/image_executor.ts @@ -27,7 +27,7 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'ResizeBilinear': { const images = @@ -40,7 +40,7 @@ export const executeOp: InternalOpExecutor = const halfPixelCenters = getParamValue('halfPixelCenters', node, tensorMap, context) as boolean; - return [tfOps.image.resizeBilinear( + return [ops.image.resizeBilinear( images as Tensor3D | Tensor4D, [size[0], size[1]], alignCorners, halfPixelCenters)]; } @@ -55,7 +55,7 @@ export const executeOp: InternalOpExecutor = const halfPixelCenters = getParamValue('halfPixelCenters', node, tensorMap, context) as boolean; - return [tfOps.image.resizeNearestNeighbor( + return [ops.image.resizeNearestNeighbor( images as Tensor3D | Tensor4D, [size[0], size[1]], alignCorners, halfPixelCenters)]; } @@ -73,7 +73,7 @@ export const executeOp: InternalOpExecutor = const extrapolationValue = getParamValue('extrapolationValue', node, tensorMap, context) as number; - return [tfOps.image.cropAndResize( + return [ops.image.cropAndResize( image as Tensor4D, boxes as Tensor2D, boxInd as Tensor1D, cropSize as [number, number], method as 'bilinear' | 'nearest', extrapolationValue)]; @@ -93,7 +93,7 @@ export const executeOp: InternalOpExecutor = string; const fillMode = getParamValue('fillMode', node, tensorMap, context) as string; - return [tfOps.image.transform( + return [ops.image.transform( images as Tensor4D, transforms as Tensor2D, interpolation.toLowerCase() as 'bilinear' | 'nearest', diff --git a/tfjs-converter/src/operations/executors/image_executor_test.ts b/tfjs-converter/src/operations/executors/image_executor_test.ts index 74d83f088e3..ef73eb97957 100644 --- a/tfjs-converter/src/operations/executors/image_executor_test.ts +++ b/tfjs-converter/src/operations/executors/image_executor_test.ts @@ -23,13 +23,18 @@ import {Node} from '../types'; import {executeOp} from './image_executor'; import {createBoolAttr, createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createStrAttr, createTensorAttr, validateParam} from './test_helper'; +import {spyOnAllFunctions, RecursiveSpy} from './spy_ops'; describe('image', () => { let node: Node; const input1 = [tfOps.tensor1d([1])]; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; node = { name: 'input1', op: '', @@ -52,9 +57,10 @@ describe('image', () => { node.attrParams['halfPixelCenters'] = createBoolAttr(true); node.inputNames = ['input1', 'input2']; const input2 = [tfOps.tensor1d([1, 2])]; - spyOn(tfOps.image, 'resizeBilinear'); - executeOp(node, {input1, input2}, context); - expect(tfOps.image.resizeBilinear) + spyOps.image.resizeBilinear.and.returnValue({}); + + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); + expect(spyOps.image.resizeBilinear) .toHaveBeenCalledWith(input1[0], [1, 2], true, true); }); it('should match json def', () => { @@ -76,9 +82,10 @@ describe('image', () => { node.attrParams['halfPixelCenters'] = createBoolAttr(true); node.inputNames = ['input1', 'input2']; const input2 = [tfOps.tensor1d([1, 2])]; - spyOn(tfOps.image, 'resizeNearestNeighbor'); - executeOp(node, {input1, input2}, context); - expect(tfOps.image.resizeNearestNeighbor) + spyOps.image.resizeNearestNeighbor.and.returnValue({}); + + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); + expect(spyOps.image.resizeNearestNeighbor) .toHaveBeenCalledWith(input1[0], [1, 2], true, true); }); it('should match json def', () => { @@ -102,13 +109,14 @@ describe('image', () => { node.attrParams['extrapolationValue'] = createNumberAttr(0.5); node.inputNames = ['input1', 'input2', 'input3', 'input4']; - spyOn(tfOps.image, 'cropAndResize'); + spyOps.image.cropAndResize.and.returnValue({}); const input2 = [tfOps.tensor1d([2])]; const input3 = [tfOps.tensor1d([3])]; const input4 = [tfOps.tensor1d([4, 5])]; - executeOp(node, {input1, input2, input3, input4}, context); - expect(tfOps.image.cropAndResize) + executeOp(node, {input1, input2, input3, input4}, context, + spyOpsAsTfOps); + expect(spyOps.image.cropAndResize) .toHaveBeenCalledWith( input1[0], input2[0], input3[0], [4, 5], 'bilinear', 0.5); }); @@ -137,13 +145,14 @@ describe('image', () => { node.attrParams['fillMode'] = createStrAttr('constant'); node.inputNames = ['input1', 'input2', 'input3', 'input4']; - spyOn(tfOps.image, 'transform'); - const input2 = [tfOps.tensor1d([2])]; + const input1 = [tfOps.tensor4d([1], [1, 1, 1, 1])]; + const input2 = [tfOps.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [1, 8])]; const input3 = [tfOps.tensor1d([4, 5])]; const input4 = [tfOps.scalar(3)]; - executeOp(node, {input1, input2, input3, input4}, context); - expect(tfOps.image.transform) + executeOp(node, {input1, input2, input3, input4}, context, + spyOpsAsTfOps); + expect(spyOps.image.transform) .toHaveBeenCalledWith( input1[0], input2[0], 'bilinear', 'constant', 3, [4, 5]); }); diff --git a/tfjs-converter/src/operations/executors/logical_executor.ts b/tfjs-converter/src/operations/executors/logical_executor.ts index 2e999f34f2d..17d43b02941 100644 --- a/tfjs-converter/src/operations/executors/logical_executor.ts +++ b/tfjs-converter/src/operations/executors/logical_executor.ts @@ -27,55 +27,55 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'Equal': { - return [tfOps.equal( + return [ops.equal( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'NotEqual': { - return [tfOps.notEqual( + return [ops.notEqual( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'Greater': { - return [tfOps.greater( + return [ops.greater( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'GreaterEqual': { - return [tfOps.greaterEqual( + return [ops.greaterEqual( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'Less': { - return [tfOps.less( + return [ops.less( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'LessEqual': { - return [tfOps.lessEqual( + return [ops.lessEqual( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'LogicalAnd': { - return [tfOps.logicalAnd( + return [ops.logicalAnd( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'LogicalNot': { - return [tfOps.logicalNot( + return [ops.logicalNot( getParamValue('a', node, tensorMap, context) as Tensor)]; } case 'LogicalOr': { - return [tfOps.logicalOr( + return [ops.logicalOr( getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; } case 'Select': case 'SelectV2': { - return [tfOps.where( + return [ops.where( getParamValue('condition', node, tensorMap, context) as Tensor, getParamValue('a', node, tensorMap, context) as Tensor, getParamValue('b', node, tensorMap, context) as Tensor)]; diff --git a/tfjs-converter/src/operations/executors/logical_executor_test.ts b/tfjs-converter/src/operations/executors/logical_executor_test.ts index 1840902e553..4d86dd0b896 100644 --- a/tfjs-converter/src/operations/executors/logical_executor_test.ts +++ b/tfjs-converter/src/operations/executors/logical_executor_test.ts @@ -21,7 +21,8 @@ import {ExecutionContext} from '../../executor/execution_context'; import {Node} from '../types'; import {executeOp} from './logical_executor'; -import {createTensorAttr} from './test_helper'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; +import {createTensorAttr, uncapitalize} from './test_helper'; describe('logical', () => { let node: Node; @@ -43,52 +44,67 @@ describe('logical', () => { }); describe('executeOp', () => { - ['Equal', 'NotEqual', 'Greater', 'GreaterEqual', 'Less', 'LessEqual', - 'LogicalAnd', 'LogicalOr'] + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; + + beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; + }); + + ([ + 'Equal', 'NotEqual', 'Greater', 'GreaterEqual', 'Less', 'LessEqual', + 'LogicalAnd', 'LogicalOr' + ] as const ) .forEach(op => { it('should call tfOps.' + op, () => { - const spy = spyOn( - tfOps, op.charAt(0).toLowerCase() + op.slice(1) as 'equal'); node.op = op; - executeOp(node, {input1, input2}, context); + // TODO(mattsoulanille): Remove type assertions after TS4 + // tslint:disable-next-line no-any + (spyOps[uncapitalize(op) as keyof typeof spyOps] as any) + .and.returnValue({}); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(spy).toHaveBeenCalledWith(input1[0], input2[0]); + // TODO(mattsoulanille): Remove type assertion after TS4 + expect(spyOps[uncapitalize(op) as keyof typeof spyOps]) + .toHaveBeenCalledWith(input1[0], input2[0]); }); }); describe('LogicalNot', () => { it('should call tfOps.logicalNot', () => { - spyOn(tfOps, 'logicalNot'); node.op = 'LogicalNot'; - executeOp(node, {input1}, context); + spyOps.logicalNot.and.returnValue({}); + + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.logicalNot).toHaveBeenCalledWith(input1[0]); + expect(spyOps.logicalNot).toHaveBeenCalledWith(input1[0]); }); }); describe('Select', () => { it('should call tfOps.where', () => { - spyOn(tfOps, 'where'); node.op = 'Select'; node.inputNames = ['input1', 'input2', 'input3']; node.inputParams.condition = createTensorAttr(2); const input3 = [tfOps.scalar(1)]; - executeOp(node, {input1, input2, input3}, context); + spyOps.where.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.where) + expect(spyOps.where) .toHaveBeenCalledWith(input3[0], input1[0], input2[0]); }); }); describe('SelectV2', () => { it('should call tfOps.where', () => { - spyOn(tfOps, 'where'); node.op = 'SelectV2'; node.inputNames = ['input1', 'input2', 'input3']; node.inputParams.condition = createTensorAttr(2); const input3 = [tfOps.scalar(1)]; - executeOp(node, {input1, input2, input3}, context); + spyOps.where.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.where) + expect(spyOps.where) .toHaveBeenCalledWith(input3[0], input1[0], input2[0]); }); }); diff --git a/tfjs-converter/src/operations/executors/matrices_executor.ts b/tfjs-converter/src/operations/executors/matrices_executor.ts index 1a0cdb66628..28a275e071d 100644 --- a/tfjs-converter/src/operations/executors/matrices_executor.ts +++ b/tfjs-converter/src/operations/executors/matrices_executor.ts @@ -27,12 +27,12 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'BatchMatMul': case 'BatchMatMulV2': case 'MatMul': - return [tfOps.matMul( + return [ops.matMul( getParamValue('a', node, tensorMap, context) as Tensor2D, getParamValue('b', node, tensorMap, context) as Tensor2D, getParamValue('transposeA', node, tensorMap, context) as boolean, @@ -40,13 +40,13 @@ export const executeOp: InternalOpExecutor = boolean)]; case 'Einsum': - return [tfOps.einsum( + return [ops.einsum( getParamValue('equation', node, tensorMap, context) as string, ...getParamValue('tensors', node, tensorMap, context) as Tensor[])]; case 'Transpose': - return [tfOps.transpose( + return [ops.transpose( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('perm', node, tensorMap, context) as number[])]; @@ -76,7 +76,7 @@ export const executeOp: InternalOpExecutor = } const [biasArg, preluArg] = getParamValue('args', node, tensorMap, context) as Tensor[]; - return [tfOps.fused.matMul({ + return [ops.fused.matMul({ a: getParamValue('a', node, tensorMap, context) as Tensor2D, b: getParamValue('b', node, tensorMap, context) as Tensor2D, transposeA: getParamValue('transposeA', node, tensorMap, context) as diff --git a/tfjs-converter/src/operations/executors/matrices_executor_test.ts b/tfjs-converter/src/operations/executors/matrices_executor_test.ts index 5e0db88c67e..c55a6abfd37 100644 --- a/tfjs-converter/src/operations/executors/matrices_executor_test.ts +++ b/tfjs-converter/src/operations/executors/matrices_executor_test.ts @@ -24,6 +24,7 @@ import {ExecutionContext} from '../../executor/execution_context'; import {Node} from '../types'; import {executeOp} from './matrices_executor'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; import {createBoolAttr, createNumberAttr, createNumericArrayAttr, createStrArrayAttr, createStrAttr, createTensorAttr, createTensorsAttr, validateParam} from './test_helper'; describe('matrices', () => { @@ -31,8 +32,12 @@ describe('matrices', () => { const input1 = [tfOps.scalar(1)]; const input2 = [tfOps.scalar(2)]; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; node = { name: 'test', op: '', @@ -48,19 +53,18 @@ describe('matrices', () => { describe('executeOp', () => { describe('MatMul', () => { it('should call tfOps.matMul', () => { - spyOn(tfOps, 'matMul'); node.op = 'MatMul'; node.attrParams.transposeA = createBoolAttr(true); node.attrParams.transposeB = createBoolAttr(false); - executeOp(node, {input1, input2}, context); + spyOps.matMul.and.returnValue({}); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.matMul) + expect(spyOps.matMul) .toHaveBeenCalledWith(input1[0], input2[0], true, false); }); }); describe('_FusedMatMul', () => { it('should call tfOps.fused.matMul', () => { - spyOn(tfOps.fused, 'matMul'); node.op = '_FusedMatMul'; node.inputParams['args'] = createTensorsAttr(2, 0); node.attrParams['fusedOps'] = createStrArrayAttr(['biasadd', 'relu']); @@ -69,9 +73,10 @@ describe('matrices', () => { node.attrParams.transposeB = createBoolAttr(false); const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + spyOps.fused.matMul.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.matMul).toHaveBeenCalledWith({ + expect(spyOps.fused.matMul).toHaveBeenCalledWith({ a: input1[0], b: input2[0], transposeA: true, @@ -83,7 +88,6 @@ describe('matrices', () => { }); }); it('should call tfOps.fused.matMul - prelu activation', () => { - spyOn(tfOps.fused, 'matMul'); node.op = '_FusedMatMul'; node.inputParams['args'] = createTensorsAttr(2, 0); node.attrParams['fusedOps'] = createStrArrayAttr(['biasadd', 'prelu']); @@ -93,9 +97,11 @@ describe('matrices', () => { const input3 = [tfOps.scalar(3.0)]; const input4 = [tfOps.scalar(4.0)]; node.inputNames = ['input1', 'input2', 'input3', 'input4']; - executeOp(node, {input1, input2, input3, input4}, context); + spyOps.fused.matMul.and.returnValue({}); + executeOp( + node, {input1, input2, input3, input4}, context, spyOpsAsTfOps); - expect(tfOps.fused.matMul).toHaveBeenCalledWith({ + expect(spyOps.fused.matMul).toHaveBeenCalledWith({ a: input1[0], b: input2[0], transposeA: true, @@ -107,7 +113,6 @@ describe('matrices', () => { }); }); it('should call tfOps.fused.matMul - leakyrelu activation', () => { - spyOn(tfOps.fused, 'matMul'); node.op = '_FusedMatMul'; node.inputParams['args'] = createTensorsAttr(2, 0); node.attrParams['fusedOps'] = @@ -118,9 +123,10 @@ describe('matrices', () => { node.attrParams.leakyreluAlpha = createNumberAttr(0.3); const input3 = [tfOps.scalar(3.0)]; node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + spyOps.fused.matMul.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.fused.matMul).toHaveBeenCalledWith({ + expect(spyOps.fused.matMul).toHaveBeenCalledWith({ a: input1[0], b: input2[0], transposeA: true, @@ -146,52 +152,51 @@ describe('matrices', () => { }); describe('BatchMatMul', () => { it('should call tfOps.matMul', () => { - spyOn(tfOps, 'matMul'); node.op = 'BatchMatMul'; node.attrParams.transposeA = createBoolAttr(true); node.attrParams.transposeB = createBoolAttr(false); - executeOp(node, {input1, input2}, context); + spyOps.matMul.and.returnValue({}); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.matMul) + expect(spyOps.matMul) .toHaveBeenCalledWith(input1[0], input2[0], true, false); }); }); describe('BatchMatMulV2', () => { it('should call tfOps.matMul', () => { - spyOn(tfOps, 'matMul'); node.op = 'BatchMatMulV2'; node.attrParams.transposeA = createBoolAttr(true); node.attrParams.transposeB = createBoolAttr(false); - executeOp(node, {input1, input2}, context); + spyOps.matMul.and.returnValue({}); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.matMul) + expect(spyOps.matMul) .toHaveBeenCalledWith(input1[0], input2[0], true, false); }); }); describe('Einsum', () => { it('should call tfOps.einsum', () => { - const spy = spyOn(tfOps, 'einsum').and.callThrough(); node.op = 'Einsum'; node.inputParams = {tensors: createTensorsAttr(0, 0)}; node.inputNames = ['input1', 'input2']; node.attrParams.equation = createStrAttr(',->'); - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); const res = executeOp(node, {input1, input2}, context) as Tensor[]; - expect(spy).toHaveBeenCalledWith(',->', input1[0], input2[0]); + expect(spyOps.einsum).toHaveBeenCalledWith(',->', input1[0], input2[0]); expect(res[0].dtype).toBe('float32'); expect(res[0].shape).toEqual([]); }); }); describe('Transpose', () => { it('should call tfOps.transpose', () => { - spyOn(tfOps, 'transpose'); node.op = 'Transpose'; node.inputNames = ['input1', 'input2', 'input3']; node.inputParams.x = createTensorAttr(0); node.attrParams.perm = createNumericArrayAttr([1, 2]); - executeOp(node, {input1}, context); + spyOps.transpose.and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.transpose).toHaveBeenCalledWith(input1[0], [1, 2]); + expect(spyOps.transpose).toHaveBeenCalledWith(input1[0], [1, 2]); }); }); }); diff --git a/tfjs-converter/src/operations/executors/normalization_executor.ts b/tfjs-converter/src/operations/executors/normalization_executor.ts index 3164fd73759..58b5e24f247 100644 --- a/tfjs-converter/src/operations/executors/normalization_executor.ts +++ b/tfjs-converter/src/operations/executors/normalization_executor.ts @@ -27,16 +27,16 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'EuclideanNorm': - return [tfOps.euclideanNorm( + return [ops.euclideanNorm( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('axis', node, tensorMap, context) as number[], getParamValue('keepDims', node, tensorMap, context) as boolean)]; case 'FusedBatchNorm': case 'FusedBatchNormV2': { - return [tfOps.batchNorm( + return [ops.batchNorm( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('mean', node, tensorMap, context) as Tensor, getParamValue('variance', node, tensorMap, context) as Tensor, @@ -45,7 +45,7 @@ export const executeOp: InternalOpExecutor = getParamValue('epsilon', node, tensorMap, context) as number)]; } case 'FusedBatchNormV3': { - return [tfOps.batchNorm( + return [ops.batchNorm( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('mean', node, tensorMap, context) as Tensor, getParamValue('variance', node, tensorMap, context) as Tensor, @@ -54,7 +54,7 @@ export const executeOp: InternalOpExecutor = getParamValue('epsilon', node, tensorMap, context) as number)]; } case 'LRN': { - return [tfOps.localResponseNormalization( + return [ops.localResponseNormalization( getParamValue('x', node, tensorMap, context) as Tensor3D | Tensor4D, getParamValue('radius', node, tensorMap, context) as number, @@ -63,15 +63,15 @@ export const executeOp: InternalOpExecutor = getParamValue('beta', node, tensorMap, context) as number)]; } case 'Softmax': { - return [tfOps.softmax( + return [ops.softmax( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'LogSoftmax': { - return [tfOps.logSoftmax( + return [ops.logSoftmax( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'SparseToDense': { - return [tfOps.sparseToDense( + return [ops.sparseToDense( getParamValue('sparseIndices', node, tensorMap, context) as Tensor, getParamValue('outputShape', node, tensorMap, context) as Tensor, diff --git a/tfjs-converter/src/operations/executors/normalization_executor_test.ts b/tfjs-converter/src/operations/executors/normalization_executor_test.ts index 552d1d05824..484836a99ec 100644 --- a/tfjs-converter/src/operations/executors/normalization_executor_test.ts +++ b/tfjs-converter/src/operations/executors/normalization_executor_test.ts @@ -23,13 +23,18 @@ import {Node} from '../types'; import {executeOp} from './normalization_executor'; import {createBoolAttr, createNumberAttr, createNumericArrayAttrFromIndex, createTensorAttr, validateParam} from './test_helper'; +import {spyOnAllFunctions, RecursiveSpy} from './spy_ops'; describe('normalization', () => { let node: Node; const input1 = [tfOps.scalar(1)]; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; node = { name: 'test', op: '', @@ -45,15 +50,15 @@ describe('normalization', () => { describe('executeOp', () => { describe('EuclideanNorm', () => { it('should call tfOps.euclideanNorm', () => { - spyOn(tfOps, 'euclideanNorm'); node.op = 'EuclideanNorm'; node.inputParams['axis'] = createNumericArrayAttrFromIndex(1); node.attrParams.keepDims = createBoolAttr(false); node.inputNames = ['input1', 'input2']; const input2 = [tfOps.tensor1d([2])]; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.euclideanNorm).toHaveBeenCalledWith(input1[0], [2], false); + expect(spyOps.euclideanNorm).toHaveBeenCalledWith(input1[0], [2], + false); }); it('should match json def', () => { node.op = 'EuclideanNorm'; @@ -66,7 +71,6 @@ describe('normalization', () => { }); describe('FusedBatchNorm', () => { it('should call tfOps.batchNorm', () => { - spyOn(tfOps, 'batchNorm'); node.op = 'FusedBatchNorm'; node.inputParams.scale = createTensorAttr(1); node.inputParams.offset = createTensorAttr(2); @@ -78,16 +82,16 @@ describe('normalization', () => { const input3 = [tfOps.scalar(2)]; const input4 = [tfOps.scalar(3)]; const input5 = [tfOps.scalar(4)]; - executeOp(node, {input1, input2, input3, input4, input5}, context); + executeOp(node, {input1, input2, input3, input4, input5}, context, + spyOpsAsTfOps); - expect(tfOps.batchNorm) + expect(spyOps.batchNorm) .toHaveBeenCalledWith( input1[0], input4[0], input5[0], input3[0], input2[0], 5); }); }); describe('FusedBatchNormV2', () => { it('should call tfOps.batchNorm', () => { - spyOn(tfOps, 'batchNorm'); node.op = 'FusedBatchNormV2'; node.inputParams.scale = createTensorAttr(1); node.inputParams.offset = createTensorAttr(2); @@ -99,16 +103,16 @@ describe('normalization', () => { const input3 = [tfOps.scalar(2)]; const input4 = [tfOps.scalar(3)]; const input5 = [tfOps.scalar(4)]; - executeOp(node, {input1, input2, input3, input4, input5}, context); + executeOp(node, {input1, input2, input3, input4, input5}, context, + spyOpsAsTfOps); - expect(tfOps.batchNorm) + expect(spyOps.batchNorm) .toHaveBeenCalledWith( input1[0], input4[0], input5[0], input3[0], input2[0], 5); }); }); describe('FusedBatchNormV3', () => { it('should call tfOps.batchNorm', () => { - spyOn(tfOps, 'batchNorm'); node.op = 'FusedBatchNormV3'; node.inputParams.scale = createTensorAttr(1); node.inputParams.offset = createTensorAttr(2); @@ -120,25 +124,26 @@ describe('normalization', () => { const input3 = [tfOps.scalar(2)]; const input4 = [tfOps.scalar(3)]; const input5 = [tfOps.scalar(4)]; - executeOp(node, {input1, input2, input3, input4, input5}, context); + executeOp(node, {input1, input2, input3, input4, input5}, context, + spyOpsAsTfOps); - expect(tfOps.batchNorm) + expect(spyOps.batchNorm) .toHaveBeenCalledWith( input1[0], input4[0], input5[0], input3[0], input2[0], 5); }); }); describe('LRN', () => { it('should call tfOps.localResponseNormalization', () => { - spyOn(tfOps, 'localResponseNormalization'); node.op = 'LRN'; node.attrParams.radius = createNumberAttr(1); node.attrParams.bias = createNumberAttr(2); node.attrParams.alpha = createNumberAttr(3); node.attrParams.beta = createNumberAttr(4); + spyOps.localResponseNormalization.and.returnValue({}); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.localResponseNormalization) + expect(spyOps.localResponseNormalization) .toHaveBeenCalledWith(input1[0], 1, 2, 3, 4); }); it('should match json def', () => { @@ -154,12 +159,12 @@ describe('normalization', () => { describe('Softmax', () => { it('should call tfOps.softmax', () => { - spyOn(tfOps, 'softmax'); node.op = 'Softmax'; + spyOps.softmax.and.returnValue({}); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.softmax).toHaveBeenCalledWith(input1[0]); + expect(spyOps.softmax).toHaveBeenCalledWith(input1[0]); }); it('should match json def', () => { node.op = 'Softmax'; @@ -170,12 +175,12 @@ describe('normalization', () => { describe('LogSoftmax', () => { it('should call tfOps.logSoftmax', () => { - spyOn(tfOps, 'logSoftmax'); node.op = 'LogSoftmax'; + spyOps.logSoftmax.and.returnValue({}); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.logSoftmax).toHaveBeenCalledWith(input1[0]); + expect(spyOps.logSoftmax).toHaveBeenCalledWith(input1[0]); }); it('should match json def', () => { node.op = 'LogSoftmax'; @@ -185,7 +190,6 @@ describe('normalization', () => { }); describe('SparseToDense', () => { it('should call tfOps.sparseToDense', () => { - spyOn(tfOps, 'sparseToDense'); node.op = 'SparseToDense'; node.inputParams.sparseIndices = createTensorAttr(0); node.inputParams.outputShape = createNumericArrayAttrFromIndex(1); @@ -195,9 +199,11 @@ describe('normalization', () => { const input2 = [tfOps.tensor1d([1], 'int32')]; const input3 = [tfOps.scalar(2)]; const input4 = [tfOps.scalar(3)]; - executeOp(node, {input1, input2, input3, input4}, context); + spyOps.sparseToDense.and.returnValue({}); + executeOp(node, {input1, input2, input3, input4}, context, + spyOpsAsTfOps); - expect(tfOps.sparseToDense) + expect(spyOps.sparseToDense) .toHaveBeenCalledWith(input1[0], [1], input3[0], input4[0]); }); it('should match json def', () => { diff --git a/tfjs-converter/src/operations/executors/reduction_executor.ts b/tfjs-converter/src/operations/executors/reduction_executor.ts index 912537dc862..f2d9e2c5c82 100644 --- a/tfjs-converter/src/operations/executors/reduction_executor.ts +++ b/tfjs-converter/src/operations/executors/reduction_executor.ts @@ -27,14 +27,14 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'Max': { const axis = getParamValue('axis', node, tensorMap, context) as number[]; const keepDims = getParamValue('keepDims', node, tensorMap, context) as boolean; - return [tfOps.max( + return [ops.max( getParamValue('x', node, tensorMap, context) as Tensor, axis, keepDims)]; } @@ -43,7 +43,7 @@ export const executeOp: InternalOpExecutor = getParamValue('axis', node, tensorMap, context) as number[]; const keepDims = getParamValue('keepDims', node, tensorMap, context) as boolean; - return [tfOps.mean( + return [ops.mean( getParamValue('x', node, tensorMap, context) as Tensor, axis, keepDims)]; } @@ -52,7 +52,7 @@ export const executeOp: InternalOpExecutor = getParamValue('axis', node, tensorMap, context) as number[]; const keepDims = getParamValue('keepDims', node, tensorMap, context) as boolean; - return [tfOps.min( + return [ops.min( getParamValue('x', node, tensorMap, context) as Tensor, axis, keepDims)]; } @@ -61,7 +61,7 @@ export const executeOp: InternalOpExecutor = getParamValue('axis', node, tensorMap, context) as number[]; const keepDims = getParamValue('keepDims', node, tensorMap, context) as boolean; - return [tfOps.sum( + return [ops.sum( getParamValue('x', node, tensorMap, context) as Tensor, axis, keepDims)]; } @@ -70,7 +70,7 @@ export const executeOp: InternalOpExecutor = getParamValue('axis', node, tensorMap, context) as number[]; const keepDims = getParamValue('keepDims', node, tensorMap, context) as boolean; - return [tfOps.all( + return [ops.all( getParamValue('x', node, tensorMap, context) as Tensor, axis, keepDims)]; } @@ -79,20 +79,20 @@ export const executeOp: InternalOpExecutor = getParamValue('axis', node, tensorMap, context) as number[]; const keepDims = getParamValue('keepDims', node, tensorMap, context) as boolean; - return [tfOps.any( + return [ops.any( getParamValue('x', node, tensorMap, context) as Tensor, axis, keepDims)]; } case 'ArgMax': { const axis = getParamValue('axis', node, tensorMap, context) as number; - return [tfOps.argMax( + return [ops.argMax( getParamValue('x', node, tensorMap, context) as Tensor, axis)]; } case 'ArgMin': { const axis = getParamValue('axis', node, tensorMap, context) as number; - return [tfOps.argMin( + return [ops.argMin( getParamValue('x', node, tensorMap, context) as Tensor, axis)]; } case 'Prod': { @@ -100,7 +100,7 @@ export const executeOp: InternalOpExecutor = getParamValue('axis', node, tensorMap, context) as number[]; const keepDims = getParamValue('keepDims', node, tensorMap, context) as boolean; - return [tfOps.prod( + return [ops.prod( getParamValue('x', node, tensorMap, context) as Tensor, axis, keepDims)]; } @@ -111,7 +111,7 @@ export const executeOp: InternalOpExecutor = getParamValue('exclusive', node, tensorMap, context) as boolean; const reverse = getParamValue('reverse', node, tensorMap, context) as boolean; - return [tfOps.cumprod( + return [ops.cumprod( getParamValue('x', node, tensorMap, context) as Tensor, axis, exclusive, reverse)]; } @@ -122,7 +122,7 @@ export const executeOp: InternalOpExecutor = getParamValue('exclusive', node, tensorMap, context) as boolean; const reverse = getParamValue('reverse', node, tensorMap, context) as boolean; - return [tfOps.cumsum( + return [ops.cumsum( getParamValue('x', node, tensorMap, context) as Tensor, axis, exclusive, reverse)]; } @@ -133,7 +133,7 @@ export const executeOp: InternalOpExecutor = const size = getParamValue('size', node, tensorMap, context) as number; - return [tfOps.bincount(x, weights, size)]; + return [ops.bincount(x, weights, size)]; case 'DenseBincount': { const x = getParamValue('x', node, tensorMap, context) as Tensor1D | Tensor2D; @@ -147,7 +147,7 @@ export const executeOp: InternalOpExecutor = getParamValue('binaryOutput', node, tensorMap, context) as boolean; - return [tfOps.denseBincount(x, weights, size, binaryOutput)]; + return [ops.denseBincount(x, weights, size, binaryOutput)]; } default: throw TypeError(`Node type ${node.op} is not implemented`); diff --git a/tfjs-converter/src/operations/executors/reduction_executor_test.ts b/tfjs-converter/src/operations/executors/reduction_executor_test.ts index a2c31c11825..de9211c936c 100644 --- a/tfjs-converter/src/operations/executors/reduction_executor_test.ts +++ b/tfjs-converter/src/operations/executors/reduction_executor_test.ts @@ -22,14 +22,19 @@ import * as reduction from '../op_list/reduction'; import {Node} from '../types'; import {executeOp} from './reduction_executor'; -import {createBoolAttr, createNumberAttr, createNumberAttrFromIndex, createTensorAttr, validateParam} from './test_helper'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; +import {createBoolAttr, createNumberAttr, createNumberAttrFromIndex, createTensorAttr, uncapitalize, validateParam} from './test_helper'; describe('reduction', () => { let node: Node; const input1 = [tfOps.scalar(1)]; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; node = { name: 'test', op: '', @@ -43,71 +48,73 @@ describe('reduction', () => { }); describe('executeOp', () => { - ['Max', 'Mean', 'Min', 'Sum', 'All', 'Any', 'Prod'].forEach(op => { - it('should call tfOps.' + op, () => { - const spy = - spyOn(tfOps, op.charAt(0).toLowerCase() + op.slice(1) as 'max'); - node.op = op; - node.attrParams.keepDims = createBoolAttr(true); - node.attrParams.axis = createNumberAttr(1); - executeOp(node, {input1}, context); + (['Max', 'Mean', 'Min', 'Sum', 'All', 'Any', 'Prod'] as const ) + .forEach(op => { + it('should call tfOps.' + op, () => { + node.op = op; + node.attrParams.keepDims = createBoolAttr(true); + node.attrParams.axis = createNumberAttr(1); + // TODO(mattsoulanille): Remove type assertions after TS4 + // tslint:disable-next-line no-any + (spyOps[uncapitalize(op) as keyof typeof spyOps] as any) + .and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(spy).toHaveBeenCalledWith(input1[0], 1, true); - }); - }); + // TODO(mattsoulanille): Remove type assertion after TS4 + expect(spyOps[uncapitalize(op) as keyof typeof spyOps]) + .toHaveBeenCalledWith(input1[0], 1, true); + }); + }); describe('ArgMax', () => { it('should call tfOps.argMax', () => { - spyOn(tfOps, 'argMax'); node.op = 'ArgMax'; node.attrParams.keepDims = createBoolAttr(true); node.attrParams.axis = createNumberAttr(1); - executeOp(node, {input1}, context); + spyOps.argMax.and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.argMax).toHaveBeenCalledWith(input1[0], 1); + expect(spyOps.argMax).toHaveBeenCalledWith(input1[0], 1); }); }); describe('ArgMin', () => { it('should call tfOps.argMin', () => { - spyOn(tfOps, 'argMin'); node.op = 'ArgMin'; node.attrParams.keepDims = createBoolAttr(true); node.attrParams.axis = createNumberAttr(1); - executeOp(node, {input1}, context); + spyOps.argMin.and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.argMin).toHaveBeenCalledWith(input1[0], 1); + expect(spyOps.argMin).toHaveBeenCalledWith(input1[0], 1); }); }); describe('Cumprod', () => { it('should call tfOps.cumprod', () => { - spyOn(tfOps, 'cumprod'); node.op = 'Cumprod'; node.attrParams.exclusive = createBoolAttr(true); node.attrParams.reverse = createBoolAttr(false); node.inputNames = ['input1', 'input2']; node.inputParams.axis = createNumberAttrFromIndex(1); const input2 = [tfOps.scalar(2)]; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.cumprod).toHaveBeenCalledWith(input1[0], 2, true, false); + expect(spyOps.cumprod).toHaveBeenCalledWith(input1[0], 2, true, false); }); }); describe('Cumsum', () => { it('should call tfOps.cumsum', () => { - spyOn(tfOps, 'cumsum'); node.op = 'Cumsum'; node.attrParams.exclusive = createBoolAttr(true); node.attrParams.reverse = createBoolAttr(false); node.inputNames = ['input1', 'input2']; node.inputParams.axis = createNumberAttrFromIndex(1); const input2 = [tfOps.scalar(2)]; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.cumsum).toHaveBeenCalledWith(input1[0], 2, true, false); + expect(spyOps.cumsum).toHaveBeenCalledWith(input1[0], 2, true, false); }); }); describe('Bincount', () => { it('should call tfOps.bincount', () => { - spyOn(tfOps, 'bincount'); node.op = 'Bincount'; node.inputNames = ['input4', 'input3', 'input2']; node.inputParams.size = createNumberAttrFromIndex(1); @@ -115,9 +122,9 @@ describe('reduction', () => { const input4 = [tfOps.tensor1d([1, 1], 'int32')]; const input3 = [tfOps.scalar(2)]; const input2 = [tfOps.tensor1d([])]; - executeOp(node, {input4, input3, input2}, context); + executeOp(node, {input4, input3, input2}, context, spyOpsAsTfOps); - expect(tfOps.bincount).toHaveBeenCalledWith(input4[0], input2[0], 2); + expect(spyOps.bincount).toHaveBeenCalledWith(input4[0], input2[0], 2); }); it('should match json def for bincount.', () => { node.op = 'Bincount'; @@ -129,7 +136,6 @@ describe('reduction', () => { }); describe('DenseBincount', () => { it('should call tfOps.denseBincount', () => { - spyOn(tfOps, 'denseBincount'); node.op = 'DenseBincount'; node.inputNames = ['input4', 'input3', 'input2']; node.inputParams.x = createTensorAttr(0); @@ -139,9 +145,9 @@ describe('reduction', () => { const input4 = [tfOps.tensor1d([1, 1], 'int32')]; const input3 = [tfOps.scalar(2)]; const input2 = [tfOps.tensor1d([])]; - executeOp(node, {input4, input3, input2}, context); + executeOp(node, {input4, input3, input2}, context, spyOpsAsTfOps); - expect(tfOps.denseBincount) + expect(spyOps.denseBincount) .toHaveBeenCalledWith(input4[0], input2[0], 2, true); }); it('should match json def for denseBincount.', () => { diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index f2131d06432..e574d614147 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -27,7 +27,7 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'ConcatV2': case 'Concat': { @@ -37,13 +37,13 @@ export const executeOp: InternalOpExecutor = let inputs = getParamValue('tensors', node, tensorMap, context) as Tensor[]; inputs = inputs.slice(0, n); - return [tfOps.concat(inputs, axis)]; + return [ops.concat(inputs, axis)]; } case 'Gather': { const input = getParamValue('x', node, tensorMap, context) as Tensor; const indices = getParamValue('indices', node, tensorMap, context) as Tensor1D; - return [tfOps.gather(input, tfOps.cast(indices, 'int32'), 0)]; + return [ops.gather(input, ops.cast(indices, 'int32'), 0)]; } case 'GatherV2': { const axis = @@ -53,8 +53,8 @@ export const executeOp: InternalOpExecutor = const input = getParamValue('x', node, tensorMap, context) as Tensor; const indices = getParamValue('indices', node, tensorMap, context) as Tensor1D; - return [tfOps.gather( - input, tfOps.cast(indices, 'int32'), axis, batchDims)]; + return [ops.gather( + input, ops.cast(indices, 'int32'), axis, batchDims)]; } case 'Reverse': { const dims = @@ -66,20 +66,20 @@ export const executeOp: InternalOpExecutor = } } const input = getParamValue('x', node, tensorMap, context) as Tensor; - return [tfOps.reverse(input, axis)]; + return [ops.reverse(input, axis)]; } case 'ReverseV2': { const axis = getParamValue('axis', node, tensorMap, context) as number[]; const input = getParamValue('x', node, tensorMap, context) as Tensor; - return [tfOps.reverse(input, axis)]; + return [ops.reverse(input, axis)]; } case 'Slice': { // tslint:disable-next-line:no-any const begin = getParamValue('begin', node, tensorMap, context) as any; // tslint:disable-next-line:no-any const size = getParamValue('size', node, tensorMap, context) as any; - return [tfOps.slice( + return [ops.slice( getParamValue('x', node, tensorMap, context) as Tensor, begin, size)]; } @@ -103,7 +103,7 @@ export const executeOp: InternalOpExecutor = number; const tensor = getParamValue('x', node, tensorMap, context) as Tensor; - return [tfOps.stridedSlice( + return [ops.stridedSlice( tensor, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)]; } @@ -116,17 +116,17 @@ export const executeOp: InternalOpExecutor = // Reshape the tensors to the first tensor's shape if they don't // match. const shape = tensors[0].shape; - const squeezedShape = tfOps.squeeze(tensors[0]).shape; + const squeezedShape = ops.squeeze(tensors[0]).shape; const mapped = tensors.map(tensor => { const sameShape = util.arraysEqual(tensor.shape, shape); if (!sameShape && !util.arraysEqual( - tfOps.squeeze(tensor).shape, squeezedShape)) { + ops.squeeze(tensor).shape, squeezedShape)) { throw new Error('the input tensors shape does not match'); } - return sameShape ? tensor : tfOps.reshape(tensor, shape); + return sameShape ? tensor : ops.reshape(tensor, shape); }); - return [tfOps.stack(mapped, axis)]; + return [ops.stack(mapped, axis)]; }); } case 'Unpack': { @@ -134,12 +134,12 @@ export const executeOp: InternalOpExecutor = getParamValue('axis', node, tensorMap, context) as number; const tensor = getParamValue('tensor', node, tensorMap, context) as Tensor; - return tfOps.unstack(tensor, axis); + return ops.unstack(tensor, axis); } case 'Tile': { const reps = getParamValue('reps', node, tensorMap, context) as number[]; - return [tfOps.tile( + return [ops.tile( getParamValue('x', node, tensorMap, context) as Tensor, reps)]; } case 'Split': @@ -152,7 +152,7 @@ export const executeOp: InternalOpExecutor = number[]; const tensor = getParamValue('x', node, tensorMap, context) as Tensor; - return tfOps.split(tensor, numOrSizeSplits, axis); + return ops.split(tensor, numOrSizeSplits, axis); } case 'ScatterNd': { const indices = @@ -161,13 +161,13 @@ export const executeOp: InternalOpExecutor = getParamValue('values', node, tensorMap, context) as Tensor; const shape = getParamValue('shape', node, tensorMap, context) as number[]; - return [tfOps.scatterND(indices, values, shape)]; + return [ops.scatterND(indices, values, shape)]; } case 'GatherNd': { const x = getParamValue('x', node, tensorMap, context) as Tensor; const indices = getParamValue('indices', node, tensorMap, context) as Tensor; - return [tfOps.gatherND(x, indices)]; + return [ops.gatherND(x, indices)]; } case 'SparseToDense': { const indices = @@ -180,11 +180,11 @@ export const executeOp: InternalOpExecutor = getParamValue('sparseValues', node, tensorMap, context) as Tensor; const defaultValue = getParamValue('defaultValue', node, tensorMap, context) as Scalar; - return [tfOps.sparseToDense( + return [ops.sparseToDense( indices, sparseValues, shape, sparseValues.dtype === defaultValue.dtype ? defaultValue : - tfOps.cast(defaultValue, sparseValues.dtype))]; + ops.cast(defaultValue, sparseValues.dtype))]; } default: throw TypeError(`Node type ${node.op} is not implemented`); diff --git a/tfjs-converter/src/operations/executors/slice_join_executor_test.ts b/tfjs-converter/src/operations/executors/slice_join_executor_test.ts index 6ea6249aab4..7545b251ab5 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor_test.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor_test.ts @@ -22,6 +22,7 @@ import * as slice_join from '../op_list/slice_join'; import {Node} from '../types'; import {executeOp} from './slice_join_executor'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; import {createBooleanArrayAttrFromIndex, createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, createTensorsAttr, validateParam} from './test_helper'; describe('slice join', () => { @@ -32,6 +33,13 @@ describe('slice join', () => { const input4 = [tfOps.tensor1d([3])]; const input5 = [tfOps.tensor1d([3, 4])]; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; + + beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; + }); describe('multi-tensor ops', () => { beforeEach(() => { @@ -48,24 +56,24 @@ describe('slice join', () => { }); describe('executeOp', () => { it('Concat', () => { - const spy = spyOn(tfOps, 'concat'); node.op = 'Concat'; node.inputParams.tensors = createTensorsAttr(1, 0); node.inputParams.axis = createNumberAttrFromIndex(0); node.attrParams.n = createNumberAttr(2); - executeOp(node, {input1, input2, input3}, context); + spyOps.concat.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(spy).toHaveBeenCalledWith([input2[0], input3[0]], 1); + expect(spyOps.concat).toHaveBeenCalledWith([input2[0], input3[0]], 1); }); it('Concat when input length and n mismatch', () => { - const spy = spyOn(tfOps, 'concat'); node.op = 'Concat'; node.inputParams.tensors = createTensorsAttr(0, -1); node.inputParams.axis = createNumberAttrFromIndex(-1); node.attrParams.n = createNumberAttr(1); - executeOp(node, {input1, input2, input3}, context); + spyOps.concat.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(spy).toHaveBeenCalledWith([input1[0]], 3); + expect(spyOps.concat).toHaveBeenCalledWith([input1[0]], 3); }); it('should match json def for Concat', () => { node.op = 'Concat'; @@ -76,24 +84,24 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json, 'Concat')).toBeTruthy(); }); it('ConcatV2', () => { - const spy = spyOn(tfOps, 'concat'); node.op = 'ConcatV2'; node.inputParams.tensors = createTensorsAttr(0, -1); node.inputParams.axis = createNumberAttrFromIndex(-1); node.attrParams.n = createNumberAttr(2); - executeOp(node, {input1, input2, input3}, context); + spyOps.concat.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(spy).toHaveBeenCalledWith([input1[0], input2[0]], 3); + expect(spyOps.concat).toHaveBeenCalledWith([input1[0], input2[0]], 3); }); it('ConcatV2 when input length and n mismatch', () => { - const spy = spyOn(tfOps, 'concat'); node.op = 'ConcatV2'; node.inputParams.tensors = createTensorsAttr(0, -1); node.inputParams.axis = createNumberAttrFromIndex(-1); node.attrParams.n = createNumberAttr(1); - executeOp(node, {input1, input2, input3}, context); + spyOps.concat.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(spy).toHaveBeenCalledWith([input1[0]], 3); + expect(spyOps.concat).toHaveBeenCalledWith([input1[0]], 3); }); it('should match json def for ConcatV2', () => { node.op = 'ConcatV2'; @@ -104,13 +112,14 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json, 'ConcatV2')).toBeTruthy(); }); it('should call tfOps.unstack', () => { - const spy = spyOn(tfOps, 'unstack'); node.op = 'Unpack'; node.inputParams.tensor = createTensorAttr(0); node.attrParams.axis = createNumberAttr(4); - executeOp(node, {input1}, context); + spyOps.unstack.and.returnValue({}); + + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(spy).toHaveBeenCalledWith(input1[0], 4); + expect(spyOps.unstack).toHaveBeenCalledWith(input1[0], 4); }); it('should match json def for unstack', () => { node.op = 'Unpack'; @@ -120,16 +129,16 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json)).toBeTruthy(); }); it('should call tfOps.stack', () => { - const spy = spyOn(tfOps, 'stack'); node.op = 'Pack'; node.inputParams.tensors = createTensorsAttr(0, 0); node.attrParams.axis = createNumberAttr(4); - executeOp(node, {input1, input2, input3}, context); + spyOps.stack.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(spy.calls.mostRecent().args[0][0]).toEqual(input1[0]); - expect(spy.calls.mostRecent().args[0][1]).toEqual(input2[0]); - expect(spy.calls.mostRecent().args[0][2]).toEqual(input3[0]); - expect(spy.calls.mostRecent().args[1]).toEqual(4); + expect(spyOps.stack.calls.mostRecent().args[0][0]).toEqual(input1[0]); + expect(spyOps.stack.calls.mostRecent().args[0][1]).toEqual(input2[0]); + expect(spyOps.stack.calls.mostRecent().args[0][2]).toEqual(input3[0]); + expect(spyOps.stack.calls.mostRecent().args[1]).toEqual(4); }); it('should match json def for unstack', () => { node.op = 'Pack'; @@ -139,18 +148,19 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json)).toBeTruthy(); }); it('should reshape tensors for tfOps.stack', () => { - const spy = spyOn(tfOps, 'stack'); node.op = 'Pack'; node.inputNames = ['input1', 'input2', 'input3', 'input4']; node.inputParams.tensors = createTensorsAttr(0, 0); node.attrParams.axis = createNumberAttr(4); - executeOp(node, {input1, input2, input3, input4}, context); + spyOps.stack.and.returnValue({}); + executeOp( + node, {input1, input2, input3, input4}, context, spyOpsAsTfOps); - expect(spy.calls.mostRecent().args[0][0]).toEqual(input1[0]); - expect(spy.calls.mostRecent().args[0][1]).toEqual(input2[0]); - expect(spy.calls.mostRecent().args[0][2]).toEqual(input3[0]); - expect(spy.calls.mostRecent().args[0][3].shape).toEqual([]); - expect(spy.calls.mostRecent().args[1]).toEqual(4); + expect(spyOps.stack.calls.mostRecent().args[0][0]).toEqual(input1[0]); + expect(spyOps.stack.calls.mostRecent().args[0][1]).toEqual(input2[0]); + expect(spyOps.stack.calls.mostRecent().args[0][2]).toEqual(input3[0]); + expect(spyOps.stack.calls.mostRecent().args[0][3].shape).toEqual([]); + expect(spyOps.stack.calls.mostRecent().args[1]).toEqual(4); }); it('should raise error if tensors shape does not match for tfOps.stack', () => { @@ -179,14 +189,14 @@ describe('slice join', () => { }); describe('executeOp', () => { it('should call tfOps.reverse', () => { - spyOn(tfOps, 'reverse'); node.op = 'Reverse'; node.inputParams.dims = createBooleanArrayAttrFromIndex(1); node.inputNames = ['input1', 'input6']; const input6 = [tfOps.tensor1d([false, true], 'bool')]; - executeOp(node, {input1, input6}, context); + spyOps.reverse.and.returnValue({}); + executeOp(node, {input1, input6}, context, spyOpsAsTfOps); - expect(tfOps.reverse).toHaveBeenCalledWith(input1[0], [1]); + expect(spyOps.reverse).toHaveBeenCalledWith(input1[0], [1]); }); it('should match json def for reverse', () => { node.op = 'Reverse'; @@ -195,13 +205,13 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json, 'Reverse')).toBeTruthy(); }); it('should call tfOps.reverse', () => { - spyOn(tfOps, 'reverse'); node.op = 'ReverseV2'; node.inputParams.axis = createNumericArrayAttrFromIndex(1); node.inputNames = ['input1', 'input4']; - executeOp(node, {input1, input4}, context); + spyOps.reverse.and.returnValue({}); + executeOp(node, {input1, input4}, context, spyOpsAsTfOps); - expect(tfOps.reverse).toHaveBeenCalledWith(input1[0], [3]); + expect(spyOps.reverse).toHaveBeenCalledWith(input1[0], [3]); }); it('should match json def for reverse', () => { node.op = 'ReverseV2'; @@ -210,13 +220,13 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json, 'ReverseV2')).toBeTruthy(); }); it('should call tfOps.tile', () => { - spyOn(tfOps, 'tile'); node.op = 'Tile'; node.inputParams.reps = createNumericArrayAttrFromIndex(1); node.inputNames = ['input1', 'input4']; - executeOp(node, {input1, input4}, context); + spyOps.tile.and.returnValue({}); + executeOp(node, {input1, input4}, context, spyOpsAsTfOps); - expect(tfOps.tile).toHaveBeenCalledWith(input1[0], [3]); + expect(spyOps.tile).toHaveBeenCalledWith(input1[0], [3]); }); it('should match json def for tile', () => { node.op = 'Tile'; @@ -225,16 +235,16 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json)).toBeTruthy(); }); it('should call tfOps.slice', () => { - spyOn(tfOps, 'slice'); node.op = 'Slice'; node.inputParams.begin = createNumericArrayAttrFromIndex(1); node.inputParams.size = createNumericArrayAttrFromIndex(2); const input6 = [tfOps.tensor1d([2], 'int32')]; node.inputNames = ['input1', 'input6', 'input4']; + spyOps.slice.and.returnValue({}); - executeOp(node, {input1, input6, input4}, context); + executeOp(node, {input1, input6, input4}, context, spyOpsAsTfOps); - expect(tfOps.slice).toHaveBeenCalledWith(input1[0], [2], [3]); + expect(spyOps.slice).toHaveBeenCalledWith(input1[0], [2], [3]); }); it('should match json def for slice', () => { node.op = 'Slice'; @@ -244,7 +254,6 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json)).toBeTruthy(); }); it('should call tfOps.stridedSlice', () => { - spyOn(tfOps, 'stridedSlice'); node.op = 'StridedSlice'; node.inputParams.begin = createNumericArrayAttrFromIndex(1); node.inputParams.end = createNumericArrayAttrFromIndex(2); @@ -257,9 +266,10 @@ describe('slice join', () => { node.inputNames = ['input1', 'input6', 'input7', 'input4']; const input6 = [tfOps.tensor1d([2], 'int32')]; const input7 = [tfOps.tensor1d([3], 'int32')]; - executeOp(node, {input1, input6, input7, input4}, context); + executeOp( + node, {input1, input6, input7, input4}, context, spyOpsAsTfOps); - expect(tfOps.stridedSlice) + expect(spyOps.stridedSlice) .toHaveBeenCalledWith(input1[0], [2], [3], [3], 4, 5, 1, 2, 3); }); it('should match json def for stridedSlice', () => { @@ -276,14 +286,14 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json)).toBeTruthy(); }); it('should call tfOps.gather', () => { - spyOn(tfOps, 'gather'); node.op = 'Gather'; node.inputParams.indices = createTensorAttr(1); const input5 = [tfOps.scalar(2, 'int32')]; node.inputNames = ['input1', 'input5']; - executeOp(node, {input1, input5, input3}, context); + spyOps.gather.and.returnValue({}); + executeOp(node, {input1, input5, input3}, context, spyOpsAsTfOps); - expect(tfOps.gather) + expect(spyOps.gather) .toHaveBeenCalledWith( input1[0], jasmine.objectContaining({dataId: input5[0].dataId}), 0); @@ -295,30 +305,30 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json, 'Gather')).toBeTruthy(); }); it('should call tfOps.gather', () => { - spyOn(tfOps, 'gather'); node.op = 'GatherV2'; node.inputParams.indices = createTensorAttr(1); node.inputParams.axis = createNumberAttrFromIndex(2); node.attrParams.batchDims = createNumberAttr(1); const input5 = [tfOps.scalar(2, 'int32')]; node.inputNames = ['input1', 'input5', 'input3']; - executeOp(node, {input1, input5, input3}, context); + spyOps.gather.and.returnValue({}); + executeOp(node, {input1, input5, input3}, context, spyOpsAsTfOps); - expect(tfOps.gather) + expect(spyOps.gather) .toHaveBeenCalledWith( input1[0], jasmine.objectContaining({dataId: input5[0].dataId}), 3, 1); }); it('should make indices param of int32 dtype', () => { - spyOn(tfOps, 'gather'); node.op = 'Gather'; node.inputParams.indices = createTensorAttr(1); node.inputNames = ['input1', 'input5']; const input5 = [tfOps.scalar(2, 'float32')]; - executeOp(node, {input1, input5}, context); + spyOps.gather.and.returnValue({}); + executeOp(node, {input1, input5}, context, spyOpsAsTfOps); - expect(tfOps.gather) + expect(spyOps.gather) .toHaveBeenCalledWith( input1[0], jasmine.objectContaining({dtype: 'int32'}), 0); }); @@ -331,15 +341,15 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json, 'GatherV2')).toBeTruthy(); }); it('should call tfOps.split', () => { - spyOn(tfOps, 'split'); node.op = 'Split'; node.inputParams.axis = createNumberAttrFromIndex(0); node.inputParams.x = createTensorAttr(1); node.attrParams.numOrSizeSplits = createNumberAttr(2); node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + spyOps.split.and.returnValue({}); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.split).toHaveBeenCalledWith(input2[0], 2, 1); + expect(spyOps.split).toHaveBeenCalledWith(input2[0], 2, 1); }); it('should match json def for split', () => { node.op = 'Split'; @@ -350,15 +360,15 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json, 'Split')).toBeTruthy(); }); it('should call tfOps.split', () => { - spyOn(tfOps, 'split'); node.op = 'SplitV'; node.inputParams.x = createTensorAttr(0); node.inputParams.numOrSizeSplits = createNumericArrayAttrFromIndex(1); node.inputParams.axis = createNumberAttrFromIndex(2); node.inputNames = ['input1', 'input2', 'input3']; - executeOp(node, {input1, input2, input3}, context); + spyOps.split.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.split).toHaveBeenCalledWith(input1[0], 2, 3); + expect(spyOps.split).toHaveBeenCalledWith(input1[0], 2, 3); }); it('should match json def for split', () => { node.op = 'SplitV'; @@ -369,15 +379,17 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json, 'SplitV')).toBeTruthy(); }); it('should call tfOps.scatterND', () => { - spyOn(tfOps, 'scatterND'); node.op = 'ScatterNd'; node.inputParams.indices = createTensorAttr(0); node.inputParams.values = createTensorAttr(1); node.inputParams.shape = createNumericArrayAttrFromIndex(2); node.inputNames = ['input1', 'input2', 'input4']; - executeOp(node, {input1, input2, input4}, context); + spyOps.scatterND.and.returnValue({}); + executeOp(node, {input1, input2, input4}, context, spyOpsAsTfOps); - expect(tfOps.scatterND).toHaveBeenCalledWith(input1[0], input2[0], [3]); + expect(spyOps.scatterND).toHaveBeenCalledWith(input1[0], input2[0], [ + 3 + ]); }); it('should match json def for scatterND', () => { node.op = 'ScatterNd'; @@ -389,14 +401,14 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json)).toBeTruthy(); }); it('should call tfOps.gatherND', () => { - spyOn(tfOps, 'gatherND'); node.op = 'GatherNd'; node.inputParams.x = createTensorAttr(0); node.inputParams.indices = createTensorAttr(1); node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + spyOps.gatherND.and.returnValue({}); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.gatherND).toHaveBeenCalledWith(input1[0], input2[0]); + expect(spyOps.gatherND).toHaveBeenCalledWith(input1[0], input2[0]); }); it('should match json def for gatherND', () => { node.op = 'GatherNd'; @@ -406,7 +418,6 @@ describe('slice join', () => { expect(validateParam(node, slice_join.json)).toBeTruthy(); }); it('should call tfOps.sparseToDense', () => { - spyOn(tfOps, 'sparseToDense'); node.op = 'SparseToDense'; node.inputParams.sparseIndices = createTensorAttr(0); node.inputParams.outputShape = createNumericArrayAttrFromIndex(1); @@ -414,13 +425,14 @@ describe('slice join', () => { node.inputParams.defaultValue = createTensorAttr(3); node.inputParams.indices = createTensorAttr(1); node.inputNames = ['input1', 'input4', 'input3', 'input2']; - executeOp(node, {input1, input2, input3, input4}, context); + spyOps.sparseToDense.and.returnValue({}); + executeOp( + node, {input1, input2, input3, input4}, context, spyOpsAsTfOps); - expect(tfOps.sparseToDense) + expect(spyOps.sparseToDense) .toHaveBeenCalledWith(input1[0], input3[0], [3], input2[0]); }); it('should make defaultValue of same dtype as sparseValues', () => { - spyOn(tfOps, 'sparseToDense'); node.op = 'SparseToDense'; node.inputParams.sparseIndices = createTensorAttr(0); node.inputParams.outputShape = createNumericArrayAttrFromIndex(1); @@ -429,9 +441,11 @@ describe('slice join', () => { node.inputParams.indices = createTensorAttr(1); const input5 = [tfOps.scalar(5, 'int32')]; node.inputNames = ['input1', 'input4', 'input3', 'input5']; - executeOp(node, {input1, input5, input3, input4}, context); + spyOps.sparseToDense.and.returnValue({}); + executeOp( + node, {input1, input5, input3, input4}, context, spyOpsAsTfOps); - expect(tfOps.sparseToDense) + expect(spyOps.sparseToDense) .toHaveBeenCalledWith( input1[0], input3[0], [3], jasmine.objectContaining({dtype: 'float32'})); diff --git a/tfjs-converter/src/operations/executors/sparse_executor.ts b/tfjs-converter/src/operations/executors/sparse_executor.ts index a751af82a7b..3cab1601d70 100644 --- a/tfjs-converter/src/operations/executors/sparse_executor.ts +++ b/tfjs-converter/src/operations/executors/sparse_executor.ts @@ -27,7 +27,7 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'SparseFillEmptyRows': { const { @@ -36,7 +36,7 @@ export const executeOp: InternalOpExecutor = emptyRowIndicator, reverseIndexMap } = - tfOps.sparse.sparseFillEmptyRows( + ops.sparse.sparseFillEmptyRows( getParamValue('indices', node, tensorMap, context) as Tensor2D, getParamValue('values', node, tensorMap, context) as Tensor1D, @@ -49,7 +49,7 @@ export const executeOp: InternalOpExecutor = ]; } case 'SparseReshape': { - const {outputIndices, outputShape} = tfOps.sparse.sparseReshape( + const {outputIndices, outputShape} = ops.sparse.sparseReshape( getParamValue('inputIndices', node, tensorMap, context) as Tensor2D, getParamValue('inputShape', node, tensorMap, context) as Tensor1D, @@ -57,7 +57,7 @@ export const executeOp: InternalOpExecutor = return [outputIndices, outputShape]; } case 'SparseSegmentMean': { - const outputData = tfOps.sparse.sparseSegmentMean( + const outputData = ops.sparse.sparseSegmentMean( getParamValue('data', node, tensorMap, context) as Tensor, getParamValue('indices', node, tensorMap, context) as Tensor1D, getParamValue('segmentIds', node, tensorMap, context) as @@ -65,7 +65,7 @@ export const executeOp: InternalOpExecutor = return [outputData]; } case 'SparseSegmentSum': { - const outputData = tfOps.sparse.sparseSegmentSum( + const outputData = ops.sparse.sparseSegmentSum( getParamValue('data', node, tensorMap, context) as Tensor, getParamValue('indices', node, tensorMap, context) as Tensor1D, getParamValue('segmentIds', node, tensorMap, context) as diff --git a/tfjs-converter/src/operations/executors/sparse_executor_test.ts b/tfjs-converter/src/operations/executors/sparse_executor_test.ts index 8d2a3ce376d..f42a3fc07a1 100644 --- a/tfjs-converter/src/operations/executors/sparse_executor_test.ts +++ b/tfjs-converter/src/operations/executors/sparse_executor_test.ts @@ -24,12 +24,17 @@ import {Node} from '../types'; import {executeOp} from './sparse_executor'; import {createTensorAttr, validateParam} from './test_helper'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; describe('sparse', () => { let node: Node; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; node = { name: 'test', op: '', @@ -45,7 +50,6 @@ describe('sparse', () => { describe('executeOp', () => { describe('SparseFillEmptyRows', () => { it('should call tfOps.sparse.sparseFillEmptyRows', async () => { - spyOn(tfOps.sparse, 'sparseFillEmptyRows').and.callThrough(); node.op = 'SparseFillEmptyRows'; node.inputParams = { indices: createTensorAttr(0), @@ -62,9 +66,9 @@ describe('sparse', () => { const defaultValue = [tfOps.scalar(-1, 'int32')]; const result = executeOp( node, {indices, values, denseShape, defaultValue}, - context) as Tensor[]; + context, spyOpsAsTfOps) as Tensor[]; - expect(tfOps.sparse.sparseFillEmptyRows) + expect(spyOps.sparse.sparseFillEmptyRows) .toHaveBeenCalledWith( indices[0], values[0], denseShape[0], defaultValue[0]); test_util.expectArraysClose( @@ -89,7 +93,6 @@ describe('sparse', () => { }); describe('SparseReshape', () => { it('should call tfOps.sparse.sparseReshape', async () => { - spyOn(tfOps.sparse, 'sparseReshape').and.callThrough(); node.op = 'SparseReshape'; node.inputParams = { inputIndices: createTensorAttr(0), @@ -103,10 +106,10 @@ describe('sparse', () => { const inputShape = [tfOps.tensor1d([2, 3, 6], 'int32')]; const newShape = [tfOps.tensor1d([9, -1], 'int32')]; const result = - executeOp(node, {inputIndices, inputShape, newShape}, context) as - Tensor[]; + executeOp(node, {inputIndices, inputShape, newShape}, context, + spyOpsAsTfOps) as Tensor[]; - expect(tfOps.sparse.sparseReshape) + expect(spyOps.sparse.sparseReshape) .toHaveBeenCalledWith(inputIndices[0], inputShape[0], newShape[0]); test_util.expectArraysClose( await result[0].data(), [0, 0, 0, 1, 1, 2, 4, 2, 8, 1]); @@ -126,7 +129,6 @@ describe('sparse', () => { }); describe('SparseSegmentMean', () => { it('should call tfOps.sparse.sparseSegmentMean', async () => { - spyOn(tfOps.sparse, 'sparseSegmentMean').and.callThrough(); node.op = 'SparseSegmentMean'; node.inputParams = { data: createTensorAttr(0), @@ -140,9 +142,10 @@ describe('sparse', () => { const indices = [tfOps.tensor1d([0, 1, 2], 'int32')]; const segmentIds = [tfOps.tensor1d([0, 1, 1], 'int32')]; const result = - executeOp(node, {data, indices, segmentIds}, context) as Tensor[]; + executeOp(node, {data, indices, segmentIds}, context, + spyOpsAsTfOps) as Tensor[]; - expect(tfOps.sparse.sparseSegmentMean) + expect(spyOps.sparse.sparseSegmentMean) .toHaveBeenCalledWith(data[0], indices[0], segmentIds[0]); test_util.expectArraysClose( await result[0].data(), [1.0, 2.0, 3.0, 4.0, 2.5, 2.5, 2.5, 2.5]); @@ -160,7 +163,6 @@ describe('sparse', () => { }); describe('SparseSegmentSum', () => { it('should call tfOps.sparse.sparseSegmentSum', async () => { - spyOn(tfOps.sparse, 'sparseSegmentSum').and.callThrough(); node.op = 'SparseSegmentSum'; node.inputParams = { data: createTensorAttr(0), @@ -174,9 +176,10 @@ describe('sparse', () => { const indices = [tfOps.tensor1d([0, 1], 'int32')]; const segmentIds = [tfOps.tensor1d([0, 0], 'int32')]; const result = - executeOp(node, {data, indices, segmentIds}, context) as Tensor[]; + executeOp(node, {data, indices, segmentIds}, context, + spyOpsAsTfOps) as Tensor[]; - expect(tfOps.sparse.sparseSegmentSum) + expect(spyOps.sparse.sparseSegmentSum) .toHaveBeenCalledWith(data[0], indices[0], segmentIds[0]); test_util.expectArraysClose(await result[0].data(), [0, 0, 0, 0]); }); diff --git a/tfjs-converter/src/operations/executors/spectral_executor.ts b/tfjs-converter/src/operations/executors/spectral_executor.ts index c0f169772dc..9a1f49b93f4 100644 --- a/tfjs-converter/src/operations/executors/spectral_executor.ts +++ b/tfjs-converter/src/operations/executors/spectral_executor.ts @@ -26,23 +26,23 @@ import {InternalOpExecutor, Node} from '../types'; import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = - (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext): - Tensor[] => { + (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext, + ops = tfOps): Tensor[] => { switch (node.op) { case 'FFT': { - return [tfOps.fft( + return [ops.fft( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'IFFT': { - return [tfOps.ifft( + return [ops.ifft( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'RFFT': { - return [tfOps.rfft( + return [ops.rfft( getParamValue('x', node, tensorMap, context) as Tensor)]; } case 'IRFFT': { - return [tfOps.irfft( + return [ops.irfft( getParamValue('x', node, tensorMap, context) as Tensor)]; } default: diff --git a/tfjs-converter/src/operations/executors/spectral_executor_test.ts b/tfjs-converter/src/operations/executors/spectral_executor_test.ts index c9f5dccc0eb..5e4d0a7947e 100644 --- a/tfjs-converter/src/operations/executors/spectral_executor_test.ts +++ b/tfjs-converter/src/operations/executors/spectral_executor_test.ts @@ -23,13 +23,18 @@ import {Node} from '../types'; import {executeOp} from './spectral_executor'; import {createTensorAttr, validateParam} from './test_helper'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; describe('spectral', () => { let node: Node; const input1 = [tfOps.scalar(1)]; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; node = { name: 'test', op: '', @@ -45,11 +50,11 @@ describe('spectral', () => { describe('executeOp', () => { describe('FFT', () => { it('should call tfOps.fft', () => { - spyOn(tfOps, 'fft'); node.op = 'FFT'; - executeOp(node, {input1}, context); + spyOps.fft.and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.fft).toHaveBeenCalledWith(input1[0]); + expect(spyOps.fft).toHaveBeenCalledWith(input1[0]); }); it('should match json def', () => { node.op = 'FFT'; @@ -59,11 +64,11 @@ describe('spectral', () => { }); describe('IFFT', () => { it('should call tfOps.ifft', () => { - spyOn(tfOps, 'ifft'); node.op = 'IFFT'; - executeOp(node, {input1}, context); + spyOps.ifft.and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.ifft).toHaveBeenCalledWith(input1[0]); + expect(spyOps.ifft).toHaveBeenCalledWith(input1[0]); }); it('should match json def', () => { node.op = 'IFFT'; @@ -73,11 +78,11 @@ describe('spectral', () => { }); describe('RFFT', () => { it('should call tfOps.rfft', () => { - spyOn(tfOps, 'rfft'); node.op = 'RFFT'; - executeOp(node, {input1}, context); + spyOps.rfft.and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.rfft).toHaveBeenCalledWith(input1[0]); + expect(spyOps.rfft).toHaveBeenCalledWith(input1[0]); }); it('should match json def', () => { node.op = 'RFFT'; @@ -87,11 +92,11 @@ describe('spectral', () => { }); describe('IRFFT', () => { it('should call tfOps.irfft', () => { - spyOn(tfOps, 'irfft'); node.op = 'IRFFT'; - executeOp(node, {input1}, context); + spyOps.irfft.and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.irfft).toHaveBeenCalledWith(input1[0]); + expect(spyOps.irfft).toHaveBeenCalledWith(input1[0]); }); it('should match json def', () => { node.op = 'IRFFT'; diff --git a/tfjs-converter/src/operations/executors/spy_ops.ts b/tfjs-converter/src/operations/executors/spy_ops.ts new file mode 100644 index 00000000000..12f772533a7 --- /dev/null +++ b/tfjs-converter/src/operations/executors/spy_ops.ts @@ -0,0 +1,17 @@ +export type RecursiveSpy = T extends Function ? jasmine.Spy : { + [K in keyof T]: RecursiveSpy +}; + +export function spyOnAllFunctions(obj: T): RecursiveSpy { + return Object.fromEntries( + Object.entries(obj).map(([key, val]) => { + if (val instanceof Function) { + return [key, jasmine.createSpy(`${key} spy`, val).and.callThrough()]; + } else if (val instanceof Array) { + return [key, val]; + } else if (val instanceof Object) { + return [key, spyOnAllFunctions(val)]; + } + return [key, val]; + })) as RecursiveSpy; +} diff --git a/tfjs-converter/src/operations/executors/string_executor.ts b/tfjs-converter/src/operations/executors/string_executor.ts index 8fa5904c9f2..e57e2107047 100644 --- a/tfjs-converter/src/operations/executors/string_executor.ts +++ b/tfjs-converter/src/operations/executors/string_executor.ts @@ -27,10 +27,10 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'StringNGrams': { - const {nGrams, nGramsSplits} = tfOps.string.stringNGrams( + const {nGrams, nGramsSplits} = ops.string.stringNGrams( getParamValue('data', node, tensorMap, context) as Tensor1D, getParamValue('dataSplits', node, tensorMap, context) as Tensor, getParamValue('separator', node, tensorMap, context) as string, @@ -45,14 +45,14 @@ export const executeOp: InternalOpExecutor = return [nGrams, nGramsSplits]; } case 'StringSplit': { - const {indices, values, shape} = tfOps.string.stringSplit( + const {indices, values, shape} = ops.string.stringSplit( getParamValue('input', node, tensorMap, context) as Tensor1D, getParamValue('delimiter', node, tensorMap, context) as Scalar, getParamValue('skipEmpty', node, tensorMap, context) as boolean); return [indices, values, shape]; } case 'StringToHashBucketFast': { - const output = tfOps.string.stringToHashBucketFast( + const output = ops.string.stringToHashBucketFast( getParamValue('input', node, tensorMap, context) as Tensor, getParamValue('numBuckets', node, tensorMap, context) as number); return [output]; diff --git a/tfjs-converter/src/operations/executors/string_executor_test.ts b/tfjs-converter/src/operations/executors/string_executor_test.ts index 5378216f19e..492d0165953 100644 --- a/tfjs-converter/src/operations/executors/string_executor_test.ts +++ b/tfjs-converter/src/operations/executors/string_executor_test.ts @@ -24,12 +24,17 @@ import {Node} from '../types'; import {executeOp} from './string_executor'; import {createBoolAttr, createNumberAttr, createNumericArrayAttr, createStrAttr, createTensorAttr, validateParam} from './test_helper'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; describe('string', () => { let node: Node; const context = new ExecutionContext({}, {}, {}); + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; node = { name: 'test', op: '', @@ -46,7 +51,6 @@ describe('string', () => { describe('executeOp', () => { describe('StringNGrams', () => { it('should call tfOps.string.stringNGrams', async () => { - spyOn(tfOps.string, 'stringNGrams').and.callThrough(); node.op = 'StringNGrams'; node.inputParams = { data: createTensorAttr(0), @@ -65,9 +69,10 @@ describe('string', () => { const data = [tfOps.tensor1d(['a', 'b', 'c', 'd', 'e', 'f'], 'string')]; const dataSplits = [tfOps.tensor1d([0, 4, 6], 'int32')]; - const result = executeOp(node, {data, dataSplits}, context) as Tensor[]; + const result = executeOp(node, {data, dataSplits}, context, + spyOpsAsTfOps) as Tensor[]; - expect(tfOps.string.stringNGrams) + expect(spyOps.string.stringNGrams) .toHaveBeenCalledWith( data[0], dataSplits[0], '|', [3], 'LP', 'RP', -1, false); test_util.expectArraysEqual(await result[0].data(), [ @@ -88,7 +93,6 @@ describe('string', () => { }); describe('StringSplit', () => { it('should call tfOps.string.stringSplit', async () => { - spyOn(tfOps.string, 'stringSplit').and.callThrough(); node.op = 'StringSplit'; node.inputParams = { input: createTensorAttr(0), @@ -100,9 +104,10 @@ describe('string', () => { const input = [tfOps.tensor1d(['#a', 'b#', '#c#'], 'string')]; const delimiter = [tfOps.scalar('#', 'string')]; - const result = executeOp(node, {input, delimiter}, context) as Tensor[]; + const result = executeOp(node, {input, delimiter}, context, + spyOpsAsTfOps) as Tensor[]; - expect(tfOps.string.stringSplit) + expect(spyOps.string.stringSplit) .toHaveBeenCalledWith(input[0], delimiter[0], false); test_util.expectArraysEqual( await result[0].data(), [0, 0, 0, 1, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2]); @@ -123,16 +128,16 @@ describe('string', () => { }); describe('StringToHashBucketFast', () => { it('should call tfOps.string.stringToHashBucketFast', async () => { - spyOn(tfOps.string, 'stringToHashBucketFast').and.callThrough(); node.op = 'StringToHashBucketFast'; node.inputParams = {input: createTensorAttr(0)}; node.attrParams = {numBuckets: createNumberAttr(10)}; node.inputNames = ['input']; const input = [tfOps.tensor1d(['a', 'b', 'c', 'd'], 'string')]; - const result = executeOp(node, {input}, context) as Tensor[]; + const result = executeOp(node, {input}, context, + spyOpsAsTfOps) as Tensor[]; - expect(tfOps.string.stringToHashBucketFast) + expect(spyOps.string.stringToHashBucketFast) .toHaveBeenCalledWith(input[0], 10); test_util.expectArraysClose(await result[0].data(), [9, 2, 2, 5]); }); diff --git a/tfjs-converter/src/operations/executors/test_helper.ts b/tfjs-converter/src/operations/executors/test_helper.ts index f4153e1d12c..f2deb049cbd 100644 --- a/tfjs-converter/src/operations/executors/test_helper.ts +++ b/tfjs-converter/src/operations/executors/test_helper.ts @@ -14,6 +14,7 @@ * limitations under the License. * ============================================================================= */ + import {InputParamValue, OpMapper, ParamValue} from '../types'; import {Node} from '../types'; @@ -94,3 +95,8 @@ export function validateParam( } return matched; } + +// TODO(mattsoulanille): Change the return type to Uncapitalize in TS4. +export function uncapitalize(name: Name): string { + return name.charAt(0).toLowerCase() + name.slice(1); +} diff --git a/tfjs-converter/src/operations/executors/transformation_executor.ts b/tfjs-converter/src/operations/executors/transformation_executor.ts index 0b0083bf062..80833468e3d 100644 --- a/tfjs-converter/src/operations/executors/transformation_executor.ts +++ b/tfjs-converter/src/operations/executors/transformation_executor.ts @@ -27,10 +27,10 @@ import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext): Tensor[] => { + context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { case 'Cast': { - return [tfOps.cast( + return [ops.cast( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('dtype', node, tensorMap, context) as 'int32' | 'float32' | 'bool')]; @@ -38,23 +38,23 @@ export const executeOp: InternalOpExecutor = case 'ExpandDims': { const axis = getParamValue('axis', node, tensorMap, context) as number; - return [tfOps.expandDims( + return [ops.expandDims( getParamValue('x', node, tensorMap, context) as Tensor, axis)]; } case 'Squeeze': { const axis = getParamValue('axis', node, tensorMap, context) as number[]; - return [tfOps.squeeze( + return [ops.squeeze( getParamValue('x', node, tensorMap, context) as Tensor, axis)]; } case 'Reshape': { - return [tfOps.reshape( + return [ops.reshape( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('shape', node, tensorMap, context) as number[])]; } case 'MirrorPad': { - return [tfOps.mirrorPad( + return [ops.mirrorPad( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('padding', node, tensorMap, context) as Array<[number, number]>, @@ -63,7 +63,7 @@ export const executeOp: InternalOpExecutor = } case 'PadV2': case 'Pad': { - return [tfOps.pad( + return [ops.pad( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('padding', node, tensorMap, context) as Array<[number, number]>, @@ -75,7 +75,7 @@ export const executeOp: InternalOpExecutor = getParamValue('blockShape', node, tensorMap, context) as number[]; const paddings = getParamValue('paddings', node, tensorMap, context) as number[][]; - return [tfOps.spaceToBatchND( + return [ops.spaceToBatchND( getParamValue('x', node, tensorMap, context) as Tensor, blockShape, paddings)]; } @@ -84,7 +84,7 @@ export const executeOp: InternalOpExecutor = getParamValue('blockShape', node, tensorMap, context) as number[]; const crops = getParamValue('crops', node, tensorMap, context) as number[][]; - return [tfOps.batchToSpaceND( + return [ops.batchToSpaceND( getParamValue('x', node, tensorMap, context) as Tensor, blockShape, crops)]; } @@ -95,17 +95,17 @@ export const executeOp: InternalOpExecutor = (getParamValue('dataFormat', node, tensorMap, context) as string).toUpperCase() as 'NHWC' | 'NCHW'; - return [tfOps.depthToSpace( + return [ops.depthToSpace( getParamValue('x', node, tensorMap, context) as Tensor4D, blockSize, dataFormat)]; } case 'BroadcastTo': { - return [tfOps.broadcastTo( + return [ops.broadcastTo( getParamValue('x', node, tensorMap, context) as Tensor, getParamValue('shape', node, tensorMap, context) as number[])]; } case 'BroadcastArgs': { - return [tfOps.broadcastArgs( + return [ops.broadcastArgs( getParamValue('s0', node, tensorMap, context) as Tensor, getParamValue('s1', node, tensorMap, context) as Tensor)]; } diff --git a/tfjs-converter/src/operations/executors/transformation_executor_test.ts b/tfjs-converter/src/operations/executors/transformation_executor_test.ts index 3ff0fe67006..94e658f9cf0 100644 --- a/tfjs-converter/src/operations/executors/transformation_executor_test.ts +++ b/tfjs-converter/src/operations/executors/transformation_executor_test.ts @@ -21,6 +21,7 @@ import {ExecutionContext} from '../../executor/execution_context'; import {Node} from '../types'; import {createDtypeAttr, createNumberAttr, createNumericArrayAttrFromIndex, createStrAttr, createTensorAttr} from './test_helper'; import {executeOp} from './transformation_executor'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; describe('transformation', () => { let node: Node; @@ -42,156 +43,161 @@ describe('transformation', () => { }); describe('executeOp', () => { + let spyOps: RecursiveSpy; + let spyOpsAsTfOps: typeof tfOps; + + beforeEach(() => { + spyOps = spyOnAllFunctions(tfOps); + spyOpsAsTfOps = spyOps as unknown as typeof tfOps; + }); + describe('Cast', () => { it('should call tfOps.cast', () => { - spyOn(tfOps, 'cast'); node.op = 'Cast'; node.attrParams.dtype = createDtypeAttr('float32'); - executeOp(node, {input1}, context); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.cast).toHaveBeenCalledWith(input1[0], 'float32'); + expect(spyOps.cast).toHaveBeenCalledWith(input1[0], 'float32'); }); }); - describe('expandDExpandDimsims', () => { + describe('ExpandDims', () => { it('should call tfOps.expandDims', () => { - spyOn(tfOps, 'expandDims'); node.op = 'ExpandDims'; node.attrParams.axis = createNumberAttr(1); - executeOp(node, {input1}, context); + spyOps.expandDims.and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.expandDims).toHaveBeenCalledWith(input1[0], 1); + expect(spyOps.expandDims).toHaveBeenCalledWith(input1[0], 1); }); }); describe('MirrorPad', () => { it('should call tfc.mirrorPad', () => { - spyOn(tfOps, 'mirrorPad'); node.op = 'MirrorPad'; node.inputParams.padding = createNumericArrayAttrFromIndex(1); node.attrParams.mode = createStrAttr('reflect'); node.inputNames = ['input1', 'input3']; const input3 = [tfOps.tensor2d([1, 1, 2, 2], [2, 2])]; - executeOp(node, {input1, input3}, context); + spyOps.mirrorPad.and.returnValue({}); + executeOp(node, {input1, input3}, context, spyOpsAsTfOps); - expect(tfOps.mirrorPad) + expect(spyOps.mirrorPad) .toHaveBeenCalledWith(input1[0], [[1, 1], [2, 2]], 'reflect'); }); }); describe('Pad', () => { it('should call tfOps.pad', () => { - spyOn(tfOps, 'pad'); node.op = 'Pad'; node.inputParams.padding = createNumericArrayAttrFromIndex(1); node.attrParams.constantValue = createNumberAttr(1); node.inputNames = ['input1', 'input3']; const input3 = [tfOps.tensor2d([1, 1, 2, 2], [2, 2])]; - executeOp(node, {input1, input3}, context); + spyOps.pad.and.returnValue({}); + executeOp(node, {input1, input3}, context, spyOpsAsTfOps); - expect(tfOps.pad).toHaveBeenCalledWith(input1[0], [[1, 1], [2, 2]], 1); + expect(spyOps.pad).toHaveBeenCalledWith(input1[0], [[1, 1], [2, 2]], 1); }); }); describe('PadV2', () => { it('should call tfOps.pad', () => { - spyOn(tfOps, 'pad'); node.op = 'PadV2'; node.inputParams.padding = createNumericArrayAttrFromIndex(1); node.attrParams.constantValue = createNumberAttr(1); node.inputNames = ['input1', 'input3']; const input3 = [tfOps.tensor2d([1, 1, 2, 2], [2, 2])]; - executeOp(node, {input1, input3}, context); + spyOps.pad.and.returnValue({}); + executeOp(node, {input1, input3}, context, spyOpsAsTfOps); - expect(tfOps.pad).toHaveBeenCalledWith(input1[0], [[1, 1], [2, 2]], 1); + expect(spyOps.pad).toHaveBeenCalledWith(input1[0], [[1, 1], [2, 2]], 1); }); }); describe('Reshape', () => { it('should call tfOps.reshape', () => { - spyOn(tfOps, 'reshape'); node.op = 'Reshape'; node.inputParams.shape = createNumericArrayAttrFromIndex(1); node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.reshape).toHaveBeenCalledWith(input1[0], [1, 1]); + expect(spyOps.reshape).toHaveBeenCalledWith(input1[0], [1, 1]); }); }); describe('Squeeze', () => { it('should call tfOps.squeeze', () => { - spyOn(tfOps, 'squeeze'); node.op = 'Squeeze'; node.attrParams.axis = createNumberAttr(1); - executeOp(node, {input1}, context); + spyOps.squeeze.and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.squeeze).toHaveBeenCalledWith(input1[0], 1); + expect(spyOps.squeeze).toHaveBeenCalledWith(input1[0], 1); }); }); describe('SpaceToBatchND', () => { it('should call tfOps.spaceToBatchND', () => { - spyOn(tfOps, 'spaceToBatchND'); node.op = 'SpaceToBatchND'; node.inputParams.blockShape = createNumericArrayAttrFromIndex(1); node.inputParams.paddings = createNumericArrayAttrFromIndex(2); node.inputNames = ['input1', 'input2', 'input3']; const input2 = [tfOps.tensor1d([1, 1, 2, 2])]; const input3 = [tfOps.tensor2d([1, 2, 2, 3, 2, 3, 3, 4], [4, 2])]; - executeOp(node, {input1, input2, input3}, context); + spyOps.spaceToBatchND.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.spaceToBatchND) + expect(spyOps.spaceToBatchND) .toHaveBeenCalledWith( input1[0], [1, 1, 2, 2], [[1, 2], [2, 3], [2, 3], [3, 4]]); }); }); describe('BatchToSpaceND', () => { it('should call tfOps.batchToSpaceND', () => { - spyOn(tfOps, 'batchToSpaceND'); node.op = 'BatchToSpaceND'; node.inputParams.blockShape = createNumericArrayAttrFromIndex(1); node.inputParams.crops = createNumericArrayAttrFromIndex(2); node.inputNames = ['input1', 'input2', 'input3']; const input2 = [tfOps.tensor1d([1, 1, 2, 2])]; const input3 = [tfOps.tensor2d([1, 2, 2, 3, 2, 3, 3, 4], [4, 2])]; - executeOp(node, {input1, input2, input3}, context); + spyOps.batchToSpaceND.and.returnValue({}); + executeOp(node, {input1, input2, input3}, context, spyOpsAsTfOps); - expect(tfOps.batchToSpaceND) + expect(spyOps.batchToSpaceND) .toHaveBeenCalledWith( input1[0], [1, 1, 2, 2], [[1, 2], [2, 3], [2, 3], [3, 4]]); }); }); describe('DepthToSpace', () => { it('should call tfOps.depthToSpace', () => { - spyOn(tfOps, 'depthToSpace'); node.op = 'DepthToSpace'; node.attrParams.blockSize = createNumberAttr(1); node.attrParams.dataFormat = createStrAttr('nhwc'); node.inputNames = ['input1']; - executeOp(node, {input1}, context); + spyOps.depthToSpace.and.returnValue({}); + executeOp(node, {input1}, context, spyOpsAsTfOps); - expect(tfOps.depthToSpace).toHaveBeenCalledWith(input1[0], 1, 'NHWC'); + expect(spyOps.depthToSpace).toHaveBeenCalledWith(input1[0], 1, 'NHWC'); }); }); describe('BroadcastTo', () => { it('should call tfOps.broadcastTo', () => { - spyOn(tfOps, 'broadcastTo'); node.op = 'BroadcastTo'; node.inputParams.shape = createNumericArrayAttrFromIndex(1); node.inputNames = ['input1', 'input2']; - executeOp(node, {input1, input2}, context); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.broadcastTo).toHaveBeenCalledWith(input1[0], [1, 1]); + expect(spyOps.broadcastTo).toHaveBeenCalledWith(input1[0], [1, 1]); }); }); describe('BroadcastArgs', () => { it('should call tfOps.broadcastArgs', () => { - spyOn(tfOps, 'broadcastArgs'); node.op = 'BroadcastArgs'; node.inputParams.s0 = createTensorAttr(0); node.inputParams.s1 = createTensorAttr(1); node.inputNames = ['input1', 'input2']; const input1 = [tfOps.tensor1d([1, 1])]; const input2 = [tfOps.tensor1d([1, 1])]; - executeOp(node, {input1, input2}, context); + spyOps.broadcastArgs.and.returnValue({}); + executeOp(node, {input1, input2}, context, spyOpsAsTfOps); - expect(tfOps.broadcastArgs).toHaveBeenCalledWith(input1[0], input2[0]); + expect(spyOps.broadcastArgs).toHaveBeenCalledWith(input1[0], input2[0]); }); }); }); diff --git a/tfjs-converter/src/operations/operation_executor.ts b/tfjs-converter/src/operations/operation_executor.ts index 6360d8d8d55..a2151ec4a1b 100644 --- a/tfjs-converter/src/operations/operation_executor.ts +++ b/tfjs-converter/src/operations/operation_executor.ts @@ -53,53 +53,48 @@ import {Node} from './types'; */ export function executeOp( node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext, - resourceManager?: ResourceManager): tfc.Tensor[]|Promise { + resourceManager?: ResourceManager, tidy = tfc.tidy): tfc.Tensor[]| + Promise { const value = ((node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext) => { switch (node.category) { case 'arithmetic': - return tfc.tidy( - () => arithmetic.executeOp(node, tensorMap, context)); + return tidy(() => arithmetic.executeOp(node, tensorMap, context)); case 'basic_math': - return tfc.tidy( - () => basicMath.executeOp(node, tensorMap, context)); + return tidy(() => basicMath.executeOp(node, tensorMap, context)); case 'control': return control.executeOp(node, tensorMap, context); case 'convolution': - return tfc.tidy( - () => convolution.executeOp(node, tensorMap, context)); + return tidy(() => convolution.executeOp(node, tensorMap, context)); case 'creation': - return tfc.tidy(() => creation.executeOp(node, tensorMap, context)); + return tidy(() => creation.executeOp(node, tensorMap, context)); case 'dynamic': return dynamic.executeOp(node, tensorMap, context); case 'evaluation': - return tfc.tidy( - () => evaluation.executeOp(node, tensorMap, context)); + return tidy(() => evaluation.executeOp(node, tensorMap, context)); case 'image': - return tfc.tidy(() => image.executeOp(node, tensorMap, context)); + return tidy(() => image.executeOp(node, tensorMap, context)); case 'graph': - return tfc.tidy(() => graph.executeOp(node, tensorMap, context)); + return tidy(() => graph.executeOp(node, tensorMap, context)); case 'logical': - return tfc.tidy(() => logical.executeOp(node, tensorMap, context)); + return tidy(() => logical.executeOp(node, tensorMap, context)); case 'matrices': - return tfc.tidy(() => matrices.executeOp(node, tensorMap, context)); + return tidy(() => matrices.executeOp(node, tensorMap, context)); case 'normalization': - return tfc.tidy( + return tidy( () => normalization.executeOp(node, tensorMap, context)); case 'reduction': - return tfc.tidy( - () => reduction.executeOp(node, tensorMap, context)); + return tidy(() => reduction.executeOp(node, tensorMap, context)); case 'slice_join': - return tfc.tidy( - () => sliceJoin.executeOp(node, tensorMap, context)); + return tidy(() => sliceJoin.executeOp(node, tensorMap, context)); case 'sparse': - return tfc.tidy(() => sparse.executeOp(node, tensorMap, context)); + return tidy(() => sparse.executeOp(node, tensorMap, context)); case 'spectral': - return tfc.tidy(() => spectral.executeOp(node, tensorMap, context)); + return tidy(() => spectral.executeOp(node, tensorMap, context)); case 'string': - return tfc.tidy(() => string.executeOp(node, tensorMap, context)); + return tidy(() => string.executeOp(node, tensorMap, context)); case 'transformation': - return tfc.tidy( + return tidy( () => transformation.executeOp(node, tensorMap, context)); case 'hash_table': return hashTable.executeOp( @@ -120,7 +115,7 @@ export function executeOp( } })(node, tensorMap, context); if (tfc.util.isPromise(value)) { - return (value as Promise).then((data) => [].concat(data)); + return value.then((data) => [].concat(data)); } return [].concat(value); } diff --git a/tfjs-converter/src/operations/operation_executor_test.ts b/tfjs-converter/src/operations/operation_executor_test.ts index 6617b267129..e826cf02b89 100644 --- a/tfjs-converter/src/operations/operation_executor_test.ts +++ b/tfjs-converter/src/operations/operation_executor_test.ts @@ -78,10 +78,11 @@ describe('OperationExecutor', () => { string, transformation] .forEach(category => { it('should call tidy around executor', () => { - spyOn(tfc, 'tidy'); + const tidySpy = jasmine.createSpy('tidy spy', tfc.tidy); + node.category = category.CATEGORY; - executeOp(node, {}, context); - expect(tfc.tidy).toHaveBeenCalled(); + executeOp(node, {}, context, undefined, tidySpy); + expect(tidySpy).toHaveBeenCalled(); }); }); diff --git a/tfjs-converter/src/operations/types.ts b/tfjs-converter/src/operations/types.ts index 04508504180..5a033e065d5 100644 --- a/tfjs-converter/src/operations/types.ts +++ b/tfjs-converter/src/operations/types.ts @@ -15,6 +15,8 @@ * ============================================================================= */ import {Tensor} from '@tensorflow/tfjs-core'; +// tslint:disable-next-line:no-imports-from-dist +import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter'; import * as tensorflow from '../data/compiled_api'; import {NamedTensorsMap} from '../data/types'; @@ -75,13 +77,13 @@ export declare interface AttrParamMapper extends ParamMapper { } export interface InternalOpExecutor { - (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext): Tensor - |Tensor[]; + (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext, + ops?: typeof tfOps): Tensor | Tensor[]; } export interface InternalOpAsyncExecutor { (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext, - resourceManager?: ResourceManager): Promise; + resourceManager?: ResourceManager, ops?: typeof tfOps): Promise; } export declare interface OpMapper { diff --git a/tfjs-converter/yarn.lock b/tfjs-converter/yarn.lock index 3f8d5360ca2..b0970fa577d 100644 --- a/tfjs-converter/yarn.lock +++ b/tfjs-converter/yarn.lock @@ -100,6 +100,18 @@ resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.38.tgz#f8bb07c371ccb1903f3752872c89f44006132947" integrity sha512-5jY9RhV7c0Z4Jy09G+NIDTsCZ5G0L5n+Z+p+Y7t5VJHM30bgwzSjVtlcBxqAj+6L/swIlvtOSzr8rBk/aNyV2g== +ansi-regex@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" + integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== + +ansi-styles@^4.0.0, ansi-styles@^4.1.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" + integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== + dependencies: + color-convert "^2.0.1" + arg@^4.1.0: version "4.1.3" resolved "https://registry.yarnpkg.com/arg/-/arg-4.1.3.tgz#269fc7ad5b8e42cb63c896d5666017261c144089" @@ -112,21 +124,166 @@ argparse@^1.0.10: dependencies: sprintf-js "~1.0.2" +balanced-match@^1.0.0: + version "1.0.2" + resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.2.tgz#e83e3a7e3f300b34cb9d87f615fa0cbf357690ee" + integrity sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw== + +brace-expansion@^1.1.7: + version "1.1.11" + resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-1.1.11.tgz#3c7fcbf529d87226f3d2f52b966ff5271eb441dd" + integrity sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA== + dependencies: + balanced-match "^1.0.0" + concat-map "0.0.1" + buffer-from@^1.0.0: version "1.1.1" resolved "https://registry.yarnpkg.com/buffer-from/-/buffer-from-1.1.1.tgz#32713bc028f75c02fdb710d7c7bcec1f2c6070ef" integrity sha512-MQcXEUbCKtEo7bhqEs6560Hyd4XaovZlO/k9V3hjVUF/zwW7KBVdSK4gIt/bzwS9MbR5qob+F5jusZsb0YQK2A== +chalk@^4.1.0: + version "4.1.2" + resolved "https://registry.yarnpkg.com/chalk/-/chalk-4.1.2.tgz#aac4e2b7734a740867aeb16bf02aad556a1e7a01" + integrity sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA== + dependencies: + ansi-styles "^4.1.0" + supports-color "^7.1.0" + +cliui@^7.0.2: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" + integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^7.0.0" + +color-convert@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" + integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ== + dependencies: + color-name "~1.1.4" + +color-name@~1.1.4: + version "1.1.4" + resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" + integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== + +concat-map@0.0.1: + version "0.0.1" + resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b" + integrity sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg== + +detect-indent@^6.0.0: + version "6.1.0" + resolved "https://registry.yarnpkg.com/detect-indent/-/detect-indent-6.1.0.tgz#592485ebbbf6b3b1ab2be175c8393d04ca0d57e6" + integrity sha512-reYkTUJAZb9gUuZ2RvVCNhVHdg62RHnJ7WJl8ftMi4diZ6NWlciOzQN88pUhSELEwflJht4oQDv0F0BMlwaYtA== + diff@^4.0.1: version "4.0.2" resolved "https://registry.yarnpkg.com/diff/-/diff-4.0.2.tgz#60f3aecb89d5fae520c11aa19efc2bb982aade7d" integrity sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A== +emoji-regex@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" + integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== + +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + +fs-extra@^8.0.1: + version "8.1.0" + resolved "https://registry.yarnpkg.com/fs-extra/-/fs-extra-8.1.0.tgz#49d43c45a88cd9677668cb7be1b46efdb8d2e1c0" + integrity sha512-yhlQgA6mnOJUKOsRUFsgJdQCvkKhcz8tlZG5HBQfReYZy46OwLcY+Zia0mtdHsOo9y/hP+CxMN0TU9QxoOtG4g== + dependencies: + graceful-fs "^4.2.0" + jsonfile "^4.0.0" + universalify "^0.1.0" + +fs.realpath@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" + integrity sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw== + +get-caller-file@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +glob@^7.1.4, glob@^7.1.6: + version "7.2.3" + resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b" + integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q== + dependencies: + fs.realpath "^1.0.0" + inflight "^1.0.4" + inherits "2" + minimatch "^3.1.1" + once "^1.3.0" + path-is-absolute "^1.0.0" + +graceful-fs@^4.1.6, graceful-fs@^4.2.0: + version "4.2.10" + resolved "https://registry.yarnpkg.com/graceful-fs/-/graceful-fs-4.2.10.tgz#147d3a006da4ca3ce14728c7aefc287c367d7a6c" + integrity sha512-9ByhssR2fPVsNZj478qUUbKfmL0+t5BDVyjShtyZZLiK7ZDAArFFfopyOTj0M05wE2tJPisA4iTnnXl2YoPvOA== + +has-flag@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b" + integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ== + +ignore-walk@^3.0.3: + version "3.0.4" + resolved "https://registry.yarnpkg.com/ignore-walk/-/ignore-walk-3.0.4.tgz#c9a09f69b7c7b479a5d74ac1a3c0d4236d2a6335" + integrity sha512-PY6Ii8o1jMRA1z4F2hRkH/xN59ox43DavKvD3oDpfurRlOJyAHpifIwpbdv1n4jt4ov0jSpw3kQ4GhJnpBL6WQ== + dependencies: + minimatch "^3.0.4" + +ignore@^5.0.4: + version "5.2.0" + resolved "https://registry.yarnpkg.com/ignore/-/ignore-5.2.0.tgz#6d3bac8fa7fe0d45d9f9be7bac2fc279577e345a" + integrity sha512-CmxgYGiEPCLhfLnpPp1MoRmifwEIOgjcHXxOBjv7mY96c+eWScsOP9c112ZyLdWHi0FxHjI+4uVhKYp/gcdRmQ== + +inflight@^1.0.4: + version "1.0.6" + resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" + integrity sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA== + dependencies: + once "^1.3.0" + wrappy "1" + +inherits@2: + version "2.0.4" + resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c" + integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ== + +ini@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/ini/-/ini-2.0.0.tgz#e5fd556ecdd5726be978fa1001862eacb0a94bc5" + integrity sha512-7PnF4oN3CvZF23ADhA5wRaYEQpJ8qygSkbtTXWBeXWXmEVRXK+1ITciHWwHhsjv1TmW0MgacIv6hEi5pX5NQdA== + +is-fullwidth-code-point@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" + integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== + is-wsl@^1.1.0: version "1.1.0" resolved "https://registry.yarnpkg.com/is-wsl/-/is-wsl-1.1.0.tgz#1f16e4aa22b04d1336b66188a66af3c600c3a66d" integrity sha1-HxbkqiKwTRM2tmGIpmrzxgDDpm0= +jsonfile@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/jsonfile/-/jsonfile-4.0.0.tgz#8771aae0799b64076b76640fca058f9c10e33ecb" + integrity sha512-m6F1R3z8jjlf2imQHS2Qez5sjKWQzbuuhuJ/FKYFRZvPE3PuHcSMVZzfsLhGVOkfd20obL5SWEBew5ShlquNxg== + optionalDependencies: + graceful-fs "^4.1.6" + long@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/long/-/long-4.0.0.tgz#9a7b71cfb7d361a194ea555241c92f7468d5bf28" @@ -137,6 +294,13 @@ make-error@^1.1.1: resolved "https://registry.yarnpkg.com/make-error/-/make-error-1.3.6.tgz#2eb2e37ea9b67c4891f684a1394799af484cf7a2" integrity sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw== +minimatch@^3.0.4, minimatch@^3.1.1: + version "3.1.2" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b" + integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw== + dependencies: + brace-expansion "^1.1.7" + minimist@1.2.6: version "1.2.6" resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.6.tgz#8637a5b759ea0d6e98702cfb3a9283323c93af44" @@ -149,6 +313,35 @@ node-fetch@~2.6.1: dependencies: whatwg-url "^5.0.0" +npm-bundled@^1.1.1: + version "1.1.2" + resolved "https://registry.yarnpkg.com/npm-bundled/-/npm-bundled-1.1.2.tgz#944c78789bd739035b70baa2ca5cc32b8d860bc1" + integrity sha512-x5DHup0SuyQcmL3s7Rx/YQ8sbw/Hzg0rj48eN0dV7hf5cmQq5PXIeioroH3raV1QC1yh3uTYuMThvEQF3iKgGQ== + dependencies: + npm-normalize-package-bin "^1.0.1" + +npm-normalize-package-bin@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/npm-normalize-package-bin/-/npm-normalize-package-bin-1.0.1.tgz#6e79a41f23fd235c0623218228da7d9c23b8f6e2" + integrity sha512-EPfafl6JL5/rU+ot6P3gRSCpPDW5VmIzX959Ob1+ySFUuuYHWHekXpwdUZcKP5C+DS4GEtdJluwBjnsNDl+fSA== + +npm-packlist@^2.1.5: + version "2.2.2" + resolved "https://registry.yarnpkg.com/npm-packlist/-/npm-packlist-2.2.2.tgz#076b97293fa620f632833186a7a8f65aaa6148c8" + integrity sha512-Jt01acDvJRhJGthnUJVF/w6gumWOZxO7IkpY/lsX9//zqQgnF7OJaxgQXcerd4uQOLu7W5bkb4mChL9mdfm+Zg== + dependencies: + glob "^7.1.6" + ignore-walk "^3.0.3" + npm-bundled "^1.1.1" + npm-normalize-package-bin "^1.0.1" + +once@^1.3.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/once/-/once-1.4.0.tgz#583b1aa775961d4b113ac17d9c50baef9dd76bd1" + integrity sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w== + dependencies: + wrappy "1" + opn@~5.1.0: version "5.1.0" resolved "https://registry.yarnpkg.com/opn/-/opn-5.1.0.tgz#72ce2306a17dbea58ff1041853352b4a8fc77519" @@ -156,6 +349,11 @@ opn@~5.1.0: dependencies: is-wsl "^1.1.0" +path-is-absolute@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" + integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg== + protobufjs@~6.11.3: version "6.11.3" resolved "https://registry.yarnpkg.com/protobufjs/-/protobufjs-6.11.3.tgz#637a527205a35caa4f3e2a9a4a13ddffe0e7af74" @@ -175,6 +373,11 @@ protobufjs@~6.11.3: "@types/node" ">=13.7.0" long "^4.0.0" +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q== + source-map-support@^0.5.6: version "0.5.19" resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.19.tgz#a98b62f86dcaf4f67399648c085291ab9e8fed61" @@ -193,6 +396,29 @@ sprintf-js@~1.0.2: resolved "https://registry.yarnpkg.com/sprintf-js/-/sprintf-js-1.0.3.tgz#04e6926f662895354f3dd015203633b857297e2c" integrity sha1-BOaSb2YolTVPPdAVIDYzuFcpfiw= +string-width@^4.1.0, string-width@^4.2.0: + version "4.2.3" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" + integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.1" + +strip-ansi@^6.0.0, strip-ansi@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + +supports-color@^7.1.0: + version "7.2.0" + resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-7.2.0.tgz#1b7dcdcb32b8138801b3e478ba6a51caa89648da" + integrity sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw== + dependencies: + has-flag "^4.0.0" + tr46@~0.0.3: version "0.0.3" resolved "https://registry.yarnpkg.com/tr46/-/tr46-0.0.3.tgz#8184fd347dac9cdc185992f3a6622e14b9d9ab6a" @@ -214,6 +440,11 @@ typescript@3.5.3: resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.5.3.tgz#c830f657f93f1ea846819e929092f5fe5983e977" integrity sha512-ACzBtm/PhXBDId6a6sDJfroT2pOWt/oOnk4/dElG5G33ZL776N3Y6/6bKZJBFpd+b05F3Ct9qDjMeJmRWtE2/g== +universalify@^0.1.0: + version "0.1.2" + resolved "https://registry.yarnpkg.com/universalify/-/universalify-0.1.2.tgz#b646f69be3942dabcecc9d6639c80dc105efaa66" + integrity sha512-rBJeI5CXAlmy1pV+617WB9J63U6XcazHHF2f2dbJix4XzpUF0RS3Zbj0FGIOCAva5P/d/GBOYaACQ1w+0azUkg== + webidl-conversions@^3.0.0: version "3.0.1" resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871" @@ -227,6 +458,57 @@ whatwg-url@^5.0.0: tr46 "~0.0.3" webidl-conversions "^3.0.0" +wrap-ansi@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" + integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + +wrappy@1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" + integrity sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ== + +y18n@^5.0.5: + version "5.0.8" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" + integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== + +yalc@~1.0.0-pre.50: + version "1.0.0-pre.53" + resolved "https://registry.yarnpkg.com/yalc/-/yalc-1.0.0-pre.53.tgz#c51db2bb924a6908f4cb7e82af78f7e5606810bc" + integrity sha512-tpNqBCpTXplnduzw5XC+FF8zNJ9L/UXmvQyyQj7NKrDNavbJtHvzmZplL5ES/RCnjX7JR7W9wz5GVDXVP3dHUQ== + dependencies: + chalk "^4.1.0" + detect-indent "^6.0.0" + fs-extra "^8.0.1" + glob "^7.1.4" + ignore "^5.0.4" + ini "^2.0.0" + npm-packlist "^2.1.5" + yargs "^16.1.1" + +yargs-parser@^20.2.2: + version "20.2.9" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.9.tgz#2eb7dc3b0289718fc295f362753845c41a0c94ee" + integrity sha512-y11nGElTIV+CT3Zv9t7VKl+Q3hTQoT9a1Qzezhhl6Rp21gJ/IVTW7Z3y9EWXhuUBC2Shnf+DX0antecpAwSP8w== + +yargs@^16.1.1: + version "16.2.0" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.2.0.tgz#1c82bf0f6b6a66eafce7ef30e376f49a12477f66" + integrity sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw== + dependencies: + cliui "^7.0.2" + escalade "^3.1.1" + get-caller-file "^2.0.5" + require-directory "^2.1.1" + string-width "^4.2.0" + y18n "^5.0.5" + yargs-parser "^20.2.2" + yn@3.1.1: version "3.1.1" resolved "https://registry.yarnpkg.com/yn/-/yn-3.1.1.tgz#1e87401a09d767c1d5eab26a6e4c185182d2eb50" diff --git a/tsconfig.json b/tsconfig.json index ccdf3bd8b6b..6d8f9a84d16 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -11,7 +11,7 @@ "declaration": true, "target": "es2017", "lib": [ - "es2017", + "es2019", "dom" ], "outDir": "./dist",