Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8925442
Add Support ChromaDB
dm1tryG Mar 8, 2025
0973e42
Update lock file deps
dm1tryG Mar 9, 2025
3021f1a
Merge branch 'dev' into feature/support-chromadb
hajdul88 Mar 11, 2025
1feee13
fix: fixes ruff format
hajdul88 Mar 11, 2025
6a2ef72
feat: adds chroma_db_test to github actions
hajdul88 Mar 11, 2025
9483b6e
deletes volume from yml
hajdul88 Mar 11, 2025
6879b7a
updates health check
hajdul88 Mar 11, 2025
26d5d7d
healthcheck update
hajdul88 Mar 11, 2025
04e4860
deletes volume once again
hajdul88 Mar 11, 2025
21ba967
updates chroma test
hajdul88 Mar 11, 2025
c6532d0
updates image and adds volume
hajdul88 Mar 11, 2025
6df6cc8
deletes healthcheck
hajdul88 Mar 11, 2025
4bdee26
changes port
hajdul88 Mar 11, 2025
570204d
removes hardcoded config from test
hajdul88 Mar 11, 2025
bb469aa
adds hardcoded values for local test
hajdul88 Mar 11, 2025
3dfc41a
Merge branch 'dev' into support_chromadb_tests_fix
Vasilije1990 Mar 11, 2025
b56332b
Merge remote-tracking branch 'origin/dev' into support_chromadb_tests…
hajdul88 Mar 12, 2025
2970b23
updates poetry lock file
hajdul88 Mar 12, 2025
b28624c
Fix original_key splits
dm1tryG Mar 12, 2025
a299bd6
give back the distance from all the elements
dm1tryG Mar 12, 2025
645be10
uncomment chroma profile
dm1tryG Mar 12, 2025
eae2f76
rm print
dm1tryG Mar 12, 2025
dc73728
change batch_search with normalize_distances
dm1tryG Mar 12, 2025
8b874ad
update docker-compose.yml
dm1tryG Mar 12, 2025
3a682fe
rm test_local_file_deletion
dm1tryG Mar 12, 2025
53428c6
chromadb as extra in pyproject.toml
dm1tryG Mar 12, 2025
18d818f
rename get_table_names
dm1tryG Mar 12, 2025
3362e6d
Merge branch 'feature/support-chromadb' into support_chromadb_tests_fix
hajdul88 Mar 13, 2025
9954a09
chore: updates naming based on comments
hajdul88 Mar 13, 2025
e60c398
chore: updates naming based on comments
hajdul88 Mar 13, 2025
990a04c
Merge branch 'dev' into feature/support-chromadb
hajdul88 Mar 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ GRAPH_DATABASE_URL=
GRAPH_DATABASE_USERNAME=
GRAPH_DATABASE_PASSWORD=

# "qdrant", "pgvector", "weaviate", "milvus" or "lancedb"
# "qdrant", "pgvector", "weaviate", "milvus", "lancedb" or "chromadb"
VECTOR_DB_PROVIDER="lancedb"
# Not needed if using "lancedb" or "pgvector"
VECTOR_DB_URL=
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,6 @@ node_modules/

# Evals
SWE-bench_testsample/

# ChromaDB Data
.chromadb_data/
2 changes: 1 addition & 1 deletion cognee/api/v1/settings/routers/get_settings_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LLMConfigInputDTO(InDTO):


class VectorDBConfigInputDTO(InDTO):
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"], Literal["pgvector"]]
provider: Union[Literal["lancedb"], Literal["chromadb"], Literal["qdrant"], Literal["weaviate"], Literal["pgvector"]]
url: str
api_key: str

Expand Down
349 changes: 349 additions & 0 deletions cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,349 @@
import logging
from typing import Dict, List, Optional, Any
import os
import json
from uuid import UUID

from chromadb import AsyncHttpClient, Settings

from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult

from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances

logger = logging.getLogger("ChromaDBAdapter")


class IndexSchema(DataPoint):
text: str

