-
Notifications
You must be signed in to change notification settings - Fork 966
Feat/cog 1365 unify retrievers #572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
468de67
d789dd0
49c2355
7619df2
5a5eb5e
8f0cbee
beacdea
4b71081
7631b11
62f8ac3
58c7eaf
c07cf22
2ef174a
5910fb7
2f70de4
3d0b839
4903d7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| from .code_graph_retrieval import code_graph_retrieval | ||
| from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Optional | ||
|
|
||
|
|
||
| class BaseRetriever(ABC): | ||
| """Base class for all retrieval operations.""" | ||
|
|
||
| @abstractmethod | ||
| async def get_context(self, query: str) -> Any: | ||
| """Retrieves context based on the query.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
| """Generates a response using the query and optional context.""" | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from typing import Any, Optional | ||
|
|
||
| from cognee.infrastructure.databases.vector import get_vector_engine | ||
| from cognee.modules.retrieval.base_retriever import BaseRetriever | ||
|
|
||
|
|
||
| class ChunksRetriever(BaseRetriever): | ||
| """Retriever for handling document chunk-based searches.""" | ||
|
|
||
| async def get_context(self, query: str) -> Any: | ||
| """Retrieves document chunks context based on the query.""" | ||
| vector_engine = get_vector_engine() | ||
| found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5) | ||
| return [result.payload for result in found_chunks] | ||
|
|
||
lxobr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
| """Generates a completion using document chunks context.""" | ||
| if context is None: | ||
| context = await self.get_context(query) | ||
| return context | ||
lxobr marked this conversation as resolved.
Show resolved
Hide resolved
borisarzentar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,146 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Optional, List, Dict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code retriever shouldn't be changed as we developed a new one, so be sure when resolving conflicts that this doesn't change that. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import aiofiles | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from pydantic import BaseModel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from cognee.low_level import DataPoint | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from cognee.modules.retrieval.base_retriever import BaseRetriever | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from cognee.infrastructure.databases.graph import get_graph_engine | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from cognee.infrastructure.databases.vector import get_vector_engine | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from cognee.infrastructure.llm.get_llm_client import get_llm_client | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from cognee.infrastructure.llm.prompts import read_query_prompt | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class CodeRetriever(BaseRetriever): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Retriever for handling code-based searches.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class CodeQueryInfo(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Response model for information extraction from the query""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| filenames: List[str] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sourcecode: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, limit: int = 3): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Initialize retriever with search parameters.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.limit = limit | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.file_name_collections = ["CodeFile_name"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.classes_and_functions_collections = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "ClassDefinition_source_code", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "FunctionDefinition_source_code", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Process the query using LLM to extract file names and source code parts.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| system_prompt = read_query_prompt("codegraph_retriever_system.txt") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| llm_client = get_llm_client() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return await llm_client.acreate_structured_output( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| text_input=query, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| system_prompt=system_prompt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response_model=self.CodeQueryInfo, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError("Failed to retrieve structured output from LLM") from e | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def get_context(self, query: str) -> Any: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Find relevant code files based on the query.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not query or not isinstance(query, str): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("The query must be a non-empty string.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector_engine = get_vector_engine() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| graph_engine = await get_graph_engine() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError("Database initialization error in code_graph_retriever, ") from e | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| files_and_codeparts = await self._process_query(query) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| similar_filenames = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| similar_codepieces = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for collection in self.file_name_collections: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| search_results_file = await vector_engine.search( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| collection, query, limit=self.limit | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for res in search_results_file: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| similar_filenames.append( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| {"id": res.id, "score": res.score, "payload": res.payload} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for collection in self.classes_and_functions_collections: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| search_results_code = await vector_engine.search( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| collection, query, limit=self.limit | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for res in search_results_code: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| similar_codepieces.append( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| {"id": res.id, "score": res.score, "payload": res.payload} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for collection in self.file_name_collections: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for file_from_query in files_and_codeparts.filenames: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| search_results_file = await vector_engine.search( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| collection, file_from_query, limit=self.limit | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for res in search_results_file: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| similar_filenames.append( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| {"id": res.id, "score": res.score, "payload": res.payload} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for collection in self.classes_and_functions_collections: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| search_results_code = await vector_engine.search( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| collection, files_and_codeparts.sourcecode, limit=self.limit | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for res in search_results_code: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| similar_codepieces.append( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| {"id": res.id, "score": res.score, "payload": res.payload} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+64
to
+99
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Refactor to reduce code duplication. There's significant code duplication between the two branches (when filenames/sourcecode are available vs. when they're not). Consider refactoring to avoid this duplication. - if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
- for collection in self.file_name_collections:
- search_results_file = await vector_engine.search(
- collection, query, limit=self.limit
- )
- for res in search_results_file:
- similar_filenames.append(
- {"id": res.id, "score": res.score, "payload": res.payload}
- )
-
- for collection in self.classes_and_functions_collections:
- search_results_code = await vector_engine.search(
- collection, query, limit=self.limit
- )
- for res in search_results_code:
- similar_codepieces.append(
- {"id": res.id, "score": res.score, "payload": res.payload}
- )
- else:
- for collection in self.file_name_collections:
- for file_from_query in files_and_codeparts.filenames:
- search_results_file = await vector_engine.search(
- collection, file_from_query, limit=self.limit
- )
- for res in search_results_file:
- similar_filenames.append(
- {"id": res.id, "score": res.score, "payload": res.payload}
- )
-
- for collection in self.classes_and_functions_collections:
- search_results_code = await vector_engine.search(
- collection, files_and_codeparts.sourcecode, limit=self.limit
- )
- for res in search_results_code:
- similar_codepieces.append(
- {"id": res.id, "score": res.score, "payload": res.payload}
- )
+ # Search for filenames
+ for collection in self.file_name_collections:
+ if files_and_codeparts.filenames:
+ for file_from_query in files_and_codeparts.filenames:
+ search_results_file = await vector_engine.search(
+ collection, file_from_query, limit=self.limit
+ )
+ for res in search_results_file:
+ similar_filenames.append(
+ {"id": res.id, "score": res.score, "payload": res.payload}
+ )
+ else:
+ search_results_file = await vector_engine.search(
+ collection, query, limit=self.limit
+ )
+ for res in search_results_file:
+ similar_filenames.append(
+ {"id": res.id, "score": res.score, "payload": res.payload}
+ )
+
+ # Search for code pieces
+ for collection in self.classes_and_functions_collections:
+ search_query = files_and_codeparts.sourcecode if files_and_codeparts.sourcecode else query
+ search_results_code = await vector_engine.search(
+ collection, search_query, limit=self.limit
+ )
+ for res in search_results_code:
+ similar_codepieces.append(
+ {"id": res.id, "score": res.score, "payload": res.payload}
+ )📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| file_ids = [str(item["id"]) for item in similar_filenames] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| code_ids = [str(item["id"]) for item in similar_codepieces] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| relevant_triplets = await asyncio.gather( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| *[graph_engine.get_connections(node_id) for node_id in code_ids + file_ids] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+104
to
+106
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainConsider limiting concurrent graph engine calls. Multiple graph engine calls are made in parallel without limiting their number, which could potentially overwhelm the database if there are many IDs. 🏁 Script executed: #!/bin/bash
# Check if there's any rate limiting or connection pooling for graph engine
rg -A 2 -B 2 "get_graph_engine" --type pyLength of output: 25960 Action Required: Limit Concurrent Graph Engine Calls In |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| paths = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for sublist in relevant_triplets: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for tpl in sublist: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(tpl, tuple) and len(tpl) >= 3: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "file_path" in tpl[0]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| paths.add(tpl[0]["file_path"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "file_path" in tpl[2]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| paths.add(tpl[2]["file_path"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| retrieved_files = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| read_tasks = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for file_path in paths: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def read_file(fp): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async with aiofiles.open(fp, "r", encoding="utf-8") as f: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| retrieved_files[fp] = await f.read() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"Error reading {fp}: {e}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| retrieved_files[fp] = "" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| read_tasks.append(read_file(file_path)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+121
to
+130
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Define read_file outside the loop to avoid recreation. The read_file function is redefined in every loop iteration. Define it once outside the loop to improve performance and readability. # Before the loop
+ async def read_file(fp):
+ try:
+ async with aiofiles.open(fp, "r", encoding="utf-8") as f:
+ retrieved_files[fp] = await f.read()
+ except Exception as e:
+ logger.error(f"Error reading {fp}: {e}")
+ retrieved_files[fp] = ""
+
retrieved_files = {}
read_tasks = []
for file_path in paths:
- async def read_file(fp):
- try:
- async with aiofiles.open(fp, "r", encoding="utf-8") as f:
- retrieved_files[fp] = await f.read()
- except Exception as e:
- print(f"Error reading {fp}: {e}")
- retrieved_files[fp] = ""
-
read_tasks.append(read_file(file_path))📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await asyncio.gather(*read_tasks) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+117
to
+131
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Replace print with proper logging for file read errors. Using print for error reporting is not ideal for production code. Consider using a proper logging mechanism. + import logging
+
+ # At the top of the file, after imports
+ logger = logging.getLogger(__name__)
+
# Then in the read_file function
async def read_file(fp):
try:
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
retrieved_files[fp] = await f.read()
except Exception as e:
- print(f"Error reading {fp}: {e}")
+ logger.error(f"Error reading {fp}: {e}")
retrieved_files[fp] = ""📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "name": file_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "description": file_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "content": retrieved_files[file_path], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for file_path in paths | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont know if completion makes sense here. @borisarzentar ? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Returns the code files context.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if context is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context = await self.get_context(query) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return context | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| from typing import Any, Optional | ||
|
|
||
| from cognee.infrastructure.databases.vector import get_vector_engine | ||
| from cognee.modules.retrieval.base_retriever import BaseRetriever | ||
| from cognee.modules.retrieval.utils.completion import generate_completion | ||
| from cognee.tasks.completion.exceptions import NoRelevantDataFound | ||
|
|
||
|
|
||
| class CompletionRetriever(BaseRetriever): | ||
| """Retriever for handling LLM-based completion searches.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| user_prompt_path: str = "context_for_question.txt", | ||
| system_prompt_path: str = "answer_simple_question.txt", | ||
| ): | ||
| """Initialize retriever with optional custom prompt paths.""" | ||
| self.user_prompt_path = user_prompt_path | ||
| self.system_prompt_path = system_prompt_path | ||
|
|
||
| async def get_context(self, query: str) -> Any: | ||
| """Retrieves relevant document chunks as context.""" | ||
| vector_engine = get_vector_engine() | ||
| found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know now the limit is hardcoded so its just a theoretical question. Shouldn't we outsource these to the user? Maybe not just asking |
||
| if len(found_chunks) == 0: | ||
| raise NoRelevantDataFound | ||
| return found_chunks[0].payload["text"] | ||
|
|
||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
| """Generates an LLM completion using the context.""" | ||
| if context is None: | ||
| context = await self.get_context(query) | ||
|
|
||
| completion = await generate_completion( | ||
| query=query, | ||
| context=context, | ||
| user_prompt_path=self.user_prompt_path, | ||
| system_prompt_path=self.system_prompt_path, | ||
| ) | ||
| return [completion] | ||
lxobr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| from typing import Any, Optional | ||
|
|
||
| from cognee.infrastructure.engine import ExtendableDataPoint | ||
| from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses | ||
| from cognee.modules.retrieval.base_retriever import BaseRetriever | ||
| from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search | ||
| from cognee.modules.retrieval.utils.completion import generate_completion | ||
| from cognee.tasks.completion.exceptions import NoRelevantDataFound | ||
|
|
||
|
|
||
| class GraphCompletionRetriever(BaseRetriever): | ||
| """Retriever for handling graph-based completion searches.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| user_prompt_path: str = "graph_context_for_question.txt", | ||
| system_prompt_path: str = "answer_simple_question.txt", | ||
| top_k: int = 5, | ||
| ): | ||
| """Initialize retriever with prompt paths and search parameters.""" | ||
| self.user_prompt_path = user_prompt_path | ||
| self.system_prompt_path = system_prompt_path | ||
| self.top_k = top_k | ||
|
|
||
| async def resolve_edges_to_text(self, retrieved_edges: list) -> str: | ||
| """Converts retrieved graph edges into a human-readable string format.""" | ||
| edge_strings = [] | ||
| for edge in retrieved_edges: | ||
| node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name") | ||
| node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name") | ||
| edge_string = edge.attributes["relationship_type"] | ||
| edge_str = f"{node1_string} -- {edge_string} -- {node2_string}" | ||
| edge_strings.append(edge_str) | ||
| return "\n---\n".join(edge_strings) | ||
|
|
||
| async def get_triplets(self, query: str) -> list: | ||
| """Retrieves relevant graph triplets.""" | ||
| subclasses = get_all_subclasses(ExtendableDataPoint) | ||
| vector_index_collections = [] | ||
|
|
||
| for subclass in subclasses: | ||
| index_fields = subclass.model_fields["metadata"].default.get("index_fields", []) | ||
| for field_name in index_fields: | ||
| vector_index_collections.append(f"{subclass.__name__}_{field_name}") | ||
|
|
||
| found_triplets = await brute_force_triplet_search( | ||
| query, top_k=self.top_k, collections=vector_index_collections or None | ||
| ) | ||
|
|
||
| if len(found_triplets) == 0: | ||
| raise NoRelevantDataFound | ||
|
|
||
| return found_triplets | ||
|
|
||
| async def get_context(self, query: str) -> Any: | ||
| """Retrieves and resolves graph triplets into context.""" | ||
| triplets = await self.get_triplets(query) | ||
| return await self.resolve_edges_to_text(triplets) | ||
|
|
||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
| """Generates a completion using graph connections context.""" | ||
| if context is None: | ||
| context = await self.get_context(query) | ||
|
|
||
| completion = await generate_completion( | ||
| query=query, | ||
| context=context, | ||
| user_prompt_path=self.user_prompt_path, | ||
| system_prompt_path=self.system_prompt_path, | ||
| ) | ||
| return [completion] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| from typing import Optional | ||
|
|
||
| from cognee.infrastructure.llm.get_llm_client import get_llm_client | ||
| from cognee.infrastructure.llm.prompts import read_query_prompt | ||
| from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever | ||
|
|
||
|
|
||
| class GraphSummaryCompletionRetriever(GraphCompletionRetriever): | ||
| """Retriever for handling graph-based completion searches with summarized context.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| user_prompt_path: str = "graph_context_for_question.txt", | ||
| system_prompt_path: str = "answer_simple_question.txt", | ||
| summarize_prompt_path: str = "summarize_search_results.txt", | ||
| top_k: int = 5, | ||
| ): | ||
| """Initialize retriever with default prompt paths and search parameters.""" | ||
| super().__init__( | ||
| user_prompt_path=user_prompt_path, | ||
| system_prompt_path=system_prompt_path, | ||
| top_k=top_k, | ||
| ) | ||
| self.summarize_prompt_path = summarize_prompt_path | ||
|
|
||
| async def resolve_edges_to_text(self, retrieved_edges: list) -> str: | ||
| """Converts retrieved graph edges into a summary without redundancies.""" | ||
| direct_text = await super().resolve_edges_to_text(retrieved_edges) | ||
| system_prompt = read_query_prompt(self.summarize_prompt_path) | ||
|
|
||
| llm_client = get_llm_client() | ||
| return await llm_client.acreate_structured_output( | ||
| text_input=direct_text, | ||
| system_prompt=system_prompt, | ||
| response_model=str, | ||
| ) | ||
lxobr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Uh oh!
There was an error while loading. Please reload this page.