From 845db5a89e5b08b57799c2e3a3fb4493e839c889 Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Thu, 12 Sep 2019 17:05:21 -0400 Subject: [PATCH 1/3] Upstream exports={} for mutable exported objects. --- tfjs-converter/src/data/compiled_api.ts | 8 ++++++++ tfjs-core/src/environment.ts | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/tfjs-converter/src/data/compiled_api.ts b/tfjs-converter/src/data/compiled_api.ts index 6b423bd4b6e..6d1bd6ac74e 100644 --- a/tfjs-converter/src/data/compiled_api.ts +++ b/tfjs-converter/src/data/compiled_api.ts @@ -16,6 +16,14 @@ * ============================================================================= */ +// This incantation makes Closure think that exported symbols are mutable. +// Mutable file-level exports are disallowed per style and won't reliably +// work. This hack also has a cost in terms of code size, and is only used +// to preserve the preexisting behavior of this code. +// tslint:disable-next-line:ban-ts-ignore see above +// @ts-ignore +exports = {}; + /* tslint:disable */ /** Properties of an Any. */ diff --git a/tfjs-core/src/environment.ts b/tfjs-core/src/environment.ts index 37e29a4d610..990f89e879f 100644 --- a/tfjs-core/src/environment.ts +++ b/tfjs-core/src/environment.ts @@ -15,6 +15,14 @@ * ============================================================================= */ +// This incantation makes Closure think that exported symbols are mutable. +// Mutable file-level exports are disallowed per style and won't reliably +// work. This hack also has a cost in terms of code size, and is only used +// to preserve the preexisting behavior of this code. +// tslint:disable-next-line:ban-ts-ignore see above +// @ts-ignore +exports = {}; + import {Platform} from './platforms/platform'; // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. From 54b1e8cd62d3a46ef641be48326f771008f5fa13 Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Mon, 30 Sep 2019 16:29:40 -0400 Subject: [PATCH 2/3] save --- tfjs-core/benchmarks/index.html | 2 +- tfjs-core/src/backends/cpu/backend_cpu.ts | 19 +- tfjs-core/src/backends/webgl/backend_webgl.ts | 165 +++++++------- .../src/backends/webgl/backend_webgl_test.ts | 81 ++++--- .../src/backends/webgl/canvas_util_test.ts | 11 +- tfjs-core/src/backends/webgl/flags_webgl.ts | 5 +- .../src/backends/webgl/flags_webgl_test.ts | 211 +++++++++--------- tfjs-core/src/backends/webgl/glsl_version.ts | 5 +- tfjs-core/src/backends/webgl/gpgpu_context.ts | 35 +-- .../src/backends/webgl/gpgpu_context_test.ts | 10 +- tfjs-core/src/backends/webgl/gpgpu_math.ts | 7 +- .../src/backends/webgl/reshape_packed_test.ts | 12 +- tfjs-core/src/backends/webgl/tex_util.ts | 5 +- .../src/backends/webgl/texture_manager.ts | 5 +- .../backends/webgl/webgl_batchnorm_test.ts | 6 +- .../src/backends/webgl/webgl_ops_test.ts | 96 ++++---- tfjs-core/src/backends/webgl/webgl_util.ts | 11 +- .../src/backends/webgl/webgl_util_test.ts | 10 +- tfjs-core/src/debug_mode_test.ts | 6 +- tfjs-core/src/environment.ts | 12 +- tfjs-core/src/flags.ts | 4 +- tfjs-core/src/flags_test.ts | 52 ++--- tfjs-core/src/globals.ts | 13 +- tfjs-core/src/globals_test.ts | 7 +- tfjs-core/src/index.ts | 4 +- tfjs-core/src/io/browser_files.ts | 7 +- tfjs-core/src/io/http.ts | 5 +- tfjs-core/src/io/http_test.ts | 4 +- tfjs-core/src/io/indexed_db.ts | 9 +- tfjs-core/src/io/local_storage.ts | 13 +- tfjs-core/src/io/weights_loader.ts | 9 +- tfjs-core/src/io/weights_loader_test.ts | 43 ++-- tfjs-core/src/jasmine_util.ts | 14 +- tfjs-core/src/log.ts | 6 +- tfjs-core/src/ops/slice_test.ts | 2 +- tfjs-core/src/ops/tensor_ops.ts | 5 +- tfjs-core/src/platforms/platform_browser.ts | 8 +- tfjs-core/src/platforms/platform_node.ts | 10 +- tfjs-core/src/platforms/platform_node_test.ts | 61 ++--- tfjs-core/src/tensor_test.ts | 2 +- tfjs-core/src/tensor_util_env.ts | 9 +- tfjs-core/src/tensor_util_test.ts | 2 +- tfjs-core/src/util.ts | 11 +- tfjs-core/src/util_test.ts | 6 +- tfjs-core/src/webgl.ts | 4 +- 45 files changed, 548 insertions(+), 476 deletions(-) diff --git a/tfjs-core/benchmarks/index.html b/tfjs-core/benchmarks/index.html index 9ef899279fb..d1cda5a5251 100644 --- a/tfjs-core/benchmarks/index.html +++ b/tfjs-core/benchmarks/index.html @@ -128,7 +128,7 @@

TensorFlow.js Model Benchmark

