-
Notifications
You must be signed in to change notification settings - Fork 2k
[tfjs-vis] Support custom tickLabels with duplicate values in heatmaps #2012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
2c9229f
0bb0ef7
811b4f7
7044eb2
03def42
8145a76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Fixes #1201
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,73 +67,53 @@ export async function heatmap( | |
|
|
||
| let inputValues = data.values; | ||
| if (options.rowMajor) { | ||
| let originalShape: number[]; | ||
| let transposed: tf.Tensor2D; | ||
| if (inputValues instanceof tf.Tensor) { | ||
| originalShape = inputValues.shape; | ||
| transposed = inputValues.transpose(); | ||
| } else { | ||
| originalShape = [inputValues.length, inputValues[0].length]; | ||
| transposed = | ||
| tf.tidy(() => tf.tensor2d(inputValues as number[][]).transpose()); | ||
| } | ||
|
|
||
| assert( | ||
| transposed.rank === 2, | ||
| 'Input to renderHeatmap must be a 2d array or Tensor2d'); | ||
| inputValues = await convertToRowMajor(data.values); | ||
| } | ||
|
|
||
| // Download the intermediate tensor values and | ||
| // dispose the transposed tensor. | ||
| inputValues = await transposed.array(); | ||
| transposed.dispose(); | ||
| // Data validation | ||
| const {xTickLabels, yTickLabels} = data; | ||
| if (xTickLabels) { | ||
| const dimension = 0; | ||
| assertLabelsMatchShape(inputValues, xTickLabels, dimension); | ||
| } | ||
|
|
||
| const transposedShape = [inputValues.length, inputValues[0].length]; | ||
| assert( | ||
| originalShape[0] === transposedShape[1] && | ||
| originalShape[1] === transposedShape[0], | ||
| `Unexpected transposed shape. Original ${originalShape} : Transposed ${ | ||
| transposedShape}`); | ||
| // Note that we will only do a check on the first element of the second | ||
| // dimension. We do not protect users against passing in a ragged array. | ||
| if (yTickLabels) { | ||
|
||
| const dimension = 1; | ||
| assertLabelsMatchShape(inputValues, yTickLabels, dimension); | ||
| } | ||
|
|
||
| // | ||
| // Format data for vega spec; an array of objects, one for for each cell | ||
| // in the matrix. | ||
| const values: MatrixEntry[] = []; | ||
| const {xTickLabels, yTickLabels} = data; | ||
| // | ||
| // If custom labels are passed in for xTickLabels or yTickLabels we need | ||
| // to make sure they are 'unique' before mapping them to visual properties. | ||
| // We therefore append the index of the label to the datum that will be used | ||
| // for that label in the x or y axis. We could do this in all cases but choose | ||
| // not to to avoid unnecessary string operations. | ||
| // | ||
| // We use IDX_SEPARATOR to demarcate the added index | ||
| const IDX_SEPARATOR = '@tfidx@'; | ||
|
|
||
| // These two branches are very similar but we want to do the test once | ||
| // rather than on every element access | ||
| const values: MatrixEntry[] = []; | ||
| if (inputValues instanceof tf.Tensor) { | ||
| assert( | ||
| inputValues.rank === 2, | ||
| 'Input to renderHeatmap must be a 2d array or Tensor2d'); | ||
|
|
||
| const shape = inputValues.shape; | ||
| if (xTickLabels) { | ||
| assert( | ||
| shape[0] === xTickLabels.length, | ||
| `Length of xTickLabels (${ | ||
| xTickLabels.length}) must match number of rows | ||
| (${shape[0]})`); | ||
| } | ||
|
|
||
| if (yTickLabels) { | ||
| assert( | ||
| shape[1] === yTickLabels.length, | ||
| `Length of yTickLabels (${ | ||
| yTickLabels.length}) must match number of columns | ||
| (${shape[1]})`); | ||
| } | ||
|
|
||
| // This is a slightly specialized version of TensorBuffer.get, inlining it | ||
| // avoids the overhead of a function call per data element access and is | ||
| // specialized to only deal with the 2d case. | ||
| const inputArray = await inputValues.data(); | ||
| const [numRows, numCols] = shape; | ||
| const [numRows, numCols] = inputValues.shape; | ||
|
|
||
| for (let row = 0; row < numRows; row++) { | ||
| const x = xTickLabels ? xTickLabels[row] : row; | ||
| const x = xTickLabels ? `${xTickLabels[row]}${IDX_SEPARATOR}${row}` : row; | ||
| for (let col = 0; col < numCols; col++) { | ||
| const y = yTickLabels ? yTickLabels[col] : col; | ||
| const y = | ||
| yTickLabels ? `${yTickLabels[col]}${IDX_SEPARATOR}${col}` : col; | ||
|
|
||
| const index = (row * numCols) + col; | ||
| const value = inputArray[index]; | ||
|
|
@@ -142,24 +122,12 @@ export async function heatmap( | |
| } | ||
| } | ||
| } else { | ||
| if (xTickLabels) { | ||
| assert( | ||
| inputValues.length === xTickLabels.length, | ||
| `Number of rows (${inputValues.length}) must match | ||
| number of xTickLabels (${xTickLabels.length})`); | ||
| } | ||
|
|
||
| const inputArray = inputValues; | ||
| for (let row = 0; row < inputArray.length; row++) { | ||
| const x = xTickLabels ? xTickLabels[row] : row; | ||
| if (yTickLabels) { | ||
| assert( | ||
| inputValues[row].length === yTickLabels.length, | ||
| `Number of columns in row ${row} (${inputValues[row].length}) | ||
| must match length of yTickLabels (${yTickLabels.length})`); | ||
| } | ||
| const x = xTickLabels ? `${xTickLabels[row]}${IDX_SEPARATOR}${row}` : row; | ||
| for (let col = 0; col < inputArray[row].length; col++) { | ||
| const y = yTickLabels ? yTickLabels[col] : col; | ||
| const y = | ||
| yTickLabels ? `${yTickLabels[col]}${IDX_SEPARATOR}${col}` : col; | ||
| const value = inputArray[row][col]; | ||
| values.push({x, y, value}); | ||
| } | ||
|
|
@@ -194,29 +162,57 @@ export async function heatmap( | |
| 'scale': {'bandPaddingInner': 0, 'bandPaddingOuter': 0}, | ||
| }, | ||
| 'data': {'values': values}, | ||
| 'mark': 'rect', | ||
| 'mark': {'type': 'rect', 'tooltip': true}, | ||
| 'encoding': { | ||
| 'x': { | ||
| 'field': 'x', | ||
| 'type': options.xType, | ||
| // Maintain sort order of the axis if labels is passed in | ||
| 'scale': {'domain': xTickLabels}, | ||
| 'title': options.xLabel, | ||
| 'sort': 'x', | ||
| }, | ||
| 'y': { | ||
| 'field': 'y', | ||
| 'type': options.yType, | ||
| // Maintain sort order of the axis if labels is passed in | ||
| 'scale': {'domain': yTickLabels}, | ||
| 'title': options.yLabel, | ||
| 'sort': 'y', | ||
| }, | ||
| 'fill': { | ||
| 'field': 'value', | ||
| 'type': 'quantitative', | ||
| }, | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| // | ||
| // Format custom labels to remove the appended indices | ||
| // | ||
| const suffixPattern = `${IDX_SEPARATOR}\\d+$`; | ||
| const suffixRegex = new RegExp(suffixPattern); | ||
| if (xTickLabels) { | ||
| // @ts-ignore | ||
| spec.encoding.x.axis = { | ||
| 'labelExpr': `replace(datum.value, regexp(/${suffixPattern}/), '')`, | ||
| }; | ||
| } | ||
|
|
||
| if (yTickLabels) { | ||
| // @ts-ignore | ||
| spec.encoding.y.axis = { | ||
| 'labelExpr': `replace(datum.value, regexp(/${suffixPattern}/), '')`, | ||
| }; | ||
| } | ||
|
|
||
| // Customize tooltip formatting to remove the appended indices | ||
| if (xTickLabels || yTickLabels) { | ||
| //@ts-ignore | ||
| embedOpts.tooltip = { | ||
| sanitize: (value: string|number) => { | ||
| const valueString = String(value); | ||
| return valueString.replace(suffixRegex, ''); | ||
| } | ||
| }; | ||
| } | ||
|
|
||
| let colorRange: string[]|string; | ||
| switch (options.colorMap) { | ||
| case 'blues': | ||
|
|
@@ -252,6 +248,54 @@ export async function heatmap( | |
| await embed(drawArea, spec, embedOpts); | ||
| } | ||
|
|
||
| async function convertToRowMajor(inputValues: number[][]| | ||
| tf.Tensor2D): Promise<number[][]> { | ||
| let originalShape: number[]; | ||
| let transposed: tf.Tensor2D; | ||
| if (inputValues instanceof tf.Tensor) { | ||
| originalShape = inputValues.shape; | ||
| transposed = inputValues.transpose(); | ||
| } else { | ||
| originalShape = [inputValues.length, inputValues[0].length]; | ||
| transposed = tf.tidy(() => tf.tensor2d(inputValues).transpose()); | ||
| } | ||
|
|
||
| assert( | ||
| transposed.rank === 2, | ||
| 'Input to renderHeatmap must be a 2d array or Tensor2d'); | ||
|
|
||
| // Download the intermediate tensor values and | ||
| // dispose the transposed tensor. | ||
| const transposedValues = await transposed.array(); | ||
| transposed.dispose(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you have memory unit tests for this stuff?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep! |
||
|
|
||
| const transposedShape = [transposedValues.length, transposedValues[0].length]; | ||
| assert( | ||
| originalShape[0] === transposedShape[1] && | ||
| originalShape[1] === transposedShape[0], | ||
| `Unexpected transposed shape. Original ${originalShape} : Transposed ${ | ||
| transposedShape}`); | ||
| return transposedValues; | ||
| } | ||
|
|
||
| function assertLabelsMatchShape( | ||
| inputValues: number[][]|tf.Tensor2D, labels: string[], dimension: 0|1) { | ||
| const shape = inputValues instanceof tf.Tensor ? | ||
| inputValues.shape : | ||
| [inputValues.length, inputValues[0].length]; | ||
| if (dimension === 0) { | ||
| assert( | ||
| shape[0] === labels.length, | ||
| `Length of xTickLabels (${labels.length}) must match number of rows` + | ||
| ` (${shape[0]})`); | ||
| } else if (dimension === 1) { | ||
| assert( | ||
| shape[1] === labels.length, | ||
| `Length of yTickLabels (${ | ||
| labels.length}) must match number of columns (${shape[1]})`); | ||
| } | ||
| } | ||
|
|
||
| const defaultOpts: HeatmapOptions = { | ||
| xLabel: null, | ||
| yLabel: null, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!= null
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done