Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions cognee/modules/search/utils/prepare_search_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, cast
from uuid import uuid5, NAMESPACE_OID
import json
from typing import Any, List, Tuple, cast
from uuid import NAMESPACE_OID, uuid5

from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
Expand All @@ -8,9 +9,33 @@
from cognee.modules.search.utils.transform_insights_to_graph import transform_insights_to_graph


def _normalize_tuple_rows(rows: List[Tuple[Any, ...]]) -> List[Tuple[Any, ...]]:
"""Convert tuple rows returned by graph queries into JSON dict tuples."""
normalized: List[Tuple[Any, ...]] = []
for row in rows:
normalized_row: list[Any] = []
for column in row:
if isinstance(column, dict):
normalized_row.append(column)
elif isinstance(column, str):
try:
normalized_row.append(json.loads(column))
except json.JSONDecodeError:
normalized_row.append({"value": column})
else:
normalized_row.append(column)
normalized.append(tuple(normalized_row))
return normalized


async def prepare_search_result(search_result):
results, context, datasets = search_result

if isinstance(context, list) and context and isinstance(context[0], tuple):
context = _normalize_tuple_rows(context)
if isinstance(results, list) and results and isinstance(results[0], tuple):
results = _normalize_tuple_rows(results)

graphs = None
result_graph = None
context_texts = {}
Expand All @@ -27,6 +52,8 @@ async def prepare_search_result(search_result):
isinstance(context, List)
and len(context) > 0
and isinstance(context[0], tuple)
and len(context[0]) > 1
and isinstance(context[0][1], dict)
and context[0][1].get("relationship_name")
):
context_graph = transform_insights_to_graph(context)
Expand All @@ -35,13 +62,16 @@ async def prepare_search_result(search_result):
}
results = None
elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
context_graph = transform_context_to_graph(context)
edge_context = cast(List[Edge], context)
context_graph = transform_context_to_graph(edge_context)

graphs = {
", ".join([dataset.name for dataset in datasets]): context_graph,
}
context_texts = {
", ".join([dataset.name for dataset in datasets]): await resolve_edges_to_text(context),
", ".join([dataset.name for dataset in datasets]): await resolve_edges_to_text(
edge_context
),
}
elif isinstance(context, str):
context_texts = {
Expand All @@ -53,7 +83,8 @@ async def prepare_search_result(search_result):
}

if isinstance(results, List) and len(results) > 0 and isinstance(results[0], Edge):
result_graph = transform_context_to_graph(results)
edge_results = cast(List[Edge], results)
result_graph = transform_context_to_graph(edge_results)

return {
"result": result_graph or results[0] if results and len(results) == 1 else results,
Expand Down