async function showEnvironment() { await tf.time(() => tf.add(tf.tensor1d([1]), tf.tensor1d([1])).data()); - envDiv.innerHTML += `
${JSON.stringify(tf.ENV.features, null, 2) + envDiv.innerHTML += `
${JSON.stringify(tf.environment().features, null, 2) } `; } diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index 2f7216111e0..b20b6d14a54 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -18,7 +18,8 @@ import * as seedrandom from 'seedrandom'; import {ENGINE} from '../../engine'; -import {ENV} from '../../environment'; +import {environment} from '../../environment'; + import {warn} from '../../log'; import * as array_ops_util from '../../ops/array_ops_util'; import * as axis_util from '../../ops/axis_util'; @@ -92,7 +93,7 @@ export class MathBackendCPU implements KernelBackend { private firstUse = true; constructor() { - if (ENV.get('IS_BROWSER')) { + if (environment().get('IS_BROWSER')) { const canvas = createCanvas(); if (canvas !== null) { this.fromPixels2DContext = @@ -105,7 +106,7 @@ export class MathBackendCPU implements KernelBackend { register(dataId: DataId, shape: number[], dtype: DataType): void { if (this.firstUse) { this.firstUse = false; - if (ENV.get('IS_NODE')) { + if (environment().get('IS_NODE')) { warn( '\n============================\n' + 'Hi there 👋. Looks like you are running TensorFlow.js in ' + @@ -154,7 +155,7 @@ export class MathBackendCPU implements KernelBackend { [pixels.width, pixels.height]; let vals: Uint8ClampedArray|Uint8Array; // tslint:disable-next-line:no-any - if (ENV.get('IS_NODE') && (pixels as any).getContext == null) { + if (environment().get('IS_NODE') && (pixels as any).getContext == null) { throw new Error( 'When running in node, pixels must be an HTMLCanvasElement ' + 'like the one returned by the `canvas` npm package'); @@ -828,7 +829,8 @@ export class MathBackendCPU implements KernelBackend { const newValues = this.readSync(result.dataId) as TypedArray; let index = 0; const offset = condition.rank === 0 || condition.rank > 1 || a.rank === 1 ? - 1 : util.sizeFromShape(a.shape.slice(1)); + 1 : + util.sizeFromShape(a.shape.slice(1)); for (let i = 0; i < values.length; i++) { for (let j = 0; j < offset; j++) { @@ -1522,9 +1524,10 @@ export class MathBackendCPU implements KernelBackend { const sign = Math.sign(values[i]); const v = Math.abs(values[i]); const t = 1.0 / (1.0 + p * v); - resultValues[i] = sign * (1.0 - - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * - Math.exp(-v * v)); + resultValues[i] = sign * + (1.0 - + (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * + Math.exp(-v * v)); } return Tensor.make(x.shape, {values: resultValues}); } diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index f1fe1425905..3d29161d149 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -20,7 +20,8 @@ import './flags_webgl'; import * as device_util from '../../device_util'; import {ENGINE, MemoryInfo, TimingInfo} from '../../engine'; -import {ENV} from '../../environment'; +import {environment} from '../../environment'; + import {tidy} from '../../globals'; import {warn} from '../../log'; import {buffer} from '../../ops/array_ops'; @@ -213,11 +214,11 @@ const CPU_HANDOFF_SIZE_THRESHOLD = 128; // * dpi / 1024 / 1024. const BEFORE_PAGING_CONSTANT = 600; function numMBBeforeWarning(): number { - if (ENV.global.screen == null) { + if (environment().global.screen == null) { return 1024; // 1 GB. } - return (ENV.global.screen.height * ENV.global.screen.width * - window.devicePixelRatio) * + return (environment().global.screen.height * + environment().global.screen.width * window.devicePixelRatio) * BEFORE_PAGING_CONSTANT / 1024 / 1024; } @@ -260,13 +261,14 @@ export class MathBackendWebGL implements KernelBackend { private warnedAboutMemory = false; constructor(private gpgpu?: GPGPUContext) { - if (!ENV.getBool('HAS_WEBGL')) { + if (!environment().getBool('HAS_WEBGL')) { throw new Error('WebGL is not supported on this device'); } if (gpgpu == null) { - const gl = getWebGLContext(ENV.getNumber('WEBGL_VERSION')); - this.binaryCache = getBinaryCache(ENV.getNumber('WEBGL_VERSION')); + const gl = getWebGLContext(environment().getNumber('WEBGL_VERSION')); + this.binaryCache = + getBinaryCache(environment().getNumber('WEBGL_VERSION')); this.gpgpu = new GPGPUContext(gl); this.canvas = gl.canvas; this.gpgpuCreatedLocally = true; @@ -331,7 +333,8 @@ export class MathBackendWebGL implements KernelBackend { if (this.fromPixels2DContext == null) { //@ts-ignore this.fromPixels2DContext = - createCanvas(ENV.getNumber('WEBGL_VERSION')).getContext('2d'); + createCanvas(environment().getNumber('WEBGL_VERSION')) + .getContext('2d'); } this.fromPixels2DContext.canvas.width = width; @@ -348,7 +351,7 @@ export class MathBackendWebGL implements KernelBackend { this.gpgpu.uploadPixelDataToTexture( this.getTexture(tempPixelHandle.dataId), pixels as ImageData); let program, res; - if (ENV.getBool('WEBGL_PACK')) { + if (environment().getBool('WEBGL_PACK')) { program = new FromPixelsPackedProgram(outShape); const packedOutput = this.makePackedTensor(program.outputShape, tempPixelHandle.dtype); @@ -374,15 +377,15 @@ export class MathBackendWebGL implements KernelBackend { throw new Error('MathBackendWebGL.write(): values can not be null'); } - if (ENV.getBool('DEBUG')) { + if (environment().getBool('DEBUG')) { for (let i = 0; i < values.length; i++) { const num = values[i] as number; if (!webgl_util.canBeRepresented(num)) { - if (ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) { + if (environment().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) { throw Error( `The value ${num} cannot be represented with your ` + `current settings. Consider enabling float32 rendering: ` + - `'tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`); + `'tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`); } throw Error(`The value ${num} cannot be represented on this device.`); } @@ -469,8 +472,8 @@ export class MathBackendWebGL implements KernelBackend { return this.convertAndCacheOnCPU(dataId); } - if (!ENV.getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && - ENV.getNumber('WEBGL_VERSION') === 2) { + if (!environment().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && + environment().getNumber('WEBGL_VERSION') === 2) { throw new Error( `tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` + `WEBGL_VERSION=2 not yet supported.`); @@ -479,7 +482,7 @@ export class MathBackendWebGL implements KernelBackend { let buffer = null; let tmpDownloadTarget: TensorHandle; - if (dtype !== 'complex64' && ENV.get('WEBGL_BUFFER_SUPPORTED')) { + if (dtype !== 'complex64' && environment().get('WEBGL_BUFFER_SUPPORTED')) { // Possibly copy the texture into a buffer before inserting a fence. tmpDownloadTarget = this.decode(dataId); const tmpData = this.texData.get(tmpDownloadTarget.dataId); @@ -530,7 +533,7 @@ export class MathBackendWebGL implements KernelBackend { private getValuesFromTexture(dataId: DataId): Float32Array { const {shape, dtype, isPacked} = this.texData.get(dataId); const size = util.sizeFromShape(shape); - if (ENV.getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) { + if (environment().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) { const tmpTarget = this.decode(dataId); const tmpData = this.texData.get(tmpTarget.dataId); const vals = this.gpgpu @@ -544,7 +547,7 @@ export class MathBackendWebGL implements KernelBackend { } const shouldUsePackedProgram = - ENV.getBool('WEBGL_PACK') && isPacked === true; + environment().getBool('WEBGL_PACK') && isPacked === true; const outputShape = shouldUsePackedProgram ? webgl_util.getShapeAs3D(shape) : shape; const tmpTarget = @@ -625,14 +628,16 @@ export class MathBackendWebGL implements KernelBackend { } private startTimer(): WebGLQuery|CPUTimerQuery { - if (ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + if (environment().getNumber( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { return this.gpgpu.beginQuery(); } return {startMs: util.now(), endMs: null}; } private endTimer(query: WebGLQuery|CPUTimerQuery): WebGLQuery|CPUTimerQuery { - if (ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + if (environment().getNumber( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { this.gpgpu.endQuery(); return query; } @@ -641,7 +646,8 @@ export class MathBackendWebGL implements KernelBackend { } private async getQueryTime(query: WebGLQuery|CPUTimerQuery): Promise { - if (ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + if (environment().getNumber( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { return this.gpgpu.waitForQueryAndGetTime(query as WebGLQuery); } const timerQuery = query as CPUTimerQuery; @@ -705,7 +711,7 @@ export class MathBackendWebGL implements KernelBackend { } private getCPUBackend(): KernelBackend|null { - if (!ENV.getBool('WEBGL_CPU_FORWARD')) { + if (!environment().getBool('WEBGL_CPU_FORWARD')) { return null; } @@ -768,7 +774,7 @@ export class MathBackendWebGL implements KernelBackend { const {isPacked} = this.texData.get(x.dataId); const isContinous = slice_util.isSliceContinous(x.shape, begin, size); if (isPacked || !isContinous) { - const program = ENV.getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new SlicePackedProgram(size) : new SliceProgram(size); const customSetup = program.getCustomSetupFunc(begin); @@ -822,7 +828,7 @@ export class MathBackendWebGL implements KernelBackend { } reverse(x: T, axis: number[]): T { - const program = ENV.getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new ReversePackedProgram(x.shape, axis) : new ReverseProgram(x.shape, axis); return this.compileAndRun(program, [x]); @@ -841,13 +847,15 @@ export class MathBackendWebGL implements KernelBackend { if (tensors.length === 1) { return tensors[0]; } - if (tensors.length > ENV.getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) { + if (tensors.length > + environment().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) { const midIndex = Math.floor(tensors.length / 2); const leftSide = this.concat(tensors.slice(0, midIndex), axis); const rightSide = this.concat(tensors.slice(midIndex), axis); return this.concat([leftSide, rightSide], axis); } - if (ENV.getBool('WEBGL_PACK_ARRAY_OPERATIONS') && tensors[0].rank > 1) { + if (environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && + tensors[0].rank > 1) { const program = new ConcatPackedProgram(tensors.map(t => t.shape), axis); return this.compileAndRun(program, tensors); } @@ -871,7 +879,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.neg(x); } - if (ENV.getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.NEG, x.dtype) as T; } const program = new UnaryOpProgram(x.shape, unary_op.NEG); @@ -966,7 +974,7 @@ export class MathBackendWebGL implements KernelBackend { if (this.shouldExecuteOnCPU([a, b])) { return this.cpuBackend.multiply(a, b); } - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_gpu.MUL, a.dtype); } const program = new BinaryOpProgram(binaryop_gpu.MUL, a.shape, b.shape); @@ -992,7 +1000,7 @@ export class MathBackendWebGL implements KernelBackend { inputs.push(scale); } - if (ENV.getBool('WEBGL_PACK_NORMALIZATION')) { + if (environment().getBool('WEBGL_PACK_NORMALIZATION')) { const batchNormPackedProgram = new BatchNormPackedProgram( x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon); @@ -1008,7 +1016,7 @@ export class MathBackendWebGL implements KernelBackend { localResponseNormalization4D( x: Tensor4D, radius: number, bias: number, alpha: number, beta: number): Tensor4D { - const program = ENV.getBool('WEBGL_PACK_NORMALIZATION') ? + const program = environment().getBool('WEBGL_PACK_NORMALIZATION') ? new LRNPackedProgram(x.shape, radius, bias, alpha, beta) : new LRNProgram(x.shape, radius, bias, alpha, beta); return this.compileAndRun(program, [x]); @@ -1036,7 +1044,7 @@ export class MathBackendWebGL implements KernelBackend { pad( x: T, paddings: Array<[number, number]>, constantValue: number): T { - const program = ENV.getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new PadPackedProgram(x.shape, paddings, constantValue) : new PadProgram(x.shape, paddings, constantValue); return this.compileAndRun(program, [x]); @@ -1046,7 +1054,7 @@ export class MathBackendWebGL implements KernelBackend { if (this.shouldExecuteOnCPU([x])) { return this.cpuBackend.transpose(x, perm); } - const program = ENV.getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new TransposePackedProgram(x.shape, perm) : new TransposeProgram(x.shape, perm); return this.compileAndRun(program, [x]); @@ -1251,7 +1259,7 @@ export class MathBackendWebGL implements KernelBackend { axis_util.assertAxesAreInnerMostDims( 'arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.rank); - if (!ENV.getBool('WEBGL_PACK_REDUCE') || x.rank <= 2) { + if (!environment().getBool('WEBGL_PACK_REDUCE') || x.rank <= 2) { const [outShape, reduceShape] = axis_util.computeOutAndReduceShapes(x.shape, axes); const inSize = util.sizeFromShape(reduceShape); @@ -1281,7 +1289,7 @@ export class MathBackendWebGL implements KernelBackend { } equal(a: Tensor, b: Tensor): Tensor { - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.EQUAL, 'bool'); } const program = new BinaryOpProgram(binaryop_gpu.EQUAL, a.shape, b.shape); @@ -1290,7 +1298,7 @@ export class MathBackendWebGL implements KernelBackend { } notEqual(a: Tensor, b: Tensor): Tensor { - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.NOT_EQUAL, 'bool'); } const program = @@ -1304,7 +1312,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.less(a, b); } - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.LESS, 'bool'); } @@ -1314,7 +1322,7 @@ export class MathBackendWebGL implements KernelBackend { } lessEqual(a: Tensor, b: Tensor): Tensor { - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.LESS_EQUAL, 'bool'); } const program = @@ -1328,7 +1336,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.greater(a, b); } - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.GREATER, 'bool'); } @@ -1338,7 +1346,7 @@ export class MathBackendWebGL implements KernelBackend { } greaterEqual(a: Tensor, b: Tensor): Tensor { - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp( a, b, binaryop_packed_gpu.GREATER_EQUAL, 'bool'); } @@ -1354,7 +1362,7 @@ export class MathBackendWebGL implements KernelBackend { } logicalAnd(a: Tensor, b: Tensor): Tensor { - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.LOGICAL_AND, 'bool'); } const program = @@ -1364,7 +1372,7 @@ export class MathBackendWebGL implements KernelBackend { } logicalOr(a: Tensor, b: Tensor): Tensor { - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.LOGICAL_OR, 'bool'); } const program = @@ -1407,14 +1415,14 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.minimum(a, b); } - const program = ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(binaryop_packed_gpu.MIN, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.MIN, a.shape, b.shape); return this.compileAndRun(program, [a, b]); } mod(a: Tensor, b: Tensor): Tensor { - const program = ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(binaryop_packed_gpu.MOD, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.MOD, a.shape, b.shape); return this.compileAndRun(program, [a, b]); @@ -1438,7 +1446,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.maximum(a, b); } - const program = ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(binaryop_packed_gpu.MAX, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.MAX, a.shape, b.shape); return this.compileAndRun(program, [a, b]); @@ -1463,7 +1471,7 @@ export class MathBackendWebGL implements KernelBackend { } squaredDifference(a: Tensor, b: Tensor): Tensor { - const program = ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram( binaryop_gpu.SQUARED_DIFFERENCE, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.SQUARED_DIFFERENCE, a.shape, b.shape); @@ -1473,7 +1481,7 @@ export class MathBackendWebGL implements KernelBackend { realDivide(a: Tensor, b: Tensor): Tensor { const op = binaryop_gpu.DIV; const outputDtype = 'float32'; - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { const checkOutOfBounds = true; return this.packedBinaryOp( a, b, binaryop_packed_gpu.DIV, outputDtype, checkOutOfBounds); @@ -1486,7 +1494,7 @@ export class MathBackendWebGL implements KernelBackend { floorDiv(a: Tensor, b: Tensor): Tensor { const op = binaryop_gpu.INT_DIV; const outputDtype = 'int32'; - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp( a, b, binaryop_packed_gpu.INT_DIV, outputDtype); } @@ -1505,7 +1513,7 @@ export class MathBackendWebGL implements KernelBackend { } const dtype = upcastType(a.dtype, b.dtype); - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_gpu.ADD, dtype); } const program = new BinaryOpProgram(binaryop_gpu.ADD, a.shape, b.shape); @@ -1576,7 +1584,7 @@ export class MathBackendWebGL implements KernelBackend { } // Limit the number of uploaded textures for optimization. - if (tensors.length > ENV.get('WEBGL_MAX_TEXTURES_IN_SHADER')) { + if (tensors.length > environment().get('WEBGL_MAX_TEXTURES_IN_SHADER')) { const midIndex = Math.floor(tensors.length / 2); const leftSide = this.addN(tensors.slice(0, midIndex)); const rightSide = this.addN(tensors.slice(midIndex)); @@ -1587,7 +1595,7 @@ export class MathBackendWebGL implements KernelBackend { tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2)); const shapes = tensors.map(t => t.shape); // We can make sure shapes are identical in op level. - const usePackedOp = ENV.getBool('WEBGL_PACK'); + const usePackedOp = environment().getBool('WEBGL_PACK'); const program = usePackedOp ? new AddNPackedProgram(tensors[0].shape, shapes) : new AddNProgram(tensors[0].shape, shapes); @@ -1606,7 +1614,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.subtract(a, b); } const dtype = upcastType(a.dtype, b.dtype); - if (ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_gpu.SUB, a.dtype); } const program = new BinaryOpProgram(binaryop_gpu.SUB, a.shape, b.shape); @@ -1615,7 +1623,7 @@ export class MathBackendWebGL implements KernelBackend { } pow(a: T, b: Tensor): T { - const usePackedOp = ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS'); + const usePackedOp = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); const program = usePackedOp ? new BinaryOpPackedProgram(binaryop_packed_gpu.POW, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.POW, a.shape, b.shape); @@ -1631,7 +1639,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.ceil(x); } - if (ENV.getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.CEIL, x.dtype) as T; } @@ -1644,7 +1652,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.floor(x); } - if (ENV.getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.FLOOR, x.dtype) as T; } @@ -1683,7 +1691,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.exp(x); } - if (ENV.getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.EXP, x.dtype) as T; } @@ -1696,7 +1704,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.expm1(x); } - if (ENV.getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.EXPM1, x.dtype) as T; } @@ -1709,7 +1717,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.log(x); } - if (ENV.getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_packed_op.LOG, x.dtype) as T; } @@ -1747,7 +1755,7 @@ export class MathBackendWebGL implements KernelBackend { relu(x: T): T { let program: UnaryOpProgram|UnaryOpPackedProgram; - if (ENV.getBool('WEBGL_PACK')) { + if (environment().getBool('WEBGL_PACK')) { program = new UnaryOpPackedProgram(x.shape, unary_packed_op.RELU); } else { program = new UnaryOpProgram(x.shape, unary_op.RELU); @@ -1757,7 +1765,7 @@ export class MathBackendWebGL implements KernelBackend { relu6(x: T): T { let program: UnaryOpProgram|UnaryOpPackedProgram; - if (ENV.getBool('WEBGL_PACK')) { + if (environment().getBool('WEBGL_PACK')) { program = new UnaryOpPackedProgram(x.shape, unary_packed_op.RELU6); } else { program = new UnaryOpProgram(x.shape, unary_op.RELU6); @@ -1766,7 +1774,7 @@ export class MathBackendWebGL implements KernelBackend { } prelu(x: T, alpha: T): T { - const program = ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram( binaryop_packed_gpu.PRELU, x.shape, alpha.shape) : new BinaryOpProgram(binaryop_gpu.PRELU, x.shape, alpha.shape); @@ -1774,7 +1782,7 @@ export class MathBackendWebGL implements KernelBackend { } elu(x: T): T { - if (ENV.getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_packed_op.ELU, x.dtype) as T; } const program = new UnaryOpProgram(x.shape, unary_op.ELU); @@ -1782,7 +1790,7 @@ export class MathBackendWebGL implements KernelBackend { } eluDer(dy: T, y: T): T { - const program = ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram( binaryop_packed_gpu.ELU_DER, dy.shape, y.shape) : new BinaryOpProgram(binaryop_gpu.ELU_DER, dy.shape, y.shape); @@ -1802,7 +1810,7 @@ export class MathBackendWebGL implements KernelBackend { clip(x: T, min: number, max: number): T { let program; - if (ENV.getBool('WEBGL_PACK_CLIP')) { + if (environment().getBool('WEBGL_PACK_CLIP')) { program = new ClipPackedProgram(x.shape); } else { program = new ClipProgram(x.shape); @@ -1816,7 +1824,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.abs(x); } - if (ENV.getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.ABS, x.dtype) as T; } @@ -1877,7 +1885,7 @@ export class MathBackendWebGL implements KernelBackend { } atan2(a: T, b: T): T { - const program = ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(binaryop_packed_gpu.ATAN2, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.ATAN2, a.shape, b.shape); return this.compileAndRun(program, [a, b]); @@ -1944,8 +1952,9 @@ export class MathBackendWebGL implements KernelBackend { sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD; const reshapeWillBeExpensive = xShape[2] % 2 !== 0 && !!xTexData.isPacked; - if (batchMatMulWillBeUnpacked || !ENV.getBool('WEBGL_LAZILY_UNPACK') || - !ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS') || + if (batchMatMulWillBeUnpacked || + !environment().getBool('WEBGL_LAZILY_UNPACK') || + !environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') || !reshapeWillBeExpensive) { const targetShape = isChannelsLast ? xShape[0] * xShape[1] * xShape[2] : xShape[0] * xShape[2] * xShape[3]; @@ -2092,7 +2101,7 @@ export class MathBackendWebGL implements KernelBackend { return this.conv2dByMatMul( input, filter, convInfo, bias, activation, preluActivationWeights); } - if (ENV.getBool('WEBGL_CONV_IM2COL') && input.shape[0] === 1) { + if (environment().getBool('WEBGL_CONV_IM2COL') && input.shape[0] === 1) { return this.conv2dWithIm2Row( input, filter, convInfo, bias, activation, preluActivationWeights); } @@ -2121,7 +2130,7 @@ export class MathBackendWebGL implements KernelBackend { convInfo.padInfo.type === 'VALID')) { return this.conv2dByMatMul(x, filter, convInfo); } - if (ENV.getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) { + if (environment().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) { return this.conv2dWithIm2Row(x, filter, convInfo); } const program = new Conv2DProgram(convInfo); @@ -2142,7 +2151,8 @@ export class MathBackendWebGL implements KernelBackend { fusedDepthwiseConv2D( {input, filter, convInfo, bias, activation, preluActivationWeights}: FusedConv2DConfig): Tensor4D { - const shouldPackDepthwiseConv = ENV.getBool('WEBGL_PACK_DEPTHWISECONV') && + const shouldPackDepthwiseConv = + environment().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1; const fusedActivation = activation ? @@ -2176,7 +2186,8 @@ export class MathBackendWebGL implements KernelBackend { depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D { let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram; - if (ENV.getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && + if (environment().getBool('WEBGL_PACK_DEPTHWISECONV') && + convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1) { program = new DepthwiseConvPacked2DProgram(convInfo); return this.compileAndRun( @@ -2326,7 +2337,7 @@ export class MathBackendWebGL implements KernelBackend { resizeBilinear( x: Tensor4D, newHeight: number, newWidth: number, alignCorners: boolean): Tensor4D { - const program = ENV.getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? + const program = environment().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeBilinearPackedProgram( x.shape, newHeight, newWidth, alignCorners) : new ResizeBilinearProgram(x.shape, newHeight, newWidth, alignCorners); @@ -2654,7 +2665,7 @@ export class MathBackendWebGL implements KernelBackend { if (texData.texture == null) { if (!program.usesPackedTextures && util.sizeFromShape(input.shape) <= - ENV.getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) { + environment().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) { // Upload small tensors that live on the CPU as uniforms, not as // textures. Do this only when the environment supports 32bit floats // due to problems when comparing 16bit floats with 32bit floats. @@ -2728,7 +2739,7 @@ export class MathBackendWebGL implements KernelBackend { {name: program.constructor.name, query: this.getQueryTime(query)}); } - if (!ENV.getBool('WEBGL_LAZILY_UNPACK') && + if (!environment().getBool('WEBGL_LAZILY_UNPACK') && this.texData.get(output.dataId).isPacked && preventEagerUnpackingOfOutput === false) { return this.unpackTensor(output as {} as Tensor) as {} as K; @@ -2776,13 +2787,13 @@ export class MathBackendWebGL implements KernelBackend { floatPrecision(): 16|32 { if (this.floatPrecisionValue == null) { this.floatPrecisionValue = tidy(() => { - if (!ENV.get('WEBGL_RENDER_FLOAT32_ENABLED')) { + if (!environment().get('WEBGL_RENDER_FLOAT32_ENABLED')) { // Momentarily switching DEBUG flag to false so we don't throw an // error trying to upload a small value. - const debugFlag = ENV.getBool('DEBUG'); - ENV.set('DEBUG', false); + const debugFlag = environment().getBool('DEBUG'); + environment().set('DEBUG', false); const underflowCheckValue = this.abs(scalar(1e-8)).dataSync()[0]; - ENV.set('DEBUG', debugFlag); + environment().set('DEBUG', debugFlag); if (underflowCheckValue > 0) { return 32; diff --git a/tfjs-core/src/backends/webgl/backend_webgl_test.ts b/tfjs-core/src/backends/webgl/backend_webgl_test.ts index f12d5a5c270..3884ebf4b0b 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl_test.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl_test.ts @@ -41,12 +41,12 @@ describeWithFlags('forced f16 render', RENDER_FLOAT32_ENVS, () => { beforeAll(() => { renderToF32FlagSaved = - tf.ENV.get('WEBGL_RENDER_FLOAT32_ENABLED') as boolean; - tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', false); + tf.environment().get('WEBGL_RENDER_FLOAT32_ENABLED') as boolean; + tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', false); }); afterAll(() => { - tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', renderToF32FlagSaved); + tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', renderToF32FlagSaved); }); it('should overflow if larger than 66k', async () => { @@ -56,11 +56,11 @@ describeWithFlags('forced f16 render', RENDER_FLOAT32_ENVS, () => { }); it('should error in debug mode', () => { - const savedDebugFlag = tf.ENV.getBool('DEBUG'); - tf.ENV.set('DEBUG', true); + const savedDebugFlag = tf.environment().getBool('DEBUG'); + tf.environment().set('DEBUG', true); const a = () => tf.tensor1d([2, Math.pow(2, 17)], 'float32'); expect(a).toThrowError(); - tf.ENV.set('DEBUG', savedDebugFlag); + tf.environment().set('DEBUG', savedDebugFlag); }); }); @@ -69,15 +69,16 @@ describeWithFlags('lazy packing and unpacking', WEBGL_ENVS, () => { let webglCpuForwardFlagSaved: boolean; beforeAll(() => { - webglLazilyUnpackFlagSaved = tf.ENV.getBool('WEBGL_LAZILY_UNPACK'); - webglCpuForwardFlagSaved = tf.ENV.getBool('WEBGL_CPU_FORWARD'); - tf.ENV.set('WEBGL_LAZILY_UNPACK', true); - tf.ENV.set('WEBGL_CPU_FORWARD', false); + webglLazilyUnpackFlagSaved = + tf.environment().getBool('WEBGL_LAZILY_UNPACK'); + webglCpuForwardFlagSaved = tf.environment().getBool('WEBGL_CPU_FORWARD'); + tf.environment().set('WEBGL_LAZILY_UNPACK', true); + tf.environment().set('WEBGL_CPU_FORWARD', false); }); afterAll(() => { - tf.ENV.set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); - tf.ENV.set('WEBGL_CPU_FORWARD', webglCpuForwardFlagSaved); + tf.environment().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); + tf.environment().set('WEBGL_CPU_FORWARD', webglCpuForwardFlagSaved); }); it('should not leak memory when lazily unpacking', () => { @@ -93,11 +94,11 @@ describeWithFlags('lazy packing and unpacking', WEBGL_ENVS, () => { (tf.memory() as tf.webgl.WebGLMemoryInfo).numBytesInGPU; const webglPackBinaryOperationsFlagSaved = - tf.ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS', false); + tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', false); // Add will unpack c before the operation to 2 tf.add(c, 1); - tf.ENV.set( + tf.environment().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); expect(tf.memory().numBytes - startNumBytes).toEqual(16); @@ -230,19 +231,20 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => { tf.registerBackend('test-storage', () => backend); tf.setBackend('test-storage'); - const webglPackFlagSaved = tf.ENV.getBool('WEBGL_PACK'); - tf.ENV.set('WEBGL_PACK', true); + const webglPackFlagSaved = tf.environment().getBool('WEBGL_PACK'); + tf.environment().set('WEBGL_PACK', true); const webglSizeUploadUniformSaved = - tf.ENV.getNumber('WEBGL_SIZE_UPLOAD_UNIFORM'); - tf.ENV.set('WEBGL_SIZE_UPLOAD_UNIFORM', 0); + tf.environment().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM'); + tf.environment().set('WEBGL_SIZE_UPLOAD_UNIFORM', 0); const a = tf.tensor2d([1, 2], [2, 1]); const b = tf.tensor2d([1], [1, 1]); const c = tf.matMul(a, b); backend.readSync(c.dataId); - tf.ENV.set('WEBGL_PACK', false); + tf.environment().set('WEBGL_PACK', false); const d = tf.add(c, 1); - tf.ENV.set('WEBGL_PACK', webglPackFlagSaved); - tf.ENV.set('WEBGL_SIZE_UPLOAD_UNIFORM', webglSizeUploadUniformSaved); + tf.environment().set('WEBGL_PACK', webglPackFlagSaved); + tf.environment().set( + 'WEBGL_SIZE_UPLOAD_UNIFORM', webglSizeUploadUniformSaved); expectArraysClose(await d.data(), [2, 3]); }); @@ -319,12 +321,13 @@ describeWithFlags('upload tensors as uniforms', FLOAT32_WEBGL_ENVS, () => { let savedUploadUniformValue: number; beforeAll(() => { - savedUploadUniformValue = tf.ENV.get('WEBGL_SIZE_UPLOAD_UNIFORM') as number; - tf.ENV.set('WEBGL_SIZE_UPLOAD_UNIFORM', SIZE_UPLOAD_UNIFORM); + savedUploadUniformValue = + tf.environment().get('WEBGL_SIZE_UPLOAD_UNIFORM') as number; + tf.environment().set('WEBGL_SIZE_UPLOAD_UNIFORM', SIZE_UPLOAD_UNIFORM); }); afterAll(() => { - tf.ENV.set('WEBGL_SIZE_UPLOAD_UNIFORM', savedUploadUniformValue); + tf.environment().set('WEBGL_SIZE_UPLOAD_UNIFORM', savedUploadUniformValue); }); it('small tensor gets uploaded as scalar', () => { @@ -384,29 +387,31 @@ describeWithFlags('debug on webgl', WEBGL_ENVS, () => { beforeAll(() => { // Silences debug warnings. spyOn(console, 'warn'); - tf.ENV.set('DEBUG', true); + tf.environment().set('DEBUG', true); }); afterAll(() => { - tf.ENV.set('DEBUG', false); + tf.environment().set('DEBUG', false); }); it('debug mode errors when overflow in tensor construction', () => { const savedRenderFloat32Flag = - tf.ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED'); - tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', false); + tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED'); + tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', false); const a = () => tf.tensor1d([2, Math.pow(2, 17)], 'float32'); expect(a).toThrowError(); - tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); + tf.environment().set( + 'WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); }); it('debug mode errors when underflow in tensor construction', () => { const savedRenderFloat32Flag = - tf.ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED'); - tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', false); + tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED'); + tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', false); const a = () => tf.tensor1d([2, 1e-8], 'float32'); expect(a).toThrowError(); - tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); + tf.environment().set( + 'WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); }); }); @@ -424,10 +429,11 @@ describeWithFlags('memory webgl', WEBGL_ENVS, () => { // We do not yet fully support half float backends. These tests are a starting // point. describeWithFlags('backend without render float32 support', WEBGL_ENVS, () => { - const savedRenderFloat32Flag = tf.ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED'); + const savedRenderFloat32Flag = + tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED'); beforeAll(() => { - tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', false); + tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', false); }); beforeEach(() => { @@ -439,7 +445,8 @@ describeWithFlags('backend without render float32 support', WEBGL_ENVS, () => { }); afterAll(() => { - tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); + tf.environment().set( + 'WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); }); it('basic usage', async () => { @@ -508,7 +515,7 @@ describeWithFlags('time webgl', WEBGL_ENVS, () => { describeWithFlags('caching on cpu', WEBGL_ENVS, () => { beforeAll(() => { - tf.ENV.set('WEBGL_CPU_FORWARD', false); + tf.environment().set('WEBGL_CPU_FORWARD', false); }); it('caches on cpu after async read', async () => { diff --git a/tfjs-core/src/backends/webgl/canvas_util_test.ts b/tfjs-core/src/backends/webgl/canvas_util_test.ts index e4e01ec0724..a13fc300468 100644 --- a/tfjs-core/src/backends/webgl/canvas_util_test.ts +++ b/tfjs-core/src/backends/webgl/canvas_util_test.ts @@ -14,16 +14,17 @@ * limitations under the License. * ============================================================================= */ - -import {ENV} from '../../environment'; +import * as tf from '../../index'; import {BROWSER_ENVS, describeWithFlags} from '../../jasmine_util'; import {getWebGLContext} from './canvas_util'; describeWithFlags('canvas_util', BROWSER_ENVS, () => { it('Returns a valid canvas', () => { - const canvas = getWebGLContext(ENV.getNumber('WEBGL_VERSION')).canvas as ( - HTMLCanvasElement | OffscreenCanvas); + const canvas = + getWebGLContext(tf.environment().getNumber('WEBGL_VERSION')).canvas as + // tslint:disable-next-line: no-any + any; expect( (canvas instanceof HTMLCanvasElement) || (canvas instanceof OffscreenCanvas)) @@ -31,7 +32,7 @@ describeWithFlags('canvas_util', BROWSER_ENVS, () => { }); it('Returns a valid gl context', () => { - const gl = getWebGLContext(ENV.getNumber('WEBGL_VERSION')); + const gl = getWebGLContext(tf.environment().getNumber('WEBGL_VERSION')); expect(gl.isContextLost()).toBe(false); }); }); diff --git a/tfjs-core/src/backends/webgl/flags_webgl.ts b/tfjs-core/src/backends/webgl/flags_webgl.ts index ecdb2803529..d3ddfc297c4 100644 --- a/tfjs-core/src/backends/webgl/flags_webgl.ts +++ b/tfjs-core/src/backends/webgl/flags_webgl.ts @@ -16,9 +16,12 @@ */ import * as device_util from '../../device_util'; -import {ENV} from '../../environment'; +import {environment} from '../../environment'; + import * as webgl_util from './webgl_util'; +const ENV = environment(); + /** * This file contains WebGL-specific flag registrations. */ diff --git a/tfjs-core/src/backends/webgl/flags_webgl_test.ts b/tfjs-core/src/backends/webgl/flags_webgl_test.ts index d1ff36c2fa9..4231234d1b7 100644 --- a/tfjs-core/src/backends/webgl/flags_webgl_test.ts +++ b/tfjs-core/src/backends/webgl/flags_webgl_test.ts @@ -16,7 +16,6 @@ */ import * as device_util from '../../device_util'; -import {ENV} from '../../environment'; import * as tf from '../../index'; import {describeWithFlags} from '../../jasmine_util'; import {webgl_util} from '../../webgl'; @@ -25,17 +24,18 @@ import {WEBGL_ENVS} from './backend_webgl_test_registry'; import * as canvas_util from './canvas_util'; describe('WEBGL_FORCE_F16_TEXTURES', () => { - afterAll(() => ENV.reset()); + afterAll(() => tf.environment().reset()); it('can be activated via forceHalfFloat utility', () => { tf.webgl.forceHalfFloat(); - expect(ENV.getBool('WEBGL_FORCE_F16_TEXTURES')).toBe(true); + expect(tf.environment().getBool('WEBGL_FORCE_F16_TEXTURES')).toBe(true); }); it('turns off WEBGL_RENDER_FLOAT32_ENABLED', () => { - ENV.reset(); + tf.environment().reset(); tf.webgl.forceHalfFloat(); - expect(ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED')).toBe(false); + expect(tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) + .toBe(false); }); }); @@ -51,212 +51,215 @@ const RENDER_FLOAT16_ENVS = { describeWithFlags('WEBGL_RENDER_FLOAT32_CAPABLE', RENDER_FLOAT32_ENVS, () => { beforeEach(() => { - ENV.reset(); + tf.environment().reset(); }); - afterAll(() => ENV.reset()); + afterAll(() => tf.environment().reset()); it('should be independent of forcing f16 rendering', () => { tf.webgl.forceHalfFloat(); - expect(ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE')).toBe(true); + expect(tf.environment().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')).toBe(true); }); it('if user is not forcing f16, device should render to f32', () => { - expect(ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED')).toBe(true); + expect(tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED')).toBe(true); }); }); describeWithFlags('WEBGL_RENDER_FLOAT32_CAPABLE', RENDER_FLOAT16_ENVS, () => { beforeEach(() => { - ENV.reset(); + tf.environment().reset(); }); - afterAll(() => ENV.reset()); + afterAll(() => tf.environment().reset()); it('should be independent of forcing f16 rendering', () => { tf.webgl.forceHalfFloat(); - expect(ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE')).toBe(false); + expect(tf.environment().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) + .toBe(false); }); it('should be reflected in WEBGL_RENDER_FLOAT32_ENABLED', () => { - expect(ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED')).toBe(false); + expect(tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) + .toBe(false); }); }); describe('HAS_WEBGL', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('false when version is 0', () => { - ENV.set('WEBGL_VERSION', 0); - expect(ENV.getBool('HAS_WEBGL')).toBe(false); + tf.environment().set('WEBGL_VERSION', 0); + expect(tf.environment().getBool('HAS_WEBGL')).toBe(false); }); it('true when version is 1', () => { - ENV.set('WEBGL_VERSION', 1); - expect(ENV.getBool('HAS_WEBGL')).toBe(true); + tf.environment().set('WEBGL_VERSION', 1); + expect(tf.environment().getBool('HAS_WEBGL')).toBe(true); }); it('true when version is 2', () => { - ENV.set('WEBGL_VERSION', 2); - expect(ENV.getBool('HAS_WEBGL')).toBe(true); + tf.environment().set('WEBGL_VERSION', 2); + expect(tf.environment().getBool('HAS_WEBGL')).toBe(true); }); }); describe('WEBGL_PACK', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('true when HAS_WEBGL is true', () => { - ENV.set('HAS_WEBGL', true); - expect(ENV.getBool('WEBGL_PACK')).toBe(true); + tf.environment().set('HAS_WEBGL', true); + expect(tf.environment().getBool('WEBGL_PACK')).toBe(true); }); it('false when HAS_WEBGL is false', () => { - ENV.set('HAS_WEBGL', false); - expect(ENV.getBool('WEBGL_PACK')).toBe(false); + tf.environment().set('HAS_WEBGL', false); + expect(tf.environment().getBool('WEBGL_PACK')).toBe(false); }); }); describe('WEBGL_PACK_NORMALIZATION', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('true when WEBGL_PACK is true', () => { - ENV.set('WEBGL_PACK', true); - expect(ENV.getBool('WEBGL_PACK_NORMALIZATION')).toBe(true); + tf.environment().set('WEBGL_PACK', true); + expect(tf.environment().getBool('WEBGL_PACK_NORMALIZATION')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - ENV.set('WEBGL_PACK', false); - expect(ENV.getBool('WEBGL_PACK_NORMALIZATION')).toBe(false); + tf.environment().set('WEBGL_PACK', false); + expect(tf.environment().getBool('WEBGL_PACK_NORMALIZATION')).toBe(false); }); }); describe('WEBGL_PACK_CLIP', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('true when WEBGL_PACK is true', () => { - ENV.set('WEBGL_PACK', true); - expect(ENV.getBool('WEBGL_PACK_CLIP')).toBe(true); + tf.environment().set('WEBGL_PACK', true); + expect(tf.environment().getBool('WEBGL_PACK_CLIP')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - ENV.set('WEBGL_PACK', false); - expect(ENV.getBool('WEBGL_PACK_CLIP')).toBe(false); + tf.environment().set('WEBGL_PACK', false); + expect(tf.environment().getBool('WEBGL_PACK_CLIP')).toBe(false); }); }); // TODO: https://github.com/tensorflow/tfjs/issues/1679 // describe('WEBGL_PACK_DEPTHWISECONV', () => { -// beforeEach(() => ENV.reset()); -// afterAll(() => ENV.reset()); +// beforeEach(() => tf.environment().reset()); +// afterAll(() => tf.environment().reset()); // it('true when WEBGL_PACK is true', () => { -// ENV.set('WEBGL_PACK', true); -// expect(ENV.getBool('WEBGL_PACK_DEPTHWISECONV')).toBe(true); +// tf.environment().set('WEBGL_PACK', true); +// expect(tf.environment().getBool('WEBGL_PACK_DEPTHWISECONV')).toBe(true); // }); // it('false when WEBGL_PACK is false', () => { -// ENV.set('WEBGL_PACK', false); -// expect(ENV.getBool('WEBGL_PACK_DEPTHWISECONV')).toBe(false); +// tf.environment().set('WEBGL_PACK', false); +// expect(tf.environment().getBool('WEBGL_PACK_DEPTHWISECONV')).toBe(false); // }); // }); describe('WEBGL_PACK_BINARY_OPERATIONS', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('true when WEBGL_PACK is true', () => { - ENV.set('WEBGL_PACK', true); - expect(ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')).toBe(true); + tf.environment().set('WEBGL_PACK', true); + expect(tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - ENV.set('WEBGL_PACK', false); - expect(ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS')).toBe(false); + tf.environment().set('WEBGL_PACK', false); + expect(tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) + .toBe(false); }); }); describe('WEBGL_PACK_ARRAY_OPERATIONS', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('true when WEBGL_PACK is true', () => { - ENV.set('WEBGL_PACK', true); - expect(ENV.getBool('WEBGL_PACK_ARRAY_OPERATIONS')).toBe(true); + tf.environment().set('WEBGL_PACK', true); + expect(tf.environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - ENV.set('WEBGL_PACK', false); - expect(ENV.getBool('WEBGL_PACK_ARRAY_OPERATIONS')).toBe(false); + tf.environment().set('WEBGL_PACK', false); + expect(tf.environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS')).toBe(false); }); }); describe('WEBGL_PACK_IMAGE_OPERATIONS', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('true when WEBGL_PACK is true', () => { - ENV.set('WEBGL_PACK', true); - expect(ENV.getBool('WEBGL_PACK_IMAGE_OPERATIONS')).toBe(true); + tf.environment().set('WEBGL_PACK', true); + expect(tf.environment().getBool('WEBGL_PACK_IMAGE_OPERATIONS')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - ENV.set('WEBGL_PACK', false); - expect(ENV.getBool('WEBGL_PACK_IMAGE_OPERATIONS')).toBe(false); + tf.environment().set('WEBGL_PACK', false); + expect(tf.environment().getBool('WEBGL_PACK_IMAGE_OPERATIONS')).toBe(false); }); }); describe('WEBGL_PACK_REDUCE', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('true when WEBGL_PACK is true', () => { - ENV.set('WEBGL_PACK', true); - expect(ENV.getBool('WEBGL_PACK_REDUCE')).toBe(true); + tf.environment().set('WEBGL_PACK', true); + expect(tf.environment().getBool('WEBGL_PACK_REDUCE')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - ENV.set('WEBGL_PACK', false); - expect(ENV.getBool('WEBGL_PACK_REDUCE')).toBe(false); + tf.environment().set('WEBGL_PACK', false); + expect(tf.environment().getBool('WEBGL_PACK_REDUCE')).toBe(false); }); }); describe('WEBGL_LAZILY_UNPACK', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('true when WEBGL_PACK is true', () => { - ENV.set('WEBGL_PACK', true); - expect(ENV.getBool('WEBGL_LAZILY_UNPACK')).toBe(true); + tf.environment().set('WEBGL_PACK', true); + expect(tf.environment().getBool('WEBGL_LAZILY_UNPACK')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - ENV.set('WEBGL_PACK', false); - expect(ENV.getBool('WEBGL_LAZILY_UNPACK')).toBe(false); + tf.environment().set('WEBGL_PACK', false); + expect(tf.environment().getBool('WEBGL_LAZILY_UNPACK')).toBe(false); }); }); describe('WEBGL_CONV_IM2COL', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('true when WEBGL_PACK is true', () => { - ENV.set('WEBGL_PACK', true); - expect(ENV.getBool('WEBGL_CONV_IM2COL')).toBe(true); + tf.environment().set('WEBGL_PACK', true); + expect(tf.environment().getBool('WEBGL_CONV_IM2COL')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - ENV.set('WEBGL_PACK', false); - expect(ENV.getBool('WEBGL_CONV_IM2COL')).toBe(false); + tf.environment().set('WEBGL_PACK', false); + expect(tf.environment().getBool('WEBGL_CONV_IM2COL')).toBe(false); }); }); describe('WEBGL_MAX_TEXTURE_SIZE', () => { beforeEach(() => { - ENV.reset(); + tf.environment().reset(); webgl_util.resetMaxTextureSize(); spyOn(canvas_util, 'getWebGLContext').and.returnValue({ @@ -270,19 +273,19 @@ describe('WEBGL_MAX_TEXTURE_SIZE', () => { }); }); afterAll(() => { - ENV.reset(); + tf.environment().reset(); webgl_util.resetMaxTextureSize(); }); it('is a function of gl.getParameter(MAX_TEXTURE_SIZE)', () => { - expect(ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE')).toBe(50); + expect(tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE')).toBe(50); }); }); describe('WEBGL_MAX_TEXTURES_IN_SHADER', () => { let maxTextures: number; beforeEach(() => { - ENV.reset(); + tf.environment().reset(); webgl_util.resetMaxTexturesInShader(); spyOn(canvas_util, 'getWebGLContext').and.callFake(() => { @@ -298,61 +301,65 @@ describe('WEBGL_MAX_TEXTURES_IN_SHADER', () => { }); }); afterAll(() => { - ENV.reset(); + tf.environment().reset(); webgl_util.resetMaxTexturesInShader(); }); it('is a function of gl.getParameter(MAX_TEXTURE_IMAGE_UNITS)', () => { maxTextures = 10; - expect(ENV.getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')).toBe(10); + expect(tf.environment().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')).toBe(10); }); it('is capped at 16', () => { maxTextures = 20; - expect(ENV.getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')).toBe(16); + expect(tf.environment().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')).toBe(16); }); }); describe('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('disjoint query timer disabled', () => { - ENV.set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 0); + tf.environment().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 0); - expect(ENV.getBool('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) + expect(tf.environment().getBool( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) .toBe(false); }); it('disjoint query timer enabled, mobile', () => { - ENV.set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 1); + tf.environment().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 1); spyOn(device_util, 'isMobile').and.returnValue(true); - expect(ENV.getBool('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) + expect(tf.environment().getBool( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) .toBe(false); }); it('disjoint query timer enabled, not mobile', () => { - ENV.set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 1); + tf.environment().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 1); spyOn(device_util, 'isMobile').and.returnValue(false); - expect(ENV.getBool('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) + expect(tf.environment().getBool( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) .toBe(true); }); }); describe('WEBGL_SIZE_UPLOAD_UNIFORM', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('is 0 when there is no float32 bit support', () => { - ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', false); - expect(ENV.getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')).toBe(0); + tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', false); + expect(tf.environment().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')).toBe(0); }); it('is > 0 when there is float32 bit support', () => { - ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', true); - expect(ENV.getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')).toBeGreaterThan(0); + tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', true); + expect(tf.environment().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) + .toBeGreaterThan(0); }); }); diff --git a/tfjs-core/src/backends/webgl/glsl_version.ts b/tfjs-core/src/backends/webgl/glsl_version.ts index 9209258007d..b074c6ce08e 100644 --- a/tfjs-core/src/backends/webgl/glsl_version.ts +++ b/tfjs-core/src/backends/webgl/glsl_version.ts @@ -14,8 +14,7 @@ * limitations under the License. * ============================================================================= */ - -import {ENV} from '../../environment'; +import {environment} from '../../environment'; export type GLSL = { version: string, @@ -42,7 +41,7 @@ export function getGlslDifferences(): GLSL { let defineSpecialInf: string; let defineRound: string; - if (ENV.getNumber('WEBGL_VERSION') === 2) { + if (environment().getNumber('WEBGL_VERSION') === 2) { version = '#version 300 es'; attribute = 'in'; varyingVs = 'out'; diff --git a/tfjs-core/src/backends/webgl/gpgpu_context.ts b/tfjs-core/src/backends/webgl/gpgpu_context.ts index 33a1a931503..6e39e875e04 100644 --- a/tfjs-core/src/backends/webgl/gpgpu_context.ts +++ b/tfjs-core/src/backends/webgl/gpgpu_context.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {ENV} from '../../environment'; +import {environment} from '../../environment'; + import {PixelData, TypedArray} from '../../types'; import * as util from '../../util'; @@ -49,7 +50,7 @@ export class GPGPUContext { private textureConfig: TextureConfig; constructor(gl?: WebGLRenderingContext) { - const glVersion = ENV.getNumber('WEBGL_VERSION'); + const glVersion = environment().getNumber('WEBGL_VERSION'); if (gl != null) { this.gl = gl; setWebGLContext(glVersion, gl); @@ -57,7 +58,7 @@ export class GPGPUContext { this.gl = getWebGLContext(glVersion); } // WebGL 2.0 enables texture floats without an extension. - if (ENV.getNumber('WEBGL_VERSION') === 1) { + if (environment().getNumber('WEBGL_VERSION') === 1) { this.textureFloatExtension = webgl_util.getExtensionOrThrow( this.gl, this.debug, 'OES_texture_float'); this.colorBufferFloatExtension = @@ -90,7 +91,7 @@ export class GPGPUContext { } private get debug(): boolean { - return ENV.getBool('DEBUG'); + return environment().getBool('DEBUG'); } public dispose() { @@ -225,7 +226,7 @@ export class GPGPUContext { let query: WebGLQuery|WebGLSync; let isFencePassed: () => boolean; - if (ENV.getBool('WEBGL_FENCE_API_ENABLED')) { + if (environment().getBool('WEBGL_FENCE_API_ENABLED')) { const gl2 = gl as WebGL2RenderingContext; const sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0); @@ -239,11 +240,14 @@ export class GPGPUContext { query = sync; } else if ( - ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + environment().getNumber( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { query = this.beginQuery(); this.endQuery(); isFencePassed = () => this.isQueryAvailable( - query, ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); + query, + environment().getNumber( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); } else { // If we have no way to fence, return true immediately. This will fire in // WebGL 1.0 when there is no disjoint query timer. In this case, because @@ -407,8 +411,8 @@ export class GPGPUContext { this.disjointQueryTimerExtension = webgl_util.getExtensionOrThrow( this.gl, this.debug, - ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === - 2 ? + environment().getNumber( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? 'EXT_disjoint_timer_query_webgl2' : 'EXT_disjoint_timer_query') as WebGL1DisjointQueryTimerExtension | @@ -426,7 +430,8 @@ export class GPGPUContext { } beginQuery(): WebGLQuery { - if (ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { + if (environment().getNumber( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { const gl2 = this.gl as WebGL2RenderingContext; const ext = this.getQueryTimerExtensionWebGL2(); @@ -441,7 +446,8 @@ export class GPGPUContext { } endQuery() { - if (ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { + if (environment().getNumber( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { const gl2 = this.gl as WebGL2RenderingContext; const ext = this.getQueryTimerExtensionWebGL2(); gl2.endQuery(ext.TIME_ELAPSED_EXT); @@ -458,9 +464,12 @@ export class GPGPUContext { // may poll for the query timer indefinitely this.isQueryAvailable( query, - ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'))); + environment().getNumber( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'))); return this.getQueryTime( - query, ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); + query, + environment().getNumber( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); } private getQueryTime(query: WebGLQuery, queryTimerVersion: number): number { diff --git a/tfjs-core/src/backends/webgl/gpgpu_context_test.ts b/tfjs-core/src/backends/webgl/gpgpu_context_test.ts index d03bae8f324..580d097d60f 100644 --- a/tfjs-core/src/backends/webgl/gpgpu_context_test.ts +++ b/tfjs-core/src/backends/webgl/gpgpu_context_test.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {ENV} from '../../environment'; +import * as tf from '../../index'; import {describeWithFlags} from '../../jasmine_util'; import {WEBGL_ENVS} from './backend_webgl_test_registry'; @@ -37,7 +37,7 @@ describeWithFlags( gpgpu = new GPGPUContext(); // Silences debug warnings. spyOn(console, 'warn'); - ENV.set('DEBUG', true); + tf.environment().set('DEBUG', true); texture = gpgpu.createFloat32MatrixTexture(1, 1); }); @@ -71,7 +71,7 @@ describeWithFlags( gpgpu = new GPGPUContext(); // Silences debug warnings. spyOn(console, 'warn'); - ENV.set('DEBUG', true); + tf.environment().set('DEBUG', true); }); afterEach(() => { @@ -109,7 +109,7 @@ describeWithFlags( gpgpu = new GPGPUContext(); // Silences debug warnings. spyOn(console, 'warn'); - ENV.set('DEBUG', true); + tf.environment().set('DEBUG', true); const glsl = getGlslDifferences(); const src = `${glsl.version} precision highp float; @@ -148,7 +148,7 @@ describeWithFlags('GPGPUContext', DOWNLOAD_FLOAT_ENVS, () => { gpgpu = new GPGPUContext(); // Silences debug warnings. spyOn(console, 'warn'); - ENV.set('DEBUG', true); + tf.environment().set('DEBUG', true); }); afterEach(() => { diff --git a/tfjs-core/src/backends/webgl/gpgpu_math.ts b/tfjs-core/src/backends/webgl/gpgpu_math.ts index a284da25500..71cfa384639 100644 --- a/tfjs-core/src/backends/webgl/gpgpu_math.ts +++ b/tfjs-core/src/backends/webgl/gpgpu_math.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {ENV} from '../../environment'; +import {environment} from '../../environment'; + import {Tensor} from '../../tensor'; import {TypedArray} from '../../types'; import * as util from '../../util'; @@ -85,7 +86,7 @@ export function compileProgram( // Add special uniforms (NAN, INFINITY) let infLoc: WebGLUniformLocation = null; const nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false); - if (ENV.getNumber('WEBGL_VERSION') === 1) { + if (environment().getNumber('WEBGL_VERSION') === 1) { infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false); } @@ -163,7 +164,7 @@ export function runProgram( gpgpu.setProgram(binary.webGLProgram); // Set special uniforms (NAN, INFINITY) - if (ENV.getNumber('WEBGL_VERSION') === 1) { + if (environment().getNumber('WEBGL_VERSION') === 1) { if (binary.infLoc !== null) { gpgpu.gl.uniform1f(binary.infLoc, Infinity); } diff --git a/tfjs-core/src/backends/webgl/reshape_packed_test.ts b/tfjs-core/src/backends/webgl/reshape_packed_test.ts index 8117ce93fe8..2645461f185 100644 --- a/tfjs-core/src/backends/webgl/reshape_packed_test.ts +++ b/tfjs-core/src/backends/webgl/reshape_packed_test.ts @@ -65,14 +65,14 @@ describeWithFlags('expensive reshape', PACKED_ENVS, () => { describeWithFlags('expensive reshape with even columns', PACKED_ENVS, () => { it('2 --> 4 columns', async () => { - const maxTextureSize = tf.ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE'); + const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); let values: number[] = new Array(16).fill(0); values = values.map((d, i) => i + 1); const a = tf.tensor2d(values, [8, 2]); const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', 2); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 2); // Setting WEBGL_MAX_TEXTURE_SIZE to 2 makes that [8, 2] tensor is packed // to texture of width 2 by height 2. Indices are packed as: // ------------- @@ -82,14 +82,14 @@ describeWithFlags('expensive reshape with even columns', PACKED_ENVS, () => { // ... const c = tf.matMul(a, b); let cAs4D = c.reshape([2, 1, 2, 4]); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); // Execute non-packed operations to unpack tensor. - const webglPackFlagSaved = tf.ENV.getBool('WEBGL_PACK'); - tf.ENV.set('WEBGL_PACK', false); + const webglPackFlagSaved = tf.environment().getBool('WEBGL_PACK'); + tf.environment().set('WEBGL_PACK', false); cAs4D = cAs4D.add(1); cAs4D = cAs4D.add(-1); - tf.ENV.set('WEBGL_PACK', webglPackFlagSaved); + tf.environment().set('WEBGL_PACK', webglPackFlagSaved); const result = [7, 10, 15, 22, 23, 34, 31, 46, 39, 58, 47, 70, 55, 82, 63, 94]; diff --git a/tfjs-core/src/backends/webgl/tex_util.ts b/tfjs-core/src/backends/webgl/tex_util.ts index f86d568e74c..e9a43e40607 100644 --- a/tfjs-core/src/backends/webgl/tex_util.ts +++ b/tfjs-core/src/backends/webgl/tex_util.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {ENV} from '../../environment'; +import {environment} from '../../environment'; + import {DataId, Tensor} from '../../tensor'; import {BackendValues, DataType} from '../../types'; import * as util from '../../util'; @@ -159,7 +160,7 @@ export function getTextureConfig( let textureTypeHalfFloat: number; let textureTypeFloat: number; - if (ENV.getNumber('WEBGL_VERSION') === 2) { + if (environment().getNumber('WEBGL_VERSION') === 2) { internalFormatFloat = glany.R32F; internalFormatHalfFloat = glany.R16F; internalFormatPackedHalfFloat = glany.RGBA16F; diff --git a/tfjs-core/src/backends/webgl/texture_manager.ts b/tfjs-core/src/backends/webgl/texture_manager.ts index db1ae889c7e..2512957fb27 100644 --- a/tfjs-core/src/backends/webgl/texture_manager.ts +++ b/tfjs-core/src/backends/webgl/texture_manager.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {ENV} from '../../environment'; +import {environment} from '../../environment'; + import {GPGPUContext} from './gpgpu_context'; import {PhysicalTextureType, TextureUsage} from './tex_util'; @@ -144,7 +145,7 @@ export class TextureManager { function getPhysicalTextureForRendering(isPacked: boolean): PhysicalTextureType { - if (ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED')) { + if (environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) { if (isPacked) { return PhysicalTextureType.PACKED_2X2_FLOAT32; } diff --git a/tfjs-core/src/backends/webgl/webgl_batchnorm_test.ts b/tfjs-core/src/backends/webgl/webgl_batchnorm_test.ts index 9396129047d..078999b46d5 100644 --- a/tfjs-core/src/backends/webgl/webgl_batchnorm_test.ts +++ b/tfjs-core/src/backends/webgl/webgl_batchnorm_test.ts @@ -32,8 +32,8 @@ describeWithFlags('batchNorm', WEBGL_ENVS, () => { }); it('should work when squarification results in zero padding', async () => { - const maxTextureSize = tf.ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', 5); + const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 5); const x = tf.tensor3d( [ @@ -52,7 +52,7 @@ describeWithFlags('batchNorm', WEBGL_ENVS, () => { const result = tf.batchNorm3d(x, mean, variance, offset, scale, varianceEpsilon); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); expectArraysClose(await result.data(), [ 0.59352049, -0.66135202, 0.5610874, -0.92077015, -1.45341019, 1.52106473, diff --git a/tfjs-core/src/backends/webgl/webgl_ops_test.ts b/tfjs-core/src/backends/webgl/webgl_ops_test.ts index 66af03335e6..673477dd0c2 100644 --- a/tfjs-core/src/backends/webgl/webgl_ops_test.ts +++ b/tfjs-core/src/backends/webgl/webgl_ops_test.ts @@ -281,15 +281,15 @@ describeWithFlags('depthToSpace', WEBGL_ENVS, () => { describeWithFlags('maximum', WEBGL_ENVS, () => { it('works with squarification for large dimension', async () => { - const maxTextureSize = tf.ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', 5); + const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 5); const a = tf.tensor2d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [2, 7]); const b = tf.tensor2d([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 7]); const result = tf.maximum(a, b); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); expectArraysClose( await result.data(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]); }); @@ -334,19 +334,19 @@ describeWithFlags('conv2d webgl', WEBGL_ENVS, () => { const x = tf.tensor3d([1, 2, 3, 4], inputShape); const w = tf.tensor4d([1, 2, 3, 4], [fSize, fSize, 2, 2]); - const webglLazilyUnpackFlagSaved = tf.ENV.getBool('WEBGL_LAZILY_UNPACK'); - tf.ENV.set('WEBGL_LAZILY_UNPACK', true); + const webglLazilyUnpackFlagSaved = tf.environment().getBool('WEBGL_LAZILY_UNPACK'); + tf.environment().set('WEBGL_LAZILY_UNPACK', true); const webglPackBinaryOperationsFlagSaved = - tf.ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS', true); + tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', true); // First conv2D tests conv2D with non-packed input |x|, and the second uses // packed input |result|. const result = tf.conv2d(x, w, stride, pad); const result1 = tf.conv2d(result, w, stride, pad); - tf.ENV.set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); - tf.ENV.set( + tf.environment().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); + tf.environment().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); expectArraysClose(await result.data(), [7, 10, 15, 22]); @@ -363,17 +363,17 @@ describeWithFlags('conv2d webgl', WEBGL_ENVS, () => { const xInit = tf.tensor4d([0, 1], inputShape); const w = tf.tensor4d([1, 2, 3, 4], [fSize, fSize, 2, 2]); - const webglLazilyUnpackFlagSaved = tf.ENV.getBool('WEBGL_LAZILY_UNPACK'); - tf.ENV.set('WEBGL_LAZILY_UNPACK', true); + const webglLazilyUnpackFlagSaved = tf.environment().getBool('WEBGL_LAZILY_UNPACK'); + tf.environment().set('WEBGL_LAZILY_UNPACK', true); const webglPackBinaryOperationsFlagSaved = - tf.ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS', true); + tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', true); const x = xInit.add(1); const result = tf.conv2d(x, w, stride, pad); - tf.ENV.set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); - tf.ENV.set( + tf.environment().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); + tf.environment().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); expectArraysClose(await result.data(), [7, 10]); @@ -504,28 +504,28 @@ describeWithFlags('matmul', PACKED_ENVS, () => { }); it('should work when input texture shapes != physical shape', async () => { - const maxTextureSize = tf.ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', 5); + const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 5); const a = tf.tensor2d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [1, 12]); const b = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [12, 1]); const c = tf.matMul(a, b); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); expectArraysClose(await c.data(), [572]); }); it('should work when squarification results in zero padding', async () => { - const maxTextureSize = tf.ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', 3); + const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 3); const a = tf.tensor2d([1, 2], [1, 2]); const b = tf.tensor2d( [[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16, 17]]); const c = tf.matMul(a, b); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); expectArraysClose(await c.data(), [18, 21, 24, 27, 30, 33, 36, 39, 42]); }); @@ -582,10 +582,10 @@ describeWithFlags('matmul', PACKED_ENVS, () => { const c = tf.matMul(a, b); - const webglPackBinarySaved = tf.ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS', false); + const webglPackBinarySaved = tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', false); const d = tf.add(c, 1); - tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS', webglPackBinarySaved); + tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', webglPackBinarySaved); expectArraysClose(await d.data(), [1, 9, -2, 21]); }); @@ -600,10 +600,10 @@ describeWithFlags('matmul', PACKED_ENVS, () => { const d = tf.reshape(c, [1, 3, 3, 1]); const webglPackBinarySaved = - tf.ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS', false); + tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', false); const e = tf.add(d, 1); - tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS', webglPackBinarySaved); + tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', webglPackBinarySaved); expectArraysClose(await e.data(), [2, 3, 4, 5, 6, 7, 8, 9, 10]); }); @@ -621,16 +621,16 @@ describeWithFlags('matmul', PACKED_ENVS, () => { describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { it('argmax 3D, odd number of rows, axis = -1', async () => { - const webglLazilyUnpackFlagSaved = tf.ENV.getBool('WEBGL_LAZILY_UNPACK'); - tf.ENV.set('WEBGL_LAZILY_UNPACK', true); + const webglLazilyUnpackFlagSaved = tf.environment().getBool('WEBGL_LAZILY_UNPACK'); + tf.environment().set('WEBGL_LAZILY_UNPACK', true); const webglPackBinaryOperationsFlagSaved = - tf.ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS', true); + tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', true); const a = tf.tensor3d([3, 2, 5, 100, -7, 2], [2, 1, 3]).add(1); const r = tf.argMax(a, -1); - tf.ENV.set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); - tf.ENV.set( + tf.environment().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); + tf.environment().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); expect(r.dtype).toBe('int32'); @@ -638,11 +638,11 @@ describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { }); it('argmin 4D, odd number of rows, axis = -1', async () => { - const webglLazilyUnpackFlagSaved = tf.ENV.getBool('WEBGL_LAZILY_UNPACK'); - tf.ENV.set('WEBGL_LAZILY_UNPACK', true); + const webglLazilyUnpackFlagSaved = tf.environment().getBool('WEBGL_LAZILY_UNPACK'); + tf.environment().set('WEBGL_LAZILY_UNPACK', true); const webglPackBinaryOperationsFlagSaved = - tf.ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS', true); + tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', true); const a = tf.tensor4d( @@ -650,8 +650,8 @@ describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { [1, 2, 3, 3]) .add(1); const r = tf.argMin(a, -1); - tf.ENV.set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); - tf.ENV.set( + tf.environment().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); + tf.environment().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); expect(r.dtype).toBe('int32'); @@ -660,8 +660,8 @@ describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { it('should not leak memory when called after unpacked op', async () => { const webglPackBinaryOperationsFlagSaved = - tf.ENV.getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.ENV.set('WEBGL_PACK_BINARY_OPERATIONS', false); + tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', false); const a = tf.tensor5d( @@ -671,7 +671,7 @@ describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { const startNumBytes = tf.memory().numBytes; const startNumTensors = tf.memory().numTensors; const r = tf.argMin(a, -1); - tf.ENV.set( + tf.environment().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); const endNumBytes = tf.memory().numBytes; const endNumTensors = tf.memory().numTensors; @@ -684,8 +684,8 @@ describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { describeWithFlags('slice and memory usage', WEBGL_ENVS, () => { beforeAll(() => { - tf.ENV.set('WEBGL_CPU_FORWARD', false); - tf.ENV.set('WEBGL_SIZE_UPLOAD_UNIFORM', 0); + tf.environment().set('WEBGL_CPU_FORWARD', false); + tf.environment().set('WEBGL_SIZE_UPLOAD_UNIFORM', 0); }); it('slice a tensor, read it and check memory', async () => { @@ -722,7 +722,7 @@ describeWithFlags('slice and memory usage', WEBGL_ENVS, () => { describeWithFlags('slice a packed texture', WEBGL_ENVS, () => { beforeAll(() => { - tf.ENV.set('WEBGL_PACK', true); + tf.environment().set('WEBGL_PACK', true); }); it('slice after a matmul', async () => { @@ -741,12 +741,12 @@ describeWithFlags('slice a packed texture', WEBGL_ENVS, () => { describeWithFlags('relu', WEBGL_ENVS, () => { it('works with squarification for prime number length vector', async () => { - const maxTextureSize = tf.ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', 5); + const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 5); const a = tf.tensor1d([1, -2, 5, -3, -1, 4, 7]); const result = tf.relu(a); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); expectArraysClose(await result.data(), [1, 0, 5, 0, 0, 4, 7]); }); }); diff --git a/tfjs-core/src/backends/webgl/webgl_util.ts b/tfjs-core/src/backends/webgl/webgl_util.ts index ea5a27d445a..2c51b052ea9 100644 --- a/tfjs-core/src/backends/webgl/webgl_util.ts +++ b/tfjs-core/src/backends/webgl/webgl_util.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {ENV} from '../../environment'; +import {environment} from '../../environment'; + import * as util from '../../util'; import {getWebGLContext} from './canvas_util'; @@ -42,7 +43,7 @@ const MIN_FLOAT16 = 5.96e-8; const MAX_FLOAT16 = 65504; export function canBeRepresented(num: number): boolean { - if (ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || + if (environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) { return true; } @@ -192,7 +193,7 @@ export function createStaticIndexBuffer( } export function getNumChannels(): number { - if (ENV.getNumber('WEBGL_VERSION') === 2) { + if (environment().getNumber('WEBGL_VERSION') === 2) { return 1; } return 4; @@ -205,7 +206,7 @@ export function createTexture( } export function validateTextureSize(width: number, height: number) { - const maxTextureSize = ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE'); + const maxTextureSize = environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); if ((width <= 0) || (height <= 0)) { const requested = `[${width}x${height}]`; throw new Error('Requested texture size ' + requested + ' is invalid.'); @@ -384,7 +385,7 @@ export function getShapeAs3D(shape: number[]): [number, number, number] { export function getTextureShapeFromLogicalShape( logShape: number[], isPacked = false): [number, number] { - let maxTexSize = ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE'); + let maxTexSize = environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); if (isPacked) { maxTexSize = maxTexSize * 2; diff --git a/tfjs-core/src/backends/webgl/webgl_util_test.ts b/tfjs-core/src/backends/webgl/webgl_util_test.ts index 2e039e26ba1..893245b0689 100644 --- a/tfjs-core/src/backends/webgl/webgl_util_test.ts +++ b/tfjs-core/src/backends/webgl/webgl_util_test.ts @@ -84,7 +84,9 @@ describeWithFlags('getTextureShapeFromLogicalShape packed', WEBGL_ENVS, () => { it('textures less than 2x max size of platform preserve their shapes', () => { const isPacked = true; const logicalShape = [ - 2, util.nearestLargerEven(tf.ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE') + 1) + 2, + util.nearestLargerEven( + tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE') + 1) ]; const texShape = webgl_util.getTextureShapeFromLogicalShape(logicalShape, isPacked); @@ -101,14 +103,14 @@ describeWithFlags('getTextureShapeFromLogicalShape packed', WEBGL_ENVS, () => { it('squarified texture shapes account for packing constraints', () => { const isPacked = true; - const max = tf.ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE'); + const max = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', 5); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 5); const logicalShape = [1, 12]; const texShape = webgl_util.getTextureShapeFromLogicalShape(logicalShape, isPacked); - tf.ENV.set('WEBGL_MAX_TEXTURE_SIZE', max); + tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', max); expect(texShape).toEqual([6, 4]); }); }); diff --git a/tfjs-core/src/debug_mode_test.ts b/tfjs-core/src/debug_mode_test.ts index 605d31ff931..5b9fdf6ea06 100644 --- a/tfjs-core/src/debug_mode_test.ts +++ b/tfjs-core/src/debug_mode_test.ts @@ -22,11 +22,11 @@ import {expectArraysClose} from './test_util'; describeWithFlags('debug on', SYNC_BACKEND_ENVS, () => { beforeAll(() => { - tf.ENV.set('DEBUG', true); + tf.environment().set('DEBUG', true); }); afterAll(() => { - tf.ENV.set('DEBUG', false); + tf.environment().set('DEBUG', false); }); it('debug mode does not error when no nans', async () => { @@ -117,7 +117,7 @@ describeWithFlags('debug on', SYNC_BACKEND_ENVS, () => { describeWithFlags('debug off', ALL_ENVS, () => { beforeAll(() => { - tf.ENV.set('DEBUG', false); + tf.environment().set('DEBUG', false); }); it('no errors where there are nans, and debug mode is disabled', async () => { diff --git a/tfjs-core/src/environment.ts b/tfjs-core/src/environment.ts index 990f89e879f..a926599519f 100644 --- a/tfjs-core/src/environment.ts +++ b/tfjs-core/src/environment.ts @@ -15,14 +15,6 @@ * ============================================================================= */ -// This incantation makes Closure think that exported symbols are mutable. -// Mutable file-level exports are disallowed per style and won't reliably -// work. This hack also has a cost in terms of code size, and is only used -// to preserve the preexisting behavior of this code. -// tslint:disable-next-line:ban-ts-ignore see above -// @ts-ignore -exports = {}; - import {Platform} from './platforms/platform'; // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. @@ -174,6 +166,10 @@ function parseValue(flagName: string, value: string): FlagValue { `Could not parse value flag value ${value} for flag ${flagName}.`); } +export function environment() { + return ENV; +} + export let ENV: Environment = null; export function setEnvironmentGlobal(environment: Environment) { ENV = environment; diff --git a/tfjs-core/src/flags.ts b/tfjs-core/src/flags.ts index 1de5d473f17..99af479a6ee 100644 --- a/tfjs-core/src/flags.ts +++ b/tfjs-core/src/flags.ts @@ -15,7 +15,9 @@ * ============================================================================= */ import * as device_util from './device_util'; -import {ENV} from './environment'; +import {environment} from './environment'; + +const ENV = environment(); /** * This file contains environment-related flag registrations. diff --git a/tfjs-core/src/flags_test.ts b/tfjs-core/src/flags_test.ts index f62e2dea24e..47964f13929 100644 --- a/tfjs-core/src/flags_test.ts +++ b/tfjs-core/src/flags_test.ts @@ -16,28 +16,28 @@ */ import * as device_util from './device_util'; -import {ENV} from './environment'; +import * as tf from './index'; describe('DEBUG', () => { beforeEach(() => { - ENV.reset(); + tf.environment().reset(); spyOn(console, 'warn').and.callFake((msg: string) => {}); }); - afterAll(() => ENV.reset()); + afterAll(() => tf.environment().reset()); it('disabled by default', () => { - expect(ENV.getBool('DEBUG')).toBe(false); + expect(tf.environment().getBool('DEBUG')).toBe(false); }); it('warns when enabled', () => { const consoleWarnSpy = console.warn as jasmine.Spy; - ENV.set('DEBUG', true); + tf.environment().set('DEBUG', true); expect(consoleWarnSpy.calls.count()).toBe(1); expect((consoleWarnSpy.calls.first().args[0] as string) .startsWith('Debugging mode is ON. ')) .toBe(true); - expect(ENV.getBool('DEBUG')).toBe(true); + expect(tf.environment().getBool('DEBUG')).toBe(true); expect(consoleWarnSpy.calls.count()).toBe(1); }); }); @@ -45,60 +45,62 @@ describe('DEBUG', () => { describe('IS_BROWSER', () => { let isBrowser: boolean; beforeEach(() => { - ENV.reset(); + tf.environment().reset(); spyOn(device_util, 'isBrowser').and.callFake(() => isBrowser); }); - afterAll(() => ENV.reset()); + afterAll(() => tf.environment().reset()); it('isBrowser: true', () => { isBrowser = true; - expect(ENV.getBool('IS_BROWSER')).toBe(true); + expect(tf.environment().getBool('IS_BROWSER')).toBe(true); }); it('isBrowser: false', () => { isBrowser = false; - expect(ENV.getBool('IS_BROWSER')).toBe(false); + expect(tf.environment().getBool('IS_BROWSER')).toBe(false); }); }); describe('PROD', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('disabled by default', () => { - expect(ENV.getBool('PROD')).toBe(false); + expect(tf.environment().getBool('PROD')).toBe(false); }); }); describe('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('disabled when debug is disabled', () => { - ENV.set('DEBUG', false); - expect(ENV.getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')).toBe(false); + tf.environment().set('DEBUG', false); + expect(tf.environment().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) + .toBe(false); }); it('enabled when debug is enabled', () => { - ENV.set('DEBUG', true); - expect(ENV.getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')).toBe(true); + tf.environment().set('DEBUG', true); + expect(tf.environment().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) + .toBe(true); }); }); describe('DEPRECATION_WARNINGS_ENABLED', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('enabled by default', () => { - expect(ENV.getBool('DEPRECATION_WARNINGS_ENABLED')).toBe(true); + expect(tf.environment().getBool('DEPRECATION_WARNINGS_ENABLED')).toBe(true); }); }); describe('IS_TEST', () => { - beforeEach(() => ENV.reset()); - afterAll(() => ENV.reset()); + beforeEach(() => tf.environment().reset()); + afterAll(() => tf.environment().reset()); it('disabled by default', () => { - expect(ENV.getBool('IS_TEST')).toBe(false); + expect(tf.environment().getBool('IS_TEST')).toBe(false); }); }); diff --git a/tfjs-core/src/globals.ts b/tfjs-core/src/globals.ts index b7a26185b06..207b6c9c173 100644 --- a/tfjs-core/src/globals.ts +++ b/tfjs-core/src/globals.ts @@ -17,7 +17,8 @@ import {KernelBackend} from './backends/backend'; import {ENGINE, Engine, MemoryInfo, ProfileInfo, ScopeFn, TimingInfo} from './engine'; -import {ENV} from './environment'; +import {environment} from './environment'; + import {Platform} from './platforms/platform'; import {setDeprecationWarningFn, Tensor} from './tensor'; import {TensorContainer} from './tensor_types'; @@ -29,7 +30,7 @@ import {getTensorsInContainer} from './tensor_util'; */ /** @doc {heading: 'Environment'} */ export function enableProdMode(): void { - ENV.set('PROD', true); + environment().set('PROD', true); } /** @@ -46,18 +47,18 @@ export function enableProdMode(): void { */ /** @doc {heading: 'Environment'} */ export function enableDebugMode(): void { - ENV.set('DEBUG', true); + environment().set('DEBUG', true); } /** Globally disables deprecation warnings */ export function disableDeprecationWarnings(): void { - ENV.set('DEPRECATION_WARNINGS_ENABLED', false); + environment().set('DEPRECATION_WARNINGS_ENABLED', false); console.warn(`TensorFlow.js deprecation warnings have been disabled.`); } /** Warn users about deprecated functionality. */ export function deprecationWarn(msg: string) { - if (ENV.getBool('DEPRECATION_WARNINGS_ENABLED')) { + if (environment().getBool('DEPRECATION_WARNINGS_ENABLED')) { console.warn( msg + ' You can disable deprecation warnings with ' + 'tf.disableDeprecationWarnings().'); @@ -354,5 +355,5 @@ export function backend(): KernelBackend { * @param platform A platform implementation. */ export function setPlatform(platformName: string, platform: Platform) { - ENV.setPlatform(platformName, platform); + environment().setPlatform(platformName, platform); } diff --git a/tfjs-core/src/globals_test.ts b/tfjs-core/src/globals_test.ts index 5e7d9f9f7d2..dfa67ff9b69 100644 --- a/tfjs-core/src/globals_test.ts +++ b/tfjs-core/src/globals_test.ts @@ -15,7 +15,6 @@ * ============================================================================= */ -import {ENV} from './environment'; import * as tf from './index'; import {ALL_ENVS, describeWithFlags, NODE_ENVS} from './jasmine_util'; import {expectArraysClose} from './test_util'; @@ -49,17 +48,17 @@ describe('deprecation warnings', () => { describe('Flag flipping methods', () => { beforeEach(() => { - ENV.reset(); + tf.environment().reset(); }); it('tf.enableProdMode', () => { tf.enableProdMode(); - expect(ENV.getBool('PROD')).toBe(true); + expect(tf.environment().getBool('PROD')).toBe(true); }); it('tf.enableDebugMode', () => { tf.enableDebugMode(); - expect(ENV.getBool('DEBUG')).toBe(true); + expect(tf.environment().getBool('DEBUG')).toBe(true); }); }); diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index 4a67212d042..c39af44d3f5 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -30,7 +30,6 @@ import './platforms/platform_browser'; import './platforms/platform_node'; import * as backend_util from './backends/backend_util'; -import * as environment from './environment'; // Serialization. import * as io from './io/io'; import * as math from './math'; @@ -66,7 +65,7 @@ export * from './globals'; export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients'; export {TimingInfo, MemoryInfo} from './engine'; -export {ENV, Environment} from './environment'; +export {Environment, environment, ENV} from './environment'; export {Platform} from './platforms/platform'; export {version as version_core}; @@ -77,7 +76,6 @@ export {nextFrame} from './browser_util'; // Second level exports. export { browser, - environment, io, math, serialization, diff --git a/tfjs-core/src/io/browser_files.ts b/tfjs-core/src/io/browser_files.ts index 9912fa74f32..a5d6dbcd383 100644 --- a/tfjs-core/src/io/browser_files.ts +++ b/tfjs-core/src/io/browser_files.ts @@ -20,7 +20,8 @@ * user-selected files in browser. */ -import {ENV} from '../environment'; +import {environment} from '../environment'; + import {basename, concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types'; @@ -42,7 +43,7 @@ export class BrowserDownloads implements IOHandler { static readonly URL_SCHEME = 'downloads://'; constructor(fileNamePrefix?: string) { - if (!ENV.getBool('IS_BROWSER')) { + if (!environment().getBool('IS_BROWSER')) { // TODO(cais): Provide info on what IOHandlers are available under the // current environment. throw new Error( @@ -244,7 +245,7 @@ class BrowserFiles implements IOHandler { } export const browserDownloadsRouter: IORouter = (url: string|string[]) => { - if (!ENV.getBool('IS_BROWSER')) { + if (!environment().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) { diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index 83ce0d9a3d1..8266d5650e2 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -21,7 +21,8 @@ * Uses [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). */ -import {ENV} from '../environment'; +import {environment} from '../environment'; + import {assert} from '../util'; import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; @@ -58,7 +59,7 @@ export class HTTPRequest implements IOHandler { 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)'); this.fetch = loadOptions.fetchFunc; } else { - this.fetch = ENV.platform.fetch; + this.fetch = environment().platform.fetch; } assert( diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index a0e6fc84019..6d7c051896e 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -81,7 +81,7 @@ const setupFakeWeightFiles = } }, requestInits: {[key: string]: RequestInit}) => { - fetchSpy = spyOn(tf.ENV.platform, 'fetch') + fetchSpy = spyOn(tf.environment().platform, 'fetch') .and.callFake((path: string, init: RequestInit) => { if (fileBufferMap[path]) { requestInits[path] = init; @@ -191,7 +191,7 @@ describeWithFlags('http-save', CHROME_ENVS, () => { beforeEach(() => { requestInits = []; - spyOn(tf.ENV.platform, 'fetch') + spyOn(tf.environment().platform, 'fetch') .and.callFake((path: string, init: RequestInit) => { if (path === 'model-upload-test' || path === 'http://model-upload-test') { diff --git a/tfjs-core/src/io/indexed_db.ts b/tfjs-core/src/io/indexed_db.ts index a2ec568f644..6e4936931c5 100644 --- a/tfjs-core/src/io/indexed_db.ts +++ b/tfjs-core/src/io/indexed_db.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {ENV} from '../environment'; +import {environment} from '../environment'; + import {getModelArtifactsInfoForJSON} from './io_utils'; import {ModelStoreManagerRegistry} from './model_management'; import {IORouter, IORouterRegistry} from './router_registry'; @@ -47,7 +48,7 @@ export async function deleteDatabase(): Promise { } function getIndexedDBFactory(): IDBFactory { - if (!ENV.getBool('IS_BROWSER')) { + if (!environment().getBool('IS_BROWSER')) { // TODO(cais): Add more info about what IOHandler subtypes are available. // Maybe point to a doc page on the web and/or automatically determine // the available IOHandlers and print them in the error message. @@ -207,7 +208,7 @@ export class BrowserIndexedDB implements IOHandler { } export const indexedDBRouter: IORouter = (url: string|string[]) => { - if (!ENV.getBool('IS_BROWSER')) { + if (!environment().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) { @@ -351,7 +352,7 @@ export class BrowserIndexedDBManager implements ModelStoreManager { } } -if (ENV.getBool('IS_BROWSER')) { +if (environment().getBool('IS_BROWSER')) { // Wrap the construction and registration, to guard against browsers that // don't support Local Storage. try { diff --git a/tfjs-core/src/io/local_storage.ts b/tfjs-core/src/io/local_storage.ts index 67e7c97e60c..fb4ce766f18 100644 --- a/tfjs-core/src/io/local_storage.ts +++ b/tfjs-core/src/io/local_storage.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {ENV} from '../environment'; +import {environment} from '../environment'; + import {assert} from '../util'; import {arrayBufferToBase64String, base64StringToArrayBuffer, getModelArtifactsInfoForJSON} from './io_utils'; import {ModelStoreManagerRegistry} from './model_management'; @@ -36,7 +37,7 @@ const MODEL_METADATA_SUFFIX = 'model_metadata'; * @returns Paths of the models purged. */ export function purgeLocalStorageArtifacts(): string[] { - if (!ENV.getBool('IS_BROWSER') || + if (!environment().getBool('IS_BROWSER') || typeof window.localStorage === 'undefined') { throw new Error( 'purgeLocalStorageModels() cannot proceed because local storage is ' + @@ -117,7 +118,7 @@ export class BrowserLocalStorage implements IOHandler { static readonly URL_SCHEME = 'localstorage://'; constructor(modelPath: string) { - if (!ENV.getBool('IS_BROWSER') || + if (!environment().getBool('IS_BROWSER') || typeof window.localStorage === 'undefined') { // TODO(cais): Add more info about what IOHandler subtypes are // available. @@ -255,7 +256,7 @@ export class BrowserLocalStorage implements IOHandler { } export const localStorageRouter: IORouter = (url: string|string[]) => { - if (!ENV.getBool('IS_BROWSER')) { + if (!environment().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { @@ -302,7 +303,7 @@ export class BrowserLocalStorageManager implements ModelStoreManager { constructor() { assert( - ENV.getBool('IS_BROWSER'), + environment().getBool('IS_BROWSER'), () => 'Current environment is not a web browser'); assert( typeof window.localStorage !== 'undefined', @@ -340,7 +341,7 @@ export class BrowserLocalStorageManager implements ModelStoreManager { } } -if (ENV.getBool('IS_BROWSER')) { +if (environment().getBool('IS_BROWSER')) { // Wrap the construction and registration, to guard against browsers that // don't support Local Storage. try { diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index ae5c3b0888b..51c61cab883 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -15,10 +15,10 @@ * ============================================================================= */ -import {ENV} from '../environment'; +import {environment} from '../environment'; + import {NamedTensorMap} from '../tensor_types'; import * as util from '../util'; - import {decodeWeights} from './io_utils'; import {monitorPromisesProgress} from './progress'; import {DTYPE_VALUE_SIZE_MAP, LoadOptions, WeightsManifestConfig, WeightsManifestEntry} from './types'; @@ -40,8 +40,9 @@ export async function loadWeightsAsArrayBuffer( loadOptions = {}; } - const fetchFunc = loadOptions.fetchFunc == null ? ENV.platform.fetch : - loadOptions.fetchFunc; + const fetchFunc = loadOptions.fetchFunc == null ? + environment().platform.fetch : + loadOptions.fetchFunc; // Create the requests for all of the weights in parallel. const requests = fetchURLs.map( diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index 4ec56c79df6..ed268377f2e 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -24,7 +24,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { [filename: string]: Float32Array|Int32Array|ArrayBuffer|Uint8Array| Uint16Array }) => { - spyOn(tf.ENV.platform, 'fetch').and.callFake((path: string) => { + spyOn(tf.environment().platform, 'fetch').and.callFake((path: string) => { return new Response( fileBufferMap[path], {headers: {'Content-type': 'application/octet-stream'}}); @@ -42,7 +42,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weightsNamesToFetch = ['weight0']; const weights = await tf.io.loadWeights(manifest, './', weightsNamesToFetch); - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(1); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(weightsNamesToFetch.length); @@ -66,7 +67,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { // Load the first weight. const weights = await tf.io.loadWeights(manifest, './', ['weight0']); - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(1); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(1); @@ -90,7 +92,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { // Load the second weight. const weights = await tf.io.loadWeights(manifest, './', ['weight1']); - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(1); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(1); @@ -115,7 +118,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { // Load all weights. const weights = await tf.io.loadWeights(manifest, './', ['weight0', 'weight1']); - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(1); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(2); @@ -154,7 +158,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { // Load all weights. const weights = await tf.io.loadWeights( manifest, './', ['weight0', 'weight1', 'weight2']); - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(1); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(3); @@ -192,7 +197,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { }]; const weights = await tf.io.loadWeights(manifest, './', ['weight0']); - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(3); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(3); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(1); @@ -232,7 +238,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weights = await tf.io.loadWeights(manifest, './', ['weight0', 'weight1']); - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(3); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(3); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(2); @@ -274,7 +281,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weights = await tf.io.loadWeights(manifest, './', ['weight0', 'weight1']); // Only the first group should be fetched. - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(1); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(2); @@ -316,7 +324,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weights = await tf.io.loadWeights(manifest, './', ['weight0', 'weight2']); // Both groups need to be fetched. - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(2); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(2); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(2); @@ -358,7 +367,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { // Don't pass a third argument to loadWeights to load all weights. const weights = await tf.io.loadWeights(manifest, './'); // Both groups need to be fetched. - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(2); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(2); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(4); @@ -434,8 +444,9 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weightsNamesToFetch = ['weight0']; await tf.io.loadWeights( manifest, './', weightsNamesToFetch, {credentials: 'include'}); - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(1); - expect(tf.ENV.platform.fetch) + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(1); + expect(tf.environment().platform.fetch) .toHaveBeenCalledWith( './weightfile0', {credentials: 'include'}, {isBinary: true}); }); @@ -466,7 +477,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weightsNamesToFetch = ['weight0', 'weight1']; const weights = await tf.io.loadWeights(manifest, './', weightsNamesToFetch); - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(1); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(weightsNamesToFetch.length); @@ -526,7 +538,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weights = await tf.io.loadWeights(manifest, './', ['weight0', 'weight2']); // Both groups need to be fetched. - expect((tf.ENV.platform.fetch as jasmine.Spy).calls.count()).toBe(2); + expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) + .toBe(2); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(2); diff --git a/tfjs-core/src/jasmine_util.ts b/tfjs-core/src/jasmine_util.ts index 522b2d2522a..9264fe8b6ec 100644 --- a/tfjs-core/src/jasmine_util.ts +++ b/tfjs-core/src/jasmine_util.ts @@ -16,7 +16,7 @@ */ import {KernelBackend} from './backends/backend'; import {ENGINE} from './engine'; -import {ENV, Environment, Flags} from './environment'; +import {Environment, environment, Flags} from './environment'; Error.stackTraceLimit = Infinity; @@ -26,13 +26,13 @@ export type Constraints = { }; export const NODE_ENVS: Constraints = { - predicate: () => ENV.platformName === 'node' + predicate: () => environment().platformName === 'node' }; export const CHROME_ENVS: Constraints = { flags: {'IS_CHROME': true} }; export const BROWSER_ENVS: Constraints = { - predicate: () => ENV.platformName === 'browser' + predicate: () => environment().platformName === 'browser' }; export const SYNC_BACKEND_ENVS: Constraints = { @@ -131,8 +131,8 @@ export function describeWithFlags( } TEST_ENVS.forEach(testEnv => { - ENV.setFlags(testEnv.flags); - if (envSatisfiesConstraints(ENV, testEnv, constraints)) { + environment().setFlags(testEnv.flags); + if (envSatisfiesConstraints(environment(), testEnv, constraints)) { const testName = name + ' ' + testEnv.name + ' ' + JSON.stringify(testEnv.flags); executeTests(testName, tests, testEnv); @@ -174,9 +174,9 @@ function executeTests( beforeAll(async () => { ENGINE.reset(); if (testEnv.flags != null) { - ENV.setFlags(testEnv.flags); + environment().setFlags(testEnv.flags); } - ENV.set('IS_TEST', true); + environment().set('IS_TEST', true); // Await setting the new backend since it can have async init. await ENGINE.setBackend(testEnv.backendName); }); diff --git a/tfjs-core/src/log.ts b/tfjs-core/src/log.ts index 964778d140b..4621a77828e 100644 --- a/tfjs-core/src/log.ts +++ b/tfjs-core/src/log.ts @@ -15,16 +15,16 @@ * ============================================================================= */ -import {ENV} from './environment'; +import {environment} from './environment'; export function warn(...msg: Array<{}>): void { - if (!ENV.getBool('IS_TEST')) { + if (!environment().getBool('IS_TEST')) { console.warn(...msg); } } export function log(...msg: Array<{}>): void { - if (!ENV.getBool('IS_TEST')) { + if (!environment().getBool('IS_TEST')) { console.log(...msg); } } diff --git a/tfjs-core/src/ops/slice_test.ts b/tfjs-core/src/ops/slice_test.ts index 2f71811b3c2..cdb8e5248b0 100644 --- a/tfjs-core/src/ops/slice_test.ts +++ b/tfjs-core/src/ops/slice_test.ts @@ -524,7 +524,7 @@ describeWithFlags('slice ergonomics', ALL_ENVS, () => { describeWithFlags('shallow slicing', ALL_ENVS, () => { beforeAll(() => { - tf.ENV.set('WEBGL_CPU_FORWARD', false); + tf.environment().set('WEBGL_CPU_FORWARD', false); }); it('shallow slice an input that was cast', async () => { diff --git a/tfjs-core/src/ops/tensor_ops.ts b/tfjs-core/src/ops/tensor_ops.ts index 2e2a7948dc6..b0c8f3504ba 100644 --- a/tfjs-core/src/ops/tensor_ops.ts +++ b/tfjs-core/src/ops/tensor_ops.ts @@ -16,7 +16,8 @@ */ import {ENGINE} from '../engine'; -import {ENV} from '../environment'; +import {environment} from '../environment'; + import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D} from '../tensor'; import {convertToTensor, inferShape} from '../tensor_util_env'; import {TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TensorLike5D, TensorLike6D, TypedArray} from '../types'; @@ -107,7 +108,7 @@ function makeTensor( shape = shape || inferredShape; values = dtype !== 'string' ? - toTypedArray(values, dtype, ENV.getBool('DEBUG')) : + toTypedArray(values, dtype, environment().getBool('DEBUG')) : flatten(values as string[], [], true) as string[]; return Tensor.make(shape, {values: values as TypedArray}, dtype); } diff --git a/tfjs-core/src/platforms/platform_browser.ts b/tfjs-core/src/platforms/platform_browser.ts index 2d69feaf12e..7b70f1c252d 100644 --- a/tfjs-core/src/platforms/platform_browser.ts +++ b/tfjs-core/src/platforms/platform_browser.ts @@ -15,11 +15,11 @@ * ============================================================================= */ -import {ENV} from '../environment'; +import {environment} from '../environment'; + import {Platform} from './platform'; export class PlatformBrowser implements Platform { - // According to the spec, the built-in encoder can do only UTF-8 encoding. // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder private textEncoder: TextEncoder; @@ -47,6 +47,6 @@ export class PlatformBrowser implements Platform { } } -if (ENV.get('IS_BROWSER')) { - ENV.setPlatform('browser', new PlatformBrowser()); +if (environment().get('IS_BROWSER')) { + environment().setPlatform('browser', new PlatformBrowser()); } diff --git a/tfjs-core/src/platforms/platform_node.ts b/tfjs-core/src/platforms/platform_node.ts index b32f8dc40ca..0e3da418b32 100644 --- a/tfjs-core/src/platforms/platform_node.ts +++ b/tfjs-core/src/platforms/platform_node.ts @@ -14,7 +14,7 @@ * limitations under the License. * ============================================================================= */ -import {ENV} from '../environment'; +import {environment} from '../environment'; import {Platform} from './platform'; @@ -52,8 +52,8 @@ export class PlatformNode implements Platform { } fetch(path: string, requestInits?: RequestInit): Promise { - if (ENV.global.fetch != null) { - return ENV.global.fetch(path, requestInits); + if (environment().global.fetch != null) { + return environment().global.fetch(path, requestInits); } if (systemFetch == null) { @@ -82,6 +82,6 @@ export class PlatformNode implements Platform { } } -if (ENV.get('IS_NODE')) { - ENV.setPlatform('node', new PlatformNode()); +if (environment().get('IS_NODE')) { + environment().setPlatform('node', new PlatformNode()); } diff --git a/tfjs-core/src/platforms/platform_node_test.ts b/tfjs-core/src/platforms/platform_node_test.ts index a58d11f5306..f81b1e61134 100644 --- a/tfjs-core/src/platforms/platform_node_test.ts +++ b/tfjs-core/src/platforms/platform_node_test.ts @@ -15,61 +15,66 @@ * ============================================================================= */ -import {ENV} from '../environment'; +import * as tf from '../index'; import {describeWithFlags, NODE_ENVS} from '../jasmine_util'; - import * as platform_node from './platform_node'; import {PlatformNode} from './platform_node'; describeWithFlags('PlatformNode', NODE_ENVS, () => { it('fetch should use global.fetch if defined', async () => { - const globalFetch = ENV.global.fetch; + const globalFetch = tf.environment().global.fetch; - spyOn(ENV.global, 'fetch').and.returnValue(() => {}); + spyOn(tf.environment().global, 'fetch').and.returnValue(() => {}); const platform = new PlatformNode(); await platform.fetch('test/url', {method: 'GET'}); - expect(ENV.global.fetch).toHaveBeenCalledWith('test/url', {method: 'GET'}); + expect(tf.environment().global.fetch).toHaveBeenCalledWith('test/url', { + method: 'GET' + }); - ENV.global.fetch = globalFetch; + tf.environment().global.fetch = globalFetch; }); - it('fetch should use node-fetch with ENV.global.fetch is null', async () => { - const globalFetch = ENV.global.fetch; - ENV.global.fetch = null; + it('fetch should use node-fetch with tf.environment().global.fetch is null', + async () => { + const globalFetch = tf.environment().global.fetch; + tf.environment().global.fetch = null; - const platform = new PlatformNode(); + const platform = new PlatformNode(); - const savedFetch = platform_node.getSystemFetch(); + const savedFetch = platform_node.getSystemFetch(); - // Null out the system fetch so we force it to require node-fetch. - platform_node.resetSystemFetch(); + // Null out the system fetch so we force it to require node-fetch. + platform_node.resetSystemFetch(); - const testFetch = {fetch: (url: string, init: RequestInit) => {}}; + const testFetch = {fetch: (url: string, init: RequestInit) => {}}; - // Mock the actual fetch call. - spyOn(testFetch, 'fetch').and.returnValue(() => {}); - // Mock the import to override the real require of node-fetch. - spyOn(platform_node.getNodeFetch, 'importFetch') - .and.callFake( - () => (url: string, init: RequestInit) => - testFetch.fetch(url, init)); + // Mock the actual fetch call. + spyOn(testFetch, 'fetch').and.returnValue(() => {}); + // Mock the import to override the real require of node-fetch. + spyOn(platform_node.getNodeFetch, 'importFetch') + .and.callFake( + () => (url: string, init: RequestInit) => + testFetch.fetch(url, init)); - await platform.fetch('test/url', {method: 'GET'}); + await platform.fetch('test/url', {method: 'GET'}); - expect(platform_node.getNodeFetch.importFetch).toHaveBeenCalled(); - expect(testFetch.fetch).toHaveBeenCalledWith('test/url', {method: 'GET'}); + expect(platform_node.getNodeFetch.importFetch).toHaveBeenCalled(); + expect(testFetch.fetch).toHaveBeenCalledWith('test/url', { + method: 'GET' + }); - platform_node.setSystemFetch(savedFetch); - ENV.global.fetch = globalFetch; - }); + platform_node.setSystemFetch(savedFetch); + tf.environment().global.fetch = globalFetch; + }); it('now should use process.hrtime', async () => { const time = [100, 200]; spyOn(process, 'hrtime').and.returnValue(time); - expect(ENV.platform.now()).toEqual(time[0] * 1000 + time[1] / 1000000); + expect(tf.environment().platform.now()) + .toEqual(time[0] * 1000 + time[1] / 1000000); }); it('encodeUTF8 single string', () => { diff --git a/tfjs-core/src/tensor_test.ts b/tfjs-core/src/tensor_test.ts index f51069f5dd2..e034bf3a2c8 100644 --- a/tfjs-core/src/tensor_test.ts +++ b/tfjs-core/src/tensor_test.ts @@ -1532,7 +1532,7 @@ describeWithFlags('tensor', ALL_ENVS, () => { describeWithFlags('tensor debug mode', ALL_ENVS, () => { beforeAll(() => { - tf.ENV.set('DEBUG', true); + tf.environment().set('DEBUG', true); }); it('tf.tensor() from TypedArray + number[] fails due to wrong shape', () => { diff --git a/tfjs-core/src/tensor_util_env.ts b/tfjs-core/src/tensor_util_env.ts index 10c9fe9e1a9..92f035ec5c4 100644 --- a/tfjs-core/src/tensor_util_env.ts +++ b/tfjs-core/src/tensor_util_env.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {ENV} from './environment'; +import {environment} from './environment'; + import {Tensor} from './tensor'; import {DataType, TensorLike} from './types'; import {assert, flatten, inferDtype, isTypedArray, toTypedArray} from './util'; @@ -36,7 +37,8 @@ export function inferShape(val: TensorLike, dtype?: DataType): number[] { shape.push(firstElem.length); firstElem = firstElem[0]; } - if (Array.isArray(val) && ENV.getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { + if (Array.isArray(val) && + environment().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { deepAssertShapeConsistency(val, shape, []); } @@ -111,7 +113,8 @@ export function convertToTensor( } const skipTypedArray = true; const values = inferredDtype !== 'string' ? - toTypedArray(x, inferredDtype as DataType, ENV.getBool('DEBUG')) : + toTypedArray( + x, inferredDtype as DataType, environment().getBool('DEBUG')) : flatten(x as string[], [], skipTypedArray) as string[]; return Tensor.make(inferredShape, {values}, inferredDtype); } diff --git a/tfjs-core/src/tensor_util_test.ts b/tfjs-core/src/tensor_util_test.ts index 53cdb1aaea7..2a4c5c6f708 100644 --- a/tfjs-core/src/tensor_util_test.ts +++ b/tfjs-core/src/tensor_util_test.ts @@ -211,7 +211,7 @@ describeWithFlags('convertToTensor', ALL_ENVS, () => { describeWithFlags('convertToTensor debug mode', ALL_ENVS, () => { beforeAll(() => { - tf.ENV.set('DEBUG', true); + tf.environment().set('DEBUG', true); }); it('fails to convert a non-valid shape array to tensor', () => { diff --git a/tfjs-core/src/util.ts b/tfjs-core/src/util.ts index 622a899a2c4..a3d6681286a 100644 --- a/tfjs-core/src/util.ts +++ b/tfjs-core/src/util.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {ENV} from './environment'; +import {environment} from './environment'; + import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray} from './types'; /** @@ -653,7 +654,7 @@ export function makeZerosTypedArray( */ /** @doc {heading: 'Util', namespace: 'util'} */ export function now(): number { - return ENV.platform.now(); + return environment().platform.now(); } export function assertNonNegativeIntegerDimensions(shape: number[]) { @@ -683,7 +684,7 @@ export function assertNonNegativeIntegerDimensions(shape: number[]) { /** @doc {heading: 'Util'} */ export function fetch( path: string, requestInits?: RequestInit): Promise { - return ENV.platform.fetch(path, requestInits); + return environment().platform.fetch(path, requestInits); } /** @@ -696,7 +697,7 @@ export function fetch( /** @doc {heading: 'Util'} */ export function encodeString(s: string, encoding = 'utf-8'): Uint8Array { encoding = encoding || 'utf-8'; - return ENV.platform.encode(s, encoding); + return environment().platform.encode(s, encoding); } /** @@ -708,5 +709,5 @@ export function encodeString(s: string, encoding = 'utf-8'): Uint8Array { /** @doc {heading: 'Util'} */ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string { encoding = encoding || 'utf-8'; - return ENV.platform.decode(bytes, encoding); + return environment().platform.decode(bytes, encoding); } diff --git a/tfjs-core/src/util_test.ts b/tfjs-core/src/util_test.ts index 7b3dce6cc4d..04d0479c11a 100644 --- a/tfjs-core/src/util_test.ts +++ b/tfjs-core/src/util_test.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {ENV} from './environment'; +import * as tf from './index'; import {ALL_ENVS, describeWithFlags} from './jasmine_util'; import {scalar, tensor2d} from './ops/ops'; import {inferShape} from './tensor_util_env'; @@ -516,11 +516,11 @@ describeWithFlags('util.toNestedArray', ALL_ENVS, () => { describe('util.fetch', () => { it('should call the platform fetch', () => { - spyOn(ENV.platform, 'fetch').and.callFake(() => {}); + spyOn(tf.environment().platform, 'fetch').and.callFake(() => {}); util.fetch('test/path', {method: 'GET'}); - expect(ENV.platform.fetch).toHaveBeenCalledWith('test/path', { + expect(tf.environment().platform.fetch).toHaveBeenCalledWith('test/path', { method: 'GET' }); }); diff --git a/tfjs-core/src/webgl.ts b/tfjs-core/src/webgl.ts index 422a8b8a780..3981bc87599 100644 --- a/tfjs-core/src/webgl.ts +++ b/tfjs-core/src/webgl.ts @@ -17,8 +17,8 @@ import * as gpgpu_util from './backends/webgl/gpgpu_util'; import * as webgl_util from './backends/webgl/webgl_util'; +import {environment} from './environment'; -import {ENV} from './environment'; export {MathBackendWebGL, WebGLMemoryInfo, WebGLTimingInfo} from './backends/webgl/backend_webgl'; export {setWebGLContext} from './backends/webgl/canvas_util'; export {GPGPUContext} from './backends/webgl/gpgpu_context'; @@ -31,5 +31,5 @@ export {gpgpu_util, webgl_util}; */ /** @doc {heading: 'Environment', namespace: 'webgl'} */ export function forceHalfFloat(): void { - ENV.set('WEBGL_FORCE_F16_TEXTURES', true); + environment().set('WEBGL_FORCE_F16_TEXTURES', true); } From cb7e9d4e1a5c02abbc7033ddc67bed9688830b1f Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Tue, 1 Oct 2019 11:08:48 -0400 Subject: [PATCH 3/3] save --- tfjs-converter/src/data/compiled_api.ts | 8 - tfjs-core/benchmarks/index.html | 2 +- tfjs-core/src/backends/cpu/backend_cpu.ts | 8 +- tfjs-core/src/backends/webgl/backend_webgl.ts | 163 +++++++------- .../src/backends/webgl/backend_webgl_test.ts | 79 +++---- .../src/backends/webgl/canvas_util_test.ts | 4 +- tfjs-core/src/backends/webgl/flags_webgl.ts | 4 +- .../src/backends/webgl/flags_webgl_test.ts | 210 +++++++++--------- tfjs-core/src/backends/webgl/glsl_version.ts | 4 +- tfjs-core/src/backends/webgl/gpgpu_context.ts | 30 +-- .../src/backends/webgl/gpgpu_context_test.ts | 8 +- tfjs-core/src/backends/webgl/gpgpu_math.ts | 6 +- .../src/backends/webgl/reshape_packed_test.ts | 12 +- tfjs-core/src/backends/webgl/tex_util.ts | 4 +- .../src/backends/webgl/texture_manager.ts | 4 +- .../backends/webgl/webgl_batchnorm_test.ts | 6 +- .../src/backends/webgl/webgl_ops_test.ts | 98 ++++---- tfjs-core/src/backends/webgl/webgl_util.ts | 10 +- .../src/backends/webgl/webgl_util_test.ts | 9 +- tfjs-core/src/debug_mode_test.ts | 6 +- tfjs-core/src/environment.ts | 15 +- tfjs-core/src/flags.ts | 4 +- tfjs-core/src/flags_test.ts | 50 ++--- tfjs-core/src/globals.ts | 12 +- tfjs-core/src/globals_test.ts | 6 +- tfjs-core/src/index.ts | 2 +- tfjs-core/src/io/browser_files.ts | 6 +- tfjs-core/src/io/http.ts | 4 +- tfjs-core/src/io/http_test.ts | 4 +- tfjs-core/src/io/indexed_db.ts | 8 +- tfjs-core/src/io/local_storage.ts | 12 +- tfjs-core/src/io/weights_loader.ts | 7 +- tfjs-core/src/io/weights_loader_test.ts | 43 ++-- tfjs-core/src/jasmine_util.ts | 14 +- tfjs-core/src/log.ts | 6 +- tfjs-core/src/ops/slice_test.ts | 2 +- tfjs-core/src/ops/tensor_ops.ts | 4 +- tfjs-core/src/platforms/platform_browser.ts | 6 +- tfjs-core/src/platforms/platform_node.ts | 10 +- tfjs-core/src/platforms/platform_node_test.ts | 19 +- tfjs-core/src/tensor_test.ts | 2 +- tfjs-core/src/tensor_util_env.ts | 7 +- tfjs-core/src/tensor_util_test.ts | 2 +- tfjs-core/src/util.ts | 10 +- tfjs-core/src/util_test.ts | 4 +- tfjs-core/src/webgl.ts | 4 +- 46 files changed, 449 insertions(+), 489 deletions(-) diff --git a/tfjs-converter/src/data/compiled_api.ts b/tfjs-converter/src/data/compiled_api.ts index 46db4f6e405..90b9c053e2d 100644 --- a/tfjs-converter/src/data/compiled_api.ts +++ b/tfjs-converter/src/data/compiled_api.ts @@ -16,14 +16,6 @@ * ============================================================================= */ -// This incantation makes Closure think that exported symbols are mutable. -// Mutable file-level exports are disallowed per style and won't reliably -// work. This hack also has a cost in terms of code size, and is only used -// to preserve the preexisting behavior of this code. -// tslint:disable-next-line:ban-ts-ignore see above -// @ts-ignore -exports = {}; - /* tslint:disable */ /** Properties of an Any. */ diff --git a/tfjs-core/benchmarks/index.html b/tfjs-core/benchmarks/index.html index d1cda5a5251..dc6f78dd63d 100644 --- a/tfjs-core/benchmarks/index.html +++ b/tfjs-core/benchmarks/index.html @@ -128,7 +128,7 @@

TensorFlow.js Model Benchmark

async function showEnvironment() { await tf.time(() => tf.add(tf.tensor1d([1]), tf.tensor1d([1])).data()); - envDiv.innerHTML += `
${JSON.stringify(tf.environment().features, null, 2) + envDiv.innerHTML += `
${JSON.stringify(tf.env().features, null, 2) } `; } diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index b20b6d14a54..3708bd74d11 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -18,7 +18,7 @@ import * as seedrandom from 'seedrandom'; import {ENGINE} from '../../engine'; -import {environment} from '../../environment'; +import {env} from '../../environment'; import {warn} from '../../log'; import * as array_ops_util from '../../ops/array_ops_util'; @@ -93,7 +93,7 @@ export class MathBackendCPU implements KernelBackend { private firstUse = true; constructor() { - if (environment().get('IS_BROWSER')) { + if (env().get('IS_BROWSER')) { const canvas = createCanvas(); if (canvas !== null) { this.fromPixels2DContext = @@ -106,7 +106,7 @@ export class MathBackendCPU implements KernelBackend { register(dataId: DataId, shape: number[], dtype: DataType): void { if (this.firstUse) { this.firstUse = false; - if (environment().get('IS_NODE')) { + if (env().get('IS_NODE')) { warn( '\n============================\n' + 'Hi there 👋. Looks like you are running TensorFlow.js in ' + @@ -155,7 +155,7 @@ export class MathBackendCPU implements KernelBackend { [pixels.width, pixels.height]; let vals: Uint8ClampedArray|Uint8Array; // tslint:disable-next-line:no-any - if (environment().get('IS_NODE') && (pixels as any).getContext == null) { + if (env().get('IS_NODE') && (pixels as any).getContext == null) { throw new Error( 'When running in node, pixels must be an HTMLCanvasElement ' + 'like the one returned by the `canvas` npm package'); diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 3d29161d149..f870aaa0a2a 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -20,7 +20,7 @@ import './flags_webgl'; import * as device_util from '../../device_util'; import {ENGINE, MemoryInfo, TimingInfo} from '../../engine'; -import {environment} from '../../environment'; +import {env} from '../../environment'; import {tidy} from '../../globals'; import {warn} from '../../log'; @@ -214,11 +214,11 @@ const CPU_HANDOFF_SIZE_THRESHOLD = 128; // * dpi / 1024 / 1024. const BEFORE_PAGING_CONSTANT = 600; function numMBBeforeWarning(): number { - if (environment().global.screen == null) { + if (env().global.screen == null) { return 1024; // 1 GB. } - return (environment().global.screen.height * - environment().global.screen.width * window.devicePixelRatio) * + return (env().global.screen.height * env().global.screen.width * + window.devicePixelRatio) * BEFORE_PAGING_CONSTANT / 1024 / 1024; } @@ -261,14 +261,13 @@ export class MathBackendWebGL implements KernelBackend { private warnedAboutMemory = false; constructor(private gpgpu?: GPGPUContext) { - if (!environment().getBool('HAS_WEBGL')) { + if (!env().getBool('HAS_WEBGL')) { throw new Error('WebGL is not supported on this device'); } if (gpgpu == null) { - const gl = getWebGLContext(environment().getNumber('WEBGL_VERSION')); - this.binaryCache = - getBinaryCache(environment().getNumber('WEBGL_VERSION')); + const gl = getWebGLContext(env().getNumber('WEBGL_VERSION')); + this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION')); this.gpgpu = new GPGPUContext(gl); this.canvas = gl.canvas; this.gpgpuCreatedLocally = true; @@ -333,8 +332,7 @@ export class MathBackendWebGL implements KernelBackend { if (this.fromPixels2DContext == null) { //@ts-ignore this.fromPixels2DContext = - createCanvas(environment().getNumber('WEBGL_VERSION')) - .getContext('2d'); + createCanvas(env().getNumber('WEBGL_VERSION')).getContext('2d'); } this.fromPixels2DContext.canvas.width = width; @@ -351,7 +349,7 @@ export class MathBackendWebGL implements KernelBackend { this.gpgpu.uploadPixelDataToTexture( this.getTexture(tempPixelHandle.dataId), pixels as ImageData); let program, res; - if (environment().getBool('WEBGL_PACK')) { + if (env().getBool('WEBGL_PACK')) { program = new FromPixelsPackedProgram(outShape); const packedOutput = this.makePackedTensor(program.outputShape, tempPixelHandle.dtype); @@ -377,15 +375,15 @@ export class MathBackendWebGL implements KernelBackend { throw new Error('MathBackendWebGL.write(): values can not be null'); } - if (environment().getBool('DEBUG')) { + if (env().getBool('DEBUG')) { for (let i = 0; i < values.length; i++) { const num = values[i] as number; if (!webgl_util.canBeRepresented(num)) { - if (environment().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) { + if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) { throw Error( `The value ${num} cannot be represented with your ` + `current settings. Consider enabling float32 rendering: ` + - `'tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`); + `'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`); } throw Error(`The value ${num} cannot be represented on this device.`); } @@ -472,8 +470,8 @@ export class MathBackendWebGL implements KernelBackend { return this.convertAndCacheOnCPU(dataId); } - if (!environment().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && - environment().getNumber('WEBGL_VERSION') === 2) { + if (!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && + env().getNumber('WEBGL_VERSION') === 2) { throw new Error( `tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` + `WEBGL_VERSION=2 not yet supported.`); @@ -482,7 +480,7 @@ export class MathBackendWebGL implements KernelBackend { let buffer = null; let tmpDownloadTarget: TensorHandle; - if (dtype !== 'complex64' && environment().get('WEBGL_BUFFER_SUPPORTED')) { + if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) { // Possibly copy the texture into a buffer before inserting a fence. tmpDownloadTarget = this.decode(dataId); const tmpData = this.texData.get(tmpDownloadTarget.dataId); @@ -533,7 +531,7 @@ export class MathBackendWebGL implements KernelBackend { private getValuesFromTexture(dataId: DataId): Float32Array { const {shape, dtype, isPacked} = this.texData.get(dataId); const size = util.sizeFromShape(shape); - if (environment().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) { + if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) { const tmpTarget = this.decode(dataId); const tmpData = this.texData.get(tmpTarget.dataId); const vals = this.gpgpu @@ -547,7 +545,7 @@ export class MathBackendWebGL implements KernelBackend { } const shouldUsePackedProgram = - environment().getBool('WEBGL_PACK') && isPacked === true; + env().getBool('WEBGL_PACK') && isPacked === true; const outputShape = shouldUsePackedProgram ? webgl_util.getShapeAs3D(shape) : shape; const tmpTarget = @@ -628,16 +626,14 @@ export class MathBackendWebGL implements KernelBackend { } private startTimer(): WebGLQuery|CPUTimerQuery { - if (environment().getNumber( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { return this.gpgpu.beginQuery(); } return {startMs: util.now(), endMs: null}; } private endTimer(query: WebGLQuery|CPUTimerQuery): WebGLQuery|CPUTimerQuery { - if (environment().getNumber( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { this.gpgpu.endQuery(); return query; } @@ -646,8 +642,7 @@ export class MathBackendWebGL implements KernelBackend { } private async getQueryTime(query: WebGLQuery|CPUTimerQuery): Promise { - if (environment().getNumber( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { return this.gpgpu.waitForQueryAndGetTime(query as WebGLQuery); } const timerQuery = query as CPUTimerQuery; @@ -711,7 +706,7 @@ export class MathBackendWebGL implements KernelBackend { } private getCPUBackend(): KernelBackend|null { - if (!environment().getBool('WEBGL_CPU_FORWARD')) { + if (!env().getBool('WEBGL_CPU_FORWARD')) { return null; } @@ -774,7 +769,7 @@ export class MathBackendWebGL implements KernelBackend { const {isPacked} = this.texData.get(x.dataId); const isContinous = slice_util.isSliceContinous(x.shape, begin, size); if (isPacked || !isContinous) { - const program = environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new SlicePackedProgram(size) : new SliceProgram(size); const customSetup = program.getCustomSetupFunc(begin); @@ -828,7 +823,7 @@ export class MathBackendWebGL implements KernelBackend { } reverse(x: T, axis: number[]): T { - const program = environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new ReversePackedProgram(x.shape, axis) : new ReverseProgram(x.shape, axis); return this.compileAndRun(program, [x]); @@ -847,15 +842,13 @@ export class MathBackendWebGL implements KernelBackend { if (tensors.length === 1) { return tensors[0]; } - if (tensors.length > - environment().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) { + if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) { const midIndex = Math.floor(tensors.length / 2); const leftSide = this.concat(tensors.slice(0, midIndex), axis); const rightSide = this.concat(tensors.slice(midIndex), axis); return this.concat([leftSide, rightSide], axis); } - if (environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && - tensors[0].rank > 1) { + if (env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && tensors[0].rank > 1) { const program = new ConcatPackedProgram(tensors.map(t => t.shape), axis); return this.compileAndRun(program, tensors); } @@ -879,7 +872,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.neg(x); } - if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.NEG, x.dtype) as T; } const program = new UnaryOpProgram(x.shape, unary_op.NEG); @@ -974,7 +967,7 @@ export class MathBackendWebGL implements KernelBackend { if (this.shouldExecuteOnCPU([a, b])) { return this.cpuBackend.multiply(a, b); } - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_gpu.MUL, a.dtype); } const program = new BinaryOpProgram(binaryop_gpu.MUL, a.shape, b.shape); @@ -1000,7 +993,7 @@ export class MathBackendWebGL implements KernelBackend { inputs.push(scale); } - if (environment().getBool('WEBGL_PACK_NORMALIZATION')) { + if (env().getBool('WEBGL_PACK_NORMALIZATION')) { const batchNormPackedProgram = new BatchNormPackedProgram( x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon); @@ -1016,7 +1009,7 @@ export class MathBackendWebGL implements KernelBackend { localResponseNormalization4D( x: Tensor4D, radius: number, bias: number, alpha: number, beta: number): Tensor4D { - const program = environment().getBool('WEBGL_PACK_NORMALIZATION') ? + const program = env().getBool('WEBGL_PACK_NORMALIZATION') ? new LRNPackedProgram(x.shape, radius, bias, alpha, beta) : new LRNProgram(x.shape, radius, bias, alpha, beta); return this.compileAndRun(program, [x]); @@ -1044,7 +1037,7 @@ export class MathBackendWebGL implements KernelBackend { pad( x: T, paddings: Array<[number, number]>, constantValue: number): T { - const program = environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new PadPackedProgram(x.shape, paddings, constantValue) : new PadProgram(x.shape, paddings, constantValue); return this.compileAndRun(program, [x]); @@ -1054,7 +1047,7 @@ export class MathBackendWebGL implements KernelBackend { if (this.shouldExecuteOnCPU([x])) { return this.cpuBackend.transpose(x, perm); } - const program = environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new TransposePackedProgram(x.shape, perm) : new TransposeProgram(x.shape, perm); return this.compileAndRun(program, [x]); @@ -1259,7 +1252,7 @@ export class MathBackendWebGL implements KernelBackend { axis_util.assertAxesAreInnerMostDims( 'arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.rank); - if (!environment().getBool('WEBGL_PACK_REDUCE') || x.rank <= 2) { + if (!env().getBool('WEBGL_PACK_REDUCE') || x.rank <= 2) { const [outShape, reduceShape] = axis_util.computeOutAndReduceShapes(x.shape, axes); const inSize = util.sizeFromShape(reduceShape); @@ -1289,7 +1282,7 @@ export class MathBackendWebGL implements KernelBackend { } equal(a: Tensor, b: Tensor): Tensor { - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.EQUAL, 'bool'); } const program = new BinaryOpProgram(binaryop_gpu.EQUAL, a.shape, b.shape); @@ -1298,7 +1291,7 @@ export class MathBackendWebGL implements KernelBackend { } notEqual(a: Tensor, b: Tensor): Tensor { - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.NOT_EQUAL, 'bool'); } const program = @@ -1312,7 +1305,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.less(a, b); } - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.LESS, 'bool'); } @@ -1322,7 +1315,7 @@ export class MathBackendWebGL implements KernelBackend { } lessEqual(a: Tensor, b: Tensor): Tensor { - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.LESS_EQUAL, 'bool'); } const program = @@ -1336,7 +1329,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.greater(a, b); } - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.GREATER, 'bool'); } @@ -1346,7 +1339,7 @@ export class MathBackendWebGL implements KernelBackend { } greaterEqual(a: Tensor, b: Tensor): Tensor { - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp( a, b, binaryop_packed_gpu.GREATER_EQUAL, 'bool'); } @@ -1362,7 +1355,7 @@ export class MathBackendWebGL implements KernelBackend { } logicalAnd(a: Tensor, b: Tensor): Tensor { - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.LOGICAL_AND, 'bool'); } const program = @@ -1372,7 +1365,7 @@ export class MathBackendWebGL implements KernelBackend { } logicalOr(a: Tensor, b: Tensor): Tensor { - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_packed_gpu.LOGICAL_OR, 'bool'); } const program = @@ -1415,14 +1408,14 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.minimum(a, b); } - const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(binaryop_packed_gpu.MIN, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.MIN, a.shape, b.shape); return this.compileAndRun(program, [a, b]); } mod(a: Tensor, b: Tensor): Tensor { - const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(binaryop_packed_gpu.MOD, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.MOD, a.shape, b.shape); return this.compileAndRun(program, [a, b]); @@ -1446,7 +1439,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.maximum(a, b); } - const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(binaryop_packed_gpu.MAX, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.MAX, a.shape, b.shape); return this.compileAndRun(program, [a, b]); @@ -1471,7 +1464,7 @@ export class MathBackendWebGL implements KernelBackend { } squaredDifference(a: Tensor, b: Tensor): Tensor { - const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram( binaryop_gpu.SQUARED_DIFFERENCE, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.SQUARED_DIFFERENCE, a.shape, b.shape); @@ -1481,7 +1474,7 @@ export class MathBackendWebGL implements KernelBackend { realDivide(a: Tensor, b: Tensor): Tensor { const op = binaryop_gpu.DIV; const outputDtype = 'float32'; - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { const checkOutOfBounds = true; return this.packedBinaryOp( a, b, binaryop_packed_gpu.DIV, outputDtype, checkOutOfBounds); @@ -1494,7 +1487,7 @@ export class MathBackendWebGL implements KernelBackend { floorDiv(a: Tensor, b: Tensor): Tensor { const op = binaryop_gpu.INT_DIV; const outputDtype = 'int32'; - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp( a, b, binaryop_packed_gpu.INT_DIV, outputDtype); } @@ -1513,7 +1506,7 @@ export class MathBackendWebGL implements KernelBackend { } const dtype = upcastType(a.dtype, b.dtype); - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_gpu.ADD, dtype); } const program = new BinaryOpProgram(binaryop_gpu.ADD, a.shape, b.shape); @@ -1584,7 +1577,7 @@ export class MathBackendWebGL implements KernelBackend { } // Limit the number of uploaded textures for optimization. - if (tensors.length > environment().get('WEBGL_MAX_TEXTURES_IN_SHADER')) { + if (tensors.length > env().get('WEBGL_MAX_TEXTURES_IN_SHADER')) { const midIndex = Math.floor(tensors.length / 2); const leftSide = this.addN(tensors.slice(0, midIndex)); const rightSide = this.addN(tensors.slice(midIndex)); @@ -1595,7 +1588,7 @@ export class MathBackendWebGL implements KernelBackend { tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2)); const shapes = tensors.map(t => t.shape); // We can make sure shapes are identical in op level. - const usePackedOp = environment().getBool('WEBGL_PACK'); + const usePackedOp = env().getBool('WEBGL_PACK'); const program = usePackedOp ? new AddNPackedProgram(tensors[0].shape, shapes) : new AddNProgram(tensors[0].shape, shapes); @@ -1614,7 +1607,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.subtract(a, b); } const dtype = upcastType(a.dtype, b.dtype); - if (environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { return this.packedBinaryOp(a, b, binaryop_gpu.SUB, a.dtype); } const program = new BinaryOpProgram(binaryop_gpu.SUB, a.shape, b.shape); @@ -1623,7 +1616,7 @@ export class MathBackendWebGL implements KernelBackend { } pow(a: T, b: Tensor): T { - const usePackedOp = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + const usePackedOp = env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); const program = usePackedOp ? new BinaryOpPackedProgram(binaryop_packed_gpu.POW, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.POW, a.shape, b.shape); @@ -1639,7 +1632,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.ceil(x); } - if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.CEIL, x.dtype) as T; } @@ -1652,7 +1645,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.floor(x); } - if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.FLOOR, x.dtype) as T; } @@ -1691,7 +1684,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.exp(x); } - if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.EXP, x.dtype) as T; } @@ -1704,7 +1697,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.expm1(x); } - if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.EXPM1, x.dtype) as T; } @@ -1717,7 +1710,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.log(x); } - if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_packed_op.LOG, x.dtype) as T; } @@ -1755,7 +1748,7 @@ export class MathBackendWebGL implements KernelBackend { relu(x: T): T { let program: UnaryOpProgram|UnaryOpPackedProgram; - if (environment().getBool('WEBGL_PACK')) { + if (env().getBool('WEBGL_PACK')) { program = new UnaryOpPackedProgram(x.shape, unary_packed_op.RELU); } else { program = new UnaryOpProgram(x.shape, unary_op.RELU); @@ -1765,7 +1758,7 @@ export class MathBackendWebGL implements KernelBackend { relu6(x: T): T { let program: UnaryOpProgram|UnaryOpPackedProgram; - if (environment().getBool('WEBGL_PACK')) { + if (env().getBool('WEBGL_PACK')) { program = new UnaryOpPackedProgram(x.shape, unary_packed_op.RELU6); } else { program = new UnaryOpProgram(x.shape, unary_op.RELU6); @@ -1774,7 +1767,7 @@ export class MathBackendWebGL implements KernelBackend { } prelu(x: T, alpha: T): T { - const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram( binaryop_packed_gpu.PRELU, x.shape, alpha.shape) : new BinaryOpProgram(binaryop_gpu.PRELU, x.shape, alpha.shape); @@ -1782,7 +1775,7 @@ export class MathBackendWebGL implements KernelBackend { } elu(x: T): T { - if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_packed_op.ELU, x.dtype) as T; } const program = new UnaryOpProgram(x.shape, unary_op.ELU); @@ -1790,7 +1783,7 @@ export class MathBackendWebGL implements KernelBackend { } eluDer(dy: T, y: T): T { - const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram( binaryop_packed_gpu.ELU_DER, dy.shape, y.shape) : new BinaryOpProgram(binaryop_gpu.ELU_DER, dy.shape, y.shape); @@ -1810,7 +1803,7 @@ export class MathBackendWebGL implements KernelBackend { clip(x: T, min: number, max: number): T { let program; - if (environment().getBool('WEBGL_PACK_CLIP')) { + if (env().getBool('WEBGL_PACK_CLIP')) { program = new ClipPackedProgram(x.shape); } else { program = new ClipProgram(x.shape); @@ -1824,7 +1817,7 @@ export class MathBackendWebGL implements KernelBackend { return this.cpuBackend.abs(x); } - if (environment().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, unary_op.ABS, x.dtype) as T; } @@ -1885,7 +1878,7 @@ export class MathBackendWebGL implements KernelBackend { } atan2(a: T, b: T): T { - const program = environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(binaryop_packed_gpu.ATAN2, a.shape, b.shape) : new BinaryOpProgram(binaryop_gpu.ATAN2, a.shape, b.shape); return this.compileAndRun(program, [a, b]); @@ -1952,9 +1945,8 @@ export class MathBackendWebGL implements KernelBackend { sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD; const reshapeWillBeExpensive = xShape[2] % 2 !== 0 && !!xTexData.isPacked; - if (batchMatMulWillBeUnpacked || - !environment().getBool('WEBGL_LAZILY_UNPACK') || - !environment().getBool('WEBGL_PACK_BINARY_OPERATIONS') || + if (batchMatMulWillBeUnpacked || !env().getBool('WEBGL_LAZILY_UNPACK') || + !env().getBool('WEBGL_PACK_BINARY_OPERATIONS') || !reshapeWillBeExpensive) { const targetShape = isChannelsLast ? xShape[0] * xShape[1] * xShape[2] : xShape[0] * xShape[2] * xShape[3]; @@ -2101,7 +2093,7 @@ export class MathBackendWebGL implements KernelBackend { return this.conv2dByMatMul( input, filter, convInfo, bias, activation, preluActivationWeights); } - if (environment().getBool('WEBGL_CONV_IM2COL') && input.shape[0] === 1) { + if (env().getBool('WEBGL_CONV_IM2COL') && input.shape[0] === 1) { return this.conv2dWithIm2Row( input, filter, convInfo, bias, activation, preluActivationWeights); } @@ -2130,7 +2122,7 @@ export class MathBackendWebGL implements KernelBackend { convInfo.padInfo.type === 'VALID')) { return this.conv2dByMatMul(x, filter, convInfo); } - if (environment().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) { + if (env().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) { return this.conv2dWithIm2Row(x, filter, convInfo); } const program = new Conv2DProgram(convInfo); @@ -2151,8 +2143,7 @@ export class MathBackendWebGL implements KernelBackend { fusedDepthwiseConv2D( {input, filter, convInfo, bias, activation, preluActivationWeights}: FusedConv2DConfig): Tensor4D { - const shouldPackDepthwiseConv = - environment().getBool('WEBGL_PACK_DEPTHWISECONV') && + const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1; const fusedActivation = activation ? @@ -2186,7 +2177,7 @@ export class MathBackendWebGL implements KernelBackend { depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D { let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram; - if (environment().getBool('WEBGL_PACK_DEPTHWISECONV') && + if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1) { program = new DepthwiseConvPacked2DProgram(convInfo); @@ -2337,7 +2328,7 @@ export class MathBackendWebGL implements KernelBackend { resizeBilinear( x: Tensor4D, newHeight: number, newWidth: number, alignCorners: boolean): Tensor4D { - const program = environment().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? + const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeBilinearPackedProgram( x.shape, newHeight, newWidth, alignCorners) : new ResizeBilinearProgram(x.shape, newHeight, newWidth, alignCorners); @@ -2665,7 +2656,7 @@ export class MathBackendWebGL implements KernelBackend { if (texData.texture == null) { if (!program.usesPackedTextures && util.sizeFromShape(input.shape) <= - environment().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) { + env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) { // Upload small tensors that live on the CPU as uniforms, not as // textures. Do this only when the environment supports 32bit floats // due to problems when comparing 16bit floats with 32bit floats. @@ -2739,7 +2730,7 @@ export class MathBackendWebGL implements KernelBackend { {name: program.constructor.name, query: this.getQueryTime(query)}); } - if (!environment().getBool('WEBGL_LAZILY_UNPACK') && + if (!env().getBool('WEBGL_LAZILY_UNPACK') && this.texData.get(output.dataId).isPacked && preventEagerUnpackingOfOutput === false) { return this.unpackTensor(output as {} as Tensor) as {} as K; @@ -2787,13 +2778,13 @@ export class MathBackendWebGL implements KernelBackend { floatPrecision(): 16|32 { if (this.floatPrecisionValue == null) { this.floatPrecisionValue = tidy(() => { - if (!environment().get('WEBGL_RENDER_FLOAT32_ENABLED')) { + if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) { // Momentarily switching DEBUG flag to false so we don't throw an // error trying to upload a small value. - const debugFlag = environment().getBool('DEBUG'); - environment().set('DEBUG', false); + const debugFlag = env().getBool('DEBUG'); + env().set('DEBUG', false); const underflowCheckValue = this.abs(scalar(1e-8)).dataSync()[0]; - environment().set('DEBUG', debugFlag); + env().set('DEBUG', debugFlag); if (underflowCheckValue > 0) { return 32; diff --git a/tfjs-core/src/backends/webgl/backend_webgl_test.ts b/tfjs-core/src/backends/webgl/backend_webgl_test.ts index 3884ebf4b0b..1436ca0c93c 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl_test.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl_test.ts @@ -41,12 +41,12 @@ describeWithFlags('forced f16 render', RENDER_FLOAT32_ENVS, () => { beforeAll(() => { renderToF32FlagSaved = - tf.environment().get('WEBGL_RENDER_FLOAT32_ENABLED') as boolean; - tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', false); + tf.env().get('WEBGL_RENDER_FLOAT32_ENABLED') as boolean; + tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', false); }); afterAll(() => { - tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', renderToF32FlagSaved); + tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', renderToF32FlagSaved); }); it('should overflow if larger than 66k', async () => { @@ -56,11 +56,11 @@ describeWithFlags('forced f16 render', RENDER_FLOAT32_ENVS, () => { }); it('should error in debug mode', () => { - const savedDebugFlag = tf.environment().getBool('DEBUG'); - tf.environment().set('DEBUG', true); + const savedDebugFlag = tf.env().getBool('DEBUG'); + tf.env().set('DEBUG', true); const a = () => tf.tensor1d([2, Math.pow(2, 17)], 'float32'); expect(a).toThrowError(); - tf.environment().set('DEBUG', savedDebugFlag); + tf.env().set('DEBUG', savedDebugFlag); }); }); @@ -69,16 +69,15 @@ describeWithFlags('lazy packing and unpacking', WEBGL_ENVS, () => { let webglCpuForwardFlagSaved: boolean; beforeAll(() => { - webglLazilyUnpackFlagSaved = - tf.environment().getBool('WEBGL_LAZILY_UNPACK'); - webglCpuForwardFlagSaved = tf.environment().getBool('WEBGL_CPU_FORWARD'); - tf.environment().set('WEBGL_LAZILY_UNPACK', true); - tf.environment().set('WEBGL_CPU_FORWARD', false); + webglLazilyUnpackFlagSaved = tf.env().getBool('WEBGL_LAZILY_UNPACK'); + webglCpuForwardFlagSaved = tf.env().getBool('WEBGL_CPU_FORWARD'); + tf.env().set('WEBGL_LAZILY_UNPACK', true); + tf.env().set('WEBGL_CPU_FORWARD', false); }); afterAll(() => { - tf.environment().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); - tf.environment().set('WEBGL_CPU_FORWARD', webglCpuForwardFlagSaved); + tf.env().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); + tf.env().set('WEBGL_CPU_FORWARD', webglCpuForwardFlagSaved); }); it('should not leak memory when lazily unpacking', () => { @@ -94,11 +93,11 @@ describeWithFlags('lazy packing and unpacking', WEBGL_ENVS, () => { (tf.memory() as tf.webgl.WebGLMemoryInfo).numBytesInGPU; const webglPackBinaryOperationsFlagSaved = - tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', false); + tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.env().set('WEBGL_PACK_BINARY_OPERATIONS', false); // Add will unpack c before the operation to 2 tf.add(c, 1); - tf.environment().set( + tf.env().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); expect(tf.memory().numBytes - startNumBytes).toEqual(16); @@ -231,20 +230,19 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => { tf.registerBackend('test-storage', () => backend); tf.setBackend('test-storage'); - const webglPackFlagSaved = tf.environment().getBool('WEBGL_PACK'); - tf.environment().set('WEBGL_PACK', true); + const webglPackFlagSaved = tf.env().getBool('WEBGL_PACK'); + tf.env().set('WEBGL_PACK', true); const webglSizeUploadUniformSaved = - tf.environment().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM'); - tf.environment().set('WEBGL_SIZE_UPLOAD_UNIFORM', 0); + tf.env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM'); + tf.env().set('WEBGL_SIZE_UPLOAD_UNIFORM', 0); const a = tf.tensor2d([1, 2], [2, 1]); const b = tf.tensor2d([1], [1, 1]); const c = tf.matMul(a, b); backend.readSync(c.dataId); - tf.environment().set('WEBGL_PACK', false); + tf.env().set('WEBGL_PACK', false); const d = tf.add(c, 1); - tf.environment().set('WEBGL_PACK', webglPackFlagSaved); - tf.environment().set( - 'WEBGL_SIZE_UPLOAD_UNIFORM', webglSizeUploadUniformSaved); + tf.env().set('WEBGL_PACK', webglPackFlagSaved); + tf.env().set('WEBGL_SIZE_UPLOAD_UNIFORM', webglSizeUploadUniformSaved); expectArraysClose(await d.data(), [2, 3]); }); @@ -322,12 +320,12 @@ describeWithFlags('upload tensors as uniforms', FLOAT32_WEBGL_ENVS, () => { beforeAll(() => { savedUploadUniformValue = - tf.environment().get('WEBGL_SIZE_UPLOAD_UNIFORM') as number; - tf.environment().set('WEBGL_SIZE_UPLOAD_UNIFORM', SIZE_UPLOAD_UNIFORM); + tf.env().get('WEBGL_SIZE_UPLOAD_UNIFORM') as number; + tf.env().set('WEBGL_SIZE_UPLOAD_UNIFORM', SIZE_UPLOAD_UNIFORM); }); afterAll(() => { - tf.environment().set('WEBGL_SIZE_UPLOAD_UNIFORM', savedUploadUniformValue); + tf.env().set('WEBGL_SIZE_UPLOAD_UNIFORM', savedUploadUniformValue); }); it('small tensor gets uploaded as scalar', () => { @@ -387,31 +385,29 @@ describeWithFlags('debug on webgl', WEBGL_ENVS, () => { beforeAll(() => { // Silences debug warnings. spyOn(console, 'warn'); - tf.environment().set('DEBUG', true); + tf.env().set('DEBUG', true); }); afterAll(() => { - tf.environment().set('DEBUG', false); + tf.env().set('DEBUG', false); }); it('debug mode errors when overflow in tensor construction', () => { const savedRenderFloat32Flag = - tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED'); - tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', false); + tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED'); + tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', false); const a = () => tf.tensor1d([2, Math.pow(2, 17)], 'float32'); expect(a).toThrowError(); - tf.environment().set( - 'WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); + tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); }); it('debug mode errors when underflow in tensor construction', () => { const savedRenderFloat32Flag = - tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED'); - tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', false); + tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED'); + tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', false); const a = () => tf.tensor1d([2, 1e-8], 'float32'); expect(a).toThrowError(); - tf.environment().set( - 'WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); + tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); }); }); @@ -430,10 +426,10 @@ describeWithFlags('memory webgl', WEBGL_ENVS, () => { // point. describeWithFlags('backend without render float32 support', WEBGL_ENVS, () => { const savedRenderFloat32Flag = - tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED'); + tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED'); beforeAll(() => { - tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', false); + tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', false); }); beforeEach(() => { @@ -445,8 +441,7 @@ describeWithFlags('backend without render float32 support', WEBGL_ENVS, () => { }); afterAll(() => { - tf.environment().set( - 'WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); + tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', savedRenderFloat32Flag); }); it('basic usage', async () => { @@ -515,7 +510,7 @@ describeWithFlags('time webgl', WEBGL_ENVS, () => { describeWithFlags('caching on cpu', WEBGL_ENVS, () => { beforeAll(() => { - tf.environment().set('WEBGL_CPU_FORWARD', false); + tf.env().set('WEBGL_CPU_FORWARD', false); }); it('caches on cpu after async read', async () => { diff --git a/tfjs-core/src/backends/webgl/canvas_util_test.ts b/tfjs-core/src/backends/webgl/canvas_util_test.ts index a13fc300468..f9ec81dfdab 100644 --- a/tfjs-core/src/backends/webgl/canvas_util_test.ts +++ b/tfjs-core/src/backends/webgl/canvas_util_test.ts @@ -22,7 +22,7 @@ import {getWebGLContext} from './canvas_util'; describeWithFlags('canvas_util', BROWSER_ENVS, () => { it('Returns a valid canvas', () => { const canvas = - getWebGLContext(tf.environment().getNumber('WEBGL_VERSION')).canvas as + getWebGLContext(tf.env().getNumber('WEBGL_VERSION')).canvas as // tslint:disable-next-line: no-any any; expect( @@ -32,7 +32,7 @@ describeWithFlags('canvas_util', BROWSER_ENVS, () => { }); it('Returns a valid gl context', () => { - const gl = getWebGLContext(tf.environment().getNumber('WEBGL_VERSION')); + const gl = getWebGLContext(tf.env().getNumber('WEBGL_VERSION')); expect(gl.isContextLost()).toBe(false); }); }); diff --git a/tfjs-core/src/backends/webgl/flags_webgl.ts b/tfjs-core/src/backends/webgl/flags_webgl.ts index d3ddfc297c4..25899b2c59b 100644 --- a/tfjs-core/src/backends/webgl/flags_webgl.ts +++ b/tfjs-core/src/backends/webgl/flags_webgl.ts @@ -16,11 +16,11 @@ */ import * as device_util from '../../device_util'; -import {environment} from '../../environment'; +import {env} from '../../environment'; import * as webgl_util from './webgl_util'; -const ENV = environment(); +const ENV = env(); /** * This file contains WebGL-specific flag registrations. diff --git a/tfjs-core/src/backends/webgl/flags_webgl_test.ts b/tfjs-core/src/backends/webgl/flags_webgl_test.ts index 4231234d1b7..53be90dd685 100644 --- a/tfjs-core/src/backends/webgl/flags_webgl_test.ts +++ b/tfjs-core/src/backends/webgl/flags_webgl_test.ts @@ -24,18 +24,17 @@ import {WEBGL_ENVS} from './backend_webgl_test_registry'; import * as canvas_util from './canvas_util'; describe('WEBGL_FORCE_F16_TEXTURES', () => { - afterAll(() => tf.environment().reset()); + afterAll(() => tf.env().reset()); it('can be activated via forceHalfFloat utility', () => { tf.webgl.forceHalfFloat(); - expect(tf.environment().getBool('WEBGL_FORCE_F16_TEXTURES')).toBe(true); + expect(tf.env().getBool('WEBGL_FORCE_F16_TEXTURES')).toBe(true); }); it('turns off WEBGL_RENDER_FLOAT32_ENABLED', () => { - tf.environment().reset(); + tf.env().reset(); tf.webgl.forceHalfFloat(); - expect(tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) - .toBe(false); + expect(tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')).toBe(false); }); }); @@ -51,215 +50,212 @@ const RENDER_FLOAT16_ENVS = { describeWithFlags('WEBGL_RENDER_FLOAT32_CAPABLE', RENDER_FLOAT32_ENVS, () => { beforeEach(() => { - tf.environment().reset(); + tf.env().reset(); }); - afterAll(() => tf.environment().reset()); + afterAll(() => tf.env().reset()); it('should be independent of forcing f16 rendering', () => { tf.webgl.forceHalfFloat(); - expect(tf.environment().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')).toBe(true); + expect(tf.env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')).toBe(true); }); it('if user is not forcing f16, device should render to f32', () => { - expect(tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED')).toBe(true); + expect(tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')).toBe(true); }); }); describeWithFlags('WEBGL_RENDER_FLOAT32_CAPABLE', RENDER_FLOAT16_ENVS, () => { beforeEach(() => { - tf.environment().reset(); + tf.env().reset(); }); - afterAll(() => tf.environment().reset()); + afterAll(() => tf.env().reset()); it('should be independent of forcing f16 rendering', () => { tf.webgl.forceHalfFloat(); - expect(tf.environment().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) - .toBe(false); + expect(tf.env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')).toBe(false); }); it('should be reflected in WEBGL_RENDER_FLOAT32_ENABLED', () => { - expect(tf.environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) - .toBe(false); + expect(tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')).toBe(false); }); }); describe('HAS_WEBGL', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('false when version is 0', () => { - tf.environment().set('WEBGL_VERSION', 0); - expect(tf.environment().getBool('HAS_WEBGL')).toBe(false); + tf.env().set('WEBGL_VERSION', 0); + expect(tf.env().getBool('HAS_WEBGL')).toBe(false); }); it('true when version is 1', () => { - tf.environment().set('WEBGL_VERSION', 1); - expect(tf.environment().getBool('HAS_WEBGL')).toBe(true); + tf.env().set('WEBGL_VERSION', 1); + expect(tf.env().getBool('HAS_WEBGL')).toBe(true); }); it('true when version is 2', () => { - tf.environment().set('WEBGL_VERSION', 2); - expect(tf.environment().getBool('HAS_WEBGL')).toBe(true); + tf.env().set('WEBGL_VERSION', 2); + expect(tf.env().getBool('HAS_WEBGL')).toBe(true); }); }); describe('WEBGL_PACK', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('true when HAS_WEBGL is true', () => { - tf.environment().set('HAS_WEBGL', true); - expect(tf.environment().getBool('WEBGL_PACK')).toBe(true); + tf.env().set('HAS_WEBGL', true); + expect(tf.env().getBool('WEBGL_PACK')).toBe(true); }); it('false when HAS_WEBGL is false', () => { - tf.environment().set('HAS_WEBGL', false); - expect(tf.environment().getBool('WEBGL_PACK')).toBe(false); + tf.env().set('HAS_WEBGL', false); + expect(tf.env().getBool('WEBGL_PACK')).toBe(false); }); }); describe('WEBGL_PACK_NORMALIZATION', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('true when WEBGL_PACK is true', () => { - tf.environment().set('WEBGL_PACK', true); - expect(tf.environment().getBool('WEBGL_PACK_NORMALIZATION')).toBe(true); + tf.env().set('WEBGL_PACK', true); + expect(tf.env().getBool('WEBGL_PACK_NORMALIZATION')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - tf.environment().set('WEBGL_PACK', false); - expect(tf.environment().getBool('WEBGL_PACK_NORMALIZATION')).toBe(false); + tf.env().set('WEBGL_PACK', false); + expect(tf.env().getBool('WEBGL_PACK_NORMALIZATION')).toBe(false); }); }); describe('WEBGL_PACK_CLIP', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('true when WEBGL_PACK is true', () => { - tf.environment().set('WEBGL_PACK', true); - expect(tf.environment().getBool('WEBGL_PACK_CLIP')).toBe(true); + tf.env().set('WEBGL_PACK', true); + expect(tf.env().getBool('WEBGL_PACK_CLIP')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - tf.environment().set('WEBGL_PACK', false); - expect(tf.environment().getBool('WEBGL_PACK_CLIP')).toBe(false); + tf.env().set('WEBGL_PACK', false); + expect(tf.env().getBool('WEBGL_PACK_CLIP')).toBe(false); }); }); // TODO: https://github.com/tensorflow/tfjs/issues/1679 // describe('WEBGL_PACK_DEPTHWISECONV', () => { -// beforeEach(() => tf.environment().reset()); -// afterAll(() => tf.environment().reset()); +// beforeEach(() => tf.env().reset()); +// afterAll(() => tf.env().reset()); // it('true when WEBGL_PACK is true', () => { -// tf.environment().set('WEBGL_PACK', true); -// expect(tf.environment().getBool('WEBGL_PACK_DEPTHWISECONV')).toBe(true); +// tf.env().set('WEBGL_PACK', true); +// expect(tf.env().getBool('WEBGL_PACK_DEPTHWISECONV')).toBe(true); // }); // it('false when WEBGL_PACK is false', () => { -// tf.environment().set('WEBGL_PACK', false); -// expect(tf.environment().getBool('WEBGL_PACK_DEPTHWISECONV')).toBe(false); +// tf.env().set('WEBGL_PACK', false); +// expect(tf.env().getBool('WEBGL_PACK_DEPTHWISECONV')).toBe(false); // }); // }); describe('WEBGL_PACK_BINARY_OPERATIONS', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('true when WEBGL_PACK is true', () => { - tf.environment().set('WEBGL_PACK', true); - expect(tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')).toBe(true); + tf.env().set('WEBGL_PACK', true); + expect(tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - tf.environment().set('WEBGL_PACK', false); - expect(tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS')) - .toBe(false); + tf.env().set('WEBGL_PACK', false); + expect(tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS')).toBe(false); }); }); describe('WEBGL_PACK_ARRAY_OPERATIONS', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('true when WEBGL_PACK is true', () => { - tf.environment().set('WEBGL_PACK', true); - expect(tf.environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS')).toBe(true); + tf.env().set('WEBGL_PACK', true); + expect(tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - tf.environment().set('WEBGL_PACK', false); - expect(tf.environment().getBool('WEBGL_PACK_ARRAY_OPERATIONS')).toBe(false); + tf.env().set('WEBGL_PACK', false); + expect(tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS')).toBe(false); }); }); describe('WEBGL_PACK_IMAGE_OPERATIONS', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('true when WEBGL_PACK is true', () => { - tf.environment().set('WEBGL_PACK', true); - expect(tf.environment().getBool('WEBGL_PACK_IMAGE_OPERATIONS')).toBe(true); + tf.env().set('WEBGL_PACK', true); + expect(tf.env().getBool('WEBGL_PACK_IMAGE_OPERATIONS')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - tf.environment().set('WEBGL_PACK', false); - expect(tf.environment().getBool('WEBGL_PACK_IMAGE_OPERATIONS')).toBe(false); + tf.env().set('WEBGL_PACK', false); + expect(tf.env().getBool('WEBGL_PACK_IMAGE_OPERATIONS')).toBe(false); }); }); describe('WEBGL_PACK_REDUCE', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('true when WEBGL_PACK is true', () => { - tf.environment().set('WEBGL_PACK', true); - expect(tf.environment().getBool('WEBGL_PACK_REDUCE')).toBe(true); + tf.env().set('WEBGL_PACK', true); + expect(tf.env().getBool('WEBGL_PACK_REDUCE')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - tf.environment().set('WEBGL_PACK', false); - expect(tf.environment().getBool('WEBGL_PACK_REDUCE')).toBe(false); + tf.env().set('WEBGL_PACK', false); + expect(tf.env().getBool('WEBGL_PACK_REDUCE')).toBe(false); }); }); describe('WEBGL_LAZILY_UNPACK', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('true when WEBGL_PACK is true', () => { - tf.environment().set('WEBGL_PACK', true); - expect(tf.environment().getBool('WEBGL_LAZILY_UNPACK')).toBe(true); + tf.env().set('WEBGL_PACK', true); + expect(tf.env().getBool('WEBGL_LAZILY_UNPACK')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - tf.environment().set('WEBGL_PACK', false); - expect(tf.environment().getBool('WEBGL_LAZILY_UNPACK')).toBe(false); + tf.env().set('WEBGL_PACK', false); + expect(tf.env().getBool('WEBGL_LAZILY_UNPACK')).toBe(false); }); }); describe('WEBGL_CONV_IM2COL', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('true when WEBGL_PACK is true', () => { - tf.environment().set('WEBGL_PACK', true); - expect(tf.environment().getBool('WEBGL_CONV_IM2COL')).toBe(true); + tf.env().set('WEBGL_PACK', true); + expect(tf.env().getBool('WEBGL_CONV_IM2COL')).toBe(true); }); it('false when WEBGL_PACK is false', () => { - tf.environment().set('WEBGL_PACK', false); - expect(tf.environment().getBool('WEBGL_CONV_IM2COL')).toBe(false); + tf.env().set('WEBGL_PACK', false); + expect(tf.env().getBool('WEBGL_CONV_IM2COL')).toBe(false); }); }); describe('WEBGL_MAX_TEXTURE_SIZE', () => { beforeEach(() => { - tf.environment().reset(); + tf.env().reset(); webgl_util.resetMaxTextureSize(); spyOn(canvas_util, 'getWebGLContext').and.returnValue({ @@ -273,19 +269,19 @@ describe('WEBGL_MAX_TEXTURE_SIZE', () => { }); }); afterAll(() => { - tf.environment().reset(); + tf.env().reset(); webgl_util.resetMaxTextureSize(); }); it('is a function of gl.getParameter(MAX_TEXTURE_SIZE)', () => { - expect(tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE')).toBe(50); + expect(tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE')).toBe(50); }); }); describe('WEBGL_MAX_TEXTURES_IN_SHADER', () => { let maxTextures: number; beforeEach(() => { - tf.environment().reset(); + tf.env().reset(); webgl_util.resetMaxTexturesInShader(); spyOn(canvas_util, 'getWebGLContext').and.callFake(() => { @@ -301,65 +297,61 @@ describe('WEBGL_MAX_TEXTURES_IN_SHADER', () => { }); }); afterAll(() => { - tf.environment().reset(); + tf.env().reset(); webgl_util.resetMaxTexturesInShader(); }); it('is a function of gl.getParameter(MAX_TEXTURE_IMAGE_UNITS)', () => { maxTextures = 10; - expect(tf.environment().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')).toBe(10); + expect(tf.env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')).toBe(10); }); it('is capped at 16', () => { maxTextures = 20; - expect(tf.environment().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')).toBe(16); + expect(tf.env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')).toBe(16); }); }); describe('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('disjoint query timer disabled', () => { - tf.environment().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 0); + tf.env().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 0); - expect(tf.environment().getBool( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) + expect(tf.env().getBool('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) .toBe(false); }); it('disjoint query timer enabled, mobile', () => { - tf.environment().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 1); + tf.env().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 1); spyOn(device_util, 'isMobile').and.returnValue(true); - expect(tf.environment().getBool( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) + expect(tf.env().getBool('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) .toBe(false); }); it('disjoint query timer enabled, not mobile', () => { - tf.environment().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 1); + tf.env().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', 1); spyOn(device_util, 'isMobile').and.returnValue(false); - expect(tf.environment().getBool( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) + expect(tf.env().getBool('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) .toBe(true); }); }); describe('WEBGL_SIZE_UPLOAD_UNIFORM', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('is 0 when there is no float32 bit support', () => { - tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', false); - expect(tf.environment().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')).toBe(0); + tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', false); + expect(tf.env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')).toBe(0); }); it('is > 0 when there is float32 bit support', () => { - tf.environment().set('WEBGL_RENDER_FLOAT32_ENABLED', true); - expect(tf.environment().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) - .toBeGreaterThan(0); + tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true); + expect(tf.env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')).toBeGreaterThan(0); }); }); diff --git a/tfjs-core/src/backends/webgl/glsl_version.ts b/tfjs-core/src/backends/webgl/glsl_version.ts index b074c6ce08e..d99e885a31f 100644 --- a/tfjs-core/src/backends/webgl/glsl_version.ts +++ b/tfjs-core/src/backends/webgl/glsl_version.ts @@ -14,7 +14,7 @@ * limitations under the License. * ============================================================================= */ -import {environment} from '../../environment'; +import {env} from '../../environment'; export type GLSL = { version: string, @@ -41,7 +41,7 @@ export function getGlslDifferences(): GLSL { let defineSpecialInf: string; let defineRound: string; - if (environment().getNumber('WEBGL_VERSION') === 2) { + if (env().getNumber('WEBGL_VERSION') === 2) { version = '#version 300 es'; attribute = 'in'; varyingVs = 'out'; diff --git a/tfjs-core/src/backends/webgl/gpgpu_context.ts b/tfjs-core/src/backends/webgl/gpgpu_context.ts index 6e39e875e04..3510ba4b5aa 100644 --- a/tfjs-core/src/backends/webgl/gpgpu_context.ts +++ b/tfjs-core/src/backends/webgl/gpgpu_context.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from '../../environment'; +import {env} from '../../environment'; import {PixelData, TypedArray} from '../../types'; import * as util from '../../util'; @@ -50,7 +50,7 @@ export class GPGPUContext { private textureConfig: TextureConfig; constructor(gl?: WebGLRenderingContext) { - const glVersion = environment().getNumber('WEBGL_VERSION'); + const glVersion = env().getNumber('WEBGL_VERSION'); if (gl != null) { this.gl = gl; setWebGLContext(glVersion, gl); @@ -58,7 +58,7 @@ export class GPGPUContext { this.gl = getWebGLContext(glVersion); } // WebGL 2.0 enables texture floats without an extension. - if (environment().getNumber('WEBGL_VERSION') === 1) { + if (env().getNumber('WEBGL_VERSION') === 1) { this.textureFloatExtension = webgl_util.getExtensionOrThrow( this.gl, this.debug, 'OES_texture_float'); this.colorBufferFloatExtension = @@ -91,7 +91,7 @@ export class GPGPUContext { } private get debug(): boolean { - return environment().getBool('DEBUG'); + return env().getBool('DEBUG'); } public dispose() { @@ -226,7 +226,7 @@ export class GPGPUContext { let query: WebGLQuery|WebGLSync; let isFencePassed: () => boolean; - if (environment().getBool('WEBGL_FENCE_API_ENABLED')) { + if (env().getBool('WEBGL_FENCE_API_ENABLED')) { const gl2 = gl as WebGL2RenderingContext; const sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0); @@ -240,14 +240,12 @@ export class GPGPUContext { query = sync; } else if ( - environment().getNumber( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { query = this.beginQuery(); this.endQuery(); isFencePassed = () => this.isQueryAvailable( query, - environment().getNumber( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); + env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); } else { // If we have no way to fence, return true immediately. This will fire in // WebGL 1.0 when there is no disjoint query timer. In this case, because @@ -411,7 +409,7 @@ export class GPGPUContext { this.disjointQueryTimerExtension = webgl_util.getExtensionOrThrow( this.gl, this.debug, - environment().getNumber( + env().getNumber( 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? 'EXT_disjoint_timer_query_webgl2' : 'EXT_disjoint_timer_query') as @@ -430,8 +428,7 @@ export class GPGPUContext { } beginQuery(): WebGLQuery { - if (environment().getNumber( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { const gl2 = this.gl as WebGL2RenderingContext; const ext = this.getQueryTimerExtensionWebGL2(); @@ -446,8 +443,7 @@ export class GPGPUContext { } endQuery() { - if (environment().getNumber( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { const gl2 = this.gl as WebGL2RenderingContext; const ext = this.getQueryTimerExtensionWebGL2(); gl2.endQuery(ext.TIME_ELAPSED_EXT); @@ -464,12 +460,10 @@ export class GPGPUContext { // may poll for the query timer indefinitely this.isQueryAvailable( query, - environment().getNumber( + env().getNumber( 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'))); return this.getQueryTime( - query, - environment().getNumber( - 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); + query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); } private getQueryTime(query: WebGLQuery, queryTimerVersion: number): number { diff --git a/tfjs-core/src/backends/webgl/gpgpu_context_test.ts b/tfjs-core/src/backends/webgl/gpgpu_context_test.ts index 580d097d60f..4e4b0445d1a 100644 --- a/tfjs-core/src/backends/webgl/gpgpu_context_test.ts +++ b/tfjs-core/src/backends/webgl/gpgpu_context_test.ts @@ -37,7 +37,7 @@ describeWithFlags( gpgpu = new GPGPUContext(); // Silences debug warnings. spyOn(console, 'warn'); - tf.environment().set('DEBUG', true); + tf.env().set('DEBUG', true); texture = gpgpu.createFloat32MatrixTexture(1, 1); }); @@ -71,7 +71,7 @@ describeWithFlags( gpgpu = new GPGPUContext(); // Silences debug warnings. spyOn(console, 'warn'); - tf.environment().set('DEBUG', true); + tf.env().set('DEBUG', true); }); afterEach(() => { @@ -109,7 +109,7 @@ describeWithFlags( gpgpu = new GPGPUContext(); // Silences debug warnings. spyOn(console, 'warn'); - tf.environment().set('DEBUG', true); + tf.env().set('DEBUG', true); const glsl = getGlslDifferences(); const src = `${glsl.version} precision highp float; @@ -148,7 +148,7 @@ describeWithFlags('GPGPUContext', DOWNLOAD_FLOAT_ENVS, () => { gpgpu = new GPGPUContext(); // Silences debug warnings. spyOn(console, 'warn'); - tf.environment().set('DEBUG', true); + tf.env().set('DEBUG', true); }); afterEach(() => { diff --git a/tfjs-core/src/backends/webgl/gpgpu_math.ts b/tfjs-core/src/backends/webgl/gpgpu_math.ts index 71cfa384639..a67f4d399d5 100644 --- a/tfjs-core/src/backends/webgl/gpgpu_math.ts +++ b/tfjs-core/src/backends/webgl/gpgpu_math.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from '../../environment'; +import {env} from '../../environment'; import {Tensor} from '../../tensor'; import {TypedArray} from '../../types'; @@ -86,7 +86,7 @@ export function compileProgram( // Add special uniforms (NAN, INFINITY) let infLoc: WebGLUniformLocation = null; const nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false); - if (environment().getNumber('WEBGL_VERSION') === 1) { + if (env().getNumber('WEBGL_VERSION') === 1) { infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false); } @@ -164,7 +164,7 @@ export function runProgram( gpgpu.setProgram(binary.webGLProgram); // Set special uniforms (NAN, INFINITY) - if (environment().getNumber('WEBGL_VERSION') === 1) { + if (env().getNumber('WEBGL_VERSION') === 1) { if (binary.infLoc !== null) { gpgpu.gl.uniform1f(binary.infLoc, Infinity); } diff --git a/tfjs-core/src/backends/webgl/reshape_packed_test.ts b/tfjs-core/src/backends/webgl/reshape_packed_test.ts index 2645461f185..660da0e500e 100644 --- a/tfjs-core/src/backends/webgl/reshape_packed_test.ts +++ b/tfjs-core/src/backends/webgl/reshape_packed_test.ts @@ -65,14 +65,14 @@ describeWithFlags('expensive reshape', PACKED_ENVS, () => { describeWithFlags('expensive reshape with even columns', PACKED_ENVS, () => { it('2 --> 4 columns', async () => { - const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + const maxTextureSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); let values: number[] = new Array(16).fill(0); values = values.map((d, i) => i + 1); const a = tf.tensor2d(values, [8, 2]); const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 2); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', 2); // Setting WEBGL_MAX_TEXTURE_SIZE to 2 makes that [8, 2] tensor is packed // to texture of width 2 by height 2. Indices are packed as: // ------------- @@ -82,14 +82,14 @@ describeWithFlags('expensive reshape with even columns', PACKED_ENVS, () => { // ... const c = tf.matMul(a, b); let cAs4D = c.reshape([2, 1, 2, 4]); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); // Execute non-packed operations to unpack tensor. - const webglPackFlagSaved = tf.environment().getBool('WEBGL_PACK'); - tf.environment().set('WEBGL_PACK', false); + const webglPackFlagSaved = tf.env().getBool('WEBGL_PACK'); + tf.env().set('WEBGL_PACK', false); cAs4D = cAs4D.add(1); cAs4D = cAs4D.add(-1); - tf.environment().set('WEBGL_PACK', webglPackFlagSaved); + tf.env().set('WEBGL_PACK', webglPackFlagSaved); const result = [7, 10, 15, 22, 23, 34, 31, 46, 39, 58, 47, 70, 55, 82, 63, 94]; diff --git a/tfjs-core/src/backends/webgl/tex_util.ts b/tfjs-core/src/backends/webgl/tex_util.ts index e9a43e40607..08f02153ab5 100644 --- a/tfjs-core/src/backends/webgl/tex_util.ts +++ b/tfjs-core/src/backends/webgl/tex_util.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from '../../environment'; +import {env} from '../../environment'; import {DataId, Tensor} from '../../tensor'; import {BackendValues, DataType} from '../../types'; @@ -160,7 +160,7 @@ export function getTextureConfig( let textureTypeHalfFloat: number; let textureTypeFloat: number; - if (environment().getNumber('WEBGL_VERSION') === 2) { + if (env().getNumber('WEBGL_VERSION') === 2) { internalFormatFloat = glany.R32F; internalFormatHalfFloat = glany.R16F; internalFormatPackedHalfFloat = glany.RGBA16F; diff --git a/tfjs-core/src/backends/webgl/texture_manager.ts b/tfjs-core/src/backends/webgl/texture_manager.ts index 2512957fb27..e2377a955ba 100644 --- a/tfjs-core/src/backends/webgl/texture_manager.ts +++ b/tfjs-core/src/backends/webgl/texture_manager.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from '../../environment'; +import {env} from '../../environment'; import {GPGPUContext} from './gpgpu_context'; import {PhysicalTextureType, TextureUsage} from './tex_util'; @@ -145,7 +145,7 @@ export class TextureManager { function getPhysicalTextureForRendering(isPacked: boolean): PhysicalTextureType { - if (environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) { + if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) { if (isPacked) { return PhysicalTextureType.PACKED_2X2_FLOAT32; } diff --git a/tfjs-core/src/backends/webgl/webgl_batchnorm_test.ts b/tfjs-core/src/backends/webgl/webgl_batchnorm_test.ts index 078999b46d5..e49eaa2a23d 100644 --- a/tfjs-core/src/backends/webgl/webgl_batchnorm_test.ts +++ b/tfjs-core/src/backends/webgl/webgl_batchnorm_test.ts @@ -32,8 +32,8 @@ describeWithFlags('batchNorm', WEBGL_ENVS, () => { }); it('should work when squarification results in zero padding', async () => { - const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 5); + const maxTextureSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', 5); const x = tf.tensor3d( [ @@ -52,7 +52,7 @@ describeWithFlags('batchNorm', WEBGL_ENVS, () => { const result = tf.batchNorm3d(x, mean, variance, offset, scale, varianceEpsilon); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); expectArraysClose(await result.data(), [ 0.59352049, -0.66135202, 0.5610874, -0.92077015, -1.45341019, 1.52106473, diff --git a/tfjs-core/src/backends/webgl/webgl_ops_test.ts b/tfjs-core/src/backends/webgl/webgl_ops_test.ts index 673477dd0c2..399ce2b0e9c 100644 --- a/tfjs-core/src/backends/webgl/webgl_ops_test.ts +++ b/tfjs-core/src/backends/webgl/webgl_ops_test.ts @@ -281,15 +281,15 @@ describeWithFlags('depthToSpace', WEBGL_ENVS, () => { describeWithFlags('maximum', WEBGL_ENVS, () => { it('works with squarification for large dimension', async () => { - const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 5); + const maxTextureSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', 5); const a = tf.tensor2d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [2, 7]); const b = tf.tensor2d([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 7]); const result = tf.maximum(a, b); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); expectArraysClose( await result.data(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]); }); @@ -334,19 +334,19 @@ describeWithFlags('conv2d webgl', WEBGL_ENVS, () => { const x = tf.tensor3d([1, 2, 3, 4], inputShape); const w = tf.tensor4d([1, 2, 3, 4], [fSize, fSize, 2, 2]); - const webglLazilyUnpackFlagSaved = tf.environment().getBool('WEBGL_LAZILY_UNPACK'); - tf.environment().set('WEBGL_LAZILY_UNPACK', true); + const webglLazilyUnpackFlagSaved = tf.env().getBool('WEBGL_LAZILY_UNPACK'); + tf.env().set('WEBGL_LAZILY_UNPACK', true); const webglPackBinaryOperationsFlagSaved = - tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', true); + tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.env().set('WEBGL_PACK_BINARY_OPERATIONS', true); // First conv2D tests conv2D with non-packed input |x|, and the second uses // packed input |result|. const result = tf.conv2d(x, w, stride, pad); const result1 = tf.conv2d(result, w, stride, pad); - tf.environment().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); - tf.environment().set( + tf.env().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); + tf.env().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); expectArraysClose(await result.data(), [7, 10, 15, 22]); @@ -363,17 +363,18 @@ describeWithFlags('conv2d webgl', WEBGL_ENVS, () => { const xInit = tf.tensor4d([0, 1], inputShape); const w = tf.tensor4d([1, 2, 3, 4], [fSize, fSize, 2, 2]); - const webglLazilyUnpackFlagSaved = tf.environment().getBool('WEBGL_LAZILY_UNPACK'); - tf.environment().set('WEBGL_LAZILY_UNPACK', true); + const webglLazilyUnpackFlagSaved = + tf.env().getBool('WEBGL_LAZILY_UNPACK'); + tf.env().set('WEBGL_LAZILY_UNPACK', true); const webglPackBinaryOperationsFlagSaved = - tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', true); + tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.env().set('WEBGL_PACK_BINARY_OPERATIONS', true); const x = xInit.add(1); const result = tf.conv2d(x, w, stride, pad); - tf.environment().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); - tf.environment().set( + tf.env().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); + tf.env().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); expectArraysClose(await result.data(), [7, 10]); @@ -504,28 +505,28 @@ describeWithFlags('matmul', PACKED_ENVS, () => { }); it('should work when input texture shapes != physical shape', async () => { - const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 5); + const maxTextureSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', 5); const a = tf.tensor2d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [1, 12]); const b = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [12, 1]); const c = tf.matMul(a, b); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); expectArraysClose(await c.data(), [572]); }); it('should work when squarification results in zero padding', async () => { - const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 3); + const maxTextureSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', 3); const a = tf.tensor2d([1, 2], [1, 2]); const b = tf.tensor2d( [[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16, 17]]); const c = tf.matMul(a, b); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); expectArraysClose(await c.data(), [18, 21, 24, 27, 30, 33, 36, 39, 42]); }); @@ -582,10 +583,11 @@ describeWithFlags('matmul', PACKED_ENVS, () => { const c = tf.matMul(a, b); - const webglPackBinarySaved = tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', false); + const webglPackBinarySaved = + tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.env().set('WEBGL_PACK_BINARY_OPERATIONS', false); const d = tf.add(c, 1); - tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', webglPackBinarySaved); + tf.env().set('WEBGL_PACK_BINARY_OPERATIONS', webglPackBinarySaved); expectArraysClose(await d.data(), [1, 9, -2, 21]); }); @@ -600,10 +602,10 @@ describeWithFlags('matmul', PACKED_ENVS, () => { const d = tf.reshape(c, [1, 3, 3, 1]); const webglPackBinarySaved = - tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', false); + tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.env().set('WEBGL_PACK_BINARY_OPERATIONS', false); const e = tf.add(d, 1); - tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', webglPackBinarySaved); + tf.env().set('WEBGL_PACK_BINARY_OPERATIONS', webglPackBinarySaved); expectArraysClose(await e.data(), [2, 3, 4, 5, 6, 7, 8, 9, 10]); }); @@ -621,16 +623,16 @@ describeWithFlags('matmul', PACKED_ENVS, () => { describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { it('argmax 3D, odd number of rows, axis = -1', async () => { - const webglLazilyUnpackFlagSaved = tf.environment().getBool('WEBGL_LAZILY_UNPACK'); - tf.environment().set('WEBGL_LAZILY_UNPACK', true); + const webglLazilyUnpackFlagSaved = tf.env().getBool('WEBGL_LAZILY_UNPACK'); + tf.env().set('WEBGL_LAZILY_UNPACK', true); const webglPackBinaryOperationsFlagSaved = - tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', true); + tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.env().set('WEBGL_PACK_BINARY_OPERATIONS', true); const a = tf.tensor3d([3, 2, 5, 100, -7, 2], [2, 1, 3]).add(1); const r = tf.argMax(a, -1); - tf.environment().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); - tf.environment().set( + tf.env().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); + tf.env().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); expect(r.dtype).toBe('int32'); @@ -638,11 +640,11 @@ describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { }); it('argmin 4D, odd number of rows, axis = -1', async () => { - const webglLazilyUnpackFlagSaved = tf.environment().getBool('WEBGL_LAZILY_UNPACK'); - tf.environment().set('WEBGL_LAZILY_UNPACK', true); + const webglLazilyUnpackFlagSaved = tf.env().getBool('WEBGL_LAZILY_UNPACK'); + tf.env().set('WEBGL_LAZILY_UNPACK', true); const webglPackBinaryOperationsFlagSaved = - tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', true); + tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.env().set('WEBGL_PACK_BINARY_OPERATIONS', true); const a = tf.tensor4d( @@ -650,8 +652,8 @@ describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { [1, 2, 3, 3]) .add(1); const r = tf.argMin(a, -1); - tf.environment().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); - tf.environment().set( + tf.env().set('WEBGL_LAZILY_UNPACK', webglLazilyUnpackFlagSaved); + tf.env().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); expect(r.dtype).toBe('int32'); @@ -660,8 +662,8 @@ describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { it('should not leak memory when called after unpacked op', async () => { const webglPackBinaryOperationsFlagSaved = - tf.environment().getBool('WEBGL_PACK_BINARY_OPERATIONS'); - tf.environment().set('WEBGL_PACK_BINARY_OPERATIONS', false); + tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + tf.env().set('WEBGL_PACK_BINARY_OPERATIONS', false); const a = tf.tensor5d( @@ -671,7 +673,7 @@ describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { const startNumBytes = tf.memory().numBytes; const startNumTensors = tf.memory().numTensors; const r = tf.argMin(a, -1); - tf.environment().set( + tf.env().set( 'WEBGL_PACK_BINARY_OPERATIONS', webglPackBinaryOperationsFlagSaved); const endNumBytes = tf.memory().numBytes; const endNumTensors = tf.memory().numTensors; @@ -684,8 +686,8 @@ describeWithFlags('Reduction: webgl packed input', WEBGL_ENVS, () => { describeWithFlags('slice and memory usage', WEBGL_ENVS, () => { beforeAll(() => { - tf.environment().set('WEBGL_CPU_FORWARD', false); - tf.environment().set('WEBGL_SIZE_UPLOAD_UNIFORM', 0); + tf.env().set('WEBGL_CPU_FORWARD', false); + tf.env().set('WEBGL_SIZE_UPLOAD_UNIFORM', 0); }); it('slice a tensor, read it and check memory', async () => { @@ -722,7 +724,7 @@ describeWithFlags('slice and memory usage', WEBGL_ENVS, () => { describeWithFlags('slice a packed texture', WEBGL_ENVS, () => { beforeAll(() => { - tf.environment().set('WEBGL_PACK', true); + tf.env().set('WEBGL_PACK', true); }); it('slice after a matmul', async () => { @@ -741,12 +743,12 @@ describeWithFlags('slice a packed texture', WEBGL_ENVS, () => { describeWithFlags('relu', WEBGL_ENVS, () => { it('works with squarification for prime number length vector', async () => { - const maxTextureSize = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 5); + const maxTextureSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', 5); const a = tf.tensor1d([1, -2, 5, -3, -1, 4, 7]); const result = tf.relu(a); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', maxTextureSize); expectArraysClose(await result.data(), [1, 0, 5, 0, 0, 4, 7]); }); }); diff --git a/tfjs-core/src/backends/webgl/webgl_util.ts b/tfjs-core/src/backends/webgl/webgl_util.ts index 2c51b052ea9..ad3b11acca0 100644 --- a/tfjs-core/src/backends/webgl/webgl_util.ts +++ b/tfjs-core/src/backends/webgl/webgl_util.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from '../../environment'; +import {env} from '../../environment'; import * as util from '../../util'; @@ -43,7 +43,7 @@ const MIN_FLOAT16 = 5.96e-8; const MAX_FLOAT16 = 65504; export function canBeRepresented(num: number): boolean { - if (environment().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || + if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) { return true; } @@ -193,7 +193,7 @@ export function createStaticIndexBuffer( } export function getNumChannels(): number { - if (environment().getNumber('WEBGL_VERSION') === 2) { + if (env().getNumber('WEBGL_VERSION') === 2) { return 1; } return 4; @@ -206,7 +206,7 @@ export function createTexture( } export function validateTextureSize(width: number, height: number) { - const maxTextureSize = environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + const maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); if ((width <= 0) || (height <= 0)) { const requested = `[${width}x${height}]`; throw new Error('Requested texture size ' + requested + ' is invalid.'); @@ -385,7 +385,7 @@ export function getShapeAs3D(shape: number[]): [number, number, number] { export function getTextureShapeFromLogicalShape( logShape: number[], isPacked = false): [number, number] { - let maxTexSize = environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + let maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); if (isPacked) { maxTexSize = maxTexSize * 2; diff --git a/tfjs-core/src/backends/webgl/webgl_util_test.ts b/tfjs-core/src/backends/webgl/webgl_util_test.ts index 893245b0689..9d0e8dd1c5f 100644 --- a/tfjs-core/src/backends/webgl/webgl_util_test.ts +++ b/tfjs-core/src/backends/webgl/webgl_util_test.ts @@ -85,8 +85,7 @@ describeWithFlags('getTextureShapeFromLogicalShape packed', WEBGL_ENVS, () => { const isPacked = true; const logicalShape = [ 2, - util.nearestLargerEven( - tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE') + 1) + util.nearestLargerEven(tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE') + 1) ]; const texShape = webgl_util.getTextureShapeFromLogicalShape(logicalShape, isPacked); @@ -103,14 +102,14 @@ describeWithFlags('getTextureShapeFromLogicalShape packed', WEBGL_ENVS, () => { it('squarified texture shapes account for packing constraints', () => { const isPacked = true; - const max = tf.environment().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + const max = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', 5); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', 5); const logicalShape = [1, 12]; const texShape = webgl_util.getTextureShapeFromLogicalShape(logicalShape, isPacked); - tf.environment().set('WEBGL_MAX_TEXTURE_SIZE', max); + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', max); expect(texShape).toEqual([6, 4]); }); }); diff --git a/tfjs-core/src/debug_mode_test.ts b/tfjs-core/src/debug_mode_test.ts index 5b9fdf6ea06..970c7680777 100644 --- a/tfjs-core/src/debug_mode_test.ts +++ b/tfjs-core/src/debug_mode_test.ts @@ -22,11 +22,11 @@ import {expectArraysClose} from './test_util'; describeWithFlags('debug on', SYNC_BACKEND_ENVS, () => { beforeAll(() => { - tf.environment().set('DEBUG', true); + tf.env().set('DEBUG', true); }); afterAll(() => { - tf.environment().set('DEBUG', false); + tf.env().set('DEBUG', false); }); it('debug mode does not error when no nans', async () => { @@ -117,7 +117,7 @@ describeWithFlags('debug on', SYNC_BACKEND_ENVS, () => { describeWithFlags('debug off', ALL_ENVS, () => { beforeAll(() => { - tf.environment().set('DEBUG', false); + tf.env().set('DEBUG', false); }); it('no errors where there are nans, and debug mode is disabled', async () => { diff --git a/tfjs-core/src/environment.ts b/tfjs-core/src/environment.ts index a926599519f..abd8f993877 100644 --- a/tfjs-core/src/environment.ts +++ b/tfjs-core/src/environment.ts @@ -29,6 +29,12 @@ export type FlagRegistryEntry = { setHook?: (value: FlagValue) => void; }; +/** + * The environment contains evaluated flags as well as the registered platform. + * This is always used as a global singleton and can be retrieved with + * `tf.env()`. + */ +/** @doc {heading: 'Environment'} */ export class Environment { private flags: Flags = {}; private flagRegistry: {[flagName: string]: FlagRegistryEntry} = {}; @@ -166,7 +172,14 @@ function parseValue(flagName: string, value: string): FlagValue { `Could not parse value flag value ${value} for flag ${flagName}.`); } -export function environment() { +/** + * Returns the current environment (a global singleton). + * + * The environment object contains the evaluated feature values as well as the + * active platform. + */ +/** @doc {heading: 'Environment'} */ +export function env() { return ENV; } diff --git a/tfjs-core/src/flags.ts b/tfjs-core/src/flags.ts index 99af479a6ee..47fa9f0ada7 100644 --- a/tfjs-core/src/flags.ts +++ b/tfjs-core/src/flags.ts @@ -15,9 +15,9 @@ * ============================================================================= */ import * as device_util from './device_util'; -import {environment} from './environment'; +import {env} from './environment'; -const ENV = environment(); +const ENV = env(); /** * This file contains environment-related flag registrations. diff --git a/tfjs-core/src/flags_test.ts b/tfjs-core/src/flags_test.ts index 47964f13929..bd494fde400 100644 --- a/tfjs-core/src/flags_test.ts +++ b/tfjs-core/src/flags_test.ts @@ -20,24 +20,24 @@ import * as tf from './index'; describe('DEBUG', () => { beforeEach(() => { - tf.environment().reset(); + tf.env().reset(); spyOn(console, 'warn').and.callFake((msg: string) => {}); }); - afterAll(() => tf.environment().reset()); + afterAll(() => tf.env().reset()); it('disabled by default', () => { - expect(tf.environment().getBool('DEBUG')).toBe(false); + expect(tf.env().getBool('DEBUG')).toBe(false); }); it('warns when enabled', () => { const consoleWarnSpy = console.warn as jasmine.Spy; - tf.environment().set('DEBUG', true); + tf.env().set('DEBUG', true); expect(consoleWarnSpy.calls.count()).toBe(1); expect((consoleWarnSpy.calls.first().args[0] as string) .startsWith('Debugging mode is ON. ')) .toBe(true); - expect(tf.environment().getBool('DEBUG')).toBe(true); + expect(tf.env().getBool('DEBUG')).toBe(true); expect(consoleWarnSpy.calls.count()).toBe(1); }); }); @@ -45,62 +45,60 @@ describe('DEBUG', () => { describe('IS_BROWSER', () => { let isBrowser: boolean; beforeEach(() => { - tf.environment().reset(); + tf.env().reset(); spyOn(device_util, 'isBrowser').and.callFake(() => isBrowser); }); - afterAll(() => tf.environment().reset()); + afterAll(() => tf.env().reset()); it('isBrowser: true', () => { isBrowser = true; - expect(tf.environment().getBool('IS_BROWSER')).toBe(true); + expect(tf.env().getBool('IS_BROWSER')).toBe(true); }); it('isBrowser: false', () => { isBrowser = false; - expect(tf.environment().getBool('IS_BROWSER')).toBe(false); + expect(tf.env().getBool('IS_BROWSER')).toBe(false); }); }); describe('PROD', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('disabled by default', () => { - expect(tf.environment().getBool('PROD')).toBe(false); + expect(tf.env().getBool('PROD')).toBe(false); }); }); describe('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('disabled when debug is disabled', () => { - tf.environment().set('DEBUG', false); - expect(tf.environment().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) - .toBe(false); + tf.env().set('DEBUG', false); + expect(tf.env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')).toBe(false); }); it('enabled when debug is enabled', () => { - tf.environment().set('DEBUG', true); - expect(tf.environment().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) - .toBe(true); + tf.env().set('DEBUG', true); + expect(tf.env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')).toBe(true); }); }); describe('DEPRECATION_WARNINGS_ENABLED', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('enabled by default', () => { - expect(tf.environment().getBool('DEPRECATION_WARNINGS_ENABLED')).toBe(true); + expect(tf.env().getBool('DEPRECATION_WARNINGS_ENABLED')).toBe(true); }); }); describe('IS_TEST', () => { - beforeEach(() => tf.environment().reset()); - afterAll(() => tf.environment().reset()); + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); it('disabled by default', () => { - expect(tf.environment().getBool('IS_TEST')).toBe(false); + expect(tf.env().getBool('IS_TEST')).toBe(false); }); }); diff --git a/tfjs-core/src/globals.ts b/tfjs-core/src/globals.ts index 207b6c9c173..352d891e52c 100644 --- a/tfjs-core/src/globals.ts +++ b/tfjs-core/src/globals.ts @@ -17,7 +17,7 @@ import {KernelBackend} from './backends/backend'; import {ENGINE, Engine, MemoryInfo, ProfileInfo, ScopeFn, TimingInfo} from './engine'; -import {environment} from './environment'; +import {env} from './environment'; import {Platform} from './platforms/platform'; import {setDeprecationWarningFn, Tensor} from './tensor'; @@ -30,7 +30,7 @@ import {getTensorsInContainer} from './tensor_util'; */ /** @doc {heading: 'Environment'} */ export function enableProdMode(): void { - environment().set('PROD', true); + env().set('PROD', true); } /** @@ -47,18 +47,18 @@ export function enableProdMode(): void { */ /** @doc {heading: 'Environment'} */ export function enableDebugMode(): void { - environment().set('DEBUG', true); + env().set('DEBUG', true); } /** Globally disables deprecation warnings */ export function disableDeprecationWarnings(): void { - environment().set('DEPRECATION_WARNINGS_ENABLED', false); + env().set('DEPRECATION_WARNINGS_ENABLED', false); console.warn(`TensorFlow.js deprecation warnings have been disabled.`); } /** Warn users about deprecated functionality. */ export function deprecationWarn(msg: string) { - if (environment().getBool('DEPRECATION_WARNINGS_ENABLED')) { + if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) { console.warn( msg + ' You can disable deprecation warnings with ' + 'tf.disableDeprecationWarnings().'); @@ -355,5 +355,5 @@ export function backend(): KernelBackend { * @param platform A platform implementation. */ export function setPlatform(platformName: string, platform: Platform) { - environment().setPlatform(platformName, platform); + env().setPlatform(platformName, platform); } diff --git a/tfjs-core/src/globals_test.ts b/tfjs-core/src/globals_test.ts index dfa67ff9b69..53b8770b3dd 100644 --- a/tfjs-core/src/globals_test.ts +++ b/tfjs-core/src/globals_test.ts @@ -48,17 +48,17 @@ describe('deprecation warnings', () => { describe('Flag flipping methods', () => { beforeEach(() => { - tf.environment().reset(); + tf.env().reset(); }); it('tf.enableProdMode', () => { tf.enableProdMode(); - expect(tf.environment().getBool('PROD')).toBe(true); + expect(tf.env().getBool('PROD')).toBe(true); }); it('tf.enableDebugMode', () => { tf.enableDebugMode(); - expect(tf.environment().getBool('DEBUG')).toBe(true); + expect(tf.env().getBool('DEBUG')).toBe(true); }); }); diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index c39af44d3f5..2aadfaca3eb 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -65,7 +65,7 @@ export * from './globals'; export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients'; export {TimingInfo, MemoryInfo} from './engine'; -export {Environment, environment, ENV} from './environment'; +export {Environment, env, ENV} from './environment'; export {Platform} from './platforms/platform'; export {version as version_core}; diff --git a/tfjs-core/src/io/browser_files.ts b/tfjs-core/src/io/browser_files.ts index a5d6dbcd383..9c3facb1da5 100644 --- a/tfjs-core/src/io/browser_files.ts +++ b/tfjs-core/src/io/browser_files.ts @@ -20,7 +20,7 @@ * user-selected files in browser. */ -import {environment} from '../environment'; +import {env} from '../environment'; import {basename, concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; @@ -43,7 +43,7 @@ export class BrowserDownloads implements IOHandler { static readonly URL_SCHEME = 'downloads://'; constructor(fileNamePrefix?: string) { - if (!environment().getBool('IS_BROWSER')) { + if (!env().getBool('IS_BROWSER')) { // TODO(cais): Provide info on what IOHandlers are available under the // current environment. throw new Error( @@ -245,7 +245,7 @@ class BrowserFiles implements IOHandler { } export const browserDownloadsRouter: IORouter = (url: string|string[]) => { - if (!environment().getBool('IS_BROWSER')) { + if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) { diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index 8266d5650e2..5b5b3646c1d 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -21,7 +21,7 @@ * Uses [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). */ -import {environment} from '../environment'; +import {env} from '../environment'; import {assert} from '../util'; import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils'; @@ -59,7 +59,7 @@ export class HTTPRequest implements IOHandler { 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)'); this.fetch = loadOptions.fetchFunc; } else { - this.fetch = environment().platform.fetch; + this.fetch = env().platform.fetch; } assert( diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index 6d7c051896e..137e98d2628 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -81,7 +81,7 @@ const setupFakeWeightFiles = } }, requestInits: {[key: string]: RequestInit}) => { - fetchSpy = spyOn(tf.environment().platform, 'fetch') + fetchSpy = spyOn(tf.env().platform, 'fetch') .and.callFake((path: string, init: RequestInit) => { if (fileBufferMap[path]) { requestInits[path] = init; @@ -191,7 +191,7 @@ describeWithFlags('http-save', CHROME_ENVS, () => { beforeEach(() => { requestInits = []; - spyOn(tf.environment().platform, 'fetch') + spyOn(tf.env().platform, 'fetch') .and.callFake((path: string, init: RequestInit) => { if (path === 'model-upload-test' || path === 'http://model-upload-test') { diff --git a/tfjs-core/src/io/indexed_db.ts b/tfjs-core/src/io/indexed_db.ts index 6e4936931c5..2ce5a8ce0f8 100644 --- a/tfjs-core/src/io/indexed_db.ts +++ b/tfjs-core/src/io/indexed_db.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from '../environment'; +import {env} from '../environment'; import {getModelArtifactsInfoForJSON} from './io_utils'; import {ModelStoreManagerRegistry} from './model_management'; @@ -48,7 +48,7 @@ export async function deleteDatabase(): Promise { } function getIndexedDBFactory(): IDBFactory { - if (!environment().getBool('IS_BROWSER')) { + if (!env().getBool('IS_BROWSER')) { // TODO(cais): Add more info about what IOHandler subtypes are available. // Maybe point to a doc page on the web and/or automatically determine // the available IOHandlers and print them in the error message. @@ -208,7 +208,7 @@ export class BrowserIndexedDB implements IOHandler { } export const indexedDBRouter: IORouter = (url: string|string[]) => { - if (!environment().getBool('IS_BROWSER')) { + if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) { @@ -352,7 +352,7 @@ export class BrowserIndexedDBManager implements ModelStoreManager { } } -if (environment().getBool('IS_BROWSER')) { +if (env().getBool('IS_BROWSER')) { // Wrap the construction and registration, to guard against browsers that // don't support Local Storage. try { diff --git a/tfjs-core/src/io/local_storage.ts b/tfjs-core/src/io/local_storage.ts index fb4ce766f18..3cbfa809e03 100644 --- a/tfjs-core/src/io/local_storage.ts +++ b/tfjs-core/src/io/local_storage.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from '../environment'; +import {env} from '../environment'; import {assert} from '../util'; import {arrayBufferToBase64String, base64StringToArrayBuffer, getModelArtifactsInfoForJSON} from './io_utils'; @@ -37,7 +37,7 @@ const MODEL_METADATA_SUFFIX = 'model_metadata'; * @returns Paths of the models purged. */ export function purgeLocalStorageArtifacts(): string[] { - if (!environment().getBool('IS_BROWSER') || + if (!env().getBool('IS_BROWSER') || typeof window.localStorage === 'undefined') { throw new Error( 'purgeLocalStorageModels() cannot proceed because local storage is ' + @@ -118,7 +118,7 @@ export class BrowserLocalStorage implements IOHandler { static readonly URL_SCHEME = 'localstorage://'; constructor(modelPath: string) { - if (!environment().getBool('IS_BROWSER') || + if (!env().getBool('IS_BROWSER') || typeof window.localStorage === 'undefined') { // TODO(cais): Add more info about what IOHandler subtypes are // available. @@ -256,7 +256,7 @@ export class BrowserLocalStorage implements IOHandler { } export const localStorageRouter: IORouter = (url: string|string[]) => { - if (!environment().getBool('IS_BROWSER')) { + if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { @@ -303,7 +303,7 @@ export class BrowserLocalStorageManager implements ModelStoreManager { constructor() { assert( - environment().getBool('IS_BROWSER'), + env().getBool('IS_BROWSER'), () => 'Current environment is not a web browser'); assert( typeof window.localStorage !== 'undefined', @@ -341,7 +341,7 @@ export class BrowserLocalStorageManager implements ModelStoreManager { } } -if (environment().getBool('IS_BROWSER')) { +if (env().getBool('IS_BROWSER')) { // Wrap the construction and registration, to guard against browsers that // don't support Local Storage. try { diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 51c61cab883..8b0f923fd8a 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from '../environment'; +import {env} from '../environment'; import {NamedTensorMap} from '../tensor_types'; import * as util from '../util'; @@ -40,9 +40,8 @@ export async function loadWeightsAsArrayBuffer( loadOptions = {}; } - const fetchFunc = loadOptions.fetchFunc == null ? - environment().platform.fetch : - loadOptions.fetchFunc; + const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : + loadOptions.fetchFunc; // Create the requests for all of the weights in parallel. const requests = fetchURLs.map( diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index ed268377f2e..a4394adcdb2 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -24,7 +24,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { [filename: string]: Float32Array|Int32Array|ArrayBuffer|Uint8Array| Uint16Array }) => { - spyOn(tf.environment().platform, 'fetch').and.callFake((path: string) => { + spyOn(tf.env().platform, 'fetch').and.callFake((path: string) => { return new Response( fileBufferMap[path], {headers: {'Content-type': 'application/octet-stream'}}); @@ -42,8 +42,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weightsNamesToFetch = ['weight0']; const weights = await tf.io.loadWeights(manifest, './', weightsNamesToFetch); - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(1); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(weightsNamesToFetch.length); @@ -67,8 +66,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { // Load the first weight. const weights = await tf.io.loadWeights(manifest, './', ['weight0']); - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(1); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(1); @@ -92,8 +90,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { // Load the second weight. const weights = await tf.io.loadWeights(manifest, './', ['weight1']); - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(1); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(1); @@ -118,8 +115,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { // Load all weights. const weights = await tf.io.loadWeights(manifest, './', ['weight0', 'weight1']); - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(1); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(2); @@ -158,8 +154,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { // Load all weights. const weights = await tf.io.loadWeights( manifest, './', ['weight0', 'weight1', 'weight2']); - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(1); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(3); @@ -197,8 +192,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { }]; const weights = await tf.io.loadWeights(manifest, './', ['weight0']); - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(3); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(3); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(1); @@ -238,8 +232,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weights = await tf.io.loadWeights(manifest, './', ['weight0', 'weight1']); - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(3); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(3); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(2); @@ -281,8 +274,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weights = await tf.io.loadWeights(manifest, './', ['weight0', 'weight1']); // Only the first group should be fetched. - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(1); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(2); @@ -324,8 +316,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weights = await tf.io.loadWeights(manifest, './', ['weight0', 'weight2']); // Both groups need to be fetched. - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(2); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(2); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(2); @@ -367,8 +358,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { // Don't pass a third argument to loadWeights to load all weights. const weights = await tf.io.loadWeights(manifest, './'); // Both groups need to be fetched. - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(2); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(2); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(4); @@ -444,9 +434,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weightsNamesToFetch = ['weight0']; await tf.io.loadWeights( manifest, './', weightsNamesToFetch, {credentials: 'include'}); - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(1); - expect(tf.environment().platform.fetch) + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(1); + expect(tf.env().platform.fetch) .toHaveBeenCalledWith( './weightfile0', {credentials: 'include'}, {isBinary: true}); }); @@ -477,8 +466,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weightsNamesToFetch = ['weight0', 'weight1']; const weights = await tf.io.loadWeights(manifest, './', weightsNamesToFetch); - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(1); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(1); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(weightsNamesToFetch.length); @@ -538,8 +526,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { const weights = await tf.io.loadWeights(manifest, './', ['weight0', 'weight2']); // Both groups need to be fetched. - expect((tf.environment().platform.fetch as jasmine.Spy).calls.count()) - .toBe(2); + expect((tf.env().platform.fetch as jasmine.Spy).calls.count()).toBe(2); const weightNames = Object.keys(weights); expect(weightNames.length).toEqual(2); diff --git a/tfjs-core/src/jasmine_util.ts b/tfjs-core/src/jasmine_util.ts index 9264fe8b6ec..ce3f9b8c143 100644 --- a/tfjs-core/src/jasmine_util.ts +++ b/tfjs-core/src/jasmine_util.ts @@ -16,7 +16,7 @@ */ import {KernelBackend} from './backends/backend'; import {ENGINE} from './engine'; -import {Environment, environment, Flags} from './environment'; +import {env, Environment, Flags} from './environment'; Error.stackTraceLimit = Infinity; @@ -26,13 +26,13 @@ export type Constraints = { }; export const NODE_ENVS: Constraints = { - predicate: () => environment().platformName === 'node' + predicate: () => env().platformName === 'node' }; export const CHROME_ENVS: Constraints = { flags: {'IS_CHROME': true} }; export const BROWSER_ENVS: Constraints = { - predicate: () => environment().platformName === 'browser' + predicate: () => env().platformName === 'browser' }; export const SYNC_BACKEND_ENVS: Constraints = { @@ -131,8 +131,8 @@ export function describeWithFlags( } TEST_ENVS.forEach(testEnv => { - environment().setFlags(testEnv.flags); - if (envSatisfiesConstraints(environment(), testEnv, constraints)) { + env().setFlags(testEnv.flags); + if (envSatisfiesConstraints(env(), testEnv, constraints)) { const testName = name + ' ' + testEnv.name + ' ' + JSON.stringify(testEnv.flags); executeTests(testName, tests, testEnv); @@ -174,9 +174,9 @@ function executeTests( beforeAll(async () => { ENGINE.reset(); if (testEnv.flags != null) { - environment().setFlags(testEnv.flags); + env().setFlags(testEnv.flags); } - environment().set('IS_TEST', true); + env().set('IS_TEST', true); // Await setting the new backend since it can have async init. await ENGINE.setBackend(testEnv.backendName); }); diff --git a/tfjs-core/src/log.ts b/tfjs-core/src/log.ts index 4621a77828e..609fbdbcc0d 100644 --- a/tfjs-core/src/log.ts +++ b/tfjs-core/src/log.ts @@ -15,16 +15,16 @@ * ============================================================================= */ -import {environment} from './environment'; +import {env} from './environment'; export function warn(...msg: Array<{}>): void { - if (!environment().getBool('IS_TEST')) { + if (!env().getBool('IS_TEST')) { console.warn(...msg); } } export function log(...msg: Array<{}>): void { - if (!environment().getBool('IS_TEST')) { + if (!env().getBool('IS_TEST')) { console.log(...msg); } } diff --git a/tfjs-core/src/ops/slice_test.ts b/tfjs-core/src/ops/slice_test.ts index cdb8e5248b0..344ead78073 100644 --- a/tfjs-core/src/ops/slice_test.ts +++ b/tfjs-core/src/ops/slice_test.ts @@ -524,7 +524,7 @@ describeWithFlags('slice ergonomics', ALL_ENVS, () => { describeWithFlags('shallow slicing', ALL_ENVS, () => { beforeAll(() => { - tf.environment().set('WEBGL_CPU_FORWARD', false); + tf.env().set('WEBGL_CPU_FORWARD', false); }); it('shallow slice an input that was cast', async () => { diff --git a/tfjs-core/src/ops/tensor_ops.ts b/tfjs-core/src/ops/tensor_ops.ts index b0c8f3504ba..f52b6f066f1 100644 --- a/tfjs-core/src/ops/tensor_ops.ts +++ b/tfjs-core/src/ops/tensor_ops.ts @@ -16,7 +16,7 @@ */ import {ENGINE} from '../engine'; -import {environment} from '../environment'; +import {env} from '../environment'; import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D} from '../tensor'; import {convertToTensor, inferShape} from '../tensor_util_env'; @@ -108,7 +108,7 @@ function makeTensor( shape = shape || inferredShape; values = dtype !== 'string' ? - toTypedArray(values, dtype, environment().getBool('DEBUG')) : + toTypedArray(values, dtype, env().getBool('DEBUG')) : flatten(values as string[], [], true) as string[]; return Tensor.make(shape, {values: values as TypedArray}, dtype); } diff --git a/tfjs-core/src/platforms/platform_browser.ts b/tfjs-core/src/platforms/platform_browser.ts index 7b70f1c252d..3f08e8f68b7 100644 --- a/tfjs-core/src/platforms/platform_browser.ts +++ b/tfjs-core/src/platforms/platform_browser.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from '../environment'; +import {env} from '../environment'; import {Platform} from './platform'; @@ -47,6 +47,6 @@ export class PlatformBrowser implements Platform { } } -if (environment().get('IS_BROWSER')) { - environment().setPlatform('browser', new PlatformBrowser()); +if (env().get('IS_BROWSER')) { + env().setPlatform('browser', new PlatformBrowser()); } diff --git a/tfjs-core/src/platforms/platform_node.ts b/tfjs-core/src/platforms/platform_node.ts index 0e3da418b32..8581af80aa6 100644 --- a/tfjs-core/src/platforms/platform_node.ts +++ b/tfjs-core/src/platforms/platform_node.ts @@ -14,7 +14,7 @@ * limitations under the License. * ============================================================================= */ -import {environment} from '../environment'; +import {env} from '../environment'; import {Platform} from './platform'; @@ -52,8 +52,8 @@ export class PlatformNode implements Platform { } fetch(path: string, requestInits?: RequestInit): Promise { - if (environment().global.fetch != null) { - return environment().global.fetch(path, requestInits); + if (env().global.fetch != null) { + return env().global.fetch(path, requestInits); } if (systemFetch == null) { @@ -82,6 +82,6 @@ export class PlatformNode implements Platform { } } -if (environment().get('IS_NODE')) { - environment().setPlatform('node', new PlatformNode()); +if (env().get('IS_NODE')) { + env().setPlatform('node', new PlatformNode()); } diff --git a/tfjs-core/src/platforms/platform_node_test.ts b/tfjs-core/src/platforms/platform_node_test.ts index f81b1e61134..89bcbaf2ed8 100644 --- a/tfjs-core/src/platforms/platform_node_test.ts +++ b/tfjs-core/src/platforms/platform_node_test.ts @@ -22,25 +22,25 @@ import {PlatformNode} from './platform_node'; describeWithFlags('PlatformNode', NODE_ENVS, () => { it('fetch should use global.fetch if defined', async () => { - const globalFetch = tf.environment().global.fetch; + const globalFetch = tf.env().global.fetch; - spyOn(tf.environment().global, 'fetch').and.returnValue(() => {}); + spyOn(tf.env().global, 'fetch').and.returnValue(() => {}); const platform = new PlatformNode(); await platform.fetch('test/url', {method: 'GET'}); - expect(tf.environment().global.fetch).toHaveBeenCalledWith('test/url', { + expect(tf.env().global.fetch).toHaveBeenCalledWith('test/url', { method: 'GET' }); - tf.environment().global.fetch = globalFetch; + tf.env().global.fetch = globalFetch; }); - it('fetch should use node-fetch with tf.environment().global.fetch is null', + it('fetch should use node-fetch with tf.env().global.fetch is null', async () => { - const globalFetch = tf.environment().global.fetch; - tf.environment().global.fetch = null; + const globalFetch = tf.env().global.fetch; + tf.env().global.fetch = null; const platform = new PlatformNode(); @@ -67,14 +67,13 @@ describeWithFlags('PlatformNode', NODE_ENVS, () => { }); platform_node.setSystemFetch(savedFetch); - tf.environment().global.fetch = globalFetch; + tf.env().global.fetch = globalFetch; }); it('now should use process.hrtime', async () => { const time = [100, 200]; spyOn(process, 'hrtime').and.returnValue(time); - expect(tf.environment().platform.now()) - .toEqual(time[0] * 1000 + time[1] / 1000000); + expect(tf.env().platform.now()).toEqual(time[0] * 1000 + time[1] / 1000000); }); it('encodeUTF8 single string', () => { diff --git a/tfjs-core/src/tensor_test.ts b/tfjs-core/src/tensor_test.ts index e034bf3a2c8..974bfdb486f 100644 --- a/tfjs-core/src/tensor_test.ts +++ b/tfjs-core/src/tensor_test.ts @@ -1532,7 +1532,7 @@ describeWithFlags('tensor', ALL_ENVS, () => { describeWithFlags('tensor debug mode', ALL_ENVS, () => { beforeAll(() => { - tf.environment().set('DEBUG', true); + tf.env().set('DEBUG', true); }); it('tf.tensor() from TypedArray + number[] fails due to wrong shape', () => { diff --git a/tfjs-core/src/tensor_util_env.ts b/tfjs-core/src/tensor_util_env.ts index 92f035ec5c4..6367e14c100 100644 --- a/tfjs-core/src/tensor_util_env.ts +++ b/tfjs-core/src/tensor_util_env.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from './environment'; +import {env} from './environment'; import {Tensor} from './tensor'; import {DataType, TensorLike} from './types'; @@ -38,7 +38,7 @@ export function inferShape(val: TensorLike, dtype?: DataType): number[] { firstElem = firstElem[0]; } if (Array.isArray(val) && - environment().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { + env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { deepAssertShapeConsistency(val, shape, []); } @@ -113,8 +113,7 @@ export function convertToTensor( } const skipTypedArray = true; const values = inferredDtype !== 'string' ? - toTypedArray( - x, inferredDtype as DataType, environment().getBool('DEBUG')) : + toTypedArray(x, inferredDtype as DataType, env().getBool('DEBUG')) : flatten(x as string[], [], skipTypedArray) as string[]; return Tensor.make(inferredShape, {values}, inferredDtype); } diff --git a/tfjs-core/src/tensor_util_test.ts b/tfjs-core/src/tensor_util_test.ts index 2a4c5c6f708..98d227cb8af 100644 --- a/tfjs-core/src/tensor_util_test.ts +++ b/tfjs-core/src/tensor_util_test.ts @@ -211,7 +211,7 @@ describeWithFlags('convertToTensor', ALL_ENVS, () => { describeWithFlags('convertToTensor debug mode', ALL_ENVS, () => { beforeAll(() => { - tf.environment().set('DEBUG', true); + tf.env().set('DEBUG', true); }); it('fails to convert a non-valid shape array to tensor', () => { diff --git a/tfjs-core/src/util.ts b/tfjs-core/src/util.ts index a3d6681286a..8edc2013427 100644 --- a/tfjs-core/src/util.ts +++ b/tfjs-core/src/util.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {environment} from './environment'; +import {env} from './environment'; import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray} from './types'; @@ -654,7 +654,7 @@ export function makeZerosTypedArray( */ /** @doc {heading: 'Util', namespace: 'util'} */ export function now(): number { - return environment().platform.now(); + return env().platform.now(); } export function assertNonNegativeIntegerDimensions(shape: number[]) { @@ -684,7 +684,7 @@ export function assertNonNegativeIntegerDimensions(shape: number[]) { /** @doc {heading: 'Util'} */ export function fetch( path: string, requestInits?: RequestInit): Promise { - return environment().platform.fetch(path, requestInits); + return env().platform.fetch(path, requestInits); } /** @@ -697,7 +697,7 @@ export function fetch( /** @doc {heading: 'Util'} */ export function encodeString(s: string, encoding = 'utf-8'): Uint8Array { encoding = encoding || 'utf-8'; - return environment().platform.encode(s, encoding); + return env().platform.encode(s, encoding); } /** @@ -709,5 +709,5 @@ export function encodeString(s: string, encoding = 'utf-8'): Uint8Array { /** @doc {heading: 'Util'} */ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string { encoding = encoding || 'utf-8'; - return environment().platform.decode(bytes, encoding); + return env().platform.decode(bytes, encoding); } diff --git a/tfjs-core/src/util_test.ts b/tfjs-core/src/util_test.ts index 04d0479c11a..72744d5a589 100644 --- a/tfjs-core/src/util_test.ts +++ b/tfjs-core/src/util_test.ts @@ -516,11 +516,11 @@ describeWithFlags('util.toNestedArray', ALL_ENVS, () => { describe('util.fetch', () => { it('should call the platform fetch', () => { - spyOn(tf.environment().platform, 'fetch').and.callFake(() => {}); + spyOn(tf.env().platform, 'fetch').and.callFake(() => {}); util.fetch('test/path', {method: 'GET'}); - expect(tf.environment().platform.fetch).toHaveBeenCalledWith('test/path', { + expect(tf.env().platform.fetch).toHaveBeenCalledWith('test/path', { method: 'GET' }); }); diff --git a/tfjs-core/src/webgl.ts b/tfjs-core/src/webgl.ts index 3981bc87599..0b962007347 100644 --- a/tfjs-core/src/webgl.ts +++ b/tfjs-core/src/webgl.ts @@ -17,7 +17,7 @@ import * as gpgpu_util from './backends/webgl/gpgpu_util'; import * as webgl_util from './backends/webgl/webgl_util'; -import {environment} from './environment'; +import {env} from './environment'; export {MathBackendWebGL, WebGLMemoryInfo, WebGLTimingInfo} from './backends/webgl/backend_webgl'; export {setWebGLContext} from './backends/webgl/canvas_util'; @@ -31,5 +31,5 @@ export {gpgpu_util, webgl_util}; */ /** @doc {heading: 'Environment', namespace: 'webgl'} */ export function forceHalfFloat(): void { - environment().set('WEBGL_FORCE_F16_TEXTURES', true); + env().set('WEBGL_FORCE_F16_TEXTURES', true); }