Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8f6d3da
Upgrade typescript to 4.4.2
mattsoulanille Aug 27, 2021
9dc3322
[core] Upgrade typescript to 4.4.2
mattsoulanille Aug 27, 2021
24e9353
Install tslib
mattsoulanille Aug 27, 2021
ea630b5
[webgl] Upgrade typescript to 4.4.2
mattsoulanille Aug 27, 2021
add0110
[wasm] Upgrade typescript to 4.4.2
mattsoulanille Aug 27, 2021
973cb6b
[layers] Update typescript to 4.4.2
mattsoulanille Aug 27, 2021
3d95906
[data] Update typescript to 4.4.2
mattsoulanille Aug 30, 2021
f482727
[tfjs-automl] Update typescript to 4.4.2
mattsoulanille Aug 30, 2021
d6cc93a
[inference] Upgrade TypeScript to 4.4.2
mattsoulanille Aug 31, 2021
68b6898
[cpu] Update package.json with ts 4.4.2
mattsoulanille Aug 31, 2021
76fb265
[converter] Fix executor tests to work with ts 4
mattsoulanille Aug 31, 2021
a2ca980
[converter] Update kenrels_to_ops to work with new mocked tfOps
mattsoulanille Nov 3, 2021
ec8a585
[converter] Fix graph model tests by using spyOnAllFunctions
mattsoulanille Nov 3, 2021
e695e22
[converter] Fix executeOp tests by using a tfc.tidy spy
mattsoulanille Nov 3, 2021
1cf25af
Fix lint errors
mattsoulanille Nov 3, 2021
ea8fb41
[union] Upgrade typescript to 4.4.2
mattsoulanille Nov 4, 2021
a70df89
[node] Upgrade typescript to 4.4.2
mattsoulanille Nov 4, 2021
db2c0a2
[tflite] Upgrade typescript to 4.4.2
mattsoulanille Nov 4, 2021
a3dce77
[e2e] Upgrade typescript to 4.4.2
mattsoulanille Nov 4, 2021
2eee6ae
[vis] Set sort to null instead of false
mattsoulanille Nov 10, 2021
29e8a79
fixup! [converter] Fix executor tests to work with ts 4
mattsoulanille Nov 10, 2021
1722513
fixup! [data] Update typescript to 4.4.2
mattsoulanille Nov 11, 2021
cdfcc35
[converter] Fix Cumprod and ImageProjectiveTransformV3 tests
mattsoulanille Apr 22, 2022
295fac6
[react-native] Upgrade typescript to 4.4.2
mattsoulanille Apr 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[converter] Fix graph model tests by using spyOnAllFunctions
  • Loading branch information