metadata: dict = {"index_fields": ["text"]}

def model_dump(self):
data = super().model_dump()
return process_data_for_chroma(data)


def process_data_for_chroma(data):
"""Convert complex data types to a format suitable for ChromaDB storage."""
processed_data = {}
for key, value in data.items():
if isinstance(value, UUID):
processed_data[key] = str(value)
elif isinstance(value, dict):
# Store dictionaries as JSON strings with special prefix
processed_data[f"{key}__dict"] = json.dumps(value)
elif isinstance(value, list):
# Store lists as JSON strings with special prefix
processed_data[f"{key}__list"] = json.dumps(value)
elif isinstance(value, (str, int, float, bool)) or value is None:
processed_data[key] = value
else:
processed_data[key] = str(value)
return processed_data


def restore_data_from_chroma(data):
"""Restore original data structure from ChromaDB storage format."""
restored_data = {}
dict_keys = []
list_keys = []

# First, identify all special keys
for key in data.keys():
if key.endswith("__dict"):
dict_keys.append(key)
elif key.endswith("__list"):
list_keys.append(key)
else:
restored_data[key] = data[key]

# Process dictionary fields
for key in dict_keys:
original_key = key[:-6] # Remove '__dict' suffix
try:
restored_data[original_key] = json.loads(data[key])
except Exception as e:
logger.debug(f"Error restoring dictionary from JSON: {e}")
restored_data[key] = data[key]

# Process list fields
for key in list_keys:
original_key = key[:-6] # Remove '__list' suffix
try:
restored_data[original_key] = json.loads(data[key])
except Exception as e:
logger.debug(f"Error restoring list from JSON: {e}")
restored_data[key] = data[key]

return restored_data


class ChromaDBAdapter(VectorDBInterface):
name = "ChromaDB"
url: str
api_key: str
connection: AsyncHttpClient = None

def __init__(self, url: Optional[str], api_key: Optional[str], embedding_engine: EmbeddingEngine):
self.embedding_engine = embedding_engine
self.url = url
self.api_key = api_key

async def get_connection(self) -> AsyncHttpClient:
if self.connection is None:
settings = Settings(
chroma_client_auth_provider="token",
chroma_client_auth_credentials=self.api_key
)
self.connection = await AsyncHttpClient(host=self.url, settings=settings)

return self.connection

async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)

async def has_collection(self, collection_name: str) -> bool:
client = await self.get_connection()
collections = await client.list_collections()
# In ChromaDB v0.6.0, list_collections returns collection names directly
return collection_name in collections

async def create_collection(self, collection_name: str, payload_schema=None):
client = await self.get_connection()

if not await self.has_collection(collection_name):
await client.create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)

async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
client = await self.get_connection()

if not await self.has_collection(collection_name):
await self.create_collection(collection_name)

collection = await client.get_collection(collection_name)

texts = [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
embeddings = await self.embed_data(texts)
ids = [str(data_point.id) for data_point in data_points]

metadatas = []
for data_point in data_points:
metadata = data_point.model_dump()
metadatas.append(process_data_for_chroma(metadata))

await collection.upsert(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=texts
)

async def create_vector_index(self, index_name: str, index_property_name: str):
"""Create a vector index as a ChromaDB collection."""
await self.create_collection(f"{index_name}_{index_property_name}")

async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
"""Index data points using the specified index property."""
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=data_point.id,
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
],
)

async def retrieve(self, collection_name: str, data_point_ids: list[str]):
"""Retrieve data points by their IDs from a collection."""
client = await self.get_connection()
collection = await client.get_collection(collection_name)
results = await collection.get(ids=data_point_ids, include=["metadatas"])

return [
ScoredResult(
id=parse_id(id),
payload=restore_data_from_chroma(metadata),
score=0,
)
for id, metadata in zip(results["ids"], results["metadatas"])
]

async def get_distance_from_collection_elements(
self, collection_name: str, query_text: str = None, query_vector: List[float] = None
):
"""Calculate distance between query and all elements in a collection."""
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")

