Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support custom tickLabels with duplicate values.
Fixes #1201
  • Loading branch information
tafsiri committed Sep 10, 2019
commit 2c9229fa2c4373db234f5728932ac3e89a830712
38 changes: 38 additions & 0 deletions tfjs-vis/demos/api/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,44 @@ <h3>render.heatmap(data, container, opts)</h3>
});
}
</script>

<p>
Repeated labels are supported
</p>
<div id='heatmap-cont5'></div>

<script class='show-script'>
{
const data = {
values: tf.tensor2d([
[10, 2, 3, 4, 5, 6],
[1, 20, 3, 4, 5, 6],
[1, 2, 30, 4, 5, 6],
[1, 2, 3, 40, 5, 6],
[1, 2, 3, 4, 50, 6],
[1, 2, 3, 4, 5, 60],
]),
xTickLabels: ['a', 'token', 'in', 'a', 'sequence', 'repeat'],
yTickLabels: ['a', 'token', 'in', 'a', 'sequence', 'repeat'],
}

// Render to visor
const surface = tfvis.visor().surface({ name: 'Heatmap', tab: 'Tensor Charts' });
tfvis.render.heatmap(surface, data, {
xLabel: 'Thing',
yLabel: 'Property',
});

// Render to page
const container = document.getElementById('heatmap-cont5');
tfvis.render.heatmap(container, data, {
width: 500,
height: 500,
xLabel: 'Thing',
yLabel: 'Property',
});
}
</script>
</section>

</article>
Expand Down
182 changes: 113 additions & 69 deletions tfjs-vis/src/render/heatmap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,73 +67,53 @@ export async function heatmap(

let inputValues = data.values;
if (options.rowMajor) {
let originalShape: number[];
let transposed: tf.Tensor2D;
if (inputValues instanceof tf.Tensor) {
originalShape = inputValues.shape;
transposed = inputValues.transpose();
} else {
originalShape = [inputValues.length, inputValues[0].length];
transposed =
tf.tidy(() => tf.tensor2d(inputValues as number[][]).transpose());
}

assert(
transposed.rank === 2,
'Input to renderHeatmap must be a 2d array or Tensor2d');
inputValues = await convertToRowMajor(data.values);
}

// Download the intermediate tensor values and
// dispose the transposed tensor.
inputValues = await transposed.array();
transposed.dispose();
// Data validation
const {xTickLabels, yTickLabels} = data;
if (xTickLabels) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!= null

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const dimension = 0;
assertLabelsMatchShape(inputValues, xTickLabels, dimension);
}

const transposedShape = [inputValues.length, inputValues[0].length];
assert(
originalShape[0] === transposedShape[1] &&
originalShape[1] === transposedShape[0],
`Unexpected transposed shape. Original ${originalShape} : Transposed ${
transposedShape}`);
// Note that we will only do a check on the first element of the second
// dimension. We do not protect users against passing in a ragged array.
if (yTickLabels) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!= null

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const dimension = 1;
assertLabelsMatchShape(inputValues, yTickLabels, dimension);
}

//
// Format data for vega spec; an array of objects, one for for each cell
// in the matrix.
const values: MatrixEntry[] = [];
const {xTickLabels, yTickLabels} = data;
//
// If custom labels are passed in for xTickLabels or yTickLabels we need
// to make sure they are 'unique' before mapping them to visual properties.
// We therefore append the index of the label to the datum that will be used
// for that label in the x or y axis. We could do this in all cases but choose
// not to to avoid unnecessary string operations.
//
// We use IDX_SEPARATOR to demarcate the added index
const IDX_SEPARATOR = '@tfidx@';

