Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tfjs-converter/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
"opn": "~5.1.0",
"protobufjs": "~6.11.3",
"ts-node": "~8.8.2",
"typescript": "3.5.3"
"typescript": "3.5.3",
"yalc": "~1.0.0-pre.50"
},
"scripts": {
"build": "bazel build :tfjs-converter_pkg",
Expand Down
6 changes: 3 additions & 3 deletions tfjs-converter/scripts/kernels_to_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,16 @@ function getKernelMappingForFile(source: SourceFile) {
const callExprs =
clausePart.getDescendantsOfKind(SyntaxKind.CallExpression);
const tfOpsCallExprs =
callExprs.filter(expr => expr.getText().match(/tfOps/));
callExprs.filter(expr => expr.getText().match(/ops/));
const tfSymbols: Set<string> = new Set();
for (const tfOpsCall of tfOpsCallExprs) {
const tfOpsCallStr = tfOpsCall.getText();
const functionCallMatcher = /(tfOps\.([\w\.]*)\()/g;
const functionCallMatcher = /(ops\.([\w\.]*)\()/g;
const matches = tfOpsCallStr.match(functionCallMatcher);
if (matches != null && matches.length > 0) {
for (const match of matches) {
// extract the method name (and any namespaces used to call it)
const symbolMatcher = /(tfOps\.([\w\.]*)\()/;
const symbolMatcher = /(ops\.([\w\.]*)\()/;
const symbol = match.match(symbolMatcher)[2];
tfSymbols.add(symbol);
}
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/src/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package(default_visibility = ["//visibility:public"])
TEST_SRCS = [
"**/*_test.ts",
"run_tests.ts",
"operations/executors/spy_ops.ts",
]

# Used for test-snippets
Expand Down
26 changes: 15 additions & 11 deletions tfjs-converter/src/executor/graph_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
private initializer: GraphExecutor;
private resourceManager: ResourceManager;
private signature: tensorflow.ISignatureDef;
private readonly io: typeof io;

// Returns the version information for the tensorflow model GraphDef.
get modelVersion(): string {
Expand Down Expand Up @@ -93,7 +94,8 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
*/
constructor(
private modelUrl: ModelURL,
private loadOptions: io.LoadOptions = {}) {
private loadOptions: io.LoadOptions = {}, tfio = io) {
this.io = tfio;
if (loadOptions == null) {
this.loadOptions = {};
}
Expand All @@ -107,14 +109,16 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
// Path is an IO Handler.
this.handler = path as IOHandler;
} else if (this.loadOptions.requestInit != null) {
this.handler = io.browserHTTPRequest(path as string, this.loadOptions) as
IOHandler;
this.handler = this.io
.browserHTTPRequest(path as string, this.loadOptions) as IOHandler;
} else {
const handlers = io.getLoadHandlers(path as string, this.loadOptions);
const handlers =
this.io.getLoadHandlers(path as string, this.loadOptions);
if (handlers.length === 0) {
// For backward compatibility: if no load handler can be found,
// assume it is a relative http path.
handlers.push(io.browserHTTPRequest(path as string, this.loadOptions));
handlers.push(
this.io.browserHTTPRequest(path as string, this.loadOptions));
} else if (handlers.length > 1) {
throw new Error(
`Found more than one (${handlers.length}) load handlers for ` +
Expand Down Expand Up @@ -171,8 +175,8 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
this.signature = signature;

this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
const weightMap =
io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);
const weightMap = this.io.decodeWeights(
this.artifacts.weightData, this.artifacts.weightSpecs);
this.executor = new GraphExecutor(
OperationMapper.Instance.transformGraph(graph, this.signature));
this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
Expand Down Expand Up @@ -243,7 +247,7 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
async save(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig):
Promise<io.SaveResult> {
if (typeof handlerOrURL === 'string') {
const handlers = io.getSaveHandlers(handlerOrURL);
const handlers = this.io.getSaveHandlers(handlerOrURL);
if (handlers.length === 0) {
throw new Error(
`Cannot find any save handlers for URL '${handlerOrURL}'`);
Expand Down Expand Up @@ -452,8 +456,8 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
* @doc {heading: 'Models', subheading: 'Loading'}
*/
export async function loadGraphModel(
modelUrl: string|io.IOHandler,
options: io.LoadOptions = {}): Promise<GraphModel> {
modelUrl: string|io.IOHandler, options: io.LoadOptions = {},
tfio = io): Promise<GraphModel> {
if (modelUrl == null) {
throw new Error(
'modelUrl in loadGraphModel() cannot be null. Please provide a url ' +
Expand All @@ -466,7 +470,7 @@ export async function loadGraphModel(
if (options.fromTFHub && typeof modelUrl === 'string') {
modelUrl = getTFHubUrl(modelUrl);
}
const model = new GraphModel(modelUrl, options);
const model = new GraphModel(modelUrl, options, tfio);
await model.load();
return model;
}
Expand Down
57 changes: 31 additions & 26 deletions tfjs-converter/src/executor/graph_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {deregisterOp, registerOp} from '../operations/custom_op/register';
import {GraphNode} from '../operations/types';

import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model';
import {RecursiveSpy, spyOnAllFunctions} from '../operations/executors/spy_ops';

const HOST = 'http://example.org';
const MODEL_URL = `${HOST}/model.json`;
Expand Down Expand Up @@ -368,6 +369,12 @@ describe('loadSync', () => {
});

describe('loadGraphModel', () => {
let spyIo: RecursiveSpy<typeof io>;

beforeEach(() => {
spyIo = spyOnAllFunctions(io);
});

it('Pass a custom io handler', async () => {
const customLoader: tfc.io.IOHandler = {
load: async () => {
Expand Down Expand Up @@ -397,11 +404,11 @@ describe('loadGraphModel', () => {

it('Pass a fetchFunc', async () => {
const fetchFunc = () => {};
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CUSTOM_HTTP_MODEL_LOADER
]);
await loadGraphModel(MODEL_URL, {fetchFunc});
expect(tfc.io.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
await loadGraphModel(MODEL_URL, {fetchFunc}, spyIo);
expect(spyIo.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
});
});

Expand Down Expand Up @@ -436,13 +443,16 @@ describe('loadGraphModelSync', () => {
});

describe('Model', () => {
let spyIo: RecursiveSpy<typeof io>;

beforeEach(() => {
model = new GraphModel(MODEL_URL);
spyIo = spyOnAllFunctions(io);
model = new GraphModel(MODEL_URL, undefined, spyIo);
});

describe('custom model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CUSTOM_HTTP_MODEL_LOADER
]);
registerOp('CustomOp', (nodeValue: GraphNode) => {
Expand Down Expand Up @@ -484,11 +494,10 @@ describe('Model', () => {

describe('simple model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
SIMPLE_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
});
it('load', async () => {
const loaded = await model.load();
Expand Down Expand Up @@ -621,7 +630,7 @@ describe('Model', () => {
describe('dispose', () => {
it('should dispose the weights', async () => {
const numOfTensors = tfc.memory().numTensors;
model = new GraphModel(MODEL_URL);
model = new GraphModel(MODEL_URL, undefined, spyIo);

await model.load();
model.dispose();
Expand All @@ -639,7 +648,7 @@ describe('Model', () => {

describe('relative path', () => {
beforeEach(() => {
model = new GraphModel(RELATIVE_MODEL_URL);
model = new GraphModel(RELATIVE_MODEL_URL, undefined, spyIo);
});

it('load', async () => {
Expand All @@ -649,22 +658,22 @@ describe('Model', () => {
});

it('should loadGraphModel', async () => {
const model = await loadGraphModel(MODEL_URL);
const model = await loadGraphModel(MODEL_URL, undefined, spyIo);
expect(model).not.toBeUndefined();
});

it('should loadGraphModel with request options', async () => {
const model = await loadGraphModel(
MODEL_URL, {requestInit: {credentials: 'include'}});
expect(tfc.io.browserHTTPRequest).toHaveBeenCalledWith(MODEL_URL, {
MODEL_URL, {requestInit: {credentials: 'include'}}, spyIo);
expect(spyIo.browserHTTPRequest).toHaveBeenCalledWith(MODEL_URL, {
requestInit: {credentials: 'include'}
});
expect(model).not.toBeUndefined();
});

it('should call loadGraphModel for TfHub Module', async () => {
const url = `${HOST}/model/1`;
const model = await loadGraphModel(url, {fromTFHub: true});
const model = await loadGraphModel(url, {fromTFHub: true}, spyIo);
expect(model).toBeDefined();
});

Expand All @@ -686,11 +695,10 @@ describe('Model', () => {

describe('control flow model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CONTROL_FLOW_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(CONTROL_FLOW_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(CONTROL_FLOW_HTTP_MODEL_LOADER);
});

describe('save', () => {
Expand Down Expand Up @@ -777,11 +785,10 @@ describe('Model', () => {
};
describe('dynamic shape model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
DYNAMIC_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(DYNAMIC_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(DYNAMIC_HTTP_MODEL_LOADER);
});

it('should throw error if call predict directly', async () => {
Expand Down Expand Up @@ -822,11 +829,10 @@ describe('Model', () => {
});
describe('dynamic shape model with metadata', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
DYNAMIC_HTTP_MODEL_NEW_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(DYNAMIC_HTTP_MODEL_NEW_LOADER);
spyIo.browserHTTPRequest.and.returnValue(DYNAMIC_HTTP_MODEL_NEW_LOADER);
});

it('should be success if call executeAsync with signature key',
Expand All @@ -848,11 +854,10 @@ describe('Model', () => {

describe('Hashtable model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
HASHTABLE_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(HASHTABLE_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(HASHTABLE_HTTP_MODEL_LOADER);
});
it('should be successful if call executeAsync', async () => {
await model.load();
Expand Down
26 changes: 13 additions & 13 deletions tfjs-converter/src/operations/executors/arithmetic_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,66 +27,66 @@ import {getParamValue} from './utils';

export const executeOp: InternalOpExecutor =
(node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext): Tensor[] => {
context: ExecutionContext, ops = tfOps): Tensor[] => {
switch (node.op) {
case 'BiasAdd':
case 'AddV2':
case 'Add': {
return [tfOps.add(
return [ops.add(
(getParamValue('a', node, tensorMap, context) as Tensor),
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'AddN': {
return [tfOps.addN((
return [ops.addN((
getParamValue('tensors', node, tensorMap, context) as Tensor[]))];
}
case 'FloorMod':
case 'Mod':
return [tfOps.mod(
return [ops.mod(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
case 'Mul':
return [tfOps.mul(
return [ops.mul(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
case 'RealDiv':
case 'Div': {
return [tfOps.div(
return [ops.div(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'DivNoNan': {
return [tfOps.divNoNan(
return [ops.divNoNan(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'FloorDiv': {
return [tfOps.floorDiv(
return [ops.floorDiv(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Sub': {
return [tfOps.sub(
return [ops.sub(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Minimum': {
return [tfOps.minimum(
return [ops.minimum(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Maximum': {
return [tfOps.maximum(
return [ops.maximum(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Pow': {
return [tfOps.pow(
return [ops.pow(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'SquaredDifference': {
return [tfOps.squaredDifference(
return [ops.squaredDifference(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
Expand Down
Loading