Skip to content

Commit d80eaa4

Browse files
feat: Callable unstructs and disarrays (#1584)
1 parent dc4da0d commit d80eaa4

File tree

13 files changed

+258
-91
lines changed

13 files changed

+258
-91
lines changed

packages/typegpu/src/core/buffer/bufferUsage.ts

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import type { AnyData } from '../../data/dataTypes.ts';
2+
import { schemaCallWrapper } from '../../data/utils.ts';
23
import type { AnyWgslData, BaseData } from '../../data/wgslTypes.ts';
3-
import { isUsableAsStorage, type StorageFlag } from '../../extension.ts';
4+
import { IllegalBufferAccessError } from '../../errors.ts';
45
import { getExecMode, inCodegenMode, isInsideTgpuFn } from '../../execMode.ts';
6+
import { isUsableAsStorage, type StorageFlag } from '../../extension.ts';
57
import type { TgpuNamable } from '../../shared/meta.ts';
68
import { getName, setName } from '../../shared/meta.ts';
79
import type { Infer, InferGPU } from '../../shared/repr.ts';
@@ -12,6 +14,7 @@ import {
1214
$repr,
1315
$wgslDataType,
1416
} from '../../shared/symbols.ts';
17+
import { assertExhaustive } from '../../shared/utilityTypes.ts';
1518
import type { LayoutMembership } from '../../tgpuBindGroupLayout.ts';
1619
import type {
1720
BindableBufferUsage,
@@ -20,9 +23,6 @@ import type {
2023
} from '../../types.ts';
2124
import { valueProxyHandler } from '../valueProxyUtils.ts';
2225
import type { TgpuBuffer, UniformFlag } from './buffer.ts';
23-
import { schemaCloneWrapper, schemaDefaultWrapper } from '../../data/utils.ts';
24-
import { assertExhaustive } from '../../shared/utilityTypes.ts';
25-
import { IllegalBufferAccessError } from '../../errors.ts';
2626

2727
// ----------
2828
// Public API
@@ -166,8 +166,7 @@ class TgpuFixedBufferImpl<
166166
if (!mode.buffers.has(this.buffer)) { // Not initialized yet
167167
mode.buffers.set(
168168
this.buffer,
169-
schemaCloneWrapper(this.buffer.dataType, this.buffer.initial) ??
170-
schemaDefaultWrapper(this.buffer.dataType),
169+
schemaCallWrapper(this.buffer.dataType, this.buffer.initial),
171170
);
172171
}
173172
return mode.buffers.get(this.buffer) as InferGPU<TData>;

packages/typegpu/src/core/function/tgpuFn.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { type AnyData, UnknownData } from '../../data/dataTypes.ts';
2-
import { schemaCloneWrapper } from '../../data/utils.ts';
32
import { snip } from '../../data/snippet.ts';
3+
import { schemaCallWrapper } from '../../data/utils.ts';
44
import { Void } from '../../data/wgslTypes.ts';
55
import { ExecutionError } from '../../errors.ts';
66
import { provideInsideTgpuFn } from '../../execMode.ts';
@@ -32,6 +32,7 @@ import {
3232
type TgpuAccessor,
3333
type TgpuSlot,
3434
} from '../slot/slotTypes.ts';
35+
import { createDualImpl } from './dualImpl.ts';
3536
import { createFnCore, type FnCore } from './fnCore.ts';
3637
import type {
3738
AnyFn,
@@ -41,7 +42,6 @@ import type {
4142
InheritArgNames,
4243
} from './fnTypes.ts';
4344
import { stripTemplate } from './templateUtils.ts';
44-
import { createDualImpl } from './dualImpl.ts';
4545

4646
// ----------
4747
// Public API
@@ -223,7 +223,7 @@ function createFn<ImplSchema extends AnyFn>(
223223
}
224224

225225
const castAndCopiedArgs = args.map((arg, index) =>
226-
schemaCloneWrapper(shell.argTypes[index], arg)
226+
schemaCallWrapper(shell.argTypes[index] as unknown as AnyData, arg)
227227
) as InferArgs<Parameters<ImplSchema>>;
228228

229229
return implementation(...castAndCopiedArgs);

packages/typegpu/src/data/array.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { $internal } from '../shared/symbols.ts';
22
import { sizeOf } from './sizeOf.ts';
3-
import { schemaCloneWrapper, schemaDefaultWrapper } from './utils.ts';
3+
import { schemaCallWrapper } from './utils.ts';
44
import type { AnyWgslData, WgslArray } from './wgslTypes.ts';
55

66
// ----------
@@ -27,16 +27,13 @@ export function arrayOf<TElement extends AnyWgslData>(
2727
const arraySchema = (elements?: TElement[]) => {
2828
if (elements && elements.length !== elementCount) {
2929
throw new Error(
30-
`Array schema of ${elementCount} elements of type ${elementType.type} called with ${elements.length} arguments.`,
30+
`Array schema of ${elementCount} elements of type ${elementType.type} called with ${elements.length} argument(s).`,
3131
);
3232
}
3333

3434
return Array.from(
3535
{ length: elementCount },
36-
(_, i) =>
37-
elements
38-
? schemaCloneWrapper(elementType, elements[i])
39-
: schemaDefaultWrapper(elementType),
36+
(_, i) => schemaCallWrapper(elementType, elements?.[i]),
4037
);
4138
};
4239
Object.setPrototypeOf(arraySchema, WgslArrayImpl);

packages/typegpu/src/data/dataTypes.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ export type TgpuDualFn<TImpl extends (...args: never[]) => unknown> =
4444
*/
4545
export interface Disarray<TElement extends wgsl.BaseData = wgsl.BaseData>
4646
extends wgsl.BaseData {
47+
<T extends TElement>(elements: Infer<T>[]): Infer<T>[];
48+
(): Infer<TElement>[];
4749
readonly type: 'disarray';
4850
readonly elementCount: number;
4951
readonly elementType: TElement;
@@ -72,6 +74,7 @@ export interface Unstruct<
7274
TProps extends Record<string, wgsl.BaseData> = any,
7375
> extends wgsl.BaseData, TgpuNamable {
7476
(props: Prettify<InferRecord<TProps>>): Prettify<InferRecord<TProps>>;
77+
(): Prettify<InferRecord<TProps>>;
7578
readonly type: 'unstruct';
7679
readonly propTypes: TProps;
7780

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
1-
import type {
2-
Infer,
3-
InferPartial,
4-
IsValidVertexSchema,
5-
} from '../shared/repr.ts';
61
import { $internal } from '../shared/symbols.ts';
7-
import type {
8-
$invalidSchemaReason,
9-
$repr,
10-
$reprPartial,
11-
$validVertexSchema,
12-
} from '../shared/symbols.ts';
132
import type { AnyData, Disarray } from './dataTypes.ts';
3+
import { schemaCallWrapper } from './utils.ts';
144

155
// ----------
166
// Public API
@@ -31,42 +21,49 @@ import type { AnyData, Disarray } from './dataTypes.ts';
3121
* const disarray = d.disarrayOf(d.align(16, d.vec3f), 3);
3222
*
3323
* @param elementType The type of elements in the array.
34-
* @param count The number of elements in the array.
24+
* @param elementCount The number of elements in the array.
3525
*/
3626
export function disarrayOf<TElement extends AnyData>(
3727
elementType: TElement,
38-
count: number,
28+
elementCount: number,
3929
): Disarray<TElement> {
40-
return new DisarrayImpl(elementType, count);
30+
// In the schema call, create and return a deep copy
31+
// by wrapping all the values in `elementType` schema calls.
32+
const disarraySchema = (elements?: TElement[]) => {
33+
if (elements && elements.length !== elementCount) {
34+
throw new Error(
35+
`Disarray schema of ${elementCount} elements of type ${elementType.type} called with ${elements.length} argument(s).`,
36+
);
37+
}
38+
39+
return Array.from(
40+
{ length: elementCount },
41+
(_, i) => schemaCallWrapper(elementType, elements?.[i]),
42+
);
43+
};
44+
Object.setPrototypeOf(disarraySchema, DisarrayImpl);
45+
46+
disarraySchema.elementType = elementType;
47+
48+
if (!Number.isInteger(elementCount) || elementCount < 0) {
49+
throw new Error(
50+
`Cannot create disarray schema with invalid element count: ${elementCount}.`,
51+
);
52+
}
53+
disarraySchema.elementCount = elementCount;
54+
55+
return disarraySchema as unknown as Disarray<TElement>;
4156
}
4257

4358
// --------------
4459
// Implementation
4560
// --------------
4661

47-
class DisarrayImpl<TElement extends AnyData> implements Disarray<TElement> {
48-
public readonly [$internal] = true;
49-
public readonly type = 'disarray';
50-
51-
// Type-tokens, not available at runtime
52-
declare readonly [$repr]: Infer<TElement>[];
53-
declare readonly [$reprPartial]: {
54-
idx: number;
55-
value: InferPartial<TElement>;
56-
}[];
57-
declare readonly [$validVertexSchema]: IsValidVertexSchema<TElement>;
58-
declare readonly [$invalidSchemaReason]:
59-
Disarray[typeof $invalidSchemaReason];
60-
// ---
62+
const DisarrayImpl = {
63+
[$internal]: true,
64+
type: 'disarray',
6165

62-
constructor(
63-
public readonly elementType: TElement,
64-
public readonly elementCount: number,
65-
) {
66-
if (!Number.isInteger(elementCount) || elementCount < 0) {
67-
throw new Error(
68-
`Cannot create disarray schema with invalid element count: ${elementCount}.`,
69-
);
70-
}
71-
}
72-
}
66+
toString(this: Disarray): string {
67+
return `disarrayOf(${this.elementType}, ${this.elementCount})`;
68+
},
69+
};

packages/typegpu/src/data/struct.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { getName, setName } from '../shared/meta.ts';
22
import { $internal } from '../shared/symbols.ts';
3-
import { schemaCloneWrapper, schemaDefaultWrapper } from './utils.ts';
3+
import type { AnyData } from './dataTypes.ts';
4+
import { schemaCallWrapper } from './utils.ts';
45
import type { AnyWgslData, BaseData, WgslStruct } from './wgslTypes.ts';
56

67
// ----------
@@ -44,9 +45,7 @@ function INTERNAL_createStruct<TProps extends Record<string, BaseData>>(
4445
Object.fromEntries(
4546
Object.entries(props).map(([key, schema]) => [
4647
key,
47-
instanceProps
48-
? schemaCloneWrapper(schema, instanceProps[key])
49-
: schemaDefaultWrapper(schema),
48+
schemaCallWrapper(schema as AnyData, instanceProps?.[key]),
5049
]),
5150
);
5251

packages/typegpu/src/data/unstruct.ts

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { getName, setName } from '../shared/meta.ts';
22
import { $internal } from '../shared/symbols.ts';
3-
import type { Unstruct } from './dataTypes.ts';
3+
import type { AnyData, Unstruct } from './dataTypes.ts';
4+
import { schemaCallWrapper } from './utils.ts';
45
import type { BaseData } from './wgslTypes.ts';
56

67
// ----------
@@ -28,11 +29,19 @@ import type { BaseData } from './wgslTypes.ts';
2829
export function unstruct<TProps extends Record<string, BaseData>>(
2930
properties: TProps,
3031
): Unstruct<TProps> {
31-
const unstruct = <T>(props: T) => props;
32-
Object.setPrototypeOf(unstruct, UnstructImpl);
33-
unstruct.propTypes = properties;
32+
// In the schema call, create and return a deep copy
33+
// by wrapping all the values in corresponding schema calls.
34+
const unstructSchema = (instanceProps?: TProps) =>
35+
Object.fromEntries(
36+
Object.entries(properties).map(([key, schema]) => [
37+
key,
38+
schemaCallWrapper(schema as AnyData, instanceProps?.[key]),
39+
]),
40+
);
41+
Object.setPrototypeOf(unstructSchema, UnstructImpl);
42+
unstructSchema.propTypes = properties;
3443

35-
return unstruct as unknown as Unstruct<TProps>;
44+
return unstructSchema as unknown as Unstruct<TProps>;
3645
}
3746

3847
// --------------

packages/typegpu/src/data/utils.ts

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,28 @@
1+
import type { AnyData } from './index.ts';
2+
import { formatToWGSLType } from './vertexFormatData.ts';
3+
14
/**
2-
* A wrapper for `schema(item)` call.
5+
* A wrapper for `schema(item)` or `schema()` call.
6+
* If the schema is a TgpuVertexFormatData, it instead calls the corresponding constructible schema.
37
* Throws an error if the schema is not callable.
48
*/
5-
export function schemaCloneWrapper<T>(schema: unknown, item: T): T {
9+
export function schemaCallWrapper<T>(schema: AnyData, item?: T): T {
10+
const maybeType = (schema as { type: string })?.type;
11+
612
try {
7-
return (schema as unknown as ((item: T) => T))(item);
8-
} catch {
9-
const maybeType = (schema as { type: string })?.type;
13+
// TgpuVertexFormatData are not callable
14+
const callSchema = (maybeType in formatToWGSLType
15+
? formatToWGSLType[maybeType as keyof typeof formatToWGSLType]
16+
: schema) as unknown as ((item?: T) => T);
17+
if (item === undefined) {
18+
return callSchema();
19+
}
20+
return callSchema(item);
21+
} catch (e) {
1022
throw new Error(
1123
`Schema of type ${
1224
maybeType ?? '<unknown>'
1325
} is not callable or was called with invalid arguments.`,
1426
);
1527
}
1628
}
17-
18-
/**
19-
* A wrapper for `schema()` call.
20-
* Throws an error if the schema is not callable.
21-
*/
22-
export function schemaDefaultWrapper<T>(schema: unknown): T {
23-
try {
24-
return (schema as unknown as (() => T))();
25-
} catch {
26-
const maybeType = (schema as { type: string })?.type;
27-
throw new Error(
28-
`Schema of type ${maybeType ?? '<unknown>'} is not callable.`,
29-
);
30-
}
31-
}

packages/typegpu/src/data/vector.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { createDualImpl } from '../core/function/dualImpl.ts';
22
import { $repr } from '../shared/symbols.ts';
3-
import { snip } from './snippet.ts';
43
import { bool, f16, f32, i32, u32 } from './numeric.ts';
4+
import { snip } from './snippet.ts';
55
import {
66
Vec2bImpl,
77
Vec2fImpl,

packages/typegpu/tests/array.test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ describe('array', () => {
9999
expectTypeOf(obj).toEqualTypeOf<number[]>();
100100
});
101101

102-
it('cannot be called with invalid properties', () => {
102+
it('cannot be called with invalid elements', () => {
103103
const ArraySchema = d.arrayOf(d.u32, 4);
104104

105105
// @ts-expect-error
@@ -131,10 +131,10 @@ describe('array', () => {
131131
const ArraySchema = d.arrayOf(d.u32, 2);
132132

133133
expect(() => ArraySchema([1])).toThrowErrorMatchingInlineSnapshot(
134-
'[Error: Array schema of 2 elements of type u32 called with 1 arguments.]',
134+
'[Error: Array schema of 2 elements of type u32 called with 1 argument(s).]',
135135
);
136136
expect(() => ArraySchema([1, 2, 3])).toThrowErrorMatchingInlineSnapshot(
137-
'[Error: Array schema of 2 elements of type u32 called with 3 arguments.]',
137+
'[Error: Array schema of 2 elements of type u32 called with 3 argument(s).]',
138138
);
139139
});
140140

0 commit comments

Comments
 (0)