-
Notifications
You must be signed in to change notification settings - Fork 967
Fix Ollama embedding response key handling #1809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| #!/usr/bin/env python3 | ||
| """ | ||
| Test script to verify OllamaEmbeddingEngine fix with real Ollama server. | ||
| Tests that the fix correctly handles Ollama's API response format. | ||
| """ | ||
| import asyncio | ||
| import sys | ||
| from cognee.infrastructure.databases.vector.embeddings.OllamaEmbeddingEngine import ( | ||
| OllamaEmbeddingEngine, | ||
| ) | ||
|
|
||
|
|
||
| async def test_ollama_embedding(): | ||
| """Test OllamaEmbeddingEngine with real Ollama server.""" | ||
|
|
||
| print("=" * 80) | ||
| print("Testing OllamaEmbeddingEngine Fix") | ||
| print("=" * 80) | ||
|
|
||
| # Configure for your Ollama server | ||
| ollama_endpoint = "http://10.0.10.9:11434/api/embeddings" | ||
| ollama_model = "nomic-embed-text" | ||
|
|
||
| print(f"\nConfiguration:") | ||
| print(f" Endpoint: {ollama_endpoint}") | ||
| print(f" Model: {ollama_model}") | ||
| print(f" Expected dimensions: 768") | ||
|
|
||
| # Initialize the embedding engine | ||
| print("\n1. Initializing OllamaEmbeddingEngine...") | ||
| try: | ||
| engine = OllamaEmbeddingEngine( | ||
| model=ollama_model, | ||
| dimensions=768, | ||
| endpoint=ollama_endpoint, | ||
| huggingface_tokenizer="bert-base-uncased", | ||
| ) | ||
| print(" ✅ Engine initialized successfully") | ||
| except Exception as e: | ||
| print(f" ❌ Failed to initialize engine: {e}") | ||
| sys.exit(1) | ||
|
Comment on lines
+38
to
+40
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Replace sys.exit() with proper test assertions. Using If keeping this as an integration test script, at least replace sys.exit() with exceptions: except Exception as e:
print(f" ❌ Failed to initialize engine: {e}")
- sys.exit(1)
+ raise RuntimeError(f"Engine initialization failed: {e}") from eBetter yet, use the proper unit test approach suggested in the previous comment. Also applies to: 59-64, 81-83 🧰 Tools🪛 Ruff (0.14.5)39-39: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||
|
|
||
| # Test single text embedding | ||
| print("\n2. Testing single text embedding...") | ||
| test_texts = ["The sky is blue and the grass is green."] | ||
|
|
||
| try: | ||
| embeddings = await engine.embed_text(test_texts) | ||
| print(f" ✅ Embedding generated successfully") | ||
| print(f" 📊 Embedding shape: {len(embeddings)} texts, {len(embeddings[0])} dimensions") | ||
| print(f" 📊 First 5 values: {embeddings[0][:5]}") | ||
|
|
||
| # Verify dimensions | ||
| if len(embeddings[0]) == 768: | ||
| print(f" ✅ Dimensions match expected (768)") | ||
| else: | ||
| print(f" ⚠️ Dimensions mismatch: got {len(embeddings[0])}, expected 768") | ||
|
|
||
| except KeyError as e: | ||
| print(f" ❌ KeyError (this is the bug we're fixing): {e}") | ||
| sys.exit(1) | ||
| except Exception as e: | ||
| print(f" ❌ Failed to generate embedding: {type(e).__name__}: {e}") | ||
| sys.exit(1) | ||
|
|
||
| # Test multiple texts | ||
| print("\n3. Testing multiple text embeddings...") | ||
| test_texts_multiple = [ | ||
| "Hello world", | ||
| "Machine learning is fascinating", | ||
| "Ollama embeddings work great" | ||
| ] | ||
|
|
||
| try: | ||
| embeddings = await engine.embed_text(test_texts_multiple) | ||
| print(f" ✅ Multiple embeddings generated successfully") | ||
| print(f" 📊 Generated {len(embeddings)} embeddings") | ||
| for i, emb in enumerate(embeddings): | ||
| print(f" 📊 Text {i+1}: {len(emb)} dimensions, first 3 values: {emb[:3]}") | ||
|
|
||
| except Exception as e: | ||
| print(f" ❌ Failed to generate embeddings: {type(e).__name__}: {e}") | ||
| sys.exit(1) | ||
|
|
||
| # Success! | ||
| print("\n" + "=" * 80) | ||
| print("✅ ALL TESTS PASSED!") | ||
| print("=" * 80) | ||
| print("\nThe OllamaEmbeddingEngine fix is working correctly!") | ||
| print("- Handles 'embedding' (singular) response from Ollama API") | ||
| print("- Generates embeddings successfully") | ||
| print("- Correct dimensions (768 for nomic-embed-text)") | ||
| print("\n✅ Ready to submit PR!") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(test_ollama_embedding()) | ||
|
Comment on lines
12
to
96
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Convert integration test to proper unit test with mocking. This script requires a live Ollama server, making it unsuitable for automated testing in CI/CD pipelines. Consider converting it to a proper unit test using mocking. Create a proper unit test that mocks the HTTP response: import unittest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.infrastructure.databases.vector.embeddings.OllamaEmbeddingEngine import OllamaEmbeddingEngine
class TestOllamaEmbeddingEngine(unittest.TestCase):
@patch('aiohttp.ClientSession')
async def test_embedding_singular_key(self, mock_session):
"""Test that engine handles 'embedding' (singular) response key."""
# Mock the API response with singular 'embedding' key
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"embedding": [0.1] * 768})
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock()
mock_session.return_value.__aenter__.return_value.post.return_value = mock_response
engine = OllamaEmbeddingEngine(
model="nomic-embed-text",
dimensions=768,
endpoint="http://localhost:11434/api/embeddings",
huggingface_tokenizer="bert-base-uncased",
)
embeddings = await engine.embed_text(["test text"])
self.assertEqual(len(embeddings), 1)
self.assertEqual(len(embeddings[0]), 768)
@patch('aiohttp.ClientSession')
async def test_embedding_plural_key(self, mock_session):
"""Test backward compatibility with 'embeddings' (plural) response key."""
# Mock the API response with plural 'embeddings' key
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"embeddings": [[0.1] * 768]})
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock()
mock_session.return_value.__aenter__.return_value.post.return_value = mock_response
engine = OllamaEmbeddingEngine(
model="nomic-embed-text",
dimensions=768,
endpoint="http://localhost:11434/api/embeddings",
huggingface_tokenizer="bert-base-uncased",
)
embeddings = await engine.embed_text(["test text"])
self.assertEqual(len(embeddings), 1)
self.assertEqual(len(embeddings[0]), 768)
@patch('aiohttp.ClientSession')
async def test_missing_embedding_keys_raises_error(self, mock_session):
"""Test that missing keys raise descriptive KeyError."""
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"error": "Invalid model"})
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock()
mock_session.return_value.__aenter__.return_value.post.return_value = mock_response
engine = OllamaEmbeddingEngine(
model="nomic-embed-text",
dimensions=768,
endpoint="http://localhost:11434/api/embeddings",
huggingface_tokenizer="bert-base-uncased",
)
with self.assertRaises(KeyError) as context:
await engine.embed_text(["test text"])
self.assertIn("No 'embedding' or 'embeddings' key found", str(context.exception))This approach:
🧰 Tools🪛 Pylint (4.0.3)[refactor] 13-13: Too many statements (53/50) (R0915) 🪛 Ruff (0.14.5)24-24: f-string without any placeholders Remove extraneous (F541) 27-27: f-string without any placeholders Remove extraneous (F541) 39-39: Do not catch blind exception: (BLE001) 49-49: f-string without any placeholders Remove extraneous (F541) 55-55: f-string without any placeholders Remove extraneous (F541) 62-62: Do not catch blind exception: (BLE001) 76-76: f-string without any placeholders Remove extraneous (F541) 81-81: Do not catch blind exception: (BLE001) |
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Hardcoded private IP address will break CI/CD and other environments.
The endpoint
"http://10.0.10.9:11434/api/embeddings"is a private network address specific to your local setup. This test will fail in:Apply this diff to make the endpoint configurable:
Add this import at the top:
import asyncio import sys +import os from cognee.infrastructure.databases.vector.embeddings.OllamaEmbeddingEngine import (This allows customization via environment variables while defaulting to localhost.
🤖 Prompt for AI Agents