Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def test_build_model(self, mock_chat_ollama, component_class, default_kwar
mock_chat_ollama.assert_called_once_with(
base_url="http://localhost:11434",
model="ollama-model",
mirostat=0,
# mirostat is not included when disabled (set to None and filtered out)
format="json",
metadata={"keywords": ["model", "llm", "language model", "large language model"]},
num_ctx=2048,
Expand Down Expand Up @@ -83,6 +83,51 @@ async def test_build_model_missing_base_url(self, mock_chat_ollama, component_cl
with pytest.raises(ValueError, match=re.escape("Unable to connect to the Ollama API.")):
component.build_model()

@patch("lfx.components.ollama.ollama.ChatOllama")
async def test_build_model_with_mirostat_enabled(self, mock_chat_ollama, component_class):
"""Test that mirostat parameters are included when Mirostat is enabled."""
mock_instance = MagicMock()
mock_chat_ollama.return_value = mock_instance

component = component_class(
base_url="http://localhost:11434",
model_name="ollama-model",
mirostat="Mirostat", # Setting to Mirostat (value 1)
mirostat_eta=0.1,
mirostat_tau=5.0,
temperature=0.1,
)
model = component.build_model()

# Verify that mirostat and its related params ARE passed
call_kwargs = mock_chat_ollama.call_args[1]
assert call_kwargs["mirostat"] == 1
assert call_kwargs["mirostat_eta"] == 0.1
assert call_kwargs["mirostat_tau"] == 5.0
assert model == mock_instance

@patch("lfx.components.ollama.ollama.ChatOllama")
async def test_build_model_with_mirostat_2_enabled(self, mock_chat_ollama, component_class):
"""Test that mirostat parameters are included when Mirostat 2.0 is enabled."""
mock_instance = MagicMock()
mock_chat_ollama.return_value = mock_instance

component = component_class(
base_url="http://localhost:11434",
model_name="ollama-model",
mirostat="Mirostat 2.0", # Setting to Mirostat 2.0 (value 2)
mirostat_eta=0.2,
mirostat_tau=10.0,
temperature=0.1,
)
model = component.build_model()
# Verify that mirostat and its related params ARE passed
call_kwargs = mock_chat_ollama.call_args[1]
assert call_kwargs["mirostat"] == 2
assert call_kwargs["mirostat_eta"] == 0.2
assert call_kwargs["mirostat_tau"] == 10.0
assert model == mock_instance

@pytest.mark.asyncio
@patch("lfx.components.ollama.ollama.httpx.AsyncClient.post")
@patch("lfx.components.ollama.ollama.httpx.AsyncClient.get")
Expand Down
6 changes: 3 additions & 3 deletions src/lfx/src/lfx/components/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ def build_model(self) -> LanguageModel: # type: ignore[type-var]
# Mapping mirostat settings to their corresponding values
mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2}

# Default to 0 for 'Disabled'
mirostat_value = mirostat_options.get(self.mirostat, 0)
# Default to None for 'Disabled'
mirostat_value = mirostat_options.get(self.mirostat, None)

# Set mirostat_eta and mirostat_tau to None if mirostat is disabled
if mirostat_value == 0:
if mirostat_value is None:
mirostat_eta = None
mirostat_tau = None
else:
Expand Down
Loading