Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[feat] register UNETLoader, UpscaleModelLoader, StylemModelLoader, GL…
…IGENLoader

Also added tests for modelToNodeStore
  • Loading branch information
arjansingh committed Sep 3, 2025
commit 99d486f423783c20dab3d7c802eaea52c85794d3
4 changes: 4 additions & 0 deletions src/stores/modelToNodeStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name')
quickRegister('vae', 'VAELoader', 'vae_name')
quickRegister('controlnet', 'ControlNetLoader', 'control_net_name')
quickRegister('unet', 'UNETLoader', 'unet_name')
quickRegister('upscale_models', 'UpscaleModelLoader', 'model_name')
quickRegister('style_models', 'StyleModelLoader', 'style_model')
quickRegister('gligen', 'GLIGENLoader', 'gligen_name')
}

return {
Expand Down
245 changes: 245 additions & 0 deletions tests-ui/tests/store/modelToNodeStore.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import { createPinia, setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it } from 'vitest'

import {
ModelNodeProvider,
useModelToNodeStore
} from '@/stores/modelToNodeStore'
import { type ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore'

describe('useModelToNodeStore', () => {
let store: ReturnType<typeof useModelToNodeStore>
let nodeDefStore: ReturnType<typeof useNodeDefStore>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not completely necessary, but you might want to check out the comments I left in here.
I generally prefer to minimize how much mutable state there is ever, but especially in test suites.

The initializing of the stores feels like a pretty important part of the Arrange section for any given test, so it'd be nice to have it directly in the it() block.


// Create minimal mock for testing - only includes 'name' field since that's
// the only property ModelNodeProvider constructor uses and tests verify
const createMockNodeDef = (name: string): ComfyNodeDefImpl => {
return { name } as ComfyNodeDefImpl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need a cast?

}

beforeEach(() => {
setActivePinia(createPinia())
store = useModelToNodeStore()
nodeDefStore = useNodeDefStore()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, why is one store and the other nodeDefStore? Feels like playing favorites 😛


const mockNodeNames = [
'CheckpointLoaderSimple',
'ImageOnlyCheckpointLoader',
'LoraLoader',
'LoraLoaderModelOnly',
'VAELoader',
'ControlNetLoader',
'UNETLoader',
'UpscaleModelLoader',
'StyleModelLoader',
'GLIGENLoader'
]

const mockNodeDefs: Record<string, ComfyNodeDefImpl> = Object.fromEntries(
mockNodeNames.map((name) => [name, createMockNodeDef(name)])
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is going to be the same every time, it could be worth extracting as a CONSTANT.
Or if it needs to be regenerated because it is mutable/mutated, an extracted method.


nodeDefStore.nodeDefsByName = mockNodeDefs
})

describe('modelToNodeMap', () => {
it('should initialize as empty', () => {
expect(Object.keys(store.modelToNodeMap)).toHaveLength(0)
})

it('should populate after registration', () => {
store.registerDefaults()
expect(Object.keys(store.modelToNodeMap)).toContain('checkpoints')
expect(Object.keys(store.modelToNodeMap)).toContain('unet')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://vitest.dev/api/expect.html#expect-arraycontaining
Might be clearer than having separate expectations, unless these are each separately important to check.

})
})

describe('getNodeProvider', () => {
it('should return provider for registered model type', () => {
store.registerDefaults()

const provider = store.getNodeProvider('checkpoints')
expect(provider).toBeDefined()
// Optional chaining used because getNodeProvider() can return undefined for unregistered types
expect(provider?.nodeDef.name).toBe('CheckpointLoaderSimple')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one piece is possibly undefined, the chain should usually go to the terminus.
Otherwise you'll get that node isn't present on undefined.

expect(provider?.key).toBe('ckpt_name')
})

it('should return undefined for unregistered model type', () => {
store.registerDefaults()
expect(store.getNodeProvider('nonexistent')).toBeUndefined()
})

it('should return first registered provider when multiple providers exist for same model type', () => {
store.registerDefaults()

const provider = store.getNodeProvider('checkpoints')
// Using optional chaining for safety since getNodeProvider() can return undefined
expect(provider?.nodeDef.name).toBe('CheckpointLoaderSimple')
})

it('should trigger lazy registration when called before registerDefaults', () => {
const provider = store.getNodeProvider('checkpoints')
expect(provider).toBeDefined()
})
})

describe('getAllNodeProviders', () => {
it('should return all providers for model type with multiple nodes', () => {
store.registerDefaults()

const checkpointProviders = store.getAllNodeProviders('checkpoints')
expect(checkpointProviders).toHaveLength(2)
expect(checkpointProviders.map((p) => p.nodeDef.name)).toContain(
'CheckpointLoaderSimple'
)
expect(checkpointProviders.map((p) => p.nodeDef.name)).toContain(
'ImageOnlyCheckpointLoader'
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const loraProviders = store.getAllNodeProviders('loras')
expect(loraProviders).toHaveLength(2)
})

it('should return single provider for model type with one node', () => {
store.registerDefaults()

const unetProviders = store.getAllNodeProviders('unet')
expect(unetProviders).toHaveLength(1)
expect(unetProviders[0].nodeDef.name).toBe('UNETLoader')
})

it('should return empty array for unregistered model type', () => {
store.registerDefaults()
expect(store.getAllNodeProviders('nonexistent')).toEqual([])
})

it('should trigger lazy registration when called before registerDefaults', () => {
const providers = store.getAllNodeProviders('checkpoints')
expect(providers.length).toBeGreaterThan(0)
})
})

describe('registerNodeProvider', () => {
it('should register provider directly', () => {
const customProvider = new ModelNodeProvider(
nodeDefStore.nodeDefsByName['UNETLoader'],
'custom_key'
)

store.registerNodeProvider('custom_type', customProvider)

const retrieved = store.getNodeProvider('custom_type')
expect(retrieved).toStrictEqual(customProvider)
// Optional chaining for consistency with getNodeProvider() return type
expect(retrieved?.key).toBe('custom_key')
})

it('should handle multiple providers for same model type and return first as primary', () => {
const provider1 = new ModelNodeProvider(
nodeDefStore.nodeDefsByName['UNETLoader'],
'key1'
)
const provider2 = new ModelNodeProvider(
nodeDefStore.nodeDefsByName['VAELoader'],
'key2'
)

store.registerNodeProvider('multi_type', provider1)
store.registerNodeProvider('multi_type', provider2)

const allProviders = store.getAllNodeProviders('multi_type')
expect(allProviders).toHaveLength(2)
expect(store.getNodeProvider('multi_type')).toStrictEqual(provider1)
})

it('should initialize new model type when first provider is registered', () => {
expect(store.modelToNodeMap['new_type']).toBeUndefined()

const provider = new ModelNodeProvider(
nodeDefStore.nodeDefsByName['UNETLoader'],
'test_key'
)
store.registerNodeProvider('new_type', provider)

expect(store.modelToNodeMap['new_type']).toBeDefined()
expect(store.modelToNodeMap['new_type']).toHaveLength(1)
})
})

describe('quickRegister', () => {
it('should connect node class to model type with parameter mapping', () => {
store.quickRegister('test_type', 'UNETLoader', 'test_param')

const provider = store.getNodeProvider('test_type')
expect(provider).toBeDefined()
// Using optional chaining since getNodeProvider() can return undefined
expect(provider?.nodeDef.name).toBe('UNETLoader')
expect(provider?.key).toBe('test_param')
})

it('should handle registration of non-existent node classes gracefully', () => {
expect(() => {
store.quickRegister('test_type', 'NonExistentLoader', 'test_param')
}).not.toThrow()

const provider = store.getNodeProvider('test_type')
// Optional chaining needed since getNodeProvider() can return undefined
expect(provider?.nodeDef).toBeUndefined()
})

it('should allow multiple node classes for same model type', () => {
store.quickRegister('multi_type', 'UNETLoader', 'param1')
store.quickRegister('multi_type', 'VAELoader', 'param2')

const providers = store.getAllNodeProviders('multi_type')
expect(providers).toHaveLength(2)
})
})

describe('registerDefaults integration', () => {
it('should register all expected model types based on mock data', () => {
store.registerDefaults()

const expectedTypes = [
'checkpoints',
'loras',
'vae',
'controlnet',
'unet',
'upscale_models',
'style_models',
'gligen'
]

expectedTypes.forEach((modelType) => {
expect(store.getNodeProvider(modelType)).toBeDefined()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you're looping like this, you usually want to add .soft so that you don't have to keep re-running it if there are multiple violations.

})
})

it('should be idempotent', () => {
store.registerDefaults()
const firstCheckpointCount =
store.getAllNodeProviders('checkpoints').length

store.registerDefaults() // Call again
const secondCheckpointCount =
store.getAllNodeProviders('checkpoints').length

expect(secondCheckpointCount).toBe(firstCheckpointCount)
})

it('should not register when nodeDefStore is empty', () => {
nodeDefStore.nodeDefsByName = {}
store.registerDefaults()
expect(store.getNodeProvider('checkpoints')).toBeUndefined()
})
})

describe('edge cases', () => {
it('should handle empty string model type', () => {
expect(store.getNodeProvider('')).toBeUndefined()
expect(store.getAllNodeProviders('')).toEqual([])
})
})
})