Skip to content

Commit 4e8845c

Browse files
authored
chore: retriever test reorganization + adding new tests (integration) (STEP 1) (#1881)
<!-- .github/pull_request_template.md --> ## Description This PR restructures/adds integration and unit tests for the retrieval module. -Old integration tests were updated and moved under unit tests + fixtures added -Added missing unit tests for all core retrieval business logic -Covered 100% of the core retrievers with tests -Minor changes (dead code deletion, typo fixed) ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Changes** * TripletRetriever now returns up to 5 results by default (was 1), providing richer context. * **Tests** * Reorganized test coverage: many unit tests removed and replaced with comprehensive integration tests across retrieval components (graph, chunks, RAG, summaries, temporal, triplets, structured output). * **Chores** * Simplified triplet formatting logic and removed debug output. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 78028b8 commit 4e8845c

23 files changed

+1889
-2304
lines changed

cognee/modules/retrieval/triplet_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
"""Initialize retriever with optional custom prompt paths."""
3737
self.user_prompt_path = user_prompt_path
3838
self.system_prompt_path = system_prompt_path
39-
self.top_k = top_k if top_k is not None else 1
39+
self.top_k = top_k if top_k is not None else 5
4040
self.system_prompt = system_prompt
4141

4242
async def get_context(self, query: str) -> str:

cognee/modules/retrieval/utils/brute_force_triplet_search.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,6 @@
1616

1717

1818
def format_triplets(edges):
19-
print("\n\n\n")
20-
21-
def filter_attributes(obj, attributes):
22-
"""Helper function to filter out non-None properties, including nested dicts."""
23-
result = {}
24-
for attr in attributes:
25-
value = getattr(obj, attr, None)
26-
if value is not None:
27-
# If the value is a dict, extract relevant keys from it
28-
if isinstance(value, dict):
29-
nested_values = {
30-
k: v for k, v in value.items() if k in attributes and v is not None
31-
}
32-
result[attr] = nested_values
33-
else:
34-
result[attr] = value
35-
return result
36-
3719
triplets = []
3820
for edge in edges:
3921
node1 = edge.node1
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import os
2+
import pytest
3+
import pathlib
4+
import pytest_asyncio
5+
from typing import List
6+
import cognee
7+
8+
from cognee.low_level import setup
9+
from cognee.tasks.storage import add_data_points
10+
from cognee.infrastructure.databases.vector import get_vector_engine
11+
from cognee.modules.chunking.models import DocumentChunk
12+
from cognee.modules.data.processing.document_types import TextDocument
13+
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
14+
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
15+
from cognee.infrastructure.engine import DataPoint
16+
from cognee.modules.data.processing.document_types import Document
17+
from cognee.modules.engine.models import Entity
18+
19+
20+
class DocumentChunkWithEntities(DataPoint):
21+
text: str
22+
chunk_size: int
23+
chunk_index: int
24+
cut_type: str
25+
is_part_of: Document
26+
contains: List[Entity] = None
27+
28+
metadata: dict = {"index_fields": ["text"]}
29+
30+
31+
@pytest_asyncio.fixture
32+
async def setup_test_environment_with_chunks_simple():
33+
"""Set up a clean test environment with simple chunks."""
34+
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
35+
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_simple")
36+
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_simple")
37+
38+
cognee.config.system_root_directory(system_directory_path)
39+
cognee.config.data_root_directory(data_directory_path)
40+
41+
await cognee.prune.prune_data()
42+
await cognee.prune.prune_system(metadata=True)
43+
await setup()
44+
45+
document = TextDocument(
46+
name="Steve Rodger's career",
47+
raw_data_location="somewhere",
48+
external_metadata="",
49+
mime_type="text/plain",
50+
)
51+
52+
chunk1 = DocumentChunk(
53+
text="Steve Rodger",
54+
chunk_size=2,
55+
chunk_index=0,
56+
cut_type="sentence_end",
57+
is_part_of=document,
58+
contains=[],
59+
)
60+
chunk2 = DocumentChunk(
61+
text="Mike Broski",
62+
chunk_size=2,
63+
chunk_index=1,
64+
cut_type="sentence_end",
65+
is_part_of=document,
66+
contains=[],
67+
)
68+
chunk3 = DocumentChunk(
69+
text="Christina Mayer",
70+
chunk_size=2,
71+
chunk_index=2,
72+
cut_type="sentence_end",
73+
is_part_of=document,
74+
contains=[],
75+
)
76+
77+
entities = [chunk1, chunk2, chunk3]
78+
79+
await add_data_points(entities)
80+
81+
yield
82+
83+
try:
84+
await cognee.prune.prune_data()
85+
await cognee.prune.prune_system(metadata=True)
86+
except Exception:
87+
pass
88+
89+
90+
@pytest_asyncio.fixture
91+
async def setup_test_environment_with_chunks_complex():
92+
"""Set up a clean test environment with complex chunks."""
93+
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
94+
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_complex")
95+
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_complex")
96+
97+
cognee.config.system_root_directory(system_directory_path)
98+
cognee.config.data_root_directory(data_directory_path)
99+
100+
await cognee.prune.prune_data()
101+
await cognee.prune.prune_system(metadata=True)
102+
await setup()
103+
104+
document1 = TextDocument(
105+
name="Employee List",
106+
raw_data_location="somewhere",
107+
external_metadata="",
108+
mime_type="text/plain",
109+
)
110+
111+
document2 = TextDocument(
112+
name="Car List",
113+
raw_data_location="somewhere",
114+
external_metadata="",
115+
mime_type="text/plain",
116+
)
117+
118+
chunk1 = DocumentChunk(
119+
text="Steve Rodger",
120+
chunk_size=2,
121+
chunk_index=0,
122+
cut_type="sentence_end",
123+
is_part_of=document1,
124+
contains=[],
125+
)
126+
chunk2 = DocumentChunk(
127+
text="Mike Broski",
128+
chunk_size=2,
129+
chunk_index=1,
130+
cut_type="sentence_end",
131+
is_part_of=document1,
132+
contains=[],
133+
)
134+
chunk3 = DocumentChunk(
135+
text="Christina Mayer",
136+
chunk_size=2,
137+
chunk_index=2,
138+
cut_type="sentence_end",
139+
is_part_of=document1,
140+
contains=[],
141+
)
142+
143+
chunk4 = DocumentChunk(
144+
text="Range Rover",
145+
chunk_size=2,
146+
chunk_index=0,
147+
cut_type="sentence_end",
148+
is_part_of=document2,
149+
contains=[],
150+
)
151+
chunk5 = DocumentChunk(
152+
text="Hyundai",
153+
chunk_size=2,
154+
chunk_index=1,
155+
cut_type="sentence_end",
156+
is_part_of=document2,
157+
contains=[],
158+
)
159+
chunk6 = DocumentChunk(
160+
text="Chrysler",
161+
chunk_size=2,
162+
chunk_index=2,
163+
cut_type="sentence_end",
164+
is_part_of=document2,
165+
contains=[],
166+
)
167+
168+
entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6]
169+
170+
await add_data_points(entities)
171+
172+
yield
173+
174+
try:
175+
await cognee.prune.prune_data()
176+
await cognee.prune.prune_system(metadata=True)
177+
except Exception:
178+
pass
179+
180+
181+
@pytest_asyncio.fixture
182+
async def setup_test_environment_empty():
183+
"""Set up a clean test environment without chunks."""
184+
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
185+
system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_empty")
186+
data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_empty")
187+
188+
cognee.config.system_root_directory(system_directory_path)
189+
cognee.config.data_root_directory(data_directory_path)
190+
191+
await cognee.prune.prune_data()
192+
await cognee.prune.prune_system(metadata=True)
193+
194+
yield
195+
196+
try:
197+
await cognee.prune.prune_data()
198+
await cognee.prune.prune_system(metadata=True)
199+
except Exception:
200+
pass
201+
202+
203+
@pytest.mark.asyncio
204+
async def test_chunks_retriever_context_multiple_chunks(setup_test_environment_with_chunks_simple):
205+
"""Integration test: verify ChunksRetriever can retrieve multiple chunks."""
206+
retriever = ChunksRetriever()
207+
208+
context = await retriever.get_context("Steve")
209+
210+
assert isinstance(context, list), "Context should be a list"
211+
assert len(context) > 0, "Context should not be empty"
212+
assert any(chunk["text"] == "Steve Rodger" for chunk in context), (
213+
"Failed to get Steve Rodger chunk"
214+
)
215+
216+
217+
@pytest.mark.asyncio
218+
async def test_chunks_retriever_top_k_limit(setup_test_environment_with_chunks_complex):
219+
"""Integration test: verify ChunksRetriever respects top_k parameter."""
220+
retriever = ChunksRetriever(top_k=2)
221+
222+
context = await retriever.get_context("Employee")
223+
224+
assert isinstance(context, list), "Context should be a list"
225+
assert len(context) <= 2, "Should respect top_k limit"
226+
227+
228+
@pytest.mark.asyncio
229+
async def test_chunks_retriever_context_complex(setup_test_environment_with_chunks_complex):
230+
"""Integration test: verify ChunksRetriever can retrieve chunk context (complex)."""
231+
retriever = ChunksRetriever(top_k=20)
232+
233+
context = await retriever.get_context("Christina")
234+
235+
assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer"
236+
237+
238+
@pytest.mark.asyncio
239+
async def test_chunks_retriever_context_on_empty_graph(setup_test_environment_empty):
240+
"""Integration test: verify ChunksRetriever handles empty graph correctly."""
241+
retriever = ChunksRetriever()
242+
243+
with pytest.raises(NoDataError):
244+
await retriever.get_context("Christina Mayer")
245+
246+
vector_engine = get_vector_engine()
247+
await vector_engine.create_collection(
248+
"DocumentChunk_text", payload_schema=DocumentChunkWithEntities
249+
)
250+
251+
context = await retriever.get_context("Christina Mayer")
252+
assert len(context) == 0, "Found chunks when none should exist"

0 commit comments

Comments
 (0)