-
Notifications
You must be signed in to change notification settings - Fork 960
feat: Add Support ChromaDB #622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
borisarzentar
merged 31 commits into
topoteretes:dev
from
dm1tryG:feature/support-chromadb
Mar 13, 2025
Merged
Changes from 5 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
8925442
Add Support ChromaDB
dm1tryG 0973e42
Update lock file deps
dm1tryG 3021f1a
Merge branch 'dev' into feature/support-chromadb
hajdul88 1feee13
fix: fixes ruff format
hajdul88 6a2ef72
feat: adds chroma_db_test to github actions
hajdul88 9483b6e
deletes volume from yml
hajdul88 6879b7a
updates health check
hajdul88 26d5d7d
healthcheck update
hajdul88 04e4860
deletes volume once again
hajdul88 21ba967
updates chroma test
hajdul88 c6532d0
updates image and adds volume
hajdul88 6df6cc8
deletes healthcheck
hajdul88 4bdee26
changes port
hajdul88 570204d
removes hardcoded config from test
hajdul88 bb469aa
adds hardcoded values for local test
hajdul88 3dfc41a
Merge branch 'dev' into support_chromadb_tests_fix
Vasilije1990 b56332b
Merge remote-tracking branch 'origin/dev' into support_chromadb_tests…
hajdul88 2970b23
updates poetry lock file
hajdul88 b28624c
Fix original_key splits
dm1tryG a299bd6
give back the distance from all the elements
dm1tryG 645be10
uncomment chroma profile
dm1tryG eae2f76
rm print
dm1tryG dc73728
change batch_search with normalize_distances
dm1tryG 8b874ad
update docker-compose.yml
dm1tryG 3a682fe
rm test_local_file_deletion
dm1tryG 53428c6
chromadb as extra in pyproject.toml
dm1tryG 18d818f
rename get_table_names
dm1tryG 3362e6d
Merge branch 'feature/support-chromadb' into support_chromadb_tests_fix
hajdul88 9954a09
chore: updates naming based on comments
hajdul88 e60c398
chore: updates naming based on comments
hajdul88 990a04c
Merge branch 'dev' into feature/support-chromadb
hajdul88 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -186,3 +186,6 @@ node_modules/ | |
|
|
||
| # Evals | ||
| SWE-bench_testsample/ | ||
|
|
||
| # ChromaDB Data | ||
| .chromadb_data/ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
349 changes: 349 additions & 0 deletions
349
cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,349 @@ | ||
| import logging | ||
| from typing import Dict, List, Optional, Any | ||
| import os | ||
| import json | ||
| from uuid import UUID | ||
|
|
||
| from chromadb import AsyncHttpClient, Settings | ||
|
|
||
| from cognee.exceptions import InvalidValueError | ||
| from cognee.infrastructure.engine.utils import parse_id | ||
| from cognee.infrastructure.engine import DataPoint | ||
| from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult | ||
|
|
||
| from ..embeddings.EmbeddingEngine import EmbeddingEngine | ||
| from ..vector_db_interface import VectorDBInterface | ||
| from ..utils import normalize_distances | ||
|
|
||
| logger = logging.getLogger("ChromaDBAdapter") | ||
|
|
||
|
|
||
| class IndexSchema(DataPoint): | ||
| text: str | ||
|
|
||
| metadata: dict = {"index_fields": ["text"]} | ||
|
|
||
| def model_dump(self): | ||
| data = super().model_dump() | ||
| return process_data_for_chroma(data) | ||
|
|
||
|
|
||
| def process_data_for_chroma(data): | ||
| """Convert complex data types to a format suitable for ChromaDB storage.""" | ||
| processed_data = {} | ||
| for key, value in data.items(): | ||
| if isinstance(value, UUID): | ||
| processed_data[key] = str(value) | ||
| elif isinstance(value, dict): | ||
| # Store dictionaries as JSON strings with special prefix | ||
| processed_data[f"{key}__dict"] = json.dumps(value) | ||
| elif isinstance(value, list): | ||
| # Store lists as JSON strings with special prefix | ||
| processed_data[f"{key}__list"] = json.dumps(value) | ||
| elif isinstance(value, (str, int, float, bool)) or value is None: | ||
| processed_data[key] = value | ||
| else: | ||
| processed_data[key] = str(value) | ||
| return processed_data | ||
|
|
||
|
|
||
| def restore_data_from_chroma(data): | ||
| """Restore original data structure from ChromaDB storage format.""" | ||
| restored_data = {} | ||
| dict_keys = [] | ||
| list_keys = [] | ||
|
|
||
| # First, identify all special keys | ||
| for key in data.keys(): | ||
| if key.endswith("__dict"): | ||
| dict_keys.append(key) | ||
| elif key.endswith("__list"): | ||
| list_keys.append(key) | ||
| else: | ||
| restored_data[key] = data[key] | ||
|
|
||
| # Process dictionary fields | ||
| for key in dict_keys: | ||
| original_key = key[:-6] # Remove '__dict' suffix | ||
| try: | ||
| restored_data[original_key] = json.loads(data[key]) | ||
| except Exception as e: | ||
| logger.debug(f"Error restoring dictionary from JSON: {e}") | ||
| restored_data[key] = data[key] | ||
|
|
||
| # Process list fields | ||
| for key in list_keys: | ||
| original_key = key[:-6] # Remove '__list' suffix | ||
| try: | ||
| restored_data[original_key] = json.loads(data[key]) | ||
| except Exception as e: | ||
| logger.debug(f"Error restoring list from JSON: {e}") | ||
| restored_data[key] = data[key] | ||
|
|
||
| return restored_data | ||
|
|
||
|
|
||
| class ChromaDBAdapter(VectorDBInterface): | ||
| name = "ChromaDB" | ||
| url: str | ||
| api_key: str | ||
| connection: AsyncHttpClient = None | ||
|
|
||
| def __init__(self, url: Optional[str], api_key: Optional[str], embedding_engine: EmbeddingEngine): | ||
| self.embedding_engine = embedding_engine | ||
| self.url = url | ||
| self.api_key = api_key | ||
|
|
||
| async def get_connection(self) -> AsyncHttpClient: | ||
| if self.connection is None: | ||
| settings = Settings( | ||
| chroma_client_auth_provider="token", | ||
| chroma_client_auth_credentials=self.api_key | ||
| ) | ||
| self.connection = await AsyncHttpClient(host=self.url, settings=settings) | ||
|
|
||
| return self.connection | ||
|
|
||
| async def embed_data(self, data: list[str]) -> list[list[float]]: | ||
| return await self.embedding_engine.embed_text(data) | ||
|
|
||
| async def has_collection(self, collection_name: str) -> bool: | ||
| client = await self.get_connection() | ||
| collections = await client.list_collections() | ||
| # In ChromaDB v0.6.0, list_collections returns collection names directly | ||
| return collection_name in collections | ||
|
|
||
| async def create_collection(self, collection_name: str, payload_schema=None): | ||
| client = await self.get_connection() | ||
|
|
||
| if not await self.has_collection(collection_name): | ||
| await client.create_collection( | ||
| name=collection_name, | ||
| metadata={"hnsw:space": "cosine"} | ||
| ) | ||
|
|
||
| async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): | ||
| client = await self.get_connection() | ||
|
|
||
| if not await self.has_collection(collection_name): | ||
| await self.create_collection(collection_name) | ||
|
|
||
| collection = await client.get_collection(collection_name) | ||
|
|
||
| texts = [DataPoint.get_embeddable_data(data_point) for data_point in data_points] | ||
| embeddings = await self.embed_data(texts) | ||
| ids = [str(data_point.id) for data_point in data_points] | ||
|
|
||
| metadatas = [] | ||
| for data_point in data_points: | ||
| metadata = data_point.model_dump() | ||
| metadatas.append(process_data_for_chroma(metadata)) | ||
|
|
||
| await collection.upsert( | ||
| ids=ids, | ||
| embeddings=embeddings, | ||
| metadatas=metadatas, | ||
| documents=texts | ||
| ) | ||
|
|
||
| async def create_vector_index(self, index_name: str, index_property_name: str): | ||
| """Create a vector index as a ChromaDB collection.""" | ||
| await self.create_collection(f"{index_name}_{index_property_name}") | ||
|
|
||
| async def index_data_points( | ||
| self, index_name: str, index_property_name: str, data_points: list[DataPoint] | ||
| ): | ||
| """Index data points using the specified index property.""" | ||
| await self.create_data_points( | ||
| f"{index_name}_{index_property_name}", | ||
| [ | ||
| IndexSchema( | ||
| id=data_point.id, | ||
| text=getattr(data_point, data_point.metadata["index_fields"][0]), | ||
| ) | ||
| for data_point in data_points | ||
| ], | ||
| ) | ||
|
|
||
| async def retrieve(self, collection_name: str, data_point_ids: list[str]): | ||
| """Retrieve data points by their IDs from a collection.""" | ||
| client = await self.get_connection() | ||
| collection = await client.get_collection(collection_name) | ||
| results = await collection.get(ids=data_point_ids, include=["metadatas"]) | ||
|
|
||
| return [ | ||
| ScoredResult( | ||
| id=parse_id(id), | ||
| payload=restore_data_from_chroma(metadata), | ||
| score=0, | ||
| ) | ||
| for id, metadata in zip(results["ids"], results["metadatas"]) | ||
| ] | ||
|
|
||
| async def get_distance_from_collection_elements( | ||
| self, collection_name: str, query_text: str = None, query_vector: List[float] = None | ||
| ): | ||
| """Calculate distance between query and all elements in a collection.""" | ||
| if query_text is None and query_vector is None: | ||
| raise InvalidValueError(message="One of query_text or query_vector must be provided!") | ||
|
|
||
| if query_text and not query_vector: | ||
| query_vector = (await self.embedding_engine.embed_text([query_text]))[0] | ||
|
|
||
| client = await self.get_connection() | ||
| try: | ||
| collection = await client.get_collection(collection_name) | ||
|
|
||
| collection_count = await collection.count() | ||
|
|
||
| results = await collection.query( | ||
| query_embeddings=[query_vector], | ||
| include=["metadatas", "distances"], | ||
| n_results=collection_count | ||
| ) | ||
|
|
||
| result_values = [] | ||
| for i, (id, metadata, distance) in enumerate(zip( | ||
| results["ids"][0], results["metadatas"][0], results["distances"][0] | ||
| )): | ||
| result_values.append({ | ||
| "id": parse_id(id), | ||
| "payload": restore_data_from_chroma(metadata), | ||
| "_distance": distance | ||
| }) | ||
|
|
||
| normalized_values = normalize_distances(result_values) | ||
|
|
||
| scored_results = [] | ||
| for i, result in enumerate(result_values): | ||
| scored_results.append( | ||
| ScoredResult( | ||
| id=result["id"], | ||
| payload=result["payload"], | ||
| score=normalized_values[i], | ||
| ) | ||
| ) | ||
|
|
||
| return scored_results | ||
| except Exception as e: | ||
| logger.error(f"Error in get_distance_from_collection_elements: {str(e)}") | ||
| return [] | ||
|
|
||
| async def search( | ||
| self, | ||
| collection_name: str, | ||
| query_text: str = None, | ||
| query_vector: List[float] = None, | ||
| limit: int = 5, | ||
| with_vector: bool = False, | ||
| normalized: bool = True, | ||
| ): | ||
| """Search for similar items in a collection using text or vector query.""" | ||
| if query_text is None and query_vector is None: | ||
| raise InvalidValueError(message="One of query_text or query_vector must be provided!") | ||
|
|
||
| if query_text and not query_vector: | ||
| query_vector = (await self.embedding_engine.embed_text([query_text]))[0] | ||
|
|
||
| try: | ||
| client = await self.get_connection() | ||
| collection = await client.get_collection(collection_name) | ||
|
|
||
| results = await collection.query( | ||
| query_embeddings=[query_vector], | ||
| include=["metadatas", "distances", "embeddings"] if with_vector else ["metadatas", "distances"], | ||
| n_results=limit | ||
| ) | ||
|
|
||
| vector_list = [] | ||
| for i, (id, metadata, distance) in enumerate(zip( | ||
| results["ids"][0], results["metadatas"][0], results["distances"][0] | ||
| )): | ||
| item = { | ||
| "id": parse_id(id), | ||
| "payload": restore_data_from_chroma(metadata), | ||
| "_distance": distance | ||
| } | ||
|
|
||
| if with_vector and "embeddings" in results: | ||
| item["vector"] = results["embeddings"][0][i] | ||
|
|
||
| vector_list.append(item) | ||
|
|
||
| # Normalize vector distance | ||
| normalized_values = normalize_distances(vector_list) | ||
| for i in range(len(normalized_values)): | ||
| vector_list[i]["score"] = normalized_values[i] | ||
|
|
||
| # Create and return ScoredResult objects | ||
| return [ | ||
| ScoredResult( | ||
| id=row["id"], | ||
| payload=row["payload"], | ||
| score=row["score"], | ||
| vector=row.get("vector") if with_vector else None | ||
| ) | ||
| for row in vector_list | ||
| ] | ||
| except Exception as e: | ||
| logger.error(f"Error in search: {str(e)}") | ||
| return [] | ||
|
|
||
borisarzentar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| async def batch_search( | ||
| self, collection_name: str, query_texts: List[str], limit: int = 5, with_vectors: bool = False | ||
| ): | ||
| """Perform multiple searches in a single request for efficiency.""" | ||
| query_vectors = await self.embed_data(query_texts) | ||
|
|
||
| client = await self.get_connection() | ||
| collection = await client.get_collection(collection_name) | ||
|
|
||
| results = await collection.query( | ||
| query_embeddings=query_vectors, | ||
| include=["metadatas", "distances", "embeddings"] if with_vectors else ["metadatas", "distances"], | ||
| n_results=limit | ||
| ) | ||
|
|
||
| all_results = [] | ||
| for i in range(len(query_texts)): | ||
| query_results = [] | ||
|
|
||
| for j, (id, metadata, distance) in enumerate(zip( | ||
| results["ids"][i], results["metadatas"][i], results["distances"][i] | ||
| )): | ||
| similarity = 1.0 - min(distance, 2.0) / 2.0 | ||
borisarzentar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
borisarzentar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| result = ScoredResult( | ||
| id=parse_id(id), | ||
| payload=metadata, | ||
| score=similarity, | ||
| ) | ||
|
|
||
| if with_vectors and "embeddings" in results: | ||
| result.vector = results["embeddings"][i][j] | ||
|
|
||
| query_results.append(result) | ||
|
|
||
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| all_results.append(query_results) | ||
|
|
||
| return all_results | ||
|
|
||
| async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): | ||
| """Remove data points from a collection by their IDs.""" | ||
| client = await self.get_connection() | ||
| collection = await client.get_collection(collection_name) | ||
| await collection.delete(ids=data_point_ids) | ||
| return True | ||
|
|
||
| async def prune(self): | ||
| """Delete all collections in the ChromaDB database.""" | ||
| client = await self.get_connection() | ||
| collections = await client.list_collections() | ||
| for collection_name in collections: | ||
| await client.delete_collection(collection_name) | ||
| return True | ||
|
|
||
| async def get_table_names(self): | ||
borisarzentar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Get a list of all collection names in the database.""" | ||
| client = await self.get_connection() | ||
| return await client.list_collections() | ||
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.