Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 148 additions & 100 deletions tests-ui/tests/store/modelToNodeStore.test.ts
Original file line number Diff line number Diff line change
@@ -1,141 +1,184 @@
import { createPinia, setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it } from 'vitest'

import type { ComfyNodeDef as ComfyNodeDefV1 } from '@/schemas/nodeDefSchema'
import {
ModelNodeProvider,
useModelToNodeStore
} from '@/stores/modelToNodeStore'
import { type ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore'
import { ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore'

const MOCK_NODE_NAMES = [
'CheckpointLoaderSimple',
'ImageOnlyCheckpointLoader',
'LoraLoader',
'LoraLoaderModelOnly',
'VAELoader',
'ControlNetLoader',
'UNETLoader',
'UpscaleModelLoader',
'StyleModelLoader',
'GLIGENLoader'
] as const

const EXPECTED_DEFAULT_TYPES = [
'checkpoints',
'loras',
'vae',
'controlnet',
'unet',
'upscale_models',
'style_models',
'gligen'
] as const

describe('useModelToNodeStore', () => {
let store: ReturnType<typeof useModelToNodeStore>
let nodeDefStore: ReturnType<typeof useNodeDefStore>

// Create minimal mock for testing - only includes 'name' field since that's
// the only property ModelNodeProvider constructor uses and tests verify
const createMockNodeDef = (name: string): ComfyNodeDefImpl => {
return { name } as ComfyNodeDefImpl
}

beforeEach(() => {
setActivePinia(createPinia())
store = useModelToNodeStore()
nodeDefStore = useNodeDefStore()

const mockNodeNames = [
'CheckpointLoaderSimple',
'ImageOnlyCheckpointLoader',
'LoraLoader',
'LoraLoaderModelOnly',
'VAELoader',
'ControlNetLoader',
'UNETLoader',
'UpscaleModelLoader',
'StyleModelLoader',
'GLIGENLoader'
]

const mockNodeDefs: Record<string, ComfyNodeDefImpl> = Object.fromEntries(
mockNodeNames.map((name) => [name, createMockNodeDef(name)])
)

nodeDefStore.nodeDefsByName = mockNodeDefs
// Create minimal but valid ComfyNodeDefImpl for testing
const createMockNodeDef = (name: string): ComfyNodeDefImpl => {
const def: ComfyNodeDefV1 = {
name,
display_name: name,
category: 'test',
python_module: 'nodes',
description: '',
input: { required: {}, optional: {} },
output: [],
output_name: [],
output_is_list: [],
output_node: false
}
return new ComfyNodeDefImpl(def)
}

// Mock nodeDefStore dependency - modelToNodeStore relies on this for registration
// Most tests expect this to be populated; tests that need empty state can override
const nodeDefStore = useNodeDefStore()
nodeDefStore.nodeDefsByName = Object.fromEntries(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use vitest mocking

MOCK_NODE_NAMES.map((name) => [name, createMockNodeDef(name)])
)
})

describe('modelToNodeMap', () => {
it('should initialize as empty', () => {
expect(Object.keys(store.modelToNodeMap)).toHaveLength(0)
const modelToNodeStore = useModelToNodeStore()
expect(Object.keys(modelToNodeStore.modelToNodeMap)).toHaveLength(0)
})

it('should populate after registration', () => {
store.registerDefaults()
expect(Object.keys(store.modelToNodeMap)).toContain('checkpoints')
expect(Object.keys(store.modelToNodeMap)).toContain('unet')
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.registerDefaults()
expect(Object.keys(modelToNodeStore.modelToNodeMap)).toEqual(
expect.arrayContaining(['checkpoints', 'unet'])
)
})
})

describe('getNodeProvider', () => {
it('should return provider for registered model type', () => {
store.registerDefaults()
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.registerDefaults()

const provider = store.getNodeProvider('checkpoints')
const provider = modelToNodeStore.getNodeProvider('checkpoints')
expect(provider).toBeDefined()
// Optional chaining used because getNodeProvider() can return undefined for unregistered types
expect(provider?.nodeDef.name).toBe('CheckpointLoaderSimple')
// After asserting provider is defined, we can safely access its properties
expect(provider?.nodeDef?.name).toBe('CheckpointLoaderSimple')
expect(provider?.key).toBe('ckpt_name')
})

it('should return undefined for unregistered model type', () => {
store.registerDefaults()
expect(store.getNodeProvider('nonexistent')).toBeUndefined()
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.registerDefaults()
expect(modelToNodeStore.getNodeProvider('nonexistent')).toBeUndefined()
})

it('should return first registered provider when multiple providers exist for same model type', () => {
store.registerDefaults()
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.registerDefaults()

const provider = store.getNodeProvider('checkpoints')
const provider = modelToNodeStore.getNodeProvider('checkpoints')
// Using optional chaining for safety since getNodeProvider() can return undefined
expect(provider?.nodeDef.name).toBe('CheckpointLoaderSimple')
expect(provider?.nodeDef?.name).toBe('CheckpointLoaderSimple')
})

it('should trigger lazy registration when called before registerDefaults', () => {
const provider = store.getNodeProvider('checkpoints')
const modelToNodeStore = useModelToNodeStore()

const provider = modelToNodeStore.getNodeProvider('checkpoints')
expect(provider).toBeDefined()
})
})

describe('getAllNodeProviders', () => {
it('should return all providers for model type with multiple nodes', () => {
store.registerDefaults()
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.registerDefaults()

const checkpointProviders = store.getAllNodeProviders('checkpoints')
const checkpointProviders =
modelToNodeStore.getAllNodeProviders('checkpoints')
expect(checkpointProviders).toHaveLength(2)
expect(checkpointProviders.map((p) => p.nodeDef.name)).toContain(
'CheckpointLoaderSimple'
)
expect(checkpointProviders.map((p) => p.nodeDef.name)).toContain(
'ImageOnlyCheckpointLoader'
expect(checkpointProviders).toEqual(
expect.arrayContaining([
expect.objectContaining({
nodeDef: expect.objectContaining({ name: 'CheckpointLoaderSimple' })
}),
expect.objectContaining({
nodeDef: expect.objectContaining({
name: 'ImageOnlyCheckpointLoader'
})
})
])
)

const loraProviders = store.getAllNodeProviders('loras')
const loraProviders = modelToNodeStore.getAllNodeProviders('loras')
expect(loraProviders).toHaveLength(2)
})

it('should return single provider for model type with one node', () => {
store.registerDefaults()
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.registerDefaults()

const unetProviders = store.getAllNodeProviders('unet')
const unetProviders = modelToNodeStore.getAllNodeProviders('unet')
expect(unetProviders).toHaveLength(1)
expect(unetProviders[0].nodeDef.name).toBe('UNETLoader')
})

it('should return empty array for unregistered model type', () => {
store.registerDefaults()
expect(store.getAllNodeProviders('nonexistent')).toEqual([])
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.registerDefaults()
expect(modelToNodeStore.getAllNodeProviders('nonexistent')).toEqual([])
})

it('should trigger lazy registration when called before registerDefaults', () => {
const providers = store.getAllNodeProviders('checkpoints')
const modelToNodeStore = useModelToNodeStore()

const providers = modelToNodeStore.getAllNodeProviders('checkpoints')
expect(providers.length).toBeGreaterThan(0)
})
})

describe('registerNodeProvider', () => {
it('should register provider directly', () => {
const modelToNodeStore = useModelToNodeStore()
const nodeDefStore = useNodeDefStore()
const customProvider = new ModelNodeProvider(
nodeDefStore.nodeDefsByName['UNETLoader'],
'custom_key'
)

store.registerNodeProvider('custom_type', customProvider)
modelToNodeStore.registerNodeProvider('custom_type', customProvider)

const retrieved = store.getNodeProvider('custom_type')
const retrieved = modelToNodeStore.getNodeProvider('custom_type')
expect(retrieved).toStrictEqual(customProvider)
// Optional chaining for consistency with getNodeProvider() return type
expect(retrieved?.key).toBe('custom_key')
})

it('should handle multiple providers for same model type and return first as primary', () => {
const modelToNodeStore = useModelToNodeStore()
const nodeDefStore = useNodeDefStore()
const provider1 = new ModelNodeProvider(
nodeDefStore.nodeDefsByName['UNETLoader'],
'key1'
Expand All @@ -145,101 +188,106 @@ describe('useModelToNodeStore', () => {
'key2'
)

store.registerNodeProvider('multi_type', provider1)
store.registerNodeProvider('multi_type', provider2)
modelToNodeStore.registerNodeProvider('multi_type', provider1)
modelToNodeStore.registerNodeProvider('multi_type', provider2)

const allProviders = store.getAllNodeProviders('multi_type')
const allProviders = modelToNodeStore.getAllNodeProviders('multi_type')
expect(allProviders).toHaveLength(2)
expect(store.getNodeProvider('multi_type')).toStrictEqual(provider1)
expect(modelToNodeStore.getNodeProvider('multi_type')).toStrictEqual(
provider1
)
})

it('should initialize new model type when first provider is registered', () => {
expect(store.modelToNodeMap['new_type']).toBeUndefined()
const modelToNodeStore = useModelToNodeStore()
const nodeDefStore = useNodeDefStore()
expect(modelToNodeStore.modelToNodeMap['new_type']).toBeUndefined()

const provider = new ModelNodeProvider(
nodeDefStore.nodeDefsByName['UNETLoader'],
'test_key'
)
store.registerNodeProvider('new_type', provider)
modelToNodeStore.registerNodeProvider('new_type', provider)

expect(store.modelToNodeMap['new_type']).toBeDefined()
expect(store.modelToNodeMap['new_type']).toHaveLength(1)
expect(modelToNodeStore.modelToNodeMap['new_type']).toBeDefined()
expect(modelToNodeStore.modelToNodeMap['new_type']).toHaveLength(1)
})
})

describe('quickRegister', () => {
it('should connect node class to model type with parameter mapping', () => {
store.quickRegister('test_type', 'UNETLoader', 'test_param')
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.quickRegister('test_type', 'UNETLoader', 'test_param')

const provider = store.getNodeProvider('test_type')
const provider = modelToNodeStore.getNodeProvider('test_type')
expect(provider).toBeDefined()
// Using optional chaining since getNodeProvider() can return undefined
expect(provider?.nodeDef.name).toBe('UNETLoader')
expect(provider?.key).toBe('test_param')
// After asserting provider is defined, we can safely access its properties
expect(provider!.nodeDef.name).toBe('UNETLoader')
expect(provider!.key).toBe('test_param')
})

it('should handle registration of non-existent node classes gracefully', () => {
const modelToNodeStore = useModelToNodeStore()
expect(() => {
store.quickRegister('test_type', 'NonExistentLoader', 'test_param')
modelToNodeStore.quickRegister(
'test_type',
'NonExistentLoader',
'test_param'
)
}).not.toThrow()

const provider = store.getNodeProvider('test_type')
const provider = modelToNodeStore.getNodeProvider('test_type')
// Optional chaining needed since getNodeProvider() can return undefined
expect(provider?.nodeDef).toBeUndefined()
})

it('should allow multiple node classes for same model type', () => {
store.quickRegister('multi_type', 'UNETLoader', 'param1')
store.quickRegister('multi_type', 'VAELoader', 'param2')
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.quickRegister('multi_type', 'UNETLoader', 'param1')
modelToNodeStore.quickRegister('multi_type', 'VAELoader', 'param2')

const providers = store.getAllNodeProviders('multi_type')
const providers = modelToNodeStore.getAllNodeProviders('multi_type')
expect(providers).toHaveLength(2)
})
})

describe('registerDefaults integration', () => {
it('should register all expected model types based on mock data', () => {
store.registerDefaults()

const expectedTypes = [
'checkpoints',
'loras',
'vae',
'controlnet',
'unet',
'upscale_models',
'style_models',
'gligen'
]

expectedTypes.forEach((modelType) => {
expect(store.getNodeProvider(modelType)).toBeDefined()
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.registerDefaults()

EXPECTED_DEFAULT_TYPES.forEach((modelType) => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for of

expect.soft(modelToNodeStore.getNodeProvider(modelType)).toBeDefined()
})
})

it('should be idempotent', () => {
store.registerDefaults()
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.registerDefaults()
const firstCheckpointCount =
store.getAllNodeProviders('checkpoints').length
modelToNodeStore.getAllNodeProviders('checkpoints').length

store.registerDefaults() // Call again
modelToNodeStore.registerDefaults() // Call again
const secondCheckpointCount =
store.getAllNodeProviders('checkpoints').length
modelToNodeStore.getAllNodeProviders('checkpoints').length

expect(secondCheckpointCount).toBe(firstCheckpointCount)
})

it('should not register when nodeDefStore is empty', () => {
const modelToNodeStore = useModelToNodeStore()
const nodeDefStore = useNodeDefStore()
nodeDefStore.nodeDefsByName = {}
store.registerDefaults()
expect(store.getNodeProvider('checkpoints')).toBeUndefined()
modelToNodeStore.registerDefaults()
expect(modelToNodeStore.getNodeProvider('checkpoints')).toBeUndefined()
})
})

describe('edge cases', () => {
it('should handle empty string model type', () => {
expect(store.getNodeProvider('')).toBeUndefined()
expect(store.getAllNodeProviders('')).toEqual([])
const modelToNodeStore = useModelToNodeStore()
expect(modelToNodeStore.getNodeProvider('')).toBeUndefined()
expect(modelToNodeStore.getAllNodeProviders('')).toEqual([])
})
})
})