diff --git a/tfjs-vis/demos/api/index.html b/tfjs-vis/demos/api/index.html index 24bc0483922..45ce84c037e 100644 --- a/tfjs-vis/demos/api/index.html +++ b/tfjs-vis/demos/api/index.html @@ -657,6 +657,44 @@
+ Repeated labels are supported +
+ + + diff --git a/tfjs-vis/src/render/confusion_matrix.ts b/tfjs-vis/src/render/confusion_matrix.ts index 1f64ebb4be6..0b01a58ac88 100644 --- a/tfjs-vis/src/render/confusion_matrix.ts +++ b/tfjs-vis/src/render/confusion_matrix.ts @@ -95,7 +95,7 @@ export async function confusionMatrix( values.push({ label, prediction, - diagCount: count, + count, noFill: true, }); } else { @@ -103,6 +103,7 @@ export async function confusionMatrix( label, prediction, count, + scaleCount: count, }); // When not shading the diagonal we want to check if there is a non // zero value. If all values are zero we will not color them as the @@ -122,7 +123,7 @@ export async function confusionMatrix( for (const val of values) { if (val.noFill === true) { val.noFill = false; - val.count = val.diagCount; + val.scaleCount = val.count; } } } @@ -171,45 +172,56 @@ export async function confusionMatrix( 'layer': [ { // The matrix + 'transform': [ + {'filter': 'datum.noFill != true'}, + ], 'mark': { 'type': 'rect', }, 'encoding': { - 'fill': { - 'condition': { - 'test': 'datum["noFill"] == true', - 'value': 'white', - }, - 'field': 'count', + 'color': { + 'field': 'scaleCount', 'type': 'quantitative', 'scale': {'range': ['#f7fbff', '#4292c6']}, }, - 'tooltip': { - 'condition': { - 'test': 'datum["noFill"] == true', - 'field': 'diagCount', - 'type': 'nominal', - }, - 'field': 'count', - 'type': 'nominal', - } + 'tooltip': [ + {'field': 'label', 'type': 'nominal'}, + {'field': 'prediction', 'type': 'nominal'}, + {'field': 'count', 'type': 'quantitative'}, + ] }, - }, ] }; + if (options.shadeDiagonal === false) { + spec.layer.push( + { + // render unfilled rects for the diagonal + 'transform': [ + {'filter': 'datum.noFill == true'}, + ], + 'mark': { + 'type': 'rect', + 'fill': 'white', + }, + 'encoding': { + 'tooltip': [ + {'field': 'label', 'type': 'nominal'}, + {'field': 'prediction', 'type': 'nominal'}, + {'field': 'count', 'type': 'quantitative'}, + ] + }, + }, + ); + } + if (options.showTextOverlay) { spec.layer.push({ // The text labels 'mark': {'type': 'text', 'baseline': 'middle'}, 'encoding': { 'text': { - 'condition': { - 'test': 'datum["noFill"] == true', - 'field': 'diagCount', - 'type': 'nominal', - }, 'field': 'count', 'type': 'nominal', }, @@ -218,7 +230,6 @@ export async function confusionMatrix( } await embed(drawArea, spec, embedOpts); - return Promise.resolve(); } const defaultOpts: ConfusionMatrixOptions = { @@ -235,7 +246,9 @@ const defaultOpts: ConfusionMatrixOptions = { interface MatrixEntry { label: string; prediction: string; - count?: number; - diagCount?: number; + // The displayed count + count: number; + // The count values used to compute the color scale + scaleCount?: number; noFill?: boolean; } diff --git a/tfjs-vis/src/render/heatmap.ts b/tfjs-vis/src/render/heatmap.ts index 893730d1813..f06777503c1 100644 --- a/tfjs-vis/src/render/heatmap.ts +++ b/tfjs-vis/src/render/heatmap.ts @@ -67,73 +67,53 @@ export async function heatmap( let inputValues = data.values; if (options.rowMajor) { - let originalShape: number[]; - let transposed: tf.Tensor2D; - if (inputValues instanceof tf.Tensor) { - originalShape = inputValues.shape; - transposed = inputValues.transpose(); - } else { - originalShape = [inputValues.length, inputValues[0].length]; - transposed = - tf.tidy(() => tf.tensor2d(inputValues as number[][]).transpose()); - } - - assert( - transposed.rank === 2, - 'Input to renderHeatmap must be a 2d array or Tensor2d'); + inputValues = await convertToRowMajor(data.values); + } - // Download the intermediate tensor values and - // dispose the transposed tensor. - inputValues = await transposed.array(); - transposed.dispose(); + // Data validation + const {xTickLabels, yTickLabels} = data; + if (xTickLabels != null) { + const dimension = 0; + assertLabelsMatchShape(inputValues, xTickLabels, dimension); + } - const transposedShape = [inputValues.length, inputValues[0].length]; - assert( - originalShape[0] === transposedShape[1] && - originalShape[1] === transposedShape[0], - `Unexpected transposed shape. Original ${originalShape} : Transposed ${ - transposedShape}`); + // Note that we will only do a check on the first element of the second + // dimension. We do not protect users against passing in a ragged array. + if (yTickLabels != null) { + const dimension = 1; + assertLabelsMatchShape(inputValues, yTickLabels, dimension); } + // // Format data for vega spec; an array of objects, one for for each cell // in the matrix. - const values: MatrixEntry[] = []; - const {xTickLabels, yTickLabels} = data; + // + // If custom labels are passed in for xTickLabels or yTickLabels we need + // to make sure they are 'unique' before mapping them to visual properties. + // We therefore append the index of the label to the datum that will be used + // for that label in the x or y axis. We could do this in all cases but choose + // not to to avoid unnecessary string operations. + // + // We use IDX_SEPARATOR to demarcate the added index + const IDX_SEPARATOR = '@tfidx@'; - // These two branches are very similar but we want to do the test once - // rather than on every element access + const values: MatrixEntry[] = []; if (inputValues instanceof tf.Tensor) { assert( inputValues.rank === 2, 'Input to renderHeatmap must be a 2d array or Tensor2d'); - const shape = inputValues.shape; - if (xTickLabels) { - assert( - shape[0] === xTickLabels.length, - `Length of xTickLabels (${ - xTickLabels.length}) must match number of rows - (${shape[0]})`); - } - - if (yTickLabels) { - assert( - shape[1] === yTickLabels.length, - `Length of yTickLabels (${ - yTickLabels.length}) must match number of columns - (${shape[1]})`); - } - // This is a slightly specialized version of TensorBuffer.get, inlining it // avoids the overhead of a function call per data element access and is // specialized to only deal with the 2d case. const inputArray = await inputValues.data(); - const [numRows, numCols] = shape; + const [numRows, numCols] = inputValues.shape; for (let row = 0; row < numRows; row++) { - const x = xTickLabels ? xTickLabels[row] : row; + const x = xTickLabels ? `${xTickLabels[row]}${IDX_SEPARATOR}${row}` : row; for (let col = 0; col < numCols; col++) { - const y = yTickLabels ? yTickLabels[col] : col; + const y = + yTickLabels ? `${yTickLabels[col]}${IDX_SEPARATOR}${col}` : col; const index = (row * numCols) + col; const value = inputArray[index]; @@ -142,24 +122,12 @@ export async function heatmap( } } } else { - if (xTickLabels) { - assert( - inputValues.length === xTickLabels.length, - `Number of rows (${inputValues.length}) must match - number of xTickLabels (${xTickLabels.length})`); - } - const inputArray = inputValues; for (let row = 0; row < inputArray.length; row++) { - const x = xTickLabels ? xTickLabels[row] : row; - if (yTickLabels) { - assert( - inputValues[row].length === yTickLabels.length, - `Number of columns in row ${row} (${inputValues[row].length}) - must match length of yTickLabels (${yTickLabels.length})`); - } + const x = xTickLabels ? `${xTickLabels[row]}${IDX_SEPARATOR}${row}` : row; for (let col = 0; col < inputArray[row].length; col++) { - const y = yTickLabels ? yTickLabels[col] : col; + const y = + yTickLabels ? `${yTickLabels[col]}${IDX_SEPARATOR}${col}` : col; const value = inputArray[row][col]; values.push({x, y, value}); } @@ -194,29 +162,57 @@ export async function heatmap( 'scale': {'bandPaddingInner': 0, 'bandPaddingOuter': 0}, }, 'data': {'values': values}, - 'mark': 'rect', + 'mark': {'type': 'rect', 'tooltip': true}, 'encoding': { 'x': { 'field': 'x', 'type': options.xType, - // Maintain sort order of the axis if labels is passed in - 'scale': {'domain': xTickLabels}, 'title': options.xLabel, + 'sort': 'x', }, 'y': { 'field': 'y', 'type': options.yType, - // Maintain sort order of the axis if labels is passed in - 'scale': {'domain': yTickLabels}, 'title': options.yLabel, + 'sort': 'y', }, 'fill': { 'field': 'value', 'type': 'quantitative', - }, + } } }; + // + // Format custom labels to remove the appended indices + // + const suffixPattern = `${IDX_SEPARATOR}\\d+$`; + const suffixRegex = new RegExp(suffixPattern); + if (xTickLabels) { + // @ts-ignore + spec.encoding.x.axis = { + 'labelExpr': `replace(datum.value, regexp(/${suffixPattern}/), '')`, + }; + } + + if (yTickLabels) { + // @ts-ignore + spec.encoding.y.axis = { + 'labelExpr': `replace(datum.value, regexp(/${suffixPattern}/), '')`, + }; + } + + // Customize tooltip formatting to remove the appended indices + if (xTickLabels || yTickLabels) { + //@ts-ignore + embedOpts.tooltip = { + sanitize: (value: string|number) => { + const valueString = String(value); + return valueString.replace(suffixRegex, ''); + } + }; + } + let colorRange: string[]|string; switch (options.colorMap) { case 'blues': @@ -252,6 +248,54 @@ export async function heatmap( await embed(drawArea, spec, embedOpts); } +async function convertToRowMajor(inputValues: number[][]| + tf.Tensor2D): Promise