if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

client = await self.get_connection()
try:
collection = await client.get_collection(collection_name)

collection_count = await collection.count()

results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances"],
n_results=collection_count
)

result_values = []
for i, (id, metadata, distance) in enumerate(zip(
results["ids"][0], results["metadatas"][0], results["distances"][0]
)):
result_values.append({
"id": parse_id(id),
"payload": restore_data_from_chroma(metadata),
"_distance": distance
})

normalized_values = normalize_distances(result_values)

scored_results = []
for i, result in enumerate(result_values):
scored_results.append(
ScoredResult(
id=result["id"],
payload=result["payload"],
score=normalized_values[i],
)
)

return scored_results
except Exception as e:
logger.error(f"Error in get_distance_from_collection_elements: {str(e)}")
return []

async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 5,
with_vector: bool = False,
normalized: bool = True,
):
"""Search for similar items in a collection using text or vector query."""
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")

if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

try:
client = await self.get_connection()
collection = await client.get_collection(collection_name)

results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances", "embeddings"] if with_vector else ["metadatas", "distances"],
n_results=limit
)

vector_list = []
for i, (id, metadata, distance) in enumerate(zip(
results["ids"][0], results["metadatas"][0], results["distances"][0]
)):
item = {
"id": parse_id(id),
"payload": restore_data_from_chroma(metadata),
"_distance": distance
}

if with_vector and "embeddings" in results:
item["vector"] = results["embeddings"][0][i]

vector_list.append(item)

# Normalize vector distance
normalized_values = normalize_distances(vector_list)
for i in range(len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]

# Create and return ScoredResult objects
return [
ScoredResult(
id=row["id"],
payload=row["payload"],
score=row["score"],
vector=row.get("vector") if with_vector else None
)
for row in vector_list
]
except Exception as e:
logger.error(f"Error in search: {str(e)}")
return []

async def batch_search(
self, collection_name: str, query_texts: List[str], limit: int = 5, with_vectors: bool = False
):
"""Perform multiple searches in a single request for efficiency."""
query_vectors = await self.embed_data(query_texts)

client = await self.get_connection()
collection = await client.get_collection(collection_name)

results = await collection.query(
query_embeddings=query_vectors,
include=["metadatas", "distances", "embeddings"] if with_vectors else ["metadatas", "distances"],
n_results=limit
)

all_results = []
for i in range(len(query_texts)):
query_results = []

for j, (id, metadata, distance) in enumerate(zip(
results["ids"][i], results["metadatas"][i], results["distances"][i]
)):
similarity = 1.0 - min(distance, 2.0) / 2.0

result = ScoredResult(
id=parse_id(id),
payload=metadata,
score=similarity,
)

if with_vectors and "embeddings" in results:
result.vector = results["embeddings"][i][j]

query_results.append(result)

all_results.append(query_results)

return all_results

async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
"""Remove data points from a collection by their IDs."""
client = await self.get_connection()
collection = await client.get_collection(collection_name)
await collection.delete(ids=data_point_ids)
return True

async def prune(self):
"""Delete all collections in the ChromaDB database."""
client = await self.get_connection()
collections = await client.list_collections()
for collection_name in collections:
await client.delete_collection(collection_name)
return True

async def get_table_names(self):
"""Get a list of all collection names in the database."""
client = await self.get_connection()
return await client.list_collections()
Empty file.
16 changes: 16 additions & 0 deletions cognee/infrastructure/databases/vector/create_vector_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ def create_vector_engine(
database_port=vector_db_port,
embedding_engine=embedding_engine,
)

elif vector_db_provider == "chromadb":
try:
import chromadb
except ImportError:
raise ImportError(
"ChromaDB is not installed. Please install it with 'pip install chromadb'"
)

from .chromadb.ChromaDBAdapter import ChromaDBAdapter

return ChromaDBAdapter(
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
)

else:
from .lancedb.LanceDBAdapter import LanceDBAdapter
Expand Down
Loading