Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tfjs-converter/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
"opn": "~5.1.0",
"protobufjs": "~6.11.3",
"ts-node": "~8.8.2",
"typescript": "3.5.3"
"typescript": "3.5.3",
"yalc": "~1.0.0-pre.50"
},
"scripts": {
"build": "bazel build :tfjs-converter_pkg",
Expand Down
6 changes: 3 additions & 3 deletions tfjs-converter/scripts/kernels_to_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,16 @@ function getKernelMappingForFile(source: SourceFile) {
const callExprs =
clausePart.getDescendantsOfKind(SyntaxKind.CallExpression);
const tfOpsCallExprs =
callExprs.filter(expr => expr.getText().match(/tfOps/));
callExprs.filter(expr => expr.getText().match(/ops/));
const tfSymbols: Set<string> = new Set();
for (const tfOpsCall of tfOpsCallExprs) {
const tfOpsCallStr = tfOpsCall.getText();
const functionCallMatcher = /(tfOps\.([\w\.]*)\()/g;
const functionCallMatcher = /(ops\.([\w\.]*)\()/g;
const matches = tfOpsCallStr.match(functionCallMatcher);
if (matches != null && matches.length > 0) {
for (const match of matches) {
// extract the method name (and any namespaces used to call it)
const symbolMatcher = /(tfOps\.([\w\.]*)\()/;
const symbolMatcher = /(ops\.([\w\.]*)\()/;
const symbol = match.match(symbolMatcher)[2];
tfSymbols.add(symbol);
}
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/src/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package(default_visibility = ["//visibility:public"])
TEST_SRCS = [
"**/*_test.ts",
"run_tests.ts",
"operations/executors/spy_ops.ts",
]

# Used for test-snippets
Expand Down
26 changes: 15 additions & 11 deletions tfjs-converter/src/executor/graph_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
private initializer: GraphExecutor;
private resourceManager: ResourceManager;
private signature: tensorflow.ISignatureDef;
private readonly io: typeof io;

// Returns the version information for the tensorflow model GraphDef.
get modelVersion(): string {
Expand Down Expand Up @@ -93,7 +94,8 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
*/
constructor(
private modelUrl: ModelURL,
private loadOptions: io.LoadOptions = {}) {
private loadOptions: io.LoadOptions = {}, tfio = io) {
this.io = tfio;
if (loadOptions == null) {
this.loadOptions = {};
}
Expand All @@ -107,14 +109,16 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
// Path is an IO Handler.
this.handler = path as IOHandler;
} else if (this.loadOptions.requestInit != null) {
this.handler = io.browserHTTPRequest(path as string, this.loadOptions) as
IOHandler;
this.handler = this.io
.browserHTTPRequest(path as string, this.loadOptions) as IOHandler;
} else {
const handlers = io.getLoadHandlers(path as string, this.loadOptions);
const handlers =
this.io.getLoadHandlers(path as string, this.loadOptions);
if (handlers.length === 0) {
// For backward compatibility: if no load handler can be found,
// assume it is a relative http path.
handlers.push(io.browserHTTPRequest(path as string, this.loadOptions));
handlers.push(
this.io.browserHTTPRequest(path as string, this.loadOptions));
} else if (handlers.length > 1) {
throw new Error(
`Found more than one (${handlers.length}) load handlers for ` +
Expand Down Expand Up @@ -171,8 +175,8 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
this.signature = signature;

this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
const weightMap =
io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);
const weightMap = this.io.decodeWeights(
this.artifacts.weightData, this.artifacts.weightSpecs);
this.executor = new GraphExecutor(
OperationMapper.Instance.transformGraph(graph, this.signature));
this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
Expand Down Expand Up @@ -243,7 +247,7 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
async save(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig):
Promise<io.SaveResult> {
if (typeof handlerOrURL === 'string') {
const handlers = io.getSaveHandlers(handlerOrURL);
const handlers = this.io.getSaveHandlers(handlerOrURL);
if (handlers.length === 0) {
throw new Error(
`Cannot find any save handlers for URL '${handlerOrURL}'`);
Expand Down Expand Up @@ -452,8 +456,8 @@ export class GraphModel<ModelURL extends Url = string|io.IOHandler>
* @doc {heading: 'Models', subheading: 'Loading'}
*/
export async function loadGraphModel(
modelUrl: string|io.IOHandler,
options: io.LoadOptions = {}): Promise<GraphModel> {
modelUrl: string|io.IOHandler, options: io.LoadOptions = {},
tfio = io): Promise<GraphModel> {
if (modelUrl == null) {
throw new Error(
'modelUrl in loadGraphModel() cannot be null. Please provide a url ' +
Expand All @@ -466,7 +470,7 @@ export async function loadGraphModel(
if (options.fromTFHub && typeof modelUrl === 'string') {
modelUrl = getTFHubUrl(modelUrl);
}
const model = new GraphModel(modelUrl, options);
const model = new GraphModel(modelUrl, options, tfio);
await model.load();
return model;
}
Expand Down
57 changes: 31 additions & 26 deletions tfjs-converter/src/executor/graph_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {deregisterOp, registerOp} from '../operations/custom_op/register';
import {GraphNode} from '../operations/types';

import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model';
import {RecursiveSpy, spyOnAllFunctions} from '../operations/executors/spy_ops';

const HOST = 'http://example.org';
const MODEL_URL = `${HOST}/model.json`;
Expand Down Expand Up @@ -368,6 +369,12 @@ describe('loadSync', () => {
});

describe('loadGraphModel', () => {
let spyIo: RecursiveSpy<typeof io>;

beforeEach(() => {
spyIo = spyOnAllFunctions(io);
});

it('Pass a custom io handler', async () => {
const customLoader: tfc.io.IOHandler = {
load: async () => {
Expand Down Expand Up @@ -397,11 +404,11 @@ describe('loadGraphModel', () => {

it('Pass a fetchFunc', async () => {
const fetchFunc = () => {};
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CUSTOM_HTTP_MODEL_LOADER
]);
await loadGraphModel(MODEL_URL, {fetchFunc});
expect(tfc.io.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
await loadGraphModel(MODEL_URL, {fetchFunc}, spyIo);
expect(spyIo.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
});
});

Expand Down Expand Up @@ -436,13 +443,16 @@ describe('loadGraphModelSync', () => {
});

describe('Model', () => {
let spyIo: RecursiveSpy<typeof io>;

beforeEach(() => {
model = new GraphModel(MODEL_URL);
spyIo = spyOnAllFunctions(io);
model = new GraphModel(MODEL_URL, undefined, spyIo);
});

describe('custom model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CUSTOM_HTTP_MODEL_LOADER
]);
registerOp('CustomOp', (nodeValue: GraphNode) => {
Expand Down Expand Up @@ -484,11 +494,10 @@ describe('Model', () => {

describe('simple model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
SIMPLE_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
});
it('load', async () => {
const loaded = await model.load();
Expand Down Expand Up @@ -621,7 +630,7 @@ describe('Model', () => {
describe('dispose', () => {
it('should dispose the weights', async () => {
const numOfTensors = tfc.memory().numTensors;
model = new GraphModel(MODEL_URL);
model = new GraphModel(MODEL_URL, undefined, spyIo);

await model.load();
model.dispose();
Expand All @@ -639,7 +648,7 @@ describe('Model', () => {

describe('relative path', () => {
beforeEach(() => {
model = new GraphModel(RELATIVE_MODEL_URL);
model = new GraphModel(RELATIVE_MODEL_URL, undefined, spyIo);
});

it('load', async () => {
Expand All @@ -649,22 +658,22 @@ describe('Model', () => {
});

it('should loadGraphModel', async () => {
const model = await loadGraphModel(MODEL_URL);
const model = await loadGraphModel(MODEL_URL, undefined, spyIo);
expect(model).not.toBeUndefined();
});

it('should loadGraphModel with request options', async () => {
const model = await loadGraphModel(
MODEL_URL, {requestInit: {credentials: 'include'}});
expect(tfc.io.browserHTTPRequest).toHaveBeenCalledWith(MODEL_URL, {
MODEL_URL, {requestInit: {credentials: 'include'}}, spyIo);
expect(spyIo.browserHTTPRequest).toHaveBeenCalledWith(MODEL_URL, {
requestInit: {credentials: 'include'}
});
expect(model).not.toBeUndefined();
});

it('should call loadGraphModel for TfHub Module', async () => {
const url = `${HOST}/model/1`;
const model = await loadGraphModel(url, {fromTFHub: true});
const model = await loadGraphModel(url, {fromTFHub: true}, spyIo);
expect(model).toBeDefined();
});

Expand All @@ -686,11 +695,10 @@ describe('Model', () => {

describe('control flow model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CONTROL_FLOW_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(CONTROL_FLOW_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(CONTROL_FLOW_HTTP_MODEL_LOADER);
});

describe('save', () => {
Expand Down Expand Up @@ -777,11 +785,10 @@ describe('Model', () => {
};
describe('dynamic shape model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
DYNAMIC_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(DYNAMIC_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(DYNAMIC_HTTP_MODEL_LOADER);
});

it('should throw error if call predict directly', async () => {
Expand Down Expand Up @@ -822,11 +829,10 @@ describe('Model', () => {
});
describe('dynamic shape model with metadata', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
DYNAMIC_HTTP_MODEL_NEW_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(DYNAMIC_HTTP_MODEL_NEW_LOADER);
spyIo.browserHTTPRequest.and.returnValue(DYNAMIC_HTTP_MODEL_NEW_LOADER);
});

it('should be success if call executeAsync with signature key',
Expand All @@ -848,11 +854,10 @@ describe('Model', () => {

describe('Hashtable model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
HASHTABLE_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(HASHTABLE_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(HASHTABLE_HTTP_MODEL_LOADER);
});
it('should be successful if call executeAsync', async () => {
await model.load();
Expand Down
26 changes: 13 additions & 13 deletions tfjs-converter/src/operations/executors/arithmetic_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,66 +27,66 @@ import {getParamValue} from './utils';

export const executeOp: InternalOpExecutor =
(node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext): Tensor[] => {
context: ExecutionContext, ops = tfOps): Tensor[] => {
switch (node.op) {
case 'BiasAdd':
case 'AddV2':
case 'Add': {
return [tfOps.add(
return [ops.add(
(getParamValue('a', node, tensorMap, context) as Tensor),
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'AddN': {
return [tfOps.addN((
return [ops.addN((
getParamValue('tensors', node, tensorMap, context) as Tensor[]))];
}
case 'FloorMod':
case 'Mod':
return [tfOps.mod(
return [ops.mod(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
case 'Mul':
return [tfOps.mul(
return [ops.mul(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
case 'RealDiv':
case 'Div': {
return [tfOps.div(
return [ops.div(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'DivNoNan': {
return [tfOps.divNoNan(
return [ops.divNoNan(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'FloorDiv': {
return [tfOps.floorDiv(
return [ops.floorDiv(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Sub': {
return [tfOps.sub(
return [ops.sub(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Minimum': {
return [tfOps.minimum(
return [ops.minimum(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Maximum': {
return [tfOps.maximum(
return [ops.maximum(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'Pow': {
return [tfOps.pow(
return [ops.pow(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
case 'SquaredDifference': {
return [tfOps.squaredDifference(
return [ops.squaredDifference(
getParamValue('a', node, tensorMap, context) as Tensor,
getParamValue('b', node, tensorMap, context) as Tensor)];
}
Expand Down
Loading