Skip to content

Commit be8e192

Browse files
author
Arjan Singh
committed
[feat] register UNETLoader, UpscaleModelLoader, StylemModelLoader, GLIGENLoader
Also added tests for modelToNodeStore
1 parent e1f2946 commit be8e192

File tree

2 files changed

+249
-0
lines changed

2 files changed

+249
-0
lines changed

src/stores/modelToNodeStore.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
8383
quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name')
8484
quickRegister('vae', 'VAELoader', 'vae_name')
8585
quickRegister('controlnet', 'ControlNetLoader', 'control_net_name')
86+
quickRegister('unet', 'UNETLoader', 'unet_name')
87+
quickRegister('upscale_models', 'UpscaleModelLoader', 'model_name')
88+
quickRegister('style_models', 'StyleModelLoader', 'style_model')
89+
quickRegister('gligen', 'GLIGENLoader', 'gligen_name')
8690
}
8791

8892
return {
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
import { createPinia, setActivePinia } from 'pinia'
2+
import { beforeEach, describe, expect, it } from 'vitest'
3+
4+
import {
5+
ModelNodeProvider,
6+
useModelToNodeStore
7+
} from '@/stores/modelToNodeStore'
8+
import { type ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore'
9+
10+
describe('useModelToNodeStore', () => {
11+
let store: ReturnType<typeof useModelToNodeStore>
12+
let nodeDefStore: ReturnType<typeof useNodeDefStore>
13+
14+
// Create minimal mock for testing - only includes 'name' field since that's
15+
// the only property ModelNodeProvider constructor uses and tests verify
16+
const createMockNodeDef = (name: string): ComfyNodeDefImpl => {
17+
return { name } as ComfyNodeDefImpl
18+
}
19+
20+
beforeEach(() => {
21+
setActivePinia(createPinia())
22+
store = useModelToNodeStore()
23+
nodeDefStore = useNodeDefStore()
24+
25+
const mockNodeNames = [
26+
'CheckpointLoaderSimple',
27+
'ImageOnlyCheckpointLoader',
28+
'LoraLoader',
29+
'LoraLoaderModelOnly',
30+
'VAELoader',
31+
'ControlNetLoader',
32+
'UNETLoader',
33+
'UpscaleModelLoader',
34+
'StyleModelLoader',
35+
'GLIGENLoader'
36+
]
37+
38+
const mockNodeDefs: Record<string, ComfyNodeDefImpl> = Object.fromEntries(
39+
mockNodeNames.map((name) => [name, createMockNodeDef(name)])
40+
)
41+
42+
nodeDefStore.nodeDefsByName = mockNodeDefs
43+
})
44+
45+
describe('modelToNodeMap', () => {
46+
it('should initialize as empty', () => {
47+
expect(Object.keys(store.modelToNodeMap)).toHaveLength(0)
48+
})
49+
50+
it('should populate after registration', () => {
51+
store.registerDefaults()
52+
expect(Object.keys(store.modelToNodeMap)).toContain('checkpoints')
53+
expect(Object.keys(store.modelToNodeMap)).toContain('unet')
54+
})
55+
})
56+
57+
describe('getNodeProvider', () => {
58+
it('should return provider for registered model type', () => {
59+
store.registerDefaults()
60+
61+
const provider = store.getNodeProvider('checkpoints')
62+
expect(provider).toBeDefined()
63+
// Optional chaining used because getNodeProvider() can return undefined for unregistered types
64+
expect(provider?.nodeDef.name).toBe('CheckpointLoaderSimple')
65+
expect(provider?.key).toBe('ckpt_name')
66+
})
67+
68+
it('should return undefined for unregistered model type', () => {
69+
store.registerDefaults()
70+
expect(store.getNodeProvider('nonexistent')).toBeUndefined()
71+
})
72+
73+
it('should return first registered provider when multiple providers exist for same model type', () => {
74+
store.registerDefaults()
75+
76+
const provider = store.getNodeProvider('checkpoints')
77+
// Using optional chaining for safety since getNodeProvider() can return undefined
78+
expect(provider?.nodeDef.name).toBe('CheckpointLoaderSimple')
79+
})
80+
81+
it('should trigger lazy registration when called before registerDefaults', () => {
82+
const provider = store.getNodeProvider('checkpoints')
83+
expect(provider).toBeDefined()
84+
})
85+
})
86+
87+
describe('getAllNodeProviders', () => {
88+
it('should return all providers for model type with multiple nodes', () => {
89+
store.registerDefaults()
90+
91+
const checkpointProviders = store.getAllNodeProviders('checkpoints')
92+
expect(checkpointProviders).toHaveLength(2)
93+
expect(checkpointProviders.map((p) => p.nodeDef.name)).toContain(
94+
'CheckpointLoaderSimple'
95+
)
96+
expect(checkpointProviders.map((p) => p.nodeDef.name)).toContain(
97+
'ImageOnlyCheckpointLoader'
98+
)
99+
100+
const loraProviders = store.getAllNodeProviders('loras')
101+
expect(loraProviders).toHaveLength(2)
102+
})
103+
104+
it('should return single provider for model type with one node', () => {
105+
store.registerDefaults()
106+
107+
const unetProviders = store.getAllNodeProviders('unet')
108+
expect(unetProviders).toHaveLength(1)
109+
expect(unetProviders[0].nodeDef.name).toBe('UNETLoader')
110+
})
111+
112+
it('should return empty array for unregistered model type', () => {
113+
store.registerDefaults()
114+
expect(store.getAllNodeProviders('nonexistent')).toEqual([])
115+
})
116+
117+
it('should trigger lazy registration when called before registerDefaults', () => {
118+
const providers = store.getAllNodeProviders('checkpoints')
119+
expect(providers.length).toBeGreaterThan(0)
120+
})
121+
})
122+
123+
describe('registerNodeProvider', () => {
124+
it('should register provider directly', () => {
125+
const customProvider = new ModelNodeProvider(
126+
nodeDefStore.nodeDefsByName['UNETLoader'],
127+
'custom_key'
128+
)
129+
130+
store.registerNodeProvider('custom_type', customProvider)
131+
132+
const retrieved = store.getNodeProvider('custom_type')
133+
expect(retrieved).toStrictEqual(customProvider)
134+
// Optional chaining for consistency with getNodeProvider() return type
135+
expect(retrieved?.key).toBe('custom_key')
136+
})
137+
138+
it('should handle multiple providers for same model type and return first as primary', () => {
139+
const provider1 = new ModelNodeProvider(
140+
nodeDefStore.nodeDefsByName['UNETLoader'],
141+
'key1'
142+
)
143+
const provider2 = new ModelNodeProvider(
144+
nodeDefStore.nodeDefsByName['VAELoader'],
145+
'key2'
146+
)
147+
148+
store.registerNodeProvider('multi_type', provider1)
149+
store.registerNodeProvider('multi_type', provider2)
150+
151+
const allProviders = store.getAllNodeProviders('multi_type')
152+
expect(allProviders).toHaveLength(2)
153+
expect(store.getNodeProvider('multi_type')).toStrictEqual(provider1)
154+
})
155+
156+
it('should initialize new model type when first provider is registered', () => {
157+
expect(store.modelToNodeMap['new_type']).toBeUndefined()
158+
159+
const provider = new ModelNodeProvider(
160+
nodeDefStore.nodeDefsByName['UNETLoader'],
161+
'test_key'
162+
)
163+
store.registerNodeProvider('new_type', provider)
164+
165+
expect(store.modelToNodeMap['new_type']).toBeDefined()
166+
expect(store.modelToNodeMap['new_type']).toHaveLength(1)
167+
})
168+
})
169+
170+
describe('quickRegister', () => {
171+
it('should connect node class to model type with parameter mapping', () => {
172+
store.quickRegister('test_type', 'UNETLoader', 'test_param')
173+
174+
const provider = store.getNodeProvider('test_type')
175+
expect(provider).toBeDefined()
176+
// Using optional chaining since getNodeProvider() can return undefined
177+
expect(provider?.nodeDef.name).toBe('UNETLoader')
178+
expect(provider?.key).toBe('test_param')
179+
})
180+
181+
it('should handle registration of non-existent node classes gracefully', () => {
182+
expect(() => {
183+
store.quickRegister('test_type', 'NonExistentLoader', 'test_param')
184+
}).not.toThrow()
185+
186+
const provider = store.getNodeProvider('test_type')
187+
// Optional chaining needed since getNodeProvider() can return undefined
188+
expect(provider?.nodeDef).toBeUndefined()
189+
})
190+
191+
it('should allow multiple node classes for same model type', () => {
192+
store.quickRegister('multi_type', 'UNETLoader', 'param1')
193+
store.quickRegister('multi_type', 'VAELoader', 'param2')
194+
195+
const providers = store.getAllNodeProviders('multi_type')
196+
expect(providers).toHaveLength(2)
197+
})
198+
})
199+
200+
describe('registerDefaults integration', () => {
201+
it('should register all expected model types based on mock data', () => {
202+
store.registerDefaults()
203+
204+
const expectedTypes = [
205+
'checkpoints',
206+
'loras',
207+
'vae',
208+
'controlnet',
209+
'unet',
210+
'upscale_models',
211+
'style_models',
212+
'gligen'
213+
]
214+
215+
expectedTypes.forEach((modelType) => {
216+
expect(store.getNodeProvider(modelType)).toBeDefined()
217+
})
218+
})
219+
220+
it('should be idempotent', () => {
221+
store.registerDefaults()
222+
const firstCheckpointCount =
223+
store.getAllNodeProviders('checkpoints').length
224+
225+
store.registerDefaults() // Call again
226+
const secondCheckpointCount =
227+
store.getAllNodeProviders('checkpoints').length
228+
229+
expect(secondCheckpointCount).toBe(firstCheckpointCount)
230+
})
231+
232+
it('should not register when nodeDefStore is empty', () => {
233+
nodeDefStore.nodeDefsByName = {}
234+
store.registerDefaults()
235+
expect(store.getNodeProvider('checkpoints')).toBeUndefined()
236+
})
237+
})
238+
239+
describe('edge cases', () => {
240+
it('should handle empty string model type', () => {
241+
expect(store.getNodeProvider('')).toBeUndefined()
242+
expect(store.getAllNodeProviders('')).toEqual([])
243+
})
244+
})
245+
})

0 commit comments

Comments
 (0)