diff --git a/packages/typegpu/src/core/buffer/bufferUsage.ts b/packages/typegpu/src/core/buffer/bufferUsage.ts index b139c1103a..dc372a9070 100644 --- a/packages/typegpu/src/core/buffer/bufferUsage.ts +++ b/packages/typegpu/src/core/buffer/bufferUsage.ts @@ -1,7 +1,9 @@ import type { AnyData } from '../../data/dataTypes.ts'; +import { schemaCallWrapper } from '../../data/utils.ts'; import type { AnyWgslData, BaseData } from '../../data/wgslTypes.ts'; -import { isUsableAsStorage, type StorageFlag } from '../../extension.ts'; +import { IllegalBufferAccessError } from '../../errors.ts'; import { getExecMode, inCodegenMode, isInsideTgpuFn } from '../../execMode.ts'; +import { isUsableAsStorage, type StorageFlag } from '../../extension.ts'; import type { TgpuNamable } from '../../shared/meta.ts'; import { getName, setName } from '../../shared/meta.ts'; import type { Infer, InferGPU } from '../../shared/repr.ts'; @@ -12,6 +14,7 @@ import { $repr, $wgslDataType, } from '../../shared/symbols.ts'; +import { assertExhaustive } from '../../shared/utilityTypes.ts'; import type { LayoutMembership } from '../../tgpuBindGroupLayout.ts'; import type { BindableBufferUsage, @@ -20,9 +23,6 @@ import type { } from '../../types.ts'; import { valueProxyHandler } from '../valueProxyUtils.ts'; import type { TgpuBuffer, UniformFlag } from './buffer.ts'; -import { schemaCloneWrapper, schemaDefaultWrapper } from '../../data/utils.ts'; -import { assertExhaustive } from '../../shared/utilityTypes.ts'; -import { IllegalBufferAccessError } from '../../errors.ts'; // ---------- // Public API @@ -166,8 +166,7 @@ class TgpuFixedBufferImpl< if (!mode.buffers.has(this.buffer)) { // Not initialized yet mode.buffers.set( this.buffer, - schemaCloneWrapper(this.buffer.dataType, this.buffer.initial) ?? - schemaDefaultWrapper(this.buffer.dataType), + schemaCallWrapper(this.buffer.dataType, this.buffer.initial), ); } return mode.buffers.get(this.buffer) as InferGPU; diff --git a/packages/typegpu/src/core/function/tgpuFn.ts b/packages/typegpu/src/core/function/tgpuFn.ts index b1bb979bd5..ed87cec167 100644 --- a/packages/typegpu/src/core/function/tgpuFn.ts +++ b/packages/typegpu/src/core/function/tgpuFn.ts @@ -1,6 +1,6 @@ import { type AnyData, UnknownData } from '../../data/dataTypes.ts'; -import { schemaCloneWrapper } from '../../data/utils.ts'; import { snip } from '../../data/snippet.ts'; +import { schemaCallWrapper } from '../../data/utils.ts'; import { Void } from '../../data/wgslTypes.ts'; import { ExecutionError } from '../../errors.ts'; import { provideInsideTgpuFn } from '../../execMode.ts'; @@ -32,6 +32,7 @@ import { type TgpuAccessor, type TgpuSlot, } from '../slot/slotTypes.ts'; +import { createDualImpl } from './dualImpl.ts'; import { createFnCore, type FnCore } from './fnCore.ts'; import type { AnyFn, @@ -41,7 +42,6 @@ import type { InheritArgNames, } from './fnTypes.ts'; import { stripTemplate } from './templateUtils.ts'; -import { createDualImpl } from './dualImpl.ts'; // ---------- // Public API @@ -223,7 +223,7 @@ function createFn( } const castAndCopiedArgs = args.map((arg, index) => - schemaCloneWrapper(shell.argTypes[index], arg) + schemaCallWrapper(shell.argTypes[index] as unknown as AnyData, arg) ) as InferArgs>; return implementation(...castAndCopiedArgs); diff --git a/packages/typegpu/src/data/array.ts b/packages/typegpu/src/data/array.ts index d1c14eebad..4e05235f44 100644 --- a/packages/typegpu/src/data/array.ts +++ b/packages/typegpu/src/data/array.ts @@ -1,6 +1,6 @@ import { $internal } from '../shared/symbols.ts'; import { sizeOf } from './sizeOf.ts'; -import { schemaCloneWrapper, schemaDefaultWrapper } from './utils.ts'; +import { schemaCallWrapper } from './utils.ts'; import type { AnyWgslData, WgslArray } from './wgslTypes.ts'; // ---------- @@ -27,16 +27,13 @@ export function arrayOf( const arraySchema = (elements?: TElement[]) => { if (elements && elements.length !== elementCount) { throw new Error( - `Array schema of ${elementCount} elements of type ${elementType.type} called with ${elements.length} arguments.`, + `Array schema of ${elementCount} elements of type ${elementType.type} called with ${elements.length} argument(s).`, ); } return Array.from( { length: elementCount }, - (_, i) => - elements - ? schemaCloneWrapper(elementType, elements[i]) - : schemaDefaultWrapper(elementType), + (_, i) => schemaCallWrapper(elementType, elements?.[i]), ); }; Object.setPrototypeOf(arraySchema, WgslArrayImpl); diff --git a/packages/typegpu/src/data/dataTypes.ts b/packages/typegpu/src/data/dataTypes.ts index 76949514bd..6d0f7cf05c 100644 --- a/packages/typegpu/src/data/dataTypes.ts +++ b/packages/typegpu/src/data/dataTypes.ts @@ -44,6 +44,8 @@ export type TgpuDualFn unknown> = */ export interface Disarray extends wgsl.BaseData { + (elements: Infer[]): Infer[]; + (): Infer[]; readonly type: 'disarray'; readonly elementCount: number; readonly elementType: TElement; @@ -72,6 +74,7 @@ export interface Unstruct< TProps extends Record = any, > extends wgsl.BaseData, TgpuNamable { (props: Prettify>): Prettify>; + (): Prettify>; readonly type: 'unstruct'; readonly propTypes: TProps; diff --git a/packages/typegpu/src/data/disarray.ts b/packages/typegpu/src/data/disarray.ts index a01bd72401..facbc69339 100644 --- a/packages/typegpu/src/data/disarray.ts +++ b/packages/typegpu/src/data/disarray.ts @@ -1,16 +1,6 @@ -import type { - Infer, - InferPartial, - IsValidVertexSchema, -} from '../shared/repr.ts'; import { $internal } from '../shared/symbols.ts'; -import type { - $invalidSchemaReason, - $repr, - $reprPartial, - $validVertexSchema, -} from '../shared/symbols.ts'; import type { AnyData, Disarray } from './dataTypes.ts'; +import { schemaCallWrapper } from './utils.ts'; // ---------- // Public API @@ -31,42 +21,49 @@ import type { AnyData, Disarray } from './dataTypes.ts'; * const disarray = d.disarrayOf(d.align(16, d.vec3f), 3); * * @param elementType The type of elements in the array. - * @param count The number of elements in the array. + * @param elementCount The number of elements in the array. */ export function disarrayOf( elementType: TElement, - count: number, + elementCount: number, ): Disarray { - return new DisarrayImpl(elementType, count); + // In the schema call, create and return a deep copy + // by wrapping all the values in `elementType` schema calls. + const disarraySchema = (elements?: TElement[]) => { + if (elements && elements.length !== elementCount) { + throw new Error( + `Disarray schema of ${elementCount} elements of type ${elementType.type} called with ${elements.length} argument(s).`, + ); + } + + return Array.from( + { length: elementCount }, + (_, i) => schemaCallWrapper(elementType, elements?.[i]), + ); + }; + Object.setPrototypeOf(disarraySchema, DisarrayImpl); + + disarraySchema.elementType = elementType; + + if (!Number.isInteger(elementCount) || elementCount < 0) { + throw new Error( + `Cannot create disarray schema with invalid element count: ${elementCount}.`, + ); + } + disarraySchema.elementCount = elementCount; + + return disarraySchema as unknown as Disarray; } // -------------- // Implementation // -------------- -class DisarrayImpl implements Disarray { - public readonly [$internal] = true; - public readonly type = 'disarray'; - - // Type-tokens, not available at runtime - declare readonly [$repr]: Infer[]; - declare readonly [$reprPartial]: { - idx: number; - value: InferPartial; - }[]; - declare readonly [$validVertexSchema]: IsValidVertexSchema; - declare readonly [$invalidSchemaReason]: - Disarray[typeof $invalidSchemaReason]; - // --- +const DisarrayImpl = { + [$internal]: true, + type: 'disarray', - constructor( - public readonly elementType: TElement, - public readonly elementCount: number, - ) { - if (!Number.isInteger(elementCount) || elementCount < 0) { - throw new Error( - `Cannot create disarray schema with invalid element count: ${elementCount}.`, - ); - } - } -} + toString(this: Disarray): string { + return `disarrayOf(${this.elementType}, ${this.elementCount})`; + }, +}; diff --git a/packages/typegpu/src/data/struct.ts b/packages/typegpu/src/data/struct.ts index 7c37172755..2bb2465ab2 100644 --- a/packages/typegpu/src/data/struct.ts +++ b/packages/typegpu/src/data/struct.ts @@ -1,6 +1,7 @@ import { getName, setName } from '../shared/meta.ts'; import { $internal } from '../shared/symbols.ts'; -import { schemaCloneWrapper, schemaDefaultWrapper } from './utils.ts'; +import type { AnyData } from './dataTypes.ts'; +import { schemaCallWrapper } from './utils.ts'; import type { AnyWgslData, BaseData, WgslStruct } from './wgslTypes.ts'; // ---------- @@ -44,9 +45,7 @@ function INTERNAL_createStruct>( Object.fromEntries( Object.entries(props).map(([key, schema]) => [ key, - instanceProps - ? schemaCloneWrapper(schema, instanceProps[key]) - : schemaDefaultWrapper(schema), + schemaCallWrapper(schema as AnyData, instanceProps?.[key]), ]), ); diff --git a/packages/typegpu/src/data/unstruct.ts b/packages/typegpu/src/data/unstruct.ts index e892b8b29f..2cede97c64 100644 --- a/packages/typegpu/src/data/unstruct.ts +++ b/packages/typegpu/src/data/unstruct.ts @@ -1,6 +1,7 @@ import { getName, setName } from '../shared/meta.ts'; import { $internal } from '../shared/symbols.ts'; -import type { Unstruct } from './dataTypes.ts'; +import type { AnyData, Unstruct } from './dataTypes.ts'; +import { schemaCallWrapper } from './utils.ts'; import type { BaseData } from './wgslTypes.ts'; // ---------- @@ -28,11 +29,19 @@ import type { BaseData } from './wgslTypes.ts'; export function unstruct>( properties: TProps, ): Unstruct { - const unstruct = (props: T) => props; - Object.setPrototypeOf(unstruct, UnstructImpl); - unstruct.propTypes = properties; + // In the schema call, create and return a deep copy + // by wrapping all the values in corresponding schema calls. + const unstructSchema = (instanceProps?: TProps) => + Object.fromEntries( + Object.entries(properties).map(([key, schema]) => [ + key, + schemaCallWrapper(schema as AnyData, instanceProps?.[key]), + ]), + ); + Object.setPrototypeOf(unstructSchema, UnstructImpl); + unstructSchema.propTypes = properties; - return unstruct as unknown as Unstruct; + return unstructSchema as unknown as Unstruct; } // -------------- diff --git a/packages/typegpu/src/data/utils.ts b/packages/typegpu/src/data/utils.ts index 4a50780aa8..a3e1d0a86a 100644 --- a/packages/typegpu/src/data/utils.ts +++ b/packages/typegpu/src/data/utils.ts @@ -1,12 +1,24 @@ +import type { AnyData } from './index.ts'; +import { formatToWGSLType } from './vertexFormatData.ts'; + /** - * A wrapper for `schema(item)` call. + * A wrapper for `schema(item)` or `schema()` call. + * If the schema is a TgpuVertexFormatData, it instead calls the corresponding constructible schema. * Throws an error if the schema is not callable. */ -export function schemaCloneWrapper(schema: unknown, item: T): T { +export function schemaCallWrapper(schema: AnyData, item?: T): T { + const maybeType = (schema as { type: string })?.type; + try { - return (schema as unknown as ((item: T) => T))(item); - } catch { - const maybeType = (schema as { type: string })?.type; + // TgpuVertexFormatData are not callable + const callSchema = (maybeType in formatToWGSLType + ? formatToWGSLType[maybeType as keyof typeof formatToWGSLType] + : schema) as unknown as ((item?: T) => T); + if (item === undefined) { + return callSchema(); + } + return callSchema(item); + } catch (e) { throw new Error( `Schema of type ${ maybeType ?? '' @@ -14,18 +26,3 @@ export function schemaCloneWrapper(schema: unknown, item: T): T { ); } } - -/** - * A wrapper for `schema()` call. - * Throws an error if the schema is not callable. - */ -export function schemaDefaultWrapper(schema: unknown): T { - try { - return (schema as unknown as (() => T))(); - } catch { - const maybeType = (schema as { type: string })?.type; - throw new Error( - `Schema of type ${maybeType ?? ''} is not callable.`, - ); - } -} diff --git a/packages/typegpu/src/data/vector.ts b/packages/typegpu/src/data/vector.ts index 3b83e37399..32acab24e8 100644 --- a/packages/typegpu/src/data/vector.ts +++ b/packages/typegpu/src/data/vector.ts @@ -1,7 +1,7 @@ import { createDualImpl } from '../core/function/dualImpl.ts'; import { $repr } from '../shared/symbols.ts'; -import { snip } from './snippet.ts'; import { bool, f16, f32, i32, u32 } from './numeric.ts'; +import { snip } from './snippet.ts'; import { Vec2bImpl, Vec2fImpl, diff --git a/packages/typegpu/tests/array.test.ts b/packages/typegpu/tests/array.test.ts index 7822f9eaae..203d5299ff 100644 --- a/packages/typegpu/tests/array.test.ts +++ b/packages/typegpu/tests/array.test.ts @@ -99,7 +99,7 @@ describe('array', () => { expectTypeOf(obj).toEqualTypeOf(); }); - it('cannot be called with invalid properties', () => { + it('cannot be called with invalid elements', () => { const ArraySchema = d.arrayOf(d.u32, 4); // @ts-expect-error @@ -131,10 +131,10 @@ describe('array', () => { const ArraySchema = d.arrayOf(d.u32, 2); expect(() => ArraySchema([1])).toThrowErrorMatchingInlineSnapshot( - '[Error: Array schema of 2 elements of type u32 called with 1 arguments.]', + '[Error: Array schema of 2 elements of type u32 called with 1 argument(s).]', ); expect(() => ArraySchema([1, 2, 3])).toThrowErrorMatchingInlineSnapshot( - '[Error: Array schema of 2 elements of type u32 called with 3 arguments.]', + '[Error: Array schema of 2 elements of type u32 called with 3 argument(s).]', ); }); diff --git a/packages/typegpu/tests/data/utils.test.ts b/packages/typegpu/tests/data/utils.test.ts new file mode 100644 index 0000000000..0157fbb3ac --- /dev/null +++ b/packages/typegpu/tests/data/utils.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, it } from 'vitest'; +import * as d from '../../src/data/index.ts'; +import { schemaCallWrapper } from '../../src/data/utils.ts'; + +describe('schemaCallWrapper', () => { + it('throws when the schema is not callable', () => { + expect(() => schemaCallWrapper(d.Void)) + .toThrowErrorMatchingInlineSnapshot( + '[Error: Schema of type void is not callable or was called with invalid arguments.]', + ); + }); + + it('calls schema without arguments', () => { + const TestStruct = d.struct({ v: d.vec2f }); + + expect(schemaCallWrapper(TestStruct)).toStrictEqual({ v: d.vec2f() }); + }); + + it('calls schema with arguments', () => { + const TestStruct = d.struct({ v: d.vec2f }); + const testInstance = { v: d.vec2f(1, 2), u: d.vec3u() }; + + expect(schemaCallWrapper(TestStruct, testInstance)) + .toStrictEqual({ v: d.vec2f(1, 2) }); + }); + + it('works with loose data', () => { + const TestUnstruct = d.unstruct({ v: d.float32x3 }); + const testInstance = { v: d.vec3f(1, 2, 3), u: d.vec3u() }; + + expect(schemaCallWrapper(TestUnstruct, testInstance)) + .toStrictEqual({ v: d.vec3f(1, 2, 3) }); + }); +}); diff --git a/packages/typegpu/tests/disarray.test.ts b/packages/typegpu/tests/disarray.test.ts index 51f59cdf94..500fe44174 100644 --- a/packages/typegpu/tests/disarray.test.ts +++ b/packages/typegpu/tests/disarray.test.ts @@ -1,5 +1,5 @@ import { BufferReader, BufferWriter } from 'typed-binary'; -import { describe, expect } from 'vitest'; +import { describe, expect, expectTypeOf } from 'vitest'; import { readData, writeData } from '../src/data/dataIO.ts'; import * as d from '../src/data/index.ts'; import { it } from './utils/extendedIt.ts'; @@ -115,4 +115,70 @@ describe('disarray', () => { writeData(new BufferWriter(buffer), TestArray, value); expect(readData(new BufferReader(buffer), TestArray)).toStrictEqual(value); }); + + it('can be called to create a disarray', () => { + const DisarraySchema = d.disarrayOf(d.uint16x2, 2); + + const obj = DisarraySchema([d.vec2u(1, 2), d.vec2u(3, 4)]); + + expect(obj).toStrictEqual([d.vec2u(1, 2), d.vec2u(3, 4)]); + expectTypeOf(obj).toEqualTypeOf(); + }); + + it('cannot be called with invalid elements', () => { + const DisarraySchema = d.disarrayOf(d.unorm16x2, 2); + + // @ts-expect-error + (() => DisarraySchema([d.vec2f(), d.vec3f()])); + // @ts-expect-error + (() => DisarraySchema([d.vec3f(), d.vec3f()])); + }); + + it('can be called to create a deep copy of other disarray', () => { + const InnerSchema = d.disarrayOf(d.uint16x2, 2); + const OuterSchema = d.disarrayOf(InnerSchema, 1); + const instance = OuterSchema([InnerSchema([d.vec2u(1, 2), d.vec2u()])]); + + const clone = OuterSchema(instance); + + expect(clone).toStrictEqual(instance); + expect(clone).not.toBe(instance); + expect(clone[0]).not.toBe(instance[0]); + expect(clone[0]).not.toBe(clone[1]); + expect(clone[0]?.[0]).not.toBe(instance[0]?.[0]); + expect(clone[0]?.[0]).toStrictEqual(d.vec2u(1, 2)); + }); + + it('throws when invalid number of arguments', () => { + const DisarraySchema = d.disarrayOf(d.float32x2, 2); + + expect(() => DisarraySchema([d.vec2f()])) + .toThrowErrorMatchingInlineSnapshot( + '[Error: Disarray schema of 2 elements of type float32x2 called with 1 argument(s).]', + ); + expect(() => DisarraySchema([d.vec2f(), d.vec2f(), d.vec2f()])) + .toThrowErrorMatchingInlineSnapshot( + '[Error: Disarray schema of 2 elements of type float32x2 called with 3 argument(s).]', + ); + }); + + it('can be called to create a default value', () => { + const DisarraySchema = d.disarrayOf(d.float32x3, 2); + + const defaultDisarray = DisarraySchema(); + + expect(defaultDisarray).toStrictEqual([d.vec3f(), d.vec3f()]); + }); + + it('can be called to create a default value with nested unstruct', () => { + const UnstructSchema = d.unstruct({ vec: d.float32x3 }); + const DisarraySchema = d.disarrayOf(UnstructSchema, 2); + + const defaultDisarray = DisarraySchema(); + + expect(defaultDisarray).toStrictEqual([ + { vec: d.vec3f() }, + { vec: d.vec3f() }, + ]); + }); }); diff --git a/packages/typegpu/tests/unstruct.test.ts b/packages/typegpu/tests/unstruct.test.ts index 3ff3e543c3..b5ac309838 100644 --- a/packages/typegpu/tests/unstruct.test.ts +++ b/packages/typegpu/tests/unstruct.test.ts @@ -1,5 +1,5 @@ import { BufferReader, BufferWriter } from 'typed-binary'; -import { describe, expect, it } from 'vitest'; +import { describe, expect, expectTypeOf, it } from 'vitest'; import { readData, writeData } from '../src/data/dataIO.ts'; import * as d from '../src/data/index.ts'; @@ -133,7 +133,7 @@ describe('d.unstruct', () => { expect(data.c.z).toBeCloseTo(3.0); }); - it('properly writes and reads data with nested structs', () => { + it('properly writes and reads data with nested unstructs', () => { const s = d.unstruct({ a: d.unorm8x2, b: d.align(16, d.snorm16x2), @@ -181,7 +181,7 @@ describe('d.unstruct', () => { const a = d.disarrayOf(s, 8); expect(d.sizeOf(s)).toBe(12); - // since the struct is aligned to 16 bytes, the array stride should be 16 not 12 + // since the unstruct is aligned to 16 bytes, the array stride should be 16 not 12 expect(d.sizeOf(a)).toBe(16 * 8); const buffer = new ArrayBuffer(d.sizeOf(a)); @@ -234,4 +234,70 @@ describe('d.unstruct', () => { expect(data.c.x).toBeCloseTo(-0.25); expect(data.c.y).toBeCloseTo(0.25); }); + + it('can be called to create an object', () => { + const TestUnstruct = d.unstruct({ + x: d.u32, + y: d.uint32x3, + }); + + const obj = TestUnstruct({ x: 1, y: d.vec3u(1, 2, 3) }); + + expect(obj).toStrictEqual({ x: 1, y: d.vec3u(1, 2, 3) }); + expectTypeOf(obj).toEqualTypeOf<{ x: number; y: d.v3u }>(); + }); + + it('cannot be called with invalid properties', () => { + const TestUnstruct = d.unstruct({ + x: d.u32, + y: d.uint32x3, + }); + + // @ts-expect-error + (() => TestUnstruct({ x: 1, z: 2 })); + }); + + it('can be called to create a deep copy of other unstruct', () => { + const NestedUnstruct = d.unstruct({ prop1: d.float32x2, prop2: d.u32 }); + const TestUnstruct = d.unstruct({ nested: NestedUnstruct }); + const instance = TestUnstruct({ + nested: { prop1: d.vec2f(1, 2), prop2: 21 }, + }); + + const clone = TestUnstruct(instance); + + expect(clone).toStrictEqual(instance); + expect(clone).not.toBe(instance); + expect(clone.nested).not.toBe(instance.nested); + expect(clone.nested.prop1).not.toBe(instance.nested.prop1); + }); + + it('can be called to strip extra properties of a unstruct', () => { + const TestUnstruct = d.unstruct({ prop1: d.vec2f, prop2: d.u32 }); + const instance = { prop1: d.vec2f(1, 2), prop2: 21, prop3: 'extra' }; + + const clone = TestUnstruct(instance); + + expect(clone).toStrictEqual({ prop1: d.vec2f(1, 2), prop2: 21 }); + }); + + it('can be called to create a default value', () => { + const TestUnstruct = d.unstruct({ + nested: d.unstruct({ prop1: d.vec2f, prop2: d.u32 }), + }); + + const defaultStruct = TestUnstruct(); + + expect(defaultStruct).toStrictEqual({ + nested: { prop1: d.vec2f(), prop2: d.u32() }, + }); + }); + + it('can be called to create a default value with nested disarray', () => { + const TestUnstruct = d.unstruct({ arr: d.disarrayOf(d.uint16, 1) }); + + const defaultStruct = TestUnstruct(); + + expect(defaultStruct).toStrictEqual({ arr: [0] }); + }); });