Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 137 additions & 59 deletions cognee-frontend/src/ui/elements/Notebook/Notebook.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -244,67 +244,122 @@ function CellResult({ content }: { content: [] }) {
for (const line of content) {
try {
if (Array.isArray(line)) {
// Insights search returns uncommon graph data structure
if (Array.from(line).length > 0 && Array.isArray(line[0]) && line[0][1]["relationship_name"]) {
parsedContent.push(
<div key={line[0][1]["relationship_name"]} className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">reasoning graph</span>
<GraphVisualization
data={transformInsightsGraphData(line)}
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
graphControls={graphControls}
className="min-h-48"
/>
</div>
);
continue;
}

// @ts-expect-error line can be Array or string
for (const item of line) {
if (typeof item === "string") {
if (
typeof item === "object" && item["search_result"] && (typeof(item["search_result"]) === "string"
|| (Array.isArray(item["search_result"]) && typeof(item["search_result"][0]) === "string"))
) {
parsedContent.push(
<pre key={item.slice(0, -10)}>
{item}
</pre>
);
}
if (typeof item === "object" && item["search_result"]) {
parsedContent.push(
<div className="w-full h-full bg-white">
<div key={String(item["search_result"])} className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">query response (dataset: {item["dataset_name"]})</span>
<span className="block px-2 py-2">{item["search_result"]}</span>
<span className="block px-2 py-2 whitespace-normal">{item["search_result"]}</span>
</div>
);
}
if (typeof item === "object" && item["graph"] && typeof item["graph"] === "object") {
} else if (typeof(item) === "object" && item["search_result"] && typeof(item["search_result"]) === "object") {
parsedContent.push(
<div className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">reasoning graph</span>
<GraphVisualization
data={transformToVisualizationData(item["graph"])}
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
graphControls={graphControls}
className="min-h-48"
/>
</div>
<pre className="px-2 w-full h-full bg-white text-sm" key={String(item).slice(0, -10)}>
{JSON.stringify(item, null, 2)}
</pre>
)
} else if (typeof(item) === "string") {
parsedContent.push(
<pre className="px-2 w-full h-full bg-white text-sm whitespace-normal" key={item.slice(0, -10)}>
{item}
</pre>
);
} else if (typeof(item) === "object" && !(item["search_result"] || item["graphs"])) {
parsedContent.push(
<pre className="px-2 w-full h-full bg-white text-sm" key={String(item).slice(0, -10)}>
{JSON.stringify(item, null, 2)}
</pre>
)
}

if (typeof item === "object" && item["graphs"] && typeof item["graphs"] === "object") {
Object.entries<{ nodes: []; edges: []; }>(item["graphs"]).forEach(([datasetName, graph]) => {
parsedContent.push(
<div key={datasetName} className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">reasoning graph (datasets: {datasetName})</span>
<GraphVisualization
data={transformToVisualizationData(graph)}
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
graphControls={graphControls}
className="min-h-80"
/>
</div>
);
});
}
}
}
if (typeof(line) === "object" && line["result"]) {

if (typeof(line) === "object" && line["result"] && typeof(line["result"]) === "string") {
const datasets = Array.from(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
new Set(Object.values(line["datasets"]).map((dataset: any) => dataset.name))
).join(", ");

parsedContent.push(
<div className="w-full h-full bg-white">
<div key={line["result"]} className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">query response (datasets: {datasets})</span>
<span className="block px-2 py-2">{line["result"]}</span>
<span className="block px-2 py-2 whitespace-normal">{line["result"]}</span>
</div>
);
if (line["graphs"]) {
}
if (typeof(line) === "object" && line["graphs"]) {
Object.entries<{ nodes: []; edges: []; }>(line["graphs"]).forEach(([datasetName, graph]) => {
parsedContent.push(
<div className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">reasoning graph</span>
<div key={datasetName} className="w-full h-full bg-white">
<span className="text-sm pl-2 mb-4">reasoning graph (datasets: {datasetName})</span>
<GraphVisualization
data={transformToVisualizationData(line["graphs"]["*"])}
data={transformToVisualizationData(graph)}
ref={graphRef as MutableRefObject<GraphVisualizationAPI>}
graphControls={graphControls}
className="min-h-80"
/>
</div>
);
}
});
}

if (typeof(line) === "object" && line["result"] && typeof(line["result"]) === "object") {
parsedContent.push(
<pre className="px-2 w-full h-full bg-white text-sm" key={String(line).slice(0, -10)}>
{JSON.stringify(line["result"], null, 2)}
</pre>
)
}
if (typeof(line) === "string") {
parsedContent.push(
<pre className="px-2 w-full h-full bg-white text-sm whitespace-normal" key={String(line).slice(0, -10)}>
{line}
</pre>
)
}
} catch (error) {
console.error(error);
parsedContent.push(line);
parsedContent.push(
<pre className="px-2 w-full h-full bg-white text-sm whitespace-normal" key={String(line).slice(0, -10)}>
{line}
</pre>
);
}
}

Expand All @@ -317,38 +372,61 @@ function CellResult({ content }: { content: [] }) {
};

function transformToVisualizationData(graph: { nodes: [], edges: [] }) {
// Implementation to transform triplet to visualization data

return {
nodes: graph.nodes,
links: graph.edges,
};
}

type Triplet = [{
id: string,
name: string,
type: string,
}, {
relationship_name: string,
}, {
id: string,
name: string,
type: string,
}]

function transformInsightsGraphData(triplets: Triplet[]) {
const nodes: {
[key: string]: {
id: string,
label: string,
type: string,
}
} = {};
const links: {
[key: string]: {
source: string,
target: string,
label: string,
}
} = {};

// const nodes = {};
// const links = {};

// for (const triplet of triplets) {
// nodes[triplet.source.id] = {
// id: triplet.source.id,
// label: triplet.source.attributes.name,
// type: triplet.source.attributes.type,
// attributes: triplet.source.attributes,
// };
// nodes[triplet.destination.id] = {
// id: triplet.destination.id,
// label: triplet.destination.attributes.name,
// type: triplet.destination.attributes.type,
// attributes: triplet.destination.attributes,
// };
// links[`${triplet.source.id}_${triplet.attributes.relationship_name}_${triplet.destination.id}`] = {
// source: triplet.source.id,
// target: triplet.destination.id,
// label: triplet.attributes.relationship_name,
// }
// }

// return {
// nodes: Object.values(nodes),
// links: Object.values(links),
// };
for (const triplet of triplets) {
nodes[triplet[0].id] = {
id: triplet[0].id,
label: triplet[0].name || triplet[0].id,
type: triplet[0].type,
};
nodes[triplet[2].id] = {
id: triplet[2].id,
label: triplet[2].name || triplet[2].id,
type: triplet[2].type,
};
const linkKey = `${triplet[0]["id"]}_${triplet[1]["relationship_name"]}_${triplet[2]["id"]}`;
links[linkKey] = {
source: triplet[0].id,
target: triplet[2].id,
label: triplet[1]["relationship_name"],
};
}

return {
nodes: Object.values(nodes),
links: Object.values(links),
};
}
2 changes: 1 addition & 1 deletion cognee/api/v1/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def search(
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
last_k: Optional[int] = 1,
only_context: bool = False,
use_combined_context: bool = False,
) -> Union[List[SearchResult], CombinedSearchResult]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,8 @@ async def _get_embedding(self, prompt: str) -> List[float]:
"""
Internal method to call the Ollama embeddings endpoint for a single prompt.
"""
payload = {
"model": self.model,
"prompt": prompt,
"input": prompt
}
payload = {"model": self.model, "prompt": prompt, "input": prompt}

headers = {}
api_key = os.getenv("LLM_API_KEY")
if api_key:
Expand Down
25 changes: 12 additions & 13 deletions cognee/modules/search/methods/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,19 @@ async def search(
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return_value = []
for search_result in search_results:
result, context, datasets = search_result
prepared_search_results = await prepare_search_result(search_result)

result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]

return_value.append(
{
"search_result": result,
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"graphs": graphs,
}
)
return return_value
Expand All @@ -155,14 +162,6 @@ async def search(
return return_value[0]
else:
return return_value
# return [
# SearchResult(
# search_result=result,
# dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
# dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
# )
# for index, (result, _, datasets) in enumerate(search_results)
# ]


async def authorized_search(
Expand Down Expand Up @@ -208,11 +207,11 @@ async def authorized_search(
context = {}
datasets: List[Dataset] = []

for _, search_context, datasets in search_responses:
for dataset in datasets:
for _, search_context, search_datasets in search_responses:
for dataset in search_datasets:
context[str(dataset.id)] = search_context

datasets.extend(datasets)
datasets.extend(search_datasets)

specific_search_tools = await get_search_type_tools(
query_type=query_type,
Expand Down
34 changes: 28 additions & 6 deletions cognee/modules/search/utils/prepare_search_result.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, cast
from uuid import uuid5, NAMESPACE_OID

from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.search.types.SearchResult import SearchResultDataset
from cognee.modules.search.utils.transform_context_to_graph import transform_context_to_graph
from cognee.modules.search.utils.transform_insights_to_graph import transform_insights_to_graph


async def prepare_search_result(search_result):
Expand All @@ -12,29 +15,48 @@ async def prepare_search_result(search_result):
result_graph = None
context_texts = {}

if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
if isinstance(datasets, list) and len(datasets) == 0:
datasets = [
SearchResultDataset(
id=uuid5(NAMESPACE_OID, "*"),
name="all available datasets",
)
]

if (
isinstance(context, List)
and len(context) > 0
and isinstance(context[0], tuple)
and context[0][1].get("relationship_name")
):
context_graph = transform_insights_to_graph(context)
graphs = {
", ".join([dataset.name for dataset in datasets]): context_graph,
}
results = None
elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
context_graph = transform_context_to_graph(context)

graphs = {
"*": context_graph,
", ".join([dataset.name for dataset in datasets]): context_graph,
}
context_texts = {
"*": await resolve_edges_to_text(context),
", ".join([dataset.name for dataset in datasets]): await resolve_edges_to_text(context),
}
elif isinstance(context, str):
context_texts = {
"*": context,
", ".join([dataset.name for dataset in datasets]): context,
}
elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], str):
context_texts = {
"*": "\n".join(cast(List[str], context)),
", ".join([dataset.name for dataset in datasets]): "\n".join(cast(List[str], context)),
}

if isinstance(results, List) and len(results) > 0 and isinstance(results[0], Edge):
result_graph = transform_context_to_graph(results)

return {
"result": result_graph or results[0] if len(results) == 1 else results,
"result": result_graph or results[0] if results and len(results) == 1 else results,
"graphs": graphs,
"context": context_texts,
"datasets": datasets,
Expand Down
2 changes: 1 addition & 1 deletion cognee/modules/search/utils/transform_context_to_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def transform_context_to_graph(context: List[Edge]):
if "name" in triplet.node1.attributes
else triplet.node1.id,
"type": triplet.node1.attributes["type"],
"attributes": triplet.node2.attributes,
"attributes": triplet.node1.attributes,
}
nodes[triplet.node2.id] = {
"id": triplet.node2.id,
Expand Down
Loading
Loading