-
Notifications
You must be signed in to change notification settings - Fork 448
[feat] register UNETLoader, UpscaleModelLoader, StylemModelLoader... #5324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
99d486f
[feat] register UNETLoader, UpscaleModelLoader, StylemModelLoader, GL…
arjansingh 68bd05e
[fix] code review feedback on tests
arjansingh b538e29
[fix] typescript bikeshedding
arjansingh 8ebe70a
[fix] remove unnecessary interface mocks
arjansingh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,293 @@ | ||
| import { createPinia, setActivePinia } from 'pinia' | ||
| import { beforeEach, describe, expect, it } from 'vitest' | ||
|
|
||
| import type { ComfyNodeDef as ComfyNodeDefV1 } from '@/schemas/nodeDefSchema' | ||
| import { | ||
| ModelNodeProvider, | ||
| useModelToNodeStore | ||
| } from '@/stores/modelToNodeStore' | ||
| import { ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore' | ||
|
|
||
| const MOCK_NODE_NAMES = [ | ||
| 'CheckpointLoaderSimple', | ||
| 'ImageOnlyCheckpointLoader', | ||
| 'LoraLoader', | ||
| 'LoraLoaderModelOnly', | ||
| 'VAELoader', | ||
| 'ControlNetLoader', | ||
| 'UNETLoader', | ||
| 'UpscaleModelLoader', | ||
| 'StyleModelLoader', | ||
| 'GLIGENLoader' | ||
| ] as const | ||
|
|
||
| const EXPECTED_DEFAULT_TYPES = [ | ||
| 'checkpoints', | ||
| 'loras', | ||
| 'vae', | ||
| 'controlnet', | ||
| 'unet', | ||
| 'upscale_models', | ||
| 'style_models', | ||
| 'gligen' | ||
| ] as const | ||
|
|
||
| describe('useModelToNodeStore', () => { | ||
| beforeEach(() => { | ||
| setActivePinia(createPinia()) | ||
|
|
||
| // Create minimal but valid ComfyNodeDefImpl for testing | ||
| const createMockNodeDef = (name: string): ComfyNodeDefImpl => { | ||
| const def: ComfyNodeDefV1 = { | ||
| name, | ||
| display_name: name, | ||
| category: 'test', | ||
| python_module: 'nodes', | ||
| description: '', | ||
| input: { required: {}, optional: {} }, | ||
| output: [], | ||
| output_name: [], | ||
| output_is_list: [], | ||
| output_node: false | ||
| } | ||
| return new ComfyNodeDefImpl(def) | ||
| } | ||
|
|
||
| // Mock nodeDefStore dependency - modelToNodeStore relies on this for registration | ||
| // Most tests expect this to be populated; tests that need empty state can override | ||
| const nodeDefStore = useNodeDefStore() | ||
| nodeDefStore.nodeDefsByName = Object.fromEntries( | ||
| MOCK_NODE_NAMES.map((name) => [name, createMockNodeDef(name)]) | ||
| ) | ||
| }) | ||
|
|
||
| describe('modelToNodeMap', () => { | ||
| it('should initialize as empty', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| expect(Object.keys(modelToNodeStore.modelToNodeMap)).toHaveLength(0) | ||
| }) | ||
|
|
||
| it('should populate after registration', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
| expect(Object.keys(modelToNodeStore.modelToNodeMap)).toEqual( | ||
| expect.arrayContaining(['checkpoints', 'unet']) | ||
| ) | ||
| }) | ||
| }) | ||
|
|
||
| describe('getNodeProvider', () => { | ||
| it('should return provider for registered model type', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
|
|
||
| const provider = modelToNodeStore.getNodeProvider('checkpoints') | ||
| expect(provider).toBeDefined() | ||
| // After asserting provider is defined, we can safely access its properties | ||
| expect(provider?.nodeDef?.name).toBe('CheckpointLoaderSimple') | ||
| expect(provider?.key).toBe('ckpt_name') | ||
| }) | ||
|
|
||
| it('should return undefined for unregistered model type', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
| expect(modelToNodeStore.getNodeProvider('nonexistent')).toBeUndefined() | ||
| }) | ||
|
|
||
| it('should return first registered provider when multiple providers exist for same model type', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
|
|
||
| const provider = modelToNodeStore.getNodeProvider('checkpoints') | ||
| // Using optional chaining for safety since getNodeProvider() can return undefined | ||
| expect(provider?.nodeDef?.name).toBe('CheckpointLoaderSimple') | ||
| }) | ||
|
|
||
| it('should trigger lazy registration when called before registerDefaults', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
|
|
||
| const provider = modelToNodeStore.getNodeProvider('checkpoints') | ||
| expect(provider).toBeDefined() | ||
| }) | ||
| }) | ||
|
|
||
| describe('getAllNodeProviders', () => { | ||
| it('should return all providers for model type with multiple nodes', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
|
|
||
| const checkpointProviders = | ||
| modelToNodeStore.getAllNodeProviders('checkpoints') | ||
| expect(checkpointProviders).toHaveLength(2) | ||
| expect(checkpointProviders).toEqual( | ||
| expect.arrayContaining([ | ||
| expect.objectContaining({ | ||
| nodeDef: expect.objectContaining({ name: 'CheckpointLoaderSimple' }) | ||
| }), | ||
| expect.objectContaining({ | ||
| nodeDef: expect.objectContaining({ | ||
| name: 'ImageOnlyCheckpointLoader' | ||
| }) | ||
| }) | ||
| ]) | ||
| ) | ||
|
|
||
| const loraProviders = modelToNodeStore.getAllNodeProviders('loras') | ||
| expect(loraProviders).toHaveLength(2) | ||
| }) | ||
|
|
||
| it('should return single provider for model type with one node', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
|
|
||
| const unetProviders = modelToNodeStore.getAllNodeProviders('unet') | ||
| expect(unetProviders).toHaveLength(1) | ||
| expect(unetProviders[0].nodeDef.name).toBe('UNETLoader') | ||
| }) | ||
|
|
||
| it('should return empty array for unregistered model type', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
| expect(modelToNodeStore.getAllNodeProviders('nonexistent')).toEqual([]) | ||
| }) | ||
|
|
||
| it('should trigger lazy registration when called before registerDefaults', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
|
|
||
| const providers = modelToNodeStore.getAllNodeProviders('checkpoints') | ||
| expect(providers.length).toBeGreaterThan(0) | ||
| }) | ||
| }) | ||
|
|
||
| describe('registerNodeProvider', () => { | ||
| it('should register provider directly', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| const nodeDefStore = useNodeDefStore() | ||
| const customProvider = new ModelNodeProvider( | ||
| nodeDefStore.nodeDefsByName['UNETLoader'], | ||
| 'custom_key' | ||
| ) | ||
|
|
||
| modelToNodeStore.registerNodeProvider('custom_type', customProvider) | ||
|
|
||
| const retrieved = modelToNodeStore.getNodeProvider('custom_type') | ||
| expect(retrieved).toStrictEqual(customProvider) | ||
| // Optional chaining for consistency with getNodeProvider() return type | ||
| expect(retrieved?.key).toBe('custom_key') | ||
| }) | ||
|
|
||
| it('should handle multiple providers for same model type and return first as primary', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| const nodeDefStore = useNodeDefStore() | ||
| const provider1 = new ModelNodeProvider( | ||
| nodeDefStore.nodeDefsByName['UNETLoader'], | ||
| 'key1' | ||
| ) | ||
| const provider2 = new ModelNodeProvider( | ||
| nodeDefStore.nodeDefsByName['VAELoader'], | ||
| 'key2' | ||
| ) | ||
|
|
||
| modelToNodeStore.registerNodeProvider('multi_type', provider1) | ||
| modelToNodeStore.registerNodeProvider('multi_type', provider2) | ||
|
|
||
| const allProviders = modelToNodeStore.getAllNodeProviders('multi_type') | ||
| expect(allProviders).toHaveLength(2) | ||
| expect(modelToNodeStore.getNodeProvider('multi_type')).toStrictEqual( | ||
| provider1 | ||
| ) | ||
| }) | ||
|
|
||
| it('should initialize new model type when first provider is registered', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| const nodeDefStore = useNodeDefStore() | ||
| expect(modelToNodeStore.modelToNodeMap['new_type']).toBeUndefined() | ||
|
|
||
| const provider = new ModelNodeProvider( | ||
| nodeDefStore.nodeDefsByName['UNETLoader'], | ||
| 'test_key' | ||
| ) | ||
| modelToNodeStore.registerNodeProvider('new_type', provider) | ||
|
|
||
| expect(modelToNodeStore.modelToNodeMap['new_type']).toBeDefined() | ||
| expect(modelToNodeStore.modelToNodeMap['new_type']).toHaveLength(1) | ||
| }) | ||
| }) | ||
|
|
||
| describe('quickRegister', () => { | ||
| it('should connect node class to model type with parameter mapping', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.quickRegister('test_type', 'UNETLoader', 'test_param') | ||
|
|
||
| const provider = modelToNodeStore.getNodeProvider('test_type') | ||
| expect(provider).toBeDefined() | ||
| // After asserting provider is defined, we can safely access its properties | ||
| expect(provider!.nodeDef.name).toBe('UNETLoader') | ||
| expect(provider!.key).toBe('test_param') | ||
| }) | ||
|
|
||
| it('should handle registration of non-existent node classes gracefully', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| expect(() => { | ||
| modelToNodeStore.quickRegister( | ||
| 'test_type', | ||
| 'NonExistentLoader', | ||
| 'test_param' | ||
| ) | ||
| }).not.toThrow() | ||
|
|
||
| const provider = modelToNodeStore.getNodeProvider('test_type') | ||
| // Optional chaining needed since getNodeProvider() can return undefined | ||
| expect(provider?.nodeDef).toBeUndefined() | ||
| }) | ||
|
|
||
| it('should allow multiple node classes for same model type', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.quickRegister('multi_type', 'UNETLoader', 'param1') | ||
| modelToNodeStore.quickRegister('multi_type', 'VAELoader', 'param2') | ||
|
|
||
| const providers = modelToNodeStore.getAllNodeProviders('multi_type') | ||
| expect(providers).toHaveLength(2) | ||
| }) | ||
| }) | ||
|
|
||
| describe('registerDefaults integration', () => { | ||
| it('should register all expected model types based on mock data', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
|
|
||
| EXPECTED_DEFAULT_TYPES.forEach((modelType) => { | ||
|
||
| expect.soft(modelToNodeStore.getNodeProvider(modelType)).toBeDefined() | ||
| }) | ||
| }) | ||
|
|
||
| it('should be idempotent', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
| const firstCheckpointCount = | ||
| modelToNodeStore.getAllNodeProviders('checkpoints').length | ||
|
|
||
| modelToNodeStore.registerDefaults() // Call again | ||
| const secondCheckpointCount = | ||
| modelToNodeStore.getAllNodeProviders('checkpoints').length | ||
|
|
||
| expect(secondCheckpointCount).toBe(firstCheckpointCount) | ||
| }) | ||
|
|
||
| it('should not register when nodeDefStore is empty', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| const nodeDefStore = useNodeDefStore() | ||
| nodeDefStore.nodeDefsByName = {} | ||
| modelToNodeStore.registerDefaults() | ||
| expect(modelToNodeStore.getNodeProvider('checkpoints')).toBeUndefined() | ||
| }) | ||
| }) | ||
|
|
||
| describe('edge cases', () => { | ||
| it('should handle empty string model type', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| expect(modelToNodeStore.getNodeProvider('')).toBeUndefined() | ||
| expect(modelToNodeStore.getAllNodeProviders('')).toEqual([]) | ||
| }) | ||
| }) | ||
| }) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use vitest mocking