Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: create_enrichments.py
  • Loading branch information
lxobr committed Oct 20, 2025
commit 834cf8b11307f38a09b42660db493bdf2ddaa14c
145 changes: 145 additions & 0 deletions cognee/tasks/feedback/create_enrichments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

from typing import Dict, List, Optional
from uuid import NAMESPACE_OID, uuid5

from cognee.infrastructure.llm import LLMGateway
from cognee.infrastructure.llm.prompts.read_query_prompt import read_query_prompt
from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.models import NodeSet

from .models import FeedbackEnrichment


logger = get_logger("create_enrichments")


def _validate_improved_answers(improved_answers: List[Dict]) -> bool:
"""Validate that all items contain required fields for enrichment creation."""
required_fields = [
"question",
"answer", # This is the original answer field from feedback_interaction
"improved_answer",
"new_context",
"feedback_id",
"interaction_id",
]
return all(
all(item.get(field) is not None for field in required_fields) for item in improved_answers
)


def _validate_uuid_fields(improved_answers: List[Dict]) -> bool:
"""Validate that feedback_id and interaction_id are valid UUID objects."""
try:
for item in improved_answers:
feedback_id = item.get("feedback_id")
interaction_id = item.get("interaction_id")
if not isinstance(feedback_id, type(feedback_id)) or not isinstance(
interaction_id, type(interaction_id)
):
return False
return True
except Exception:
return False


async def _generate_enrichment_report(
question: str, improved_answer: str, new_context: str, report_prompt_location: str
) -> str:
"""Generate educational report using feedback report prompt."""
try:
prompt_template = read_query_prompt(report_prompt_location)
rendered_prompt = prompt_template.format(
question=question,
improved_answer=improved_answer,
new_context=new_context,
)
return await LLMGateway.acreate_structured_output(
text_input=rendered_prompt,
system_prompt="You are a helpful assistant that creates educational content.",
response_model=str,
)
except Exception as exc:
logger.warning("Failed to generate enrichment report", error=str(exc), question=question)
return f"Educational content for: {question} - {improved_answer}"


async def _create_enrichment_datapoint(
improved_answer_item: Dict,
report_text: str,
) -> Optional[FeedbackEnrichment]:
"""Create a single FeedbackEnrichment DataPoint with proper ID and nodeset assignment."""
try:
question = improved_answer_item["question"]
improved_answer = improved_answer_item["improved_answer"]

# Create nodeset following UserQAFeedback pattern
nodeset = NodeSet(
id=uuid5(NAMESPACE_OID, name="FeedbackEnrichment"), name="FeedbackEnrichment"
)

enrichment = FeedbackEnrichment(
id=str(uuid5(NAMESPACE_OID, f"{question}_{improved_answer}")),
text=report_text,
question=question,
original_answer=improved_answer_item["answer"], # Use "answer" field
improved_answer=improved_answer,
feedback_id=improved_answer_item["feedback_id"],
interaction_id=improved_answer_item["interaction_id"],
belongs_to_set=nodeset,
)

return enrichment
except Exception as exc:
logger.error(
"Failed to create enrichment datapoint",
error=str(exc),
question=improved_answer_item.get("question"),
)
return None


async def create_enrichments(
improved_answers: List[Dict],
report_prompt_location: str = "feedback_report_prompt.txt",
) -> List[FeedbackEnrichment]:
"""Create FeedbackEnrichment DataPoint instances from improved answers."""
if not improved_answers:
logger.info("No improved answers provided; returning empty list")
return []

if not _validate_improved_answers(improved_answers):
logger.error("Input validation failed; missing required fields")
return []

if not _validate_uuid_fields(improved_answers):
logger.error("UUID validation failed; invalid feedback_id or interaction_id")
return []

logger.info("Creating enrichments", count=len(improved_answers))

enrichments: List[FeedbackEnrichment] = []

for improved_answer_item in improved_answers:
question = improved_answer_item["question"]
improved_answer = improved_answer_item["improved_answer"]
new_context = improved_answer_item["new_context"]

report_text = await _generate_enrichment_report(
question, improved_answer, new_context, report_prompt_location
)

enrichment = await _create_enrichment_datapoint(improved_answer_item, report_text)

if enrichment:
enrichments.append(enrichment)
else:
logger.warning(
"Failed to create enrichment",
question=question,
interaction_id=improved_answer_item.get("interaction_id"),
)

logger.info("Created enrichments", successful=len(enrichments))
return enrichments
5 changes: 4 additions & 1 deletion cognee/tasks/feedback/generate_improved_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ async def _generate_improved_answer_for_single_interaction(

retrieved_context = await retriever.get_context(query_text)
completion = await retriever.get_structured_completion(
query=query_text, context=retrieved_context, response_model=ImprovedAnswerResponse
query=query_text,
context=retrieved_context,
response_model=ImprovedAnswerResponse,
max_iter=1,
)
new_context_text = await retriever.resolve_edges_to_text(retrieved_context)

Expand Down
3 changes: 2 additions & 1 deletion cognee/tasks/feedback/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models import Entity
from cognee.modules.engine.models import Entity, NodeSet
from cognee.tasks.temporal_graph.models import Event


Expand All @@ -18,3 +18,4 @@ class FeedbackEnrichment(DataPoint):
improved_answer: str
feedback_id: UUID
interaction_id: UUID
belongs_to_set: Optional[NodeSet] = None
12 changes: 7 additions & 5 deletions examples/python/feedback_enrichment_minimal_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from cognee.tasks.feedback.extract_feedback_interactions import extract_feedback_interactions
from cognee.tasks.feedback.generate_improved_answers import generate_improved_answers
from cognee.tasks.feedback.create_enrichments import create_enrichments


CONVERSATION = [
Expand Down Expand Up @@ -48,11 +49,12 @@ async def run_question_and_submit_feedback(question_text: str) -> bool:


async def run_feedback_enrichment_memify(last_n: int = 5):
"""Execute memify with extraction and answer improvement tasks."""
"""Execute memify with extraction, answer improvement, and enrichment creation tasks."""
# Instantiate tasks with their own kwargs
extraction_tasks = [Task(extract_feedback_interactions, last_n=last_n)]
enrichment_tasks = [
Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20)
Task(generate_improved_answers, retriever_name="graph_completion_cot", top_k=20),
Task(create_enrichments),
]
await cognee.memify(
extraction_tasks=extraction_tasks,
Expand All @@ -63,9 +65,9 @@ async def run_feedback_enrichment_memify(last_n: int = 5):


async def main():
# await initialize_conversation_and_graph(CONVERSATION)
# is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
is_correct = False
await initialize_conversation_and_graph(CONVERSATION)
is_correct = await run_question_and_submit_feedback("Who told Bob to bring the donuts?")
# is_correct = False
if not is_correct:
await run_feedback_enrichment_memify(last_n=5)

Expand Down