mattsoulanille committed Apr 22, 2022
commit ec8a58571fd161d724e736cc7582bdd73f46ec09
19 changes: 11 additions & 8 deletions tfjs-converter/src/executor/graph_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export class GraphModel implements InferenceModel {
private initializer: GraphExecutor;
private resourceManager: ResourceManager;
private signature: tensorflow.ISignatureDef;
private readonly io: typeof io;

// Returns the version information for the tensorflow model GraphDef.
get modelVersion(): string {
Expand Down Expand Up @@ -89,7 +90,9 @@ export class GraphModel implements InferenceModel {
*/
constructor(
private modelUrl: string|io.IOHandler,
private loadOptions: io.LoadOptions = {}) {
private loadOptions: io.LoadOptions = {},
tfio = io) {
this.io = tfio;
if (loadOptions == null) {
this.loadOptions = {};
}
Expand All @@ -102,13 +105,13 @@ export class GraphModel implements InferenceModel {
// Path is an IO Handler.
this.handler = path as io.IOHandler;
} else if (this.loadOptions.requestInit != null) {
this.handler = io.browserHTTPRequest(path as string, this.loadOptions);
this.handler = this.io.browserHTTPRequest(path as string, this.loadOptions);
} else {
const handlers = io.getLoadHandlers(path as string, this.loadOptions);
const handlers = this.io.getLoadHandlers(path as string, this.loadOptions);
if (handlers.length === 0) {
// For backward compatibility: if no load handler can be found,
// assume it is a relative http path.
handlers.push(io.browserHTTPRequest(path as string, this.loadOptions));
handlers.push(this.io.browserHTTPRequest(path as string, this.loadOptions));
} else if (handlers.length > 1) {
throw new Error(
`Found more than one (${handlers.length}) load handlers for ` +
Expand Down Expand Up @@ -157,7 +160,7 @@ export class GraphModel implements InferenceModel {

this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`;
const weightMap =
io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);
this.io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);
this.executor = new GraphExecutor(
OperationMapper.Instance.transformGraph(graph, this.signature));
this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
Expand Down Expand Up @@ -228,7 +231,7 @@ export class GraphModel implements InferenceModel {
async save(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig):
Promise<io.SaveResult> {
if (typeof handlerOrURL === 'string') {
const handlers = io.getSaveHandlers(handlerOrURL);
const handlers = this.io.getSaveHandlers(handlerOrURL);
if (handlers.length === 0) {
throw new Error(
`Cannot find any save handlers for URL '${handlerOrURL}'`);
Expand Down Expand Up @@ -438,7 +441,7 @@ export class GraphModel implements InferenceModel {
*/
export async function loadGraphModel(
modelUrl: string|io.IOHandler,
options: io.LoadOptions = {}): Promise<GraphModel> {
options: io.LoadOptions = {}, tfio = io): Promise<GraphModel> {
if (modelUrl == null) {
throw new Error(
'modelUrl in loadGraphModel() cannot be null. Please provide a url ' +
Expand All @@ -456,7 +459,7 @@ export async function loadGraphModel(
modelUrl = `${modelUrl}${DEFAULT_MODEL_NAME}${TFHUB_SEARCH_PARAM}`;
}
}
const model = new GraphModel(modelUrl, options);
const model = new GraphModel(modelUrl, options, tfio);
await model.load();
return model;
}
57 changes: 31 additions & 26 deletions tfjs-converter/src/executor/graph_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {deregisterOp, registerOp} from '../operations/custom_op/register';
import {GraphNode} from '../operations/types';

import {GraphModel, loadGraphModel} from './graph_model';
import {RecursiveSpy, spyOnAllFunctions} from '../operations/executors/spy_ops';

const HOST = 'http://example.org';
const MODEL_URL = `${HOST}/model.json`;
Expand Down Expand Up @@ -368,6 +369,12 @@ describe('loadSync', () => {
});

describe('loadGraphModel', () => {
let spyIo: RecursiveSpy<typeof io>;

beforeEach(() => {
spyIo = spyOnAllFunctions(io);
});

it('Pass a custom io handler', async () => {
const customLoader: tfc.io.IOHandler = {
load: async () => {
Expand Down Expand Up @@ -397,22 +404,25 @@ describe('loadGraphModel', () => {

it('Pass a fetchFunc', async () => {
const fetchFunc = () => {};
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CUSTOM_HTTP_MODEL_LOADER
]);
await loadGraphModel(MODEL_URL, {fetchFunc});
expect(tfc.io.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
await loadGraphModel(MODEL_URL, {fetchFunc}, spyIo);
expect(spyIo.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc});
});
});

describe('Model', () => {
let spyIo: RecursiveSpy<typeof io>;

beforeEach(() => {
model = new GraphModel(MODEL_URL);
spyIo = spyOnAllFunctions(io);
model = new GraphModel(MODEL_URL, undefined, spyIo);
});

describe('custom model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CUSTOM_HTTP_MODEL_LOADER
]);
registerOp('CustomOp', (nodeValue: GraphNode) => {
Expand Down Expand Up @@ -454,11 +464,10 @@ describe('Model', () => {

describe('simple model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
SIMPLE_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(SIMPLE_HTTP_MODEL_LOADER);
});
it('load', async () => {
const loaded = await model.load();
Expand Down Expand Up @@ -591,7 +600,7 @@ describe('Model', () => {
describe('dispose', () => {
it('should dispose the weights', async () => {
const numOfTensors = tfc.memory().numTensors;
model = new GraphModel(MODEL_URL);
model = new GraphModel(MODEL_URL, undefined, spyIo);

await model.load();
model.dispose();
Expand All @@ -609,7 +618,7 @@ describe('Model', () => {

describe('relative path', () => {
beforeEach(() => {
model = new GraphModel(RELATIVE_MODEL_URL);
model = new GraphModel(RELATIVE_MODEL_URL, undefined, spyIo);
});

it('load', async () => {
Expand All @@ -619,22 +628,22 @@ describe('Model', () => {
});

it('should loadGraphModel', async () => {
const model = await loadGraphModel(MODEL_URL);
const model = await loadGraphModel(MODEL_URL, undefined, spyIo);
expect(model).not.toBeUndefined();
});

it('should loadGraphModel with request options', async () => {
const model = await loadGraphModel(
MODEL_URL, {requestInit: {credentials: 'include'}});
expect(tfc.io.browserHTTPRequest).toHaveBeenCalledWith(MODEL_URL, {
MODEL_URL, {requestInit: {credentials: 'include'}}, spyIo);
expect(spyIo.browserHTTPRequest).toHaveBeenCalledWith(MODEL_URL, {
requestInit: {credentials: 'include'}
});
expect(model).not.toBeUndefined();
});

it('should call loadGraphModel for TfHub Module', async () => {
const url = `${HOST}/model/1`;
const model = await loadGraphModel(url, {fromTFHub: true});
const model = await loadGraphModel(url, {fromTFHub: true}, spyIo);
expect(model).toBeDefined();
});

Expand All @@ -656,11 +665,10 @@ describe('Model', () => {

describe('control flow model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
CONTROL_FLOW_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(CONTROL_FLOW_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(CONTROL_FLOW_HTTP_MODEL_LOADER);
});

describe('save', () => {
Expand Down Expand Up @@ -747,11 +755,10 @@ describe('Model', () => {
};
describe('dynamic shape model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
DYNAMIC_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(DYNAMIC_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(DYNAMIC_HTTP_MODEL_LOADER);
});

it('should throw error if call predict directly', async () => {
Expand Down Expand Up @@ -792,11 +799,10 @@ describe('Model', () => {
});
describe('dynamic shape model with metadata', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
DYNAMIC_HTTP_MODEL_NEW_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(DYNAMIC_HTTP_MODEL_NEW_LOADER);
spyIo.browserHTTPRequest.and.returnValue(DYNAMIC_HTTP_MODEL_NEW_LOADER);
});

it('should be success if call executeAsync with signature key',
Expand All @@ -818,11 +824,10 @@ describe('Model', () => {

describe('Hashtable model', () => {
beforeEach(() => {
spyOn(tfc.io, 'getLoadHandlers').and.returnValue([
spyIo.getLoadHandlers.and.returnValue([
HASHTABLE_HTTP_MODEL_LOADER
]);
spyOn(tfc.io, 'browserHTTPRequest')
.and.returnValue(HASHTABLE_HTTP_MODEL_LOADER);
spyIo.browserHTTPRequest.and.returnValue(HASHTABLE_HTTP_MODEL_LOADER);
});
it('should be successful if call executeAsync', async () => {
await model.load();
Expand Down