Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions cognee/tests/test_relational_db_migration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pathlib
import os
from typing import List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import (
get_migration_relational_engine,
Expand All @@ -10,7 +9,7 @@
create_db_and_tables as create_pgvector_db_and_tables,
)
from cognee.tasks.ingestion import migrate_relational_database
from cognee.modules.search.types import SearchResult, SearchType
from cognee.modules.search.types import SearchType
import cognee


Expand Down Expand Up @@ -274,6 +273,55 @@ async def test_schema_only_migration():
print(f"Edge counts: {edge_counts}")


async def test_search_result_quality():
from cognee.infrastructure.databases.relational import (
get_migration_relational_engine,
)

# Get relational database with original data
migration_engine = get_migration_relational_engine()
from sqlalchemy import text

async with migration_engine.engine.connect() as conn:
result = await conn.execute(
text("""
SELECT
c.CustomerId,
c.FirstName,
c.LastName,
GROUP_CONCAT(i.InvoiceId, ',') AS invoice_ids
FROM Customer AS c
LEFT JOIN Invoice AS i ON c.CustomerId = i.CustomerId
GROUP BY c.CustomerId, c.FirstName, c.LastName
""")
)

for row in result:
# Get expected invoice IDs from relational DB for each Customer
customer_id = row.CustomerId
invoice_ids = row.invoice_ids.split(",") if row.invoice_ids else []
print(f"Relational DB Customer {customer_id}: {invoice_ids}")

# Use Cognee search to get invoice IDs for the same Customer but by providing Customer name
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=f"List me all the invoices of Customer:{row.FirstName} {row.LastName}.",
top_k=50,
system_prompt="Just return me the invoiceID as a number without any text. This is an example output: ['1', '2', '3']. Where 1, 2, 3 are invoiceIDs of an invoice",
)
print(f"Cognee search result: {search_results}")

import ast

lst = ast.literal_eval(search_results[0]) # converts string -> Python list
# Transfrom both lists to int for comparison, sorting and type consistency
lst = sorted([int(x) for x in lst])
invoice_ids = sorted([int(x) for x in invoice_ids])
assert lst == invoice_ids, (
f"Search results {lst} do not match expected invoice IDs {invoice_ids} for Customer:{customer_id}"
)


async def test_migration_sqlite():
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")

Expand All @@ -286,6 +334,7 @@ async def test_migration_sqlite():
)

await relational_db_migration()
await test_search_result_quality()
await test_schema_only_migration()


Expand Down
Loading