Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
8afbecf
merge
Vasilije1990 Apr 18, 2025
7bdb2ab
Merge branch 'dev' of github.com:topoteretes/cognee into dev
Vasilije1990 Apr 18, 2025
b35e047
Merge branch 'dev' of github.com:topoteretes/cognee into dev
Vasilije1990 Apr 19, 2025
2a485f9
Merge branch 'dev' of github.com:topoteretes/cognee into dev
Vasilije1990 Apr 19, 2025
f072e8d
Merge branch 'dev' of github.com:topoteretes/cognee into dev
Vasilije1990 Apr 20, 2025
98a1b79
fix: run cognee in Docker [COG-1961] (#775)
dexters1 Apr 23, 2025
17a77c5
Merge remote-tracking branch 'origin/main' into dev
borisarzentar Apr 24, 2025
80e5edc
Merge branch 'dev' of github.com:topoteretes/cognee into dev
Vasilije1990 Apr 25, 2025
5aca3f0
fix: Doesn't drop entire PG database, just cleans public schema - Cog…
Vasilije1990 Apr 25, 2025
0a9e1a4
Merge branch 'dev' of github.com:topoteretes/cognee into dev
Vasilije1990 Apr 26, 2025
79921f8
Merge remote-tracking branch 'origin/main' into dev
Vasilije1990 Apr 26, 2025
6109bf5
feat: Add uv and poetry support to Cognee [COG-1572] (#780)
dexters1 Apr 28, 2025
a627841
fix: networkx id type change [COG-1876] (#786)
dexters1 Apr 28, 2025
c4915a4
Mcp SSE support [COG-1781] (#785)
dexters1 Apr 28, 2025
773752a
feat: Add detailed log handling options for Cognee exceptions [COG-19…
dexters1 Apr 28, 2025
66ecd35
fix: s3fs version fix [COG-2025] (#798)
dexters1 Apr 30, 2025
ad943d0
docs: add cognee UI (#799)
hande-k Apr 30, 2025
cd9c489
feat: remove get_distance_from_collection_names and adapt search (#766)
borisarzentar Apr 30, 2025
7db7422
docs: update colab demo (#795)
hande-k Apr 30, 2025
5970d96
feat: pass context argument to tasks that require it (#788)
borisarzentar Apr 30, 2025
9729547
feat: abstract logging tool integration (#787)
borisarzentar Apr 30, 2025
d417c71
merged
Vasilije1990 May 8, 2025
5d415dc
feat: Add Memgraph integration (#751)
matea16 May 10, 2025
34b95b6
refactor: Handle boto3 s3fs dependencies better (#809)
dexters1 May 10, 2025
a78fec3
fix: Fixes collection search limit in brute force triplet search (#814)
hajdul88 May 12, 2025
9c131f0
refactor: Update lanceDB and change delete to work async (#770)
dexters1 May 12, 2025
f93463e
fix: make onnxruntime flexible (#815)
borisarzentar May 13, 2025
8ea0097
fix: graphiti example (#816)
soobrosa May 13, 2025
13bb244
feat: Create notebook to show how to compute ranks from graph (#771)
diegoabt May 13, 2025
966e337
feat: add MCP check status tool [COG-1784] (#793)
dexters1 May 13, 2025
e3121f5
docs: Update log level of CollectionNotFoundError (#819)
dexters1 May 13, 2025
91f3cd9
fix: notebooks (#818)
soobrosa May 13, 2025
1e7b56f
Merge branch 'dev' of github.com:topoteretes/cognee into dev
Vasilije1990 May 13, 2025
0f3522e
fix: cognee docker image (#820)
borisarzentar May 15, 2025
badd73c
Merge branch 'dev' of github.com:topoteretes/cognee into dev
Vasilije1990 May 15, 2025
c058219
Clean up core cognee repo
Vasilije1990 May 15, 2025
729cb9b
Revert "Clean up core cognee repo"
Vasilije1990 May 15, 2025
ad0bb0c
version: v0.1.40 (#825)
borisarzentar May 15, 2025
7ac5761
Merge branch 'main' into dev
Vasilije1990 May 15, 2025
f9f18d1
feat: Add columns as nodes in relational db migration (#826)
dexters1 May 15, 2025
8178b72
fix: exclude files from build (#828)
borisarzentar May 15, 2025
1dd179b
feat: OpenAI compatible route /api/v1/responses (#792)
dm1tryG May 16, 2025
3b07f3c
feat: Test db examples (#817)
hande-k May 16, 2025
4371b9d
fix: 812 anthropic fix (#822)
Vasilije1990 May 16, 2025
5cf14eb
fix: Mcp small updates (#831)
Vasilije1990 May 16, 2025
86efeee
fix: pipeline run status migration (#836)
borisarzentar May 19, 2025
3ed9504
feat: Add developer rules (#833)
Vasilije1990 May 19, 2025
a874988
fix: Fixes pipeline run status migration (#838)
hajdul88 May 19, 2025
f8f7877
Fix: Fixes graph completion search limit (#839)
hajdul88 May 19, 2025
5c36a5d
feat: Adds modal parallel evaluation for retriever development (#844)
hajdul88 May 20, 2025
9d9ea63
fix: use default threading in Fastembed (#846)
lxobr May 20, 2025
4c52ef6
feat: added util logger OS (#841)
Vasilije1990 May 20, 2025
7eee769
Feat: Adds dashboard application to parallel modal evals (#847)
hajdul88 May 21, 2025
94c785d
fix: hotfix the file uploader in the delete router. (#842)
soobrosa May 21, 2025
08bc472
Feat: Removes hardcoded user prompts from adapters
hajdul88 May 21, 2025
e0798ff
Feat: Adds chain of thought retriever (#864)
hajdul88 May 22, 2025
d663921
Feat: Adds context extension search (#865)
hajdul88 May 22, 2025
b71b704
chore: Move files (#848)
Vasilije1990 May 22, 2025
4650c9c
chore: add neo4j to mcp dependencies (#867)
hande-k May 23, 2025
834d959
Readme local install (#872)
dexters1 May 26, 2025
965033e
Feat: Adds subgraph retriever to graph based completion searches (#874)
hajdul88 May 27, 2025
ec68e99
Fix: removes ontology resolver initialization at import (#876)
hajdul88 May 27, 2025
bb68d6a
Docstring tasks. (#878)
soobrosa May 27, 2025
ff997f4
Docstring modules. (#877)
soobrosa May 27, 2025
b5ebed1
Docstring infrastructure. (#880)
soobrosa May 28, 2025
b94c846
Fix: Disable faulty graph metrics calculation in demos (#888)
hajdul88 May 29, 2025
d8ef290
feat: removes unused properies from node and edge pydantic models (#884)
hajdul88 May 30, 2025
d91602e
0.1.41 Release fixes (#889)
borisarzentar May 30, 2025
5a04421
version: v0.1.41 (#890)
borisarzentar May 30, 2025
57b0e0e
Merge with main (#892)
borisarzentar May 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: remove get_distance_from_collection_names and adapt search (#766)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
  • Loading branch information
borisarzentar authored Apr 30, 2025
commit cd9c4897a4c2f56d7205460cc07787807a539ed3
94 changes: 25 additions & 69 deletions cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from cognee.exceptions import InvalidValueError
from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult

from ..embeddings.EmbeddingEngine import EmbeddingEngine
Expand Down Expand Up @@ -108,9 +109,7 @@ async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)

async def has_collection(self, collection_name: str) -> bool:
client = await self.get_connection()
collections = await client.list_collections()
# In ChromaDB v0.6.0, list_collections returns collection names directly
collections = await self.get_collection_names()
return collection_name in collections

async def create_collection(self, collection_name: str, payload_schema=None):
Expand All @@ -119,13 +118,17 @@ async def create_collection(self, collection_name: str, payload_schema=None):
if not await self.has_collection(collection_name):
await client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})

async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
async def get_collection(self, collection_name: str) -> AsyncHttpClient:
if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")

client = await self.get_connection()
return await client.get_collection(collection_name)

if not await self.has_collection(collection_name):
await self.create_collection(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
await self.create_collection(collection_name)

collection = await client.get_collection(collection_name)
collection = await self.get_collection(collection_name)

texts = [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
embeddings = await self.embed_data(texts)
Expand Down Expand Up @@ -161,8 +164,7 @@ async def index_data_points(

async def retrieve(self, collection_name: str, data_point_ids: list[str]):
"""Retrieve data points by their IDs from a collection."""
client = await self.get_connection()
collection = await client.get_collection(collection_name)
collection = await self.get_collection(collection_name)
results = await collection.get(ids=data_point_ids, include=["metadatas"])

return [
Expand All @@ -174,62 +176,12 @@ async def retrieve(self, collection_name: str, data_point_ids: list[str]):
for id, metadata in zip(results["ids"], results["metadatas"])
]

async def get_distance_from_collection_elements(
self, collection_name: str, query_text: str = None, query_vector: List[float] = None
):
"""Calculate distance between query and all elements in a collection."""
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")

if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

client = await self.get_connection()
try:
collection = await client.get_collection(collection_name)

collection_count = await collection.count()

results = await collection.query(
query_embeddings=[query_vector],
include=["metadatas", "distances"],
n_results=collection_count,
)

result_values = []
for i, (id, metadata, distance) in enumerate(
zip(results["ids"][0], results["metadatas"][0], results["distances"][0])
):
result_values.append(
{
"id": parse_id(id),
"payload": restore_data_from_chroma(metadata),
"_distance": distance,
}
)

normalized_values = normalize_distances(result_values)

scored_results = []
for i, result in enumerate(result_values):
scored_results.append(
ScoredResult(
id=result["id"],
payload=result["payload"],
score=normalized_values[i],
)
)

return scored_results
except Exception:
return []

async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 5,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
):
Expand All @@ -241,8 +193,10 @@ async def search(
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

try:
client = await self.get_connection()
collection = await client.get_collection(collection_name)
collection = await self.get_collection(collection_name)

if limit == 0:
limit = await collection.count()

results = await collection.query(
query_embeddings=[query_vector],
Expand Down Expand Up @@ -296,8 +250,7 @@ async def batch_search(
"""Perform multiple searches in a single request for efficiency."""
query_vectors = await self.embed_data(query_texts)

client = await self.get_connection()
collection = await client.get_collection(collection_name)
collection = await self.get_collection(collection_name)

results = await collection.query(
query_embeddings=query_vectors,
Expand Down Expand Up @@ -346,20 +299,23 @@ async def batch_search(

async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
"""Remove data points from a collection by their IDs."""
client = await self.get_connection()
collection = await client.get_collection(collection_name)
collection = await self.get_collection(collection_name)
await collection.delete(ids=data_point_ids)
return True

async def prune(self):
"""Delete all collections in the ChromaDB database."""
client = await self.get_connection()
collections = await client.list_collections()
collections = await self.list_collections()
for collection_name in collections:
await client.delete_collection(collection_name)
return True

async def get_collection_names(self):
"""Get a list of all collection names in the database."""
client = await self.get_connection()
return await client.list_collections()
collections = await client.list_collections()
return [
collection.name if hasattr(collection, "name") else collection["name"]
for collection in collections
]
64 changes: 14 additions & 50 deletions cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints

import lancedb
from lancedb.pydantic import LanceModel, Vector
from pydantic import BaseModel
Expand Down Expand Up @@ -76,9 +75,14 @@ class LanceDataPoint(LanceModel):
exist_ok=True,
)

async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
async def get_collection(self, collection_name: str):
if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")

connection = await self.get_connection()
return await connection.open_table(collection_name)

async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
payload_schema = type(data_points[0])

if not await self.has_collection(collection_name):
Expand All @@ -87,7 +91,7 @@ async def create_data_points(self, collection_name: str, data_points: list[DataP
payload_schema,
)

collection = await connection.open_table(collection_name)
collection = await self.get_collection(collection_name)

data_vectors = await self.embed_data(
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
Expand Down Expand Up @@ -125,8 +129,7 @@ def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> Lance
)

async def retrieve(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
collection = await self.get_collection(collection_name)

if len(data_point_ids) == 1:
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
Expand All @@ -142,48 +145,12 @@ async def retrieve(self, collection_name: str, data_point_ids: list[str]):
for result in results.to_dict("index").values()
]

async def get_distance_from_collection_elements(
self, collection_name: str, query_text: str = None, query_vector: List[float] = None
):
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")

if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

connection = await self.get_connection()

try:
collection = await connection.open_table(collection_name)

collection_size = await collection.count_rows()

results = (
await collection.vector_search(query_vector).limit(collection_size).to_pandas()
)

result_values = list(results.to_dict("index").values())

normalized_values = normalize_distances(result_values)

return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
]
except ValueError:
# Ignore if collection doesn't exist
return []

async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 5,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
):
Expand All @@ -193,12 +160,10 @@ async def search(
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]

connection = await self.get_connection()
collection = await self.get_collection(collection_name)

try:
collection = await connection.open_table(collection_name)
except ValueError:
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
if limit == 0:
limit = await collection.count_rows()

results = await collection.vector_search(query_vector).limit(limit).to_pandas()

Expand Down Expand Up @@ -242,8 +207,7 @@ async def batch_search(
def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
async def _delete_data_points():
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
collection = await self.get_collection(collection_name)

# Delete one at a time to avoid commit conflicts
for data_point_id in data_point_ids:
Expand Down Expand Up @@ -288,7 +252,7 @@ async def prune(self):
collection_names = await connection.table_names()

for collection_name in collection_names:
collection = await connection.open_table(collection_name)
collection = await self.get_collection(collection_name)
await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name)

Expand Down
31 changes: 22 additions & 9 deletions cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import asyncio
from cognee.shared.logging_utils import get_logger
from uuid import UUID
from typing import List, Optional

from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError

from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
Expand Down Expand Up @@ -96,7 +97,7 @@ async def create_collection(
raise e

async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
from pymilvus import MilvusException
from pymilvus import MilvusException, exceptions

client = self.get_milvus_client()
data_vectors = await self.embed_data(
Expand All @@ -118,6 +119,10 @@ async def create_data_points(self, collection_name: str, data_points: List[DataP
f"Inserted {result.get('insert_count', 0)} data points into collection '{collection_name}'."
)
return result
except exceptions.CollectionNotExistException as error:
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist!"
) from error
except MilvusException as e:
logger.error(
f"Error inserting data points into collection '{collection_name}': {str(e)}"
Expand All @@ -140,8 +145,8 @@ async def index_data_points(
collection_name = f"{index_name}_{index_property_name}"
await self.create_data_points(collection_name, formatted_data_points)

async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from pymilvus import MilvusException
async def retrieve(self, collection_name: str, data_point_ids: list[UUID]):
from pymilvus import MilvusException, exceptions

client = self.get_milvus_client()
try:
Expand All @@ -153,6 +158,10 @@ async def retrieve(self, collection_name: str, data_point_ids: list[str]):
output_fields=["*"],
)
return results
except exceptions.CollectionNotExistException as error:
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist!"
) from error
except MilvusException as e:
logger.error(
f"Error retrieving data points from collection '{collection_name}': {str(e)}"
Expand All @@ -164,10 +173,10 @@ async def search(
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 5,
limit: int = 15,
with_vector: bool = False,
):
from pymilvus import MilvusException
from pymilvus import MilvusException, exceptions

client = self.get_milvus_client()
if query_text is None and query_vector is None:
Expand All @@ -184,7 +193,7 @@ async def search(
collection_name=collection_name,
data=[query_vector],
anns_field="vector",
limit=limit,
limit=limit if limit > 0 else None,
output_fields=output_fields,
search_params={
"metric_type": "COSINE",
Expand All @@ -199,6 +208,10 @@ async def search(
)
for result in results[0]
]
except exceptions.CollectionNotExistException as error:
raise CollectionNotFoundError(
f"Collection '{collection_name}' does not exist!"
) from error
except MilvusException as e:
logger.error(f"Error during search in collection '{collection_name}': {str(e)}")
raise e
Expand All @@ -220,7 +233,7 @@ async def batch_search(
]
)

async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
from pymilvus import MilvusException

client = self.get_milvus_client()
Expand Down
Loading
Loading