Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions packages/typegpu/src/data/dataTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,9 @@ export class InfixDispatch {
readonly operator: (lhs: Snippet, rhs: Snippet) => Snippet,
) {}
}

export class MatrixColumnsAccess {
constructor(
readonly matrix: Snippet,
) {}
}
36 changes: 25 additions & 11 deletions packages/typegpu/src/tgsl/wgslGenerator.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import * as tinyest from 'tinyest';
import { stitch } from '../core/resolve/stitch.ts';
import { arrayOf } from '../data/array.ts';
import {
type AnyData,
InfixDispatch,
isData,
isLooseData,
MatrixColumnsAccess,
UnknownData,
} from '../data/dataTypes.ts';
import { isSnippet, snip, type Snippet } from '../data/snippet.ts';
import { abstractInt, bool, u32 } from '../data/numeric.ts';
import { isSnippet, snip, type Snippet } from '../data/snippet.ts';
import * as wgsl from '../data/wgslTypes.ts';
import { ResolutionError, WgslTypeError } from '../errors.ts';
import { getName } from '../shared/meta.ts';
import { $internal } from '../shared/symbols.ts';
import { add, div, mul, sub } from '../std/operators.ts';
import { type FnArgsConversionHint, isMarkedInternal } from '../types.ts';
import {
convertStructValues,
convertToCommonType,
tryConvertSnippet,
} from './conversion.ts';
import {
coerceToSnippet,
concretize,
Expand All @@ -22,13 +30,6 @@ import {
getTypeForPropAccess,
numericLiteralToSnippet,
} from './generationHelpers.ts';
import {
convertStructValues,
convertToCommonType,
tryConvertSnippet,
} from './conversion.ts';
import { add, div, mul, sub } from '../std/operators.ts';
import { stitch } from '../core/resolve/stitch.ts';

const { NodeTypeCatalog: NODE } = tinyest;

Expand Down Expand Up @@ -309,7 +310,7 @@ export function generateExpression(
}

if (wgsl.isMat(target.dataType) && property === 'columns') {
return snip(target.value, target.dataType);
return snip(new MatrixColumnsAccess(target), UnknownData);
}

if (
Expand All @@ -329,11 +330,18 @@ export function generateExpression(
if (expression[0] === NODE.indexAccess) {
// Index Access
const [_, targetNode, propertyNode] = expression;
const target = generateExpression(ctx, targetNode);
const property = generateExpression(ctx, propertyNode);
const targetStr = ctx.resolve(target.value, target.dataType);
const propertyStr = ctx.resolve(property.value, property.dataType);

const target = generateExpression(ctx, targetNode);
if (target.value instanceof MatrixColumnsAccess) {
return snip(
stitch`${target.value.matrix}[${propertyStr}]`,
getTypeForIndexAccess(target.value.matrix.dataType as AnyData),
);
}
const targetStr = ctx.resolve(target.value, target.dataType);

if (target.dataType.type === 'unknown') {
// No idea what the type is, so we act on the snippet's value and try to guess

Expand All @@ -351,6 +359,12 @@ export function generateExpression(
);
}

if (wgsl.isMat(target.dataType)) {
throw new Error(
"The only way of accessing matrix elements in TGSL is through the 'columns' property.",
);
}

if (wgsl.isPtr(target.dataType)) {
return snip(
`(*${targetStr})[${propertyStr}]`,
Expand Down
53 changes: 52 additions & 1 deletion packages/typegpu/tests/tgsl/wgslGenerator.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import * as tinyest from 'tinyest';
import { beforeEach, describe, expect } from 'vitest';
import { snip } from '../../src/data/snippet.ts';
import * as d from '../../src/data/index.ts';
import { abstractFloat, abstractInt } from '../../src/data/numeric.ts';
import { snip } from '../../src/data/snippet.ts';
import { Void, type WgslArray } from '../../src/data/wgslTypes.ts';
import { provideCtx } from '../../src/execMode.ts';
import tgpu from '../../src/index.ts';
Expand Down Expand Up @@ -941,4 +941,55 @@ describe('wgslGenerator', () => {
}`),
);
});

it('throws error when accessing matrix elements directly', () => {
const testFn = tgpu.fn([])(() => {
const matrix = d.mat4x4f();
const element = matrix[4];
});

expect(() => parseResolved({ testFn }))
.toThrowErrorMatchingInlineSnapshot(`
[Error: Resolution of the following tree failed:
- <root>
- fn:testFn: The only way of accessing matrix elements in TGSL is through the 'columns' property.]
`);
});

it('generates correct code when accessing matrix elements through .columns', () => {
const testFn = tgpu.fn([])(() => {
const matrix = d.mat4x4f();
const column = matrix.columns[1];
const element = column[0];
const directElement = matrix.columns[1][0];
});

expect(tgpu.resolve({ externals: { testFn } })).toMatchInlineSnapshot(`
"fn testFn_0() {
var matrix = mat4x4f();
var column = matrix[1];
var element = column[0];
var directElement = matrix[1][0];
}"
`);
});

it('resolves when accessing matrix elements through .columns', () => {
const matrix = tgpu['~unstable'].workgroupVar(d.mat4x4f);
const index = tgpu['~unstable'].workgroupVar(d.u32);

const testFn = tgpu.fn([])(() => {
const element = matrix.$.columns[index.$];
});

expect(tgpu.resolve({ externals: { testFn } })).toMatchInlineSnapshot(`
"var<workgroup> index_1: u32;

var<workgroup> matrix_2: mat4x4f;

fn testFn_0() {
var element = matrix_2[index_1];
}"
`);
});
});