Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 25 additions & 69 deletions cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from cognee.exceptions import InvalidValueError
from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult

from ..embeddings.EmbeddingEngine import EmbeddingEngine
Expand Down Expand Up @@ -108,9 +109,7 @@ async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)

async def has_collection(self, collection_name: str) -> bool:
client = await self.get_connection()
collections = await client.list_collections()
# In ChromaDB v0.6.0, list_collections returns collection names directly
collections = await self.get_collection_names()
return collection_name in collections

async def create_collection(self, collection_name: str, payload_schema=None):
Expand All @@ -119,13 +118,17 @@ async def create_collection(self, collection_name: str, payload_schema=None):
if not await self.has_collection(collection_name):
await client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})

async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
async def get_collection(self, collection_name: str) -> AsyncHttpClient:
if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")

client = await self.get_connection()
return await client.get_collection(collection_name)

if not await self.has_collection(collection_name):
await self.create_collection(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
await self.create_collection(collection_name)

collection = await client.get_collection(collection_name)
collection = await self.get_collection(collection_name)

texts = [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
embeddings = await self.embed_data(texts)
Expand Down Expand Up @@ -161,8 +164,7 @@ async def index_data_points(

async def retrieve(self, collection_name: str, data_point_ids: list[str]):
"""Retrieve data points by their IDs from a collection."""
client = await self.get_connection()
collection = await client.get_collection(collection_name)
collection = await self.get_collection(collection_name)
results = await collection.get(ids=data_point_ids, include=["metadatas"])

return [
Expand All @@ -174,62 +176,12 @@ async def retrieve(self, collection_name: str, data_point_ids: list[str]):
for id, metadata in zip(results["ids"], results["metadatas"])
]

async def get_distance_from_collection_elements(
self, collection_name: str, query_text: str = None, query_vector: List[float] = None
):
"""Calculate distance between query and all elements in a collection."""
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")

if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

client = await self.get_connection()
try:
collection = await client.get_collection(collection_name)

collection_count = await collection.count()

results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances"],
n_results=collection_count,
)

result_values = []
for i, (id, metadata, distance) in enumerate(
zip(results["ids"][0], results["metadatas"][0], results["distances"][0])
):
result_values.append(
{
"id": parse_id(id),
"payload": restore_data_from_chroma(metadata),
"_distance": distance,
}
)

normalized_values = normalize_distances(result_values)

scored_results = []
for i, result in enumerate(result_values):
scored_results.append(
ScoredResult(
id=result["id"],
payload=result["payload"],
score=normalized_values[i],
)
)

return scored_results
except Exception:
return []

async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 5,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
):
Expand All @@ -241,8 +193,10 @@ async def search(
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

try:
client = await self.get_connection()
collection = await client.get_collection(collection_name)
collection = await self.get_collection(collection_name)

if limit == 0:
limit = await collection.count()

results = await collection.query(
query_embeddings=[query_vector],
Expand Down Expand Up @@ -296,8 +250,7 @@ async def batch_search(
"""Perform multiple searches in a single request for efficiency."""
query_vectors = await self.embed_data(query_texts)

client = await self.get_connection()
collection = await client.get_collection(collection_name)
collection = await self.get_collection(collection_name)

results = await collection.query(
query_embeddings=query_vectors,
Expand Down Expand Up @@ -346,20 +299,23 @@ async def batch_search(

async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
"""Remove data points from a collection by their IDs."""
client = await self.get_connection()
collection = await client.get_collection(collection_name)
collection = await self.get_collection(collection_name)
await collection.delete(ids=data_point_ids)
return True

async def prune(self):
"""Delete all collections in the ChromaDB database."""
client = await self.get_connection()
collections = await client.list_collections()
collections = await self.list_collections()
for collection_name in collections:
await client.delete_collection(collection_name)
return True

async def get_collection_names(self):
"""Get a list of all collection names in the database."""
client = await self.get_connection()
return await client.list_collections()
collections = await client.list_collections()
return [
collection.name if hasattr(collection, "name") else collection["name"]
for collection in collections
]
64 changes: 14 additions & 50 deletions cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints

import lancedb
from lancedb.pydantic import LanceModel, Vector
from pydantic import BaseModel
Expand Down Expand Up @@ -76,9 +75,14 @@ class LanceDataPoint(LanceModel):
exist_ok=True,
)

async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
async def get_collection(self, collection_name: str):
if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")

connection = await self.get_connection()
return await connection.open_table(collection_name)

async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
payload_schema = type(data_points[0])

if not await self.has_collection(collection_name):
Expand All @@ -87,7 +91,7 @@ async def create_data_points(self, collection_name: str, data_points: list[DataP
payload_schema,
)

collection = await connection.open_table(collection_name)
collection = await self.get_collection(collection_name)

data_vectors = await self.embed_data(
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
Expand Down Expand Up @@ -125,8 +129,7 @@ def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> Lance
)

async def retrieve(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
collection = await self.get_collection(collection_name)

if len(data_point_ids) == 1:
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
Expand All @@ -142,48 +145,12 @@ async def retrieve(self, collection_name: str, data_point_ids: list[str]):
for result in results.to_dict("index").values()
]

async def get_distance_from_collection_elements(
self, collection_name: str, query_text: str = None, query_vector: List[float] = None
):
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")

if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

connection = await self.get_connection()

try:
collection = await connection.open_table(collection_name)

collection_size = await collection.count_rows()

results = (
await collection.vector_search(query_vector).limit(collection_size).to_pandas()
)

result_values = list(results.to_dict("index").values())

normalized_values = normalize_distances(result_values)

return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
]
except ValueError:
# Ignore if collection doesn't exist
return []

async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 5,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
):
Expand All @@ -193,12 +160,10 @@ async def search(
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

connection = await self.get_connection()
collection = await self.get_collection(collection_name)

try:
collection = await connection.open_table(collection_name)
except ValueError:
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
if limit == 0:
limit = await collection.count_rows()

results = await collection.vector_search(query_vector).limit(limit).to_pandas()

Expand Down Expand Up @@ -242,8 +207,7 @@ async def batch_search(
def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
async def _delete_data_points():
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
collection = await self.get_collection(collection_name)

# Delete one at a time to avoid commit conflicts
for data_point_id in data_point_ids:
Expand Down Expand Up @@ -288,7 +252,7 @@ async def prune(self):
collection_names = await connection.table_names()

for collection_name in collection_names:
collection = await connection.open_table(collection_name)
collection = await self.get_collection(collection_name)
await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name)

Expand Down
31 changes: 22 additions & 9 deletions cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import asyncio
from cognee.shared.logging_utils import get_logger
from uuid import UUID
from typing import List, Optional

from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError

from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
Expand Down Expand Up @@ -96,7 +97,7 @@ async def create_collection(
raise e

async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
from pymilvus import MilvusException
from pymilvus import MilvusException, exceptions

client = self.get_milvus_client()
data_vectors = await self.embed_data(
Expand All @@ -118,6 +119,10 @@ async def create_data_points(self, collection_name: str, data_points: List[DataP
f"Inserted {result.get('insert_count', 0)} data points into collection '{collection_name}'."
)
return result
except exceptions.CollectionNotExistException as error:
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist!"
) from error
except MilvusException as e:
logger.error(
f"Error inserting data points into collection '{collection_name}': {str(e)}"
Expand All @@ -140,8 +145,8 @@ async def index_data_points(
collection_name = f"{index_name}_{index_property_name}"
await self.create_data_points(collection_name, formatted_data_points)

async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from pymilvus import MilvusException
async def retrieve(self, collection_name: str, data_point_ids: list[UUID]):
from pymilvus import MilvusException, exceptions

client = self.get_milvus_client()
try:
Expand All @@ -153,6 +158,10 @@ async def retrieve(self, collection_name: str, data_point_ids: list[str]):
output_fields=["*"],
)
return results
except exceptions.CollectionNotExistException as error:
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist!"
) from error
except MilvusException as e:
logger.error(
f"Error retrieving data points from collection '{collection_name}': {str(e)}"
Expand All @@ -164,10 +173,10 @@ async def search(
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 5,
limit: int = 15,
with_vector: bool = False,
):
from pymilvus import MilvusException
from pymilvus import MilvusException, exceptions

client = self.get_milvus_client()
if query_text is None and query_vector is None:
Expand All @@ -184,7 +193,7 @@ async def search(
collection_name=collection_name,
data=[query_vector],
anns_field="vector",
limit=limit,
limit=limit if limit > 0 else None,
output_fields=output_fields,
search_params={
"metric_type": "COSINE",
Expand All @@ -199,6 +208,10 @@ async def search(
)
for result in results[0]
]
except exceptions.CollectionNotExistException as error:
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist!"
) from error
except MilvusException as e:
logger.error(f"Error during search in collection '{collection_name}': {str(e)}")
raise e
Expand All @@ -220,7 +233,7 @@ async def batch_search(
]
)

async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
from pymilvus import MilvusException

client = self.get_milvus_client()
Expand Down
Loading
Loading