Skip to content
Merged
11 changes: 5 additions & 6 deletions packages/typegpu/src/core/buffer/bufferUsage.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import type { AnyData } from '../../data/dataTypes.ts';
import { schemaCallWrapper } from '../../data/utils.ts';
import type { AnyWgslData, BaseData } from '../../data/wgslTypes.ts';
import { isUsableAsStorage, type StorageFlag } from '../../extension.ts';
import { IllegalBufferAccessError } from '../../errors.ts';
import { getExecMode, inCodegenMode, isInsideTgpuFn } from '../../execMode.ts';
import { isUsableAsStorage, type StorageFlag } from '../../extension.ts';
import type { TgpuNamable } from '../../shared/meta.ts';
import { getName, setName } from '../../shared/meta.ts';
import type { Infer, InferGPU } from '../../shared/repr.ts';
Expand All @@ -12,6 +14,7 @@ import {
$repr,
$wgslDataType,
} from '../../shared/symbols.ts';
import { assertExhaustive } from '../../shared/utilityTypes.ts';
import type { LayoutMembership } from '../../tgpuBindGroupLayout.ts';
import type {
BindableBufferUsage,
Expand All @@ -20,9 +23,6 @@ import type {
} from '../../types.ts';
import { valueProxyHandler } from '../valueProxyUtils.ts';
import type { TgpuBuffer, UniformFlag } from './buffer.ts';
import { schemaCloneWrapper, schemaDefaultWrapper } from '../../data/utils.ts';
import { assertExhaustive } from '../../shared/utilityTypes.ts';
import { IllegalBufferAccessError } from '../../errors.ts';

// ----------
// Public API
Expand Down Expand Up @@ -163,8 +163,7 @@ class TgpuFixedBufferImpl<
if (!mode.buffers.has(this.buffer)) { // Not initialized yet
mode.buffers.set(
this.buffer,
schemaCloneWrapper(this.buffer.dataType, this.buffer.initial) ??
schemaDefaultWrapper(this.buffer.dataType),
schemaCallWrapper(this.buffer.dataType, this.buffer.initial),
);
}
return mode.buffers.get(this.buffer) as InferGPU<TData>;
Expand Down
6 changes: 3 additions & 3 deletions packages/typegpu/src/core/function/tgpuFn.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { type AnyData, UnknownData } from '../../data/dataTypes.ts';
import { schemaCloneWrapper } from '../../data/utils.ts';
import { snip } from '../../data/snippet.ts';
import { schemaCallWrapper } from '../../data/utils.ts';
import { Void } from '../../data/wgslTypes.ts';
import { ExecutionError } from '../../errors.ts';
import { provideInsideTgpuFn } from '../../execMode.ts';
Expand Down Expand Up @@ -32,6 +32,7 @@ import {
type TgpuAccessor,
type TgpuSlot,
} from '../slot/slotTypes.ts';
import { createDualImpl } from './dualImpl.ts';
import { createFnCore, type FnCore } from './fnCore.ts';
import type {
AnyFn,
Expand All @@ -41,7 +42,6 @@ import type {
InheritArgNames,
} from './fnTypes.ts';
import { stripTemplate } from './templateUtils.ts';
import { createDualImpl } from './dualImpl.ts';

// ----------
// Public API
Expand Down Expand Up @@ -223,7 +223,7 @@ function createFn<ImplSchema extends AnyFn>(
}

const castAndCopiedArgs = args.map((arg, index) =>
schemaCloneWrapper(shell.argTypes[index], arg)
schemaCallWrapper(shell.argTypes[index], arg)
) as InferArgs<Parameters<ImplSchema>>;

return implementation(...castAndCopiedArgs);
Expand Down
9 changes: 3 additions & 6 deletions packages/typegpu/src/data/array.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { $internal } from '../shared/symbols.ts';
import { sizeOf } from './sizeOf.ts';
import { schemaCloneWrapper, schemaDefaultWrapper } from './utils.ts';
import { schemaCallWrapper } from './utils.ts';
import type { AnyWgslData, WgslArray } from './wgslTypes.ts';

// ----------
Expand All @@ -27,16 +27,13 @@ export function arrayOf<TElement extends AnyWgslData>(
const arraySchema = (elements?: TElement[]) => {
if (elements && elements.length !== elementCount) {
throw new Error(
`Array schema of ${elementCount} elements of type ${elementType.type} called with ${elements.length} arguments.`,
`Array schema of ${elementCount} elements of type ${elementType.type} called with ${elements.length} argument(s).`,
);
}

return Array.from(
{ length: elementCount },
(_, i) =>
elements
? schemaCloneWrapper(elementType, elements[i])
: schemaDefaultWrapper(elementType),
(_, i) => schemaCallWrapper(elementType, elements?.[i]),
);
};
Object.setPrototypeOf(arraySchema, WgslArrayImpl);
Expand Down
3 changes: 3 additions & 0 deletions packages/typegpu/src/data/dataTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ export type TgpuDualFn<TImpl extends (...args: never[]) => unknown> =
*/
export interface Disarray<TElement extends wgsl.BaseData = wgsl.BaseData>
extends wgsl.BaseData {
<T extends TElement>(elements: Infer<T>[]): Infer<T>[];
(): Infer<TElement>[];
readonly type: 'disarray';
readonly elementCount: number;
readonly elementType: TElement;
Expand Down Expand Up @@ -72,6 +74,7 @@ export interface Unstruct<
TProps extends Record<string, wgsl.BaseData> = any,
> extends wgsl.BaseData, TgpuNamable {
(props: Prettify<InferRecord<TProps>>): Prettify<InferRecord<TProps>>;
(): Prettify<InferRecord<TProps>>;
readonly type: 'unstruct';
readonly propTypes: TProps;

Expand Down
75 changes: 36 additions & 39 deletions packages/typegpu/src/data/disarray.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import type {
Infer,
InferPartial,
IsValidVertexSchema,
} from '../shared/repr.ts';
import { $internal } from '../shared/symbols.ts';
import type {
$invalidSchemaReason,
$repr,
$reprPartial,
$validVertexSchema,
} from '../shared/symbols.ts';
import type { AnyData, Disarray } from './dataTypes.ts';
import { schemaCallWrapper } from './utils.ts';

// ----------
// Public API
Expand All @@ -31,42 +21,49 @@ import type { AnyData, Disarray } from './dataTypes.ts';
* const disarray = d.disarrayOf(d.align(16, d.vec3f), 3);
*
* @param elementType The type of elements in the array.
* @param count The number of elements in the array.
* @param elementCount The number of elements in the array.
*/
export function disarrayOf<TElement extends AnyData>(
elementType: TElement,
count: number,
elementCount: number,
): Disarray<TElement> {
return new DisarrayImpl(elementType, count);
// In the schema call, create and return a deep copy
// by wrapping all the values in `elementType` schema calls.
const disarraySchema = (elements?: TElement[]) => {
if (elements && elements.length !== elementCount) {
throw new Error(
`Disarray schema of ${elementCount} elements of type ${elementType.type} called with ${elements.length} argument(s).`,
);
}

return Array.from(
{ length: elementCount },
(_, i) => schemaCallWrapper(elementType, elements?.[i]),
);
};
Object.setPrototypeOf(disarraySchema, DisarrayImpl);

disarraySchema.elementType = elementType;

if (!Number.isInteger(elementCount) || elementCount < 0) {
throw new Error(
`Cannot create array schema with invalid element count: ${elementCount}.`,
);
}
disarraySchema.elementCount = elementCount;

return disarraySchema as unknown as Disarray<TElement>;
}

// --------------
// Implementation
// --------------

class DisarrayImpl<TElement extends AnyData> implements Disarray<TElement> {
public readonly [$internal] = true;
public readonly type = 'disarray';

// Type-tokens, not available at runtime
declare readonly [$repr]: Infer<TElement>[];
declare readonly [$reprPartial]: {
idx: number;
value: InferPartial<TElement>;
}[];
declare readonly [$validVertexSchema]: IsValidVertexSchema<TElement>;
declare readonly [$invalidSchemaReason]:
Disarray[typeof $invalidSchemaReason];
// ---
const DisarrayImpl = {
[$internal]: true,
type: 'disarray',

constructor(
public readonly elementType: TElement,
public readonly elementCount: number,
) {
if (!Number.isInteger(elementCount) || elementCount < 0) {
throw new Error(
`Cannot create disarray schema with invalid element count: ${elementCount}.`,
);
}
}
}
toString(this: Disarray): string {
return `disarrayOf(${this.elementType}, ${this.elementCount})`;
},
};
6 changes: 2 additions & 4 deletions packages/typegpu/src/data/struct.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { getName, setName } from '../shared/meta.ts';
import { $internal } from '../shared/symbols.ts';
import { schemaCloneWrapper, schemaDefaultWrapper } from './utils.ts';
import { schemaCallWrapper } from './utils.ts';
import type { AnyWgslData, WgslStruct } from './wgslTypes.ts';

// ----------
Expand All @@ -27,9 +27,7 @@ export function struct<TProps extends Record<string, AnyWgslData>>(
Object.fromEntries(
Object.entries(props).map(([key, schema]) => [
key,
instanceProps
? schemaCloneWrapper(schema, instanceProps[key])
: schemaDefaultWrapper(schema),
schemaCallWrapper(schema, instanceProps?.[key]),
]),
);
Object.setPrototypeOf(structSchema, WgslStructImpl);
Expand Down
17 changes: 13 additions & 4 deletions packages/typegpu/src/data/unstruct.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { getName, setName } from '../shared/meta.ts';
import { $internal } from '../shared/symbols.ts';
import type { Unstruct } from './dataTypes.ts';
import { schemaCallWrapper } from './utils.ts';
import type { BaseData } from './wgslTypes.ts';

// ----------
Expand Down Expand Up @@ -28,11 +29,19 @@ import type { BaseData } from './wgslTypes.ts';
export function unstruct<TProps extends Record<string, BaseData>>(
properties: TProps,
): Unstruct<TProps> {
const unstruct = <T>(props: T) => props;
Object.setPrototypeOf(unstruct, UnstructImpl);
unstruct.propTypes = properties;
// In the schema call, create and return a deep copy
// by wrapping all the values in corresponding schema calls.
const unstructSchema = (instanceProps?: TProps) =>
Object.fromEntries(
Object.entries(properties).map(([key, schema]) => [
key,
schemaCallWrapper(schema, instanceProps?.[key]),
]),
);
Object.setPrototypeOf(unstructSchema, UnstructImpl);
unstructSchema.propTypes = properties;

return unstruct as unknown as Unstruct<TProps>;
return unstructSchema as unknown as Unstruct<TProps>;
}

// --------------
Expand Down
36 changes: 16 additions & 20 deletions packages/typegpu/src/data/utils.ts
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import { formatToWGSLType } from './vertexFormatData.ts';

/**
* A wrapper for `schema(item)` call.
* A wrapper for `schema(item)` or `schema()` call.
* If the schema is a TgpuVertexFormatData, it instead calls the corresponding constructible schema.
* Throws an error if the schema is not callable.
*/
export function schemaCloneWrapper<T>(schema: unknown, item: T): T {
export function schemaCallWrapper<T>(schema: unknown, item?: T): T {
const maybeType = (schema as { type: string })?.type;

try {
return (schema as unknown as ((item: T) => T))(item);
} catch {
const maybeType = (schema as { type: string })?.type;
// TgpuVertexFormatData are not callable
const callSchema = (maybeType in formatToWGSLType
? formatToWGSLType[maybeType as keyof typeof formatToWGSLType]
: schema) as unknown as ((item?: T) => T);
if (item === undefined) {
return callSchema();
}
Comment on lines +17 to +19
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if is necessary because vector constructors actually distinguish between vec2f() and vec2f(undefined) (the second throws an error, I'm not sure, do we consider this an issue?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say the error is correct. Calling vec2f(undefined) is not valid imo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just that in most cases, f() is equivalent to f(undefined).
image
Though, now that I researched some more, I see that there are more exceptions to this, like destructuring a default parameter, so I guess we may just conclude vec2f(undefined) as invalid.

return callSchema(item);
} catch (e) {
throw new Error(
`Schema of type ${
maybeType ?? '<unknown>'
} is not callable or was called with invalid arguments.`,
);
}
}

/**
* A wrapper for `schema()` call.
* Throws an error if the schema is not callable.
*/
export function schemaDefaultWrapper<T>(schema: unknown): T {
try {
return (schema as unknown as (() => T))();
} catch {
const maybeType = (schema as { type: string })?.type;
throw new Error(
`Schema of type ${maybeType ?? '<unknown>'} is not callable.`,
);
}
}
2 changes: 1 addition & 1 deletion packages/typegpu/src/data/vector.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { createDualImpl } from '../core/function/dualImpl.ts';
import { $repr } from '../shared/symbols.ts';
import { snip } from './snippet.ts';
import { bool, f16, f32, i32, u32 } from './numeric.ts';
import { snip } from './snippet.ts';
import {
Vec2bImpl,
Vec2fImpl,
Expand Down
6 changes: 3 additions & 3 deletions packages/typegpu/tests/array.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ describe('array', () => {
expectTypeOf(obj).toEqualTypeOf<number[]>();
});

it('cannot be called with invalid properties', () => {
it('cannot be called with invalid elements', () => {
const ArraySchema = d.arrayOf(d.u32, 4);

// @ts-expect-error
Expand Down Expand Up @@ -131,10 +131,10 @@ describe('array', () => {
const ArraySchema = d.arrayOf(d.u32, 2);

expect(() => ArraySchema([1])).toThrowErrorMatchingInlineSnapshot(
'[Error: Array schema of 2 elements of type u32 called with 1 arguments.]',
'[Error: Array schema of 2 elements of type u32 called with 1 argument(s).]',
);
expect(() => ArraySchema([1, 2, 3])).toThrowErrorMatchingInlineSnapshot(
'[Error: Array schema of 2 elements of type u32 called with 3 arguments.]',
'[Error: Array schema of 2 elements of type u32 called with 3 argument(s).]',
);
});

Expand Down
Loading