diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index b4afc05b30..1b984d465e 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -1,21 +1,17 @@ -import cognee -from typing import List, Dict, Callable, Awaitable -from cognee.api.v1.search import SearchType - -question_answering_engine_options: Dict[str, Callable[[str, str], Awaitable[List[str]]]] = { - "cognee_graph_completion": lambda query, system_prompt_path: cognee.search( - query_type=SearchType.GRAPH_COMPLETION, - query_text=query, - system_prompt_path=system_prompt_path, - ), - "cognee_completion": lambda query, system_prompt_path: cognee.search( - query_type=SearchType.COMPLETION, query_text=query, system_prompt_path=system_prompt_path - ), - "graph_summary_completion": lambda query, system_prompt_path: cognee.search( - query_type=SearchType.GRAPH_SUMMARY_COMPLETION, - query_text=query, - system_prompt_path=system_prompt_path, - ), +from typing import List, Dict +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.graph_summary_completion_retriever import ( + GraphSummaryCompletionRetriever, +) + +from cognee.modules.retrieval.base_retriever import BaseRetriever + + +retriever_options: Dict[str, BaseRetriever] = { + "cognee_graph_completion": GraphCompletionRetriever, + "cognee_completion": CompletionRetriever, + "graph_summary_completion": GraphSummaryCompletionRetriever, } @@ -23,20 +19,22 @@ class AnswerGeneratorExecutor: async def question_answering_non_parallel( self, questions: List[Dict[str, str]], - answer_resolver: Callable[[str], Awaitable[List[str]]], + retriever: BaseRetriever, ) -> List[Dict[str, str]]: answers = [] for instance in questions: query_text = instance["question"] correct_answer = instance["answer"] - search_results = await answer_resolver(query_text) + retrieval_context = await retriever.get_context(query_text) + search_results = await retriever.get_completion(query_text, retrieval_context) answers.append( { "question": query_text, "answer": search_results[0], "golden_answer": correct_answer, + "retrieval_context": retrieval_context, } ) diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py index 9caf71f6a2..1d3686efb1 100644 --- a/cognee/eval_framework/answer_generation/run_question_answering_module.py +++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py @@ -3,7 +3,7 @@ from typing import List from cognee.eval_framework.answer_generation.answer_generation_executor import ( AnswerGeneratorExecutor, - question_answering_engine_options, + retriever_options, ) from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.databases.relational.get_relational_engine import ( @@ -48,9 +48,7 @@ async def run_question_answering( answer_generator = AnswerGeneratorExecutor() answers = await answer_generator.question_answering_non_parallel( questions=questions, - answer_resolver=lambda query: question_answering_engine_options[params["qa_engine"]]( - query, system_prompt - ), + retriever=retriever_options[params["qa_engine"]](system_prompt_path=system_prompt), ) with open(params["answers_path"], "w", encoding="utf-8") as f: json.dump(answers, f, ensure_ascii=False, indent=4) diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py index 1ac72a1056..1ac399ffeb 100644 --- a/cognee/eval_framework/eval_config.py +++ b/cognee/eval_framework/eval_config.py @@ -18,6 +18,7 @@ class EvalConfig(BaseSettings): # Evaluation params evaluating_answers: bool = True + evaluating_contexts: bool = True evaluation_engine: str = "DeepEval" # Options: 'DeepEval' (uses deepeval_model), 'DirectLLM' (uses default llm from .env) evaluation_metrics: List[str] = [ "correctness", @@ -51,6 +52,7 @@ def to_dict(self) -> dict: "answering_questions": self.answering_questions, "qa_engine": self.qa_engine, "evaluating_answers": self.evaluating_answers, + "evaluating_contexts": self.evaluating_contexts, # Controls whether context evaluation should be performed "evaluation_engine": self.evaluation_engine, "evaluation_metrics": self.evaluation_metrics, "calculate_metrics": self.calculate_metrics, diff --git a/cognee/eval_framework/evaluation/deep_eval_adapter.py b/cognee/eval_framework/evaluation/deep_eval_adapter.py index 84ae79f706..11f33571b9 100644 --- a/cognee/eval_framework/evaluation/deep_eval_adapter.py +++ b/cognee/eval_framework/evaluation/deep_eval_adapter.py @@ -5,6 +5,7 @@ from cognee.eval_framework.evaluation.metrics.exact_match import ExactMatchMetric from cognee.eval_framework.evaluation.metrics.f1 import F1ScoreMetric from typing import Any, Dict, List +from deepeval.metrics import ContextualRelevancyMetric class DeepEvalAdapter(BaseEvalAdapter): @@ -13,6 +14,7 @@ def __init__(self): "correctness": self.g_eval_correctness(), "EM": ExactMatchMetric(), "f1": F1ScoreMetric(), + "contextual_relevancy": ContextualRelevancyMetric(), } async def evaluate_answers( @@ -29,6 +31,7 @@ async def evaluate_answers( input=answer["question"], actual_output=answer["answer"], expected_output=answer["golden_answer"], + retrieval_context=[answer["retrieval_context"]], ) metric_results = {} for metric in evaluator_metrics: diff --git a/cognee/eval_framework/evaluation/evaluation_executor.py b/cognee/eval_framework/evaluation/evaluation_executor.py index dcee2281e0..5e56b50c78 100644 --- a/cognee/eval_framework/evaluation/evaluation_executor.py +++ b/cognee/eval_framework/evaluation/evaluation_executor.py @@ -3,7 +3,11 @@ class EvaluationExecutor: - def __init__(self, evaluator_engine: Union[str, EvaluatorAdapter, Any] = "DeepEval") -> None: + def __init__( + self, + evaluator_engine: Union[str, EvaluatorAdapter, Any] = "DeepEval", + evaluate_contexts: bool = False, + ) -> None: if isinstance(evaluator_engine, str): try: adapter_enum = EvaluatorAdapter(evaluator_engine) @@ -14,7 +18,10 @@ def __init__(self, evaluator_engine: Union[str, EvaluatorAdapter, Any] = "DeepEv self.eval_adapter = evaluator_engine.adapter_class() else: self.eval_adapter = evaluator_engine + self.evaluate_contexts = evaluate_contexts async def execute(self, answers: List[Dict[str, str]], evaluator_metrics: Any) -> Any: + if self.evaluate_contexts: + evaluator_metrics.append("contextual_relevancy") metrics = await self.eval_adapter.evaluate_answers(answers, evaluator_metrics) return metrics diff --git a/cognee/eval_framework/evaluation/run_evaluation_module.py b/cognee/eval_framework/evaluation/run_evaluation_module.py index 14230f2244..a344d154b7 100644 --- a/cognee/eval_framework/evaluation/run_evaluation_module.py +++ b/cognee/eval_framework/evaluation/run_evaluation_module.py @@ -42,7 +42,10 @@ async def execute_evaluation(params: dict) -> None: raise ValueError(f"Error decoding JSON from {params['answers_path']}: {e}") logging.info(f"Loaded {len(answers)} answers from {params['answers_path']}") - evaluator = EvaluationExecutor(evaluator_engine=params["evaluation_engine"]) + evaluator = EvaluationExecutor( + evaluator_engine=params["evaluation_engine"], + evaluate_contexts=params["evaluating_contexts"], + ) metrics = await evaluator.execute( answers=answers, evaluator_metrics=params["evaluation_metrics"] ) diff --git a/cognee/tests/unit/eval_framework/answer_generation_test.py b/cognee/tests/unit/eval_framework/answer_generation_test.py index d02ffd27d5..aa45b16a81 100644 --- a/cognee/tests/unit/eval_framework/answer_generation_test.py +++ b/cognee/tests/unit/eval_framework/answer_generation_test.py @@ -11,14 +11,18 @@ async def test_answer_generation(): limit = 1 corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit) - mock_answer_resolver = AsyncMock() - mock_answer_resolver.side_effect = lambda query: ["mock_answer"] + mock_retriever = AsyncMock() + mock_retriever.get_context = AsyncMock(return_value="Mocked retrieval context") + mock_retriever.get_completion = AsyncMock(return_value=["Mocked answer"]) answer_generator = AnswerGeneratorExecutor() answers = await answer_generator.question_answering_non_parallel( - questions=qa_pairs, answer_resolver=mock_answer_resolver + questions=qa_pairs, + retriever=mock_retriever, ) + mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"]) + assert len(answers) == len(qa_pairs) assert answers[0]["question"] == qa_pairs[0]["question"], ( "AnswerGeneratorExecutor is passing the question incorrectly" @@ -26,6 +30,6 @@ async def test_answer_generation(): assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], ( "AnswerGeneratorExecutor is passing the golden answer incorrectly" ) - assert answers[0]["answer"] == "mock_answer", ( + assert answers[0]["answer"] == "Mocked answer", ( "AnswerGeneratorExecutor is passing the generated answer incorrectly" )