// These two branches are very similar but we want to do the test once
// rather than on every element access
const values: MatrixEntry[] = [];
if (inputValues instanceof tf.Tensor) {
assert(
inputValues.rank === 2,
'Input to renderHeatmap must be a 2d array or Tensor2d');

const shape = inputValues.shape;
if (xTickLabels) {
assert(
shape[0] === xTickLabels.length,
`Length of xTickLabels (${
xTickLabels.length}) must match number of rows
(${shape[0]})`);
}

if (yTickLabels) {
assert(
shape[1] === yTickLabels.length,
`Length of yTickLabels (${
yTickLabels.length}) must match number of columns
(${shape[1]})`);
}

// This is a slightly specialized version of TensorBuffer.get, inlining it
// avoids the overhead of a function call per data element access and is
// specialized to only deal with the 2d case.
const inputArray = await inputValues.data();
const [numRows, numCols] = shape;
const [numRows, numCols] = inputValues.shape;

for (let row = 0; row < numRows; row++) {
const x = xTickLabels ? xTickLabels[row] : row;
const x = xTickLabels ? `${xTickLabels[row]}${IDX_SEPARATOR}${row}` : row;
for (let col = 0; col < numCols; col++) {
const y = yTickLabels ? yTickLabels[col] : col;
const y =
yTickLabels ? `${yTickLabels[col]}${IDX_SEPARATOR}${col}` : col;

const index = (row * numCols) + col;
const value = inputArray[index];
Expand All @@ -142,24 +122,12 @@ export async function heatmap(
}
}
} else {
if (xTickLabels) {
assert(
inputValues.length === xTickLabels.length,
`Number of rows (${inputValues.length}) must match
number of xTickLabels (${xTickLabels.length})`);
}

const inputArray = inputValues;
for (let row = 0; row < inputArray.length; row++) {
const x = xTickLabels ? xTickLabels[row] : row;
if (yTickLabels) {
assert(
inputValues[row].length === yTickLabels.length,
`Number of columns in row ${row} (${inputValues[row].length})
must match length of yTickLabels (${yTickLabels.length})`);
}
const x = xTickLabels ? `${xTickLabels[row]}${IDX_SEPARATOR}${row}` : row;
for (let col = 0; col < inputArray[row].length; col++) {
const y = yTickLabels ? yTickLabels[col] : col;
const y =
yTickLabels ? `${yTickLabels[col]}${IDX_SEPARATOR}${col}` : col;
const value = inputArray[row][col];
values.push({x, y, value});
}
Expand Down Expand Up @@ -194,29 +162,57 @@ export async function heatmap(
'scale': {'bandPaddingInner': 0, 'bandPaddingOuter': 0},
},
'data': {'values': values},
'mark': 'rect',
'mark': {'type': 'rect', 'tooltip': true},
'encoding': {
'x': {
'field': 'x',
'type': options.xType,
// Maintain sort order of the axis if labels is passed in
'scale': {'domain': xTickLabels},
'title': options.xLabel,
'sort': 'x',
},
'y': {
'field': 'y',
'type': options.yType,
// Maintain sort order of the axis if labels is passed in
'scale': {'domain': yTickLabels},
'title': options.yLabel,
'sort': 'y',
},
'fill': {
'field': 'value',
'type': 'quantitative',
},
}
}
};

//
// Format custom labels to remove the appended indices
//
const suffixPattern = `${IDX_SEPARATOR}\\d+$`;
const suffixRegex = new RegExp(suffixPattern);
if (xTickLabels) {
// @ts-ignore
spec.encoding.x.axis = {
'labelExpr': `replace(datum.value, regexp(/${suffixPattern}/), '')`,
};
}

if (yTickLabels) {
// @ts-ignore
spec.encoding.y.axis = {
'labelExpr': `replace(datum.value, regexp(/${suffixPattern}/), '')`,
};
}

// Customize tooltip formatting to remove the appended indices
if (xTickLabels || yTickLabels) {
//@ts-ignore
embedOpts.tooltip = {
sanitize: (value: string|number) => {
const valueString = String(value);
return valueString.replace(suffixRegex, '');
}
};
}

let colorRange: string[]|string;
switch (options.colorMap) {
case 'blues':
Expand Down Expand Up @@ -252,6 +248,54 @@ export async function heatmap(
await embed(drawArea, spec, embedOpts);
}

async function convertToRowMajor(inputValues: number[][]|
tf.Tensor2D): Promise<number[][]> {
let originalShape: number[];
let transposed: tf.Tensor2D;
if (inputValues instanceof tf.Tensor) {
originalShape = inputValues.shape;
transposed = inputValues.transpose();
} else {
originalShape = [inputValues.length, inputValues[0].length];
transposed = tf.tidy(() => tf.tensor2d(inputValues).transpose());
}

assert(
transposed.rank === 2,
'Input to renderHeatmap must be a 2d array or Tensor2d');

// Download the intermediate tensor values and
// dispose the transposed tensor.
const transposedValues = await transposed.array();
transposed.dispose();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have memory unit tests for this stuff?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!


const transposedShape = [transposedValues.length, transposedValues[0].length];
assert(
originalShape[0] === transposedShape[1] &&
originalShape[1] === transposedShape[0],
`Unexpected transposed shape. Original ${originalShape} : Transposed ${
transposedShape}`);
return transposedValues;
}

function assertLabelsMatchShape(
inputValues: number[][]|tf.Tensor2D, labels: string[], dimension: 0|1) {
const shape = inputValues instanceof tf.Tensor ?
inputValues.shape :
[inputValues.length, inputValues[0].length];
if (dimension === 0) {
assert(
shape[0] === labels.length,
`Length of xTickLabels (${labels.length}) must match number of rows` +
` (${shape[0]})`);
} else if (dimension === 1) {
assert(
shape[1] === labels.length,
`Length of yTickLabels (${
labels.length}) must match number of columns (${shape[1]})`);
}
}

const defaultOpts: HeatmapOptions = {
xLabel: null,
yLabel: null,
Expand Down
34 changes: 34 additions & 0 deletions tfjs-vis/src/render/heatmap_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,38 @@ describe('renderHeatmap', () => {
expect(document.querySelectorAll('canvas').length).toBe(1);
expect(document.querySelector('canvas').height).toBe(200 * pixelRatio);
});

it('throws on wrong number of xTickLabels', async () => {
const data: HeatmapData = {
values: [[4, 2, 8], [1, 7, 2], [3, 3, 20], [8, 2, 8]],
xTickLabels: ['alpha'],
yTickLabels: ['first', 'second', 'third', 'fourth'],
};

const container = document.getElementById('container');
let threw = false;
try {
await heatmap(container, data, {height: 200});
} catch (e) {
threw = true;
}
expect(threw).toBe(true);
});

it('throws on wrong number of yTickLabels', async () => {
const data: HeatmapData = {
values: [[4, 2, 8], [1, 7, 2], [3, 3, 20], [8, 2, 8]],
xTickLabels: ['alpha', 'beta', 'gamma'],
yTickLabels: ['first'],
};

const container = document.getElementById('container');
let threw = false;
try {
await heatmap(container, data, {height: 200});
} catch (e) {
threw = true;
}
expect(threw).toBe(true);
});
});