diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index a521b316b0..b5c7a5a00b 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -26,6 +26,7 @@ async def add( preferred_loaders: Optional[List[Union[str, dict[str, dict[str, Any]]]]] = None, incremental_loading: bool = True, data_per_batch: Optional[int] = 20, + importance_weight: float = 0.5, ): """ Add data to Cognee for knowledge graph processing. @@ -85,6 +86,9 @@ async def add( extraction_rules: Optional dictionary of rules (e.g., CSS selectors, XPath) for extracting specific content from web pages using BeautifulSoup tavily_config: Optional configuration for Tavily API, including API key and extraction settings soup_crawler_config: Optional configuration for BeautifulSoup crawler, specifying concurrency, crawl delay, and extraction rules. + importance_weight: A float between 0.0 and 1.0 representing the importance of the + ingested data. This weight will influence search result ranking. + Defaults to 0.5. Returns: PipelineRunInfo: Information about the ingestion pipeline execution including: @@ -164,6 +168,9 @@ async def add( - TAVILY_API_KEY: YOUR_TAVILY_API_KEY """ + if not 0.0 <= importance_weight <= 1.0: + raise ValueError("importance_weight must be a float between 0.0 and 1.0") + if preferred_loaders is not None: transformed = {} for item in preferred_loaders: @@ -182,6 +189,7 @@ async def add( node_set, dataset_id, preferred_loaders, + importance_weight, ), ] diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 812380eaa5..77dc3e9c5f 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -1,6 +1,6 @@ import pickle from uuid import UUID, uuid4 -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field, ConfigDict, field_validator from datetime import datetime, timezone from typing_extensions import TypedDict from typing import Optional, Any, Dict, List @@ -49,6 +49,14 @@ class DataPoint(BaseModel): metadata: Optional[MetaData] = {"index_fields": []} type: str = Field(default_factory=lambda: DataPoint.__name__) belongs_to_set: Optional[List["DataPoint"]] = None + importance_weight: Optional[float] = Field(default=0.5, ge=0.0, le=1.0) + + @field_validator('importance_weight', mode='before') + @classmethod + def set_default_weight_on_none(cls, v): + if v is None: + return 0.5 + return v def __init__(self, **data): super().__init__(**data) diff --git a/cognee/modules/chunking/LangchainChunker.py b/cognee/modules/chunking/LangchainChunker.py index 849e51fd99..87f1ae0eae 100644 --- a/cognee/modules/chunking/LangchainChunker.py +++ b/cognee/modules/chunking/LangchainChunker.py @@ -48,6 +48,7 @@ async def read(self): chunk_index=self.chunk_index, cut_type="missing", contains=[], + importance_weight=self.document.importance_weight, metadata={ "index_fields": ["text"], }, diff --git a/cognee/modules/chunking/TextChunker.py b/cognee/modules/chunking/TextChunker.py index f7b2032545..a434f3f759 100644 --- a/cognee/modules/chunking/TextChunker.py +++ b/cognee/modules/chunking/TextChunker.py @@ -30,6 +30,7 @@ async def read(self): chunk_index=self.chunk_index, cut_type=chunk_data["cut_type"], contains=[], + importance_weight=self.document.importance_weight, metadata={ "index_fields": ["text"], }, @@ -49,6 +50,7 @@ async def read(self): chunk_index=self.chunk_index, cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"], contains=[], + importance_weight=self.document.importance_weight, metadata={ "index_fields": ["text"], }, @@ -71,6 +73,7 @@ async def read(self): chunk_index=self.chunk_index, cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"], contains=[], + importance_weight=self.document.importance_weight, metadata={"index_fields": ["text"]}, ) except Exception as e: diff --git a/cognee/modules/data/models/Data.py b/cognee/modules/data/models/Data.py index ef228f2e10..fe6d43b3b3 100644 --- a/cognee/modules/data/models/Data.py +++ b/cognee/modules/data/models/Data.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone from uuid import uuid4 -from sqlalchemy import UUID, Column, DateTime, String, JSON, Integer +from sqlalchemy import UUID, Column, DateTime, String, JSON, Integer,Float from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import relationship @@ -34,9 +34,11 @@ class Data(Base): pipeline_status = Column(MutableDict.as_mutable(JSON)) token_count = Column(Integer) data_size = Column(Integer, nullable=True) # File size in bytes + importance_weight = Column(Float, nullable=False, default=0.5) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) + datasets = relationship( "Dataset", secondary=DatasetData.__tablename__, @@ -55,5 +57,6 @@ def to_json(self) -> dict: "createdAt": self.created_at.isoformat(), "updatedAt": self.updated_at.isoformat() if self.updated_at else None, "nodeSet": self.node_set, + "importanceWeight": self.importance_weight, # "datasets": [dataset.to_json() for dataset in self.datasets] } diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 2e0b82e8d8..5737529123 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -205,11 +205,24 @@ async def map_vector_distances_to_graph_nodes(self, node_distances) -> None: for category, scored_results in node_distances.items(): for scored_result in scored_results: node_id = str(scored_result.id) - score = scored_result.score node = self.get_node(node_id) - if node: - node.add_attribute("vector_distance", score) - mapped_nodes += 1 + if not node: + continue + + # vector_distance → similarity + vector_distance = scored_result.score + vector_score = 1 - vector_distance + + # if importance_weight is missing, fallback to 1.0 + importance_weight = node.attributes.get("importance_weight", 0.5) + + final_score = vector_score * importance_weight + + node.add_attribute("vector_distance", vector_distance) + node.add_attribute("importance_weight", importance_weight) + node.add_attribute("importance_score", final_score) + + mapped_nodes += 1 async def map_vector_distances_to_graph_edges( self, vector_engine, query_vector, edge_distances @@ -238,17 +251,32 @@ async def map_vector_distances_to_graph_edges( ) distance = embedding_map.get(edge_key, None) if distance is not None: - edge.attributes["vector_distance"] = distance + vector_score = 1 - distance + else: + vector_score = 0 + + # fallback weight + importance_weight = edge.attributes.get("importance_weight", 1.0) + + final_score = vector_score * importance_weight + + edge.add_attribute("vector_distance", distance) + edge.add_attribute("importance_weight", importance_weight) + edge.add_attribute("importance_score", final_score) except Exception as ex: logger.error(f"Error mapping vector distances to edges: {str(ex)}") raise ex async def calculate_top_triplet_importances(self, k: int) -> List[Edge]: - def score(edge): - n1 = edge.node1.attributes.get("vector_distance", 1) - n2 = edge.node2.attributes.get("vector_distance", 1) - e = edge.attributes.get("vector_distance", 1) + """ + Rank triplets using merged importance_score: + importance_score = vector_similarity * importance_weight + """ + def final_score(edge: Edge): + n1 = edge.node1.attributes.get("importance_score", 0) + n2 = edge.node2.attributes.get("importance_score", 0) + e = edge.attributes.get("importance_score", 0) return n1 + n2 + e - return heapq.nsmallest(k, self.edges, key=score) + return heapq.nlargest(k, self.edges, key=final_score) diff --git a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py index c68eb494d4..8a79cfacf3 100644 --- a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py +++ b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py @@ -47,6 +47,7 @@ def _process_ontology_nodes( name=ont_node_name, description=ont_node_name, ontology_valid=True, + importance_weight=data_chunk.importance_weight, ) elif ontology_node.category == "individuals": @@ -58,11 +59,12 @@ def _process_ontology_nodes( description=ont_node_name, ontology_valid=True, belongs_to_set=data_chunk.belongs_to_set, + importance_weight=data_chunk.importance_weight, ) def _process_ontology_edges( - ontology_edges: list, existing_edges_map: dict, ontology_relationships: list + ontology_edges: list, existing_edges_map: dict, ontology_relationships: list,data_chunk: DocumentChunk, ) -> None: """Process ontology edges and add them if new""" for source, relation, target in ontology_edges: @@ -82,6 +84,7 @@ def _process_ontology_edges( "source_node_id": source_node_id, "target_node_id": target_node_id, "ontology_valid": True, + "importance_weight": data_chunk.importance_weight, }, ) ) @@ -132,13 +135,14 @@ def _create_type_node( type=node_name, description=node_name, ontology_valid=ontology_validated, + importance_weight=data_chunk.importance_weight, ) added_nodes_map[type_node_key] = type_node # Process ontology nodes and edges _process_ontology_nodes(ontology_nodes, data_chunk, added_nodes_map, added_ontology_nodes_map) - _process_ontology_edges(ontology_edges, existing_edges_map, ontology_relationships) + _process_ontology_edges(ontology_edges, existing_edges_map, ontology_relationships,data_chunk) return type_node @@ -191,13 +195,14 @@ def _create_entity_node( description=node_description, ontology_valid=ontology_validated, belongs_to_set=data_chunk.belongs_to_set, + importance_weight=data_chunk.importance_weight, ) added_nodes_map[entity_node_key] = entity_node # Process ontology nodes and edges _process_ontology_nodes(ontology_nodes, data_chunk, added_nodes_map, added_ontology_nodes_map) - _process_ontology_edges(ontology_edges, existing_edges_map, ontology_relationships) + _process_ontology_edges(ontology_edges, existing_edges_map, ontology_relationships,data_chunk) return entity_node @@ -267,8 +272,8 @@ def _process_graph_nodes( def _process_graph_edges( - graph: KnowledgeGraph, name_mapping: dict, existing_edges_map: dict, relationships: list -) -> None: + graph: KnowledgeGraph, name_mapping: dict, existing_edges_map: dict, relationships: list, + data_chunk: DocumentChunk) -> None: """Process edges in a knowledge graph""" for edge in graph.edges: # Apply name mapping if exists @@ -291,6 +296,7 @@ def _process_graph_edges( "source_node_id": source_node_id, "target_node_id": target_node_id, "ontology_valid": False, + "importance_weight": data_chunk.importance_weight, }, ) ) @@ -379,7 +385,7 @@ def expand_with_nodes_and_edges( ) # Then process edges - _process_graph_edges(graph, name_mapping, existing_edges_map, relationships) + _process_graph_edges(graph, name_mapping, existing_edges_map, relationships,data_chunk) # Return combined results graph_nodes = data_chunks + list(added_ontology_nodes_map.values()) diff --git a/cognee/modules/memify/memify.py b/cognee/modules/memify/memify.py index 2d9b32a1ba..d3fd128d83 100644 --- a/cognee/modules/memify/memify.py +++ b/cognee/modules/memify/memify.py @@ -21,6 +21,7 @@ from cognee.tasks.codingagents.coding_rule_associations import ( add_rule_associations, ) +from cognee.tasks.memify.propagate_importance_weights import propagate_importance_weights logger = get_logger("memify") @@ -69,13 +70,15 @@ async def memify( if not extraction_tasks: extraction_tasks = [Task(extract_subgraph_chunks)] if not enrichment_tasks: - enrichment_tasks = [ + default_enrichment_tasks = [ + Task(propagate_importance_weights), Task( add_rule_associations, rules_nodeset_name="coding_agent_rules", task_config={"batch_size": 1}, - ) + ), ] + enrichment_tasks = default_enrichment_tasks await setup() diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index 94b9d3fb9e..9a3b2bacdc 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -24,8 +24,12 @@ class ChunksRetriever(BaseRetriever): def __init__( self, top_k: Optional[int] = 5, + default_importance_weight: float = 0.5, ): self.top_k = top_k + self.default_importance_weight = default_importance_weight + self.candidate = top_k * 10 + self.vector_engine = get_vector_engine() async def get_context(self, query: str) -> Any: """ @@ -48,18 +52,36 @@ async def get_context(self, query: str) -> Any: f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'" ) - vector_engine = get_vector_engine() + vector_engine = self.vector_engine try: - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) + found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.candidate) logger.info(f"Found {len(found_chunks)} chunks from vector search") except CollectionNotFoundError as error: logger.error("DocumentChunk_text collection not found in vector database") raise NoDataError("No data found in the system, please add data first.") from error - chunk_payloads = [result.payload for result in found_chunks] - logger.info(f"Returning {len(chunk_payloads)} chunk payloads") - return chunk_payloads + rescored = [] + for item in found_chunks: + payload = item.payload or {} + importance_weight = payload.get("importance_weight", self.default_importance_weight) + + distance_score = item.score if hasattr(item, "score") and item.score is not None else 0.0 + similarity_score = 1 / (1 + distance_score) + final_score = similarity_score * importance_weight + text_preview = payload.get('text', '')[:20] + logger.debug( + f"Chunk: {text_preview:<20} | VecScore: {distance_score:.4f} | Weight: {importance_weight} | Final: {final_score:.4f}") + rescored.append((final_score, payload)) + + # sort descending by final_score + rescored.sort(key=lambda x: x[0], reverse=True) + + # take top_k after re-ranking + top_payloads = [p for (_, p) in rescored[: self.top_k]] + + logger.info(f"Returning {len(top_payloads)} re-ranked chunk payloads") + return top_payloads async def get_completion( self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None diff --git a/cognee/modules/retrieval/lexical_retriever.py b/cognee/modules/retrieval/lexical_retriever.py index 71b50a0b32..ae74259fc3 100644 --- a/cognee/modules/retrieval/lexical_retriever.py +++ b/cognee/modules/retrieval/lexical_retriever.py @@ -7,13 +7,13 @@ from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.shared.logging_utils import get_logger - logger = get_logger("LexicalRetriever") class LexicalRetriever(BaseRetriever): def __init__( - self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False + self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False, + default_importance_weight: float = 0.5 ): if not callable(tokenizer) or not callable(scorer): raise TypeError("tokenizer and scorer must be callables") @@ -24,6 +24,7 @@ def __init__( self.scorer = scorer self.top_k = top_k self.with_scores = bool(with_scores) + self.default_importance_weight = default_importance_weight # Cache keyed by dataset context self.chunks: dict[str, Any] = {} # {chunk_id: tokens} @@ -31,6 +32,9 @@ def __init__( self._initialized = False self._init_lock = asyncio.Lock() + def add(self, item_id: str, payload: dict): + self.payloads[item_id] = payload + async def initialize(self): """Initialize retriever by reading all DocumentChunks from graph_engine.""" async with self._init_lock: @@ -98,10 +102,16 @@ async def get_context(self, query: str) -> Any: if not isinstance(score, (int, float)): logger.warning("Non-numeric score for chunk %s → treated as 0.0", chunk_id) score = 0.0 + payload = self.payloads.get(chunk_id, {}) + weight = payload.get("importance_weight", self.default_importance_weight) + + if not isinstance(weight, (int, float)): + weight = self.default_importance_weight + final_score = score * weight except Exception as e: logger.error("Scorer failed for chunk %s: %s", chunk_id, str(e)) - score = 0.0 - results.append((chunk_id, score)) + final_score = 0.0 + results.append((chunk_id, final_score)) top_results = nlargest(self.top_k, results, key=lambda x: x[1]) logger.info( @@ -117,7 +127,7 @@ async def get_context(self, query: str) -> Any: return [self.payloads[chunk_id] for chunk_id, _ in top_results] async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None ) -> Any: """ Returns context for the given query (retrieves if not provided). diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 87b2249466..492e950b4c 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -22,9 +22,11 @@ class SummariesRetriever(BaseRetriever): - top_k: int - Number of top summaries to retrieve. """ - def __init__(self, top_k: int = 5): + def __init__(self, top_k: int = 5, default_importance_weight: float = 0.5): """Initialize retriever with search parameters.""" self.top_k = top_k + self.candidate = top_k * 10 + self.default_importance_weight = default_importance_weight async def get_context(self, query: str) -> Any: """ @@ -51,19 +53,33 @@ async def get_context(self, query: str) -> Any: try: summaries_results = await vector_engine.search( - "TextSummary_text", query, limit=self.top_k + "TextSummary_text", query, limit=self.candidate ) logger.info(f"Found {len(summaries_results)} summaries from vector search") except CollectionNotFoundError as error: logger.error("TextSummary_text collection not found in vector database") raise NoDataError("No data found in the system, please add data first.") from error - summary_payloads = [summary.payload for summary in summaries_results] - logger.info(f"Returning {len(summary_payloads)} summary payloads") - return summary_payloads + rescored = [] + for item in summaries_results: + payload = item.payload or {} + importance_weight = payload.get("importance_weight", self.default_importance_weight) + + vector_score = item.score if hasattr(item, "score") else 1.0 + final_score = vector_score * importance_weight + + rescored.append((final_score, payload)) + + # sort descending + rescored.sort(key=lambda x: x[0], reverse=True) + + top_payloads = [p for (_, p) in rescored[: self.top_k]] + + logger.info(f"Returning {len(top_payloads)} re-ranked summary payloads") + return top_payloads async def get_completion( - self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs + self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs ) -> Any: """ Generates a completion using summaries context. diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 2f8a545f73..44e7c30071 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -63,7 +63,7 @@ async def get_memory_fragment( ) -> CogneeGraph: """Creates and initializes a CogneeGraph memory fragment with optional property projections.""" if properties_to_project is None: - properties_to_project = ["id", "description", "name", "type", "text"] + properties_to_project = ["id", "description", "name", "type", "text","importance_weight"] memory_fragment = CogneeGraph() @@ -124,10 +124,15 @@ async def brute_force_triplet_search( if top_k <= 0: raise ValueError("top_k must be a positive integer.") - # Setting wide search limit based on the parameters - non_global_search = node_name is None + if properties_to_project is None: + properties_to_project = ["id", "description", "name", "type", "text", "importance_weight"] + elif "importance_weight" not in properties_to_project: + properties_to_project.append("importance_weight") - wide_search_limit = wide_search_top_k if non_global_search else None + if memory_fragment is None: + memory_fragment = await get_memory_fragment( + properties_to_project, node_type=node_type, node_name=node_name + ) if collections is None: collections = [ @@ -201,12 +206,36 @@ async def search_in_collection(collection_name: str): vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances ) - results = await memory_fragment.calculate_top_triplet_importances(k=top_k) + expansion_factor = 3 + candidate_k = top_k * expansion_factor + + initial_results = await memory_fragment.calculate_top_triplet_importances(k=candidate_k) + + if not initial_results: + return [] + + scored_candidates = [] + + for index, edge in enumerate(initial_results): + similarity_score = 1.0 / (index + 1) - return results + node1_weight = edge.node1.attributes.get("importance_weight", 0.5) + node2_weight = edge.node2.attributes.get("importance_weight", 0.5) + + importance_score = (node1_weight + node2_weight) / 2 + + final_score = similarity_score * importance_score + + scored_candidates.append((final_score, edge)) + scored_candidates.sort(key=lambda x: x[0], reverse=True) + + final_triplets = [item[1] for item in scored_candidates[:top_k]] + + return final_triplets except CollectionNotFoundError: return [] + except Exception as error: logger.error( "Error during brute force search for query: %s. Error: %s", diff --git a/cognee/tasks/ingestion/ingest_data.py b/cognee/tasks/ingestion/ingest_data.py index 5987f38d58..4bc6452c2d 100644 --- a/cognee/tasks/ingestion/ingest_data.py +++ b/cognee/tasks/ingestion/ingest_data.py @@ -20,6 +20,7 @@ from .save_data_item_to_storage import save_data_item_to_storage from .data_item_to_text_file import data_item_to_text_file +from ...shared.data_models import Document async def ingest_data( @@ -29,6 +30,7 @@ async def ingest_data( node_set: Optional[List[str]] = None, dataset_id: UUID = None, preferred_loaders: dict[str, dict[str, Any]] = None, + importance_weight: float = 0.5, ): if not user: user = await get_default_user() @@ -46,6 +48,7 @@ async def store_data_to_dataset( node_set: Optional[List[str]] = None, dataset_id: UUID = None, preferred_loaders: dict[str, dict[str, Any]] = None, + importance_weight: float = 0.5, ): new_datapoints = [] existing_data_points = [] @@ -139,6 +142,7 @@ async def store_data_to_dataset( data_point.external_metadata = ext_metadata data_point.node_set = json.dumps(node_set) if node_set else None data_point.tenant_id = user.tenant_id if user.tenant_id else None + data_point.importance_weight = importance_weight # Check if data is already in dataset if str(data_point.id) in dataset_data_map: @@ -169,6 +173,7 @@ async def store_data_to_dataset( tenant_id=user.tenant_id if user.tenant_id else None, pipeline_status={}, token_count=-1, + importance_weight = importance_weight ) new_datapoints.append(data_point) @@ -195,5 +200,5 @@ async def store_data_to_dataset( return existing_data_points + dataset_new_data_points + new_datapoints return await store_data_to_dataset( - data, dataset_name, user, node_set, dataset_id, preferred_loaders + data, dataset_name, user, node_set, dataset_id, preferred_loaders,importance_weight ) diff --git a/cognee/tasks/memify/propagate_importance_weights.py b/cognee/tasks/memify/propagate_importance_weights.py new file mode 100644 index 0000000000..356ff1f27f --- /dev/null +++ b/cognee/tasks/memify/propagate_importance_weights.py @@ -0,0 +1,90 @@ +from typing import List, Dict, Any +from collections import defaultdict +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge +from cognee.shared.logging_utils import get_logger + +logger = get_logger("WeightPropagationTask") +DEFAULT_WEIGHT = 0.5 + +async def propagate_importance_weights(data: List[CogneeGraph], task_config: Dict[str, Any] = None) -> List[ + CogneeGraph]: + """ + Propagates and fuses importance weights from initial nodes (e.g., DocumentChunks) + to their neighboring nodes and edges in the graph memory fragment using an + Average Aggregation strategy. + + Args: + data: A list containing the CogneeGraph memory fragment to be processed. + task_config: Configuration dictionary (optional). + + Returns: + The updated CogneeGraph memory fragment list. + """ + if not data or not isinstance(data[0], CogneeGraph): + logger.warning("No CogneeGraph memory fragment provided for weight propagation.") + return data + + memory_fragment: CogneeGraph = data[0] + + logger.info("Starting importance weight propagation and fusion.") + + node_weight_contributions = defaultdict(list) + + all_nodes: List[Node] = list(memory_fragment.nodes.values()) + + source_nodes: List[Node] = [ + node for node in all_nodes + if node.attributes.get("importance_weight") is not None and 0.0 <= node.attributes.get("importance_weight", + -1) <= 1.0 + ] + + for source_node in source_nodes: + initial_weight = source_node.attributes["importance_weight"] + + node_weight_contributions[source_node.id].append(initial_weight) + + for neighbor in source_node.get_skeleton_neighbours(): + node_weight_contributions[neighbor.id].append(initial_weight) + + updated_node_count = 0 + for node_id, weights in node_weight_contributions.items(): + if weights: + avg_weight = sum(weights) / len(weights) + target_node = memory_fragment.get_node(node_id) + + if target_node: + target_node.add_attribute("importance_weight", round(avg_weight, 4)) + updated_node_count += 1 + else: + logger.error(f"Target Node ID {node_id} unexpectedly not found in fragment during weight update.") + + logger.info(f"Propagation Phase 1 completed: Updated {updated_node_count} nodes via Average Aggregation.") + + all_edges: List[Edge] = memory_fragment.get_edges() + + updated_edge_count = 0 + for edge in all_edges: + node1_weight = edge.node1.attributes.get("importance_weight", DEFAULT_WEIGHT) + node2_weight = edge.node2.attributes.get("importance_weight", DEFAULT_WEIGHT) + + edge_weight = (node1_weight + node2_weight) / 2 + edge.add_attribute("importance_weight", round(edge_weight, 4)) + updated_edge_count += 1 + + logger.info(f"Propagation Phase 2 completed: Updated {updated_edge_count} edges.") + + return data + + +class PropagateImportanceWeights(Task): + """ + Cognee Task wrapper for propagating importance weights across the graph. + """ + + def __init__(self, **kwargs): + super().__init__(propagate_importance_weights, **kwargs) + + async def __call__(self, data: List[CogneeGraph], task_config: Dict[str, Any] = None) -> List[CogneeGraph]: + return await self.func(data, task_config=task_config) \ No newline at end of file diff --git a/cognee/tests/integration/documents/TextDocument_test.py b/cognee/tests/integration/documents/TextDocument_test.py index 241ffe07fc..6759616daf 100644 --- a/cognee/tests/integration/documents/TextDocument_test.py +++ b/cognee/tests/integration/documents/TextDocument_test.py @@ -26,14 +26,17 @@ @pytest.mark.parametrize( - "input_file,chunk_size", - [("code.txt", 256), ("Natural_language_processing.txt", 128)], + "input_file,chunk_size,importance_weight,expected_weight", + [ + ("code.txt", 256, 0.9, 0.9), + ("Natural_language_processing.txt", 128, None, 0.5), + ], ) @patch.object( chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine ) @pytest.mark.asyncio -async def test_TextDocument(mock_engine, input_file, chunk_size): +async def test_TextDocument(mock_engine, input_file, chunk_size, expected_weight): test_file_path = os.path.join( pathlib.Path(__file__).parent.parent.parent, "test_data", input_file ) @@ -43,11 +46,12 @@ async def test_TextDocument(mock_engine, input_file, chunk_size): raw_data_location=test_file_path, external_metadata="", mime_type="", + importance_weight=importance_weight, ) async for ground_truth, paragraph_data in async_gen_zip( - GROUND_TRUTH[input_file], - document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size), + GROUND_TRUTH[input_file], + document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size), ): assert ground_truth["word_count"] == paragraph_data.chunk_size, ( f'{ground_truth["word_count"] = } != {paragraph_data.chunk_size = }' @@ -58,3 +62,11 @@ async def test_TextDocument(mock_engine, input_file, chunk_size): assert ground_truth["cut_type"] == paragraph_data.cut_type, ( f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }' ) + + assert hasattr(paragraph_data, "importance_weight"), ( + "DocumentChunk object is missing the 'importance_weight' attribute." + ) + assert paragraph_data.importance_weight == expected_weight, ( + f"Chunk importance_weight failed for Document {input_file}. " + f"Expected {expected_weight}, but got {paragraph_data.importance_weight}." + ) \ No newline at end of file diff --git a/cognee/tests/test_importance_weight.py b/cognee/tests/test_importance_weight.py new file mode 100644 index 0000000000..d74da3542a --- /dev/null +++ b/cognee/tests/test_importance_weight.py @@ -0,0 +1,197 @@ +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.graph import get_graph_engine + + +class MockEdge: + def __init__(self, node1, node2, relationship,score = None): + self.node1 = node1 + self.node2 = node2 + self.relationship = relationship + self.attributes = {} + if score is not None: + self.score = score + + +class MockNode: + def __init__(self, node_id, node_type, importance_weight=0.5): + self.id = node_id + self.type = node_type + self.attributes = {"importance_weight": importance_weight} + + +@pytest.fixture +def mock_graph_engine(): + with patch('cognee.infrastructure.databases.graph.get_graph_engine') as mock: + mock.return_value = AsyncMock() + mock.return_value.is_empty.return_value = False + yield mock.return_value + + +@pytest.fixture +def mock_vector_engine(): + with patch('cognee.infrastructure.databases.vector.get_vector_engine') as mock: + mock.return_value = MagicMock() + mock.return_value.embedding_engine.embed_text.return_value = [[0.1] * 768] # 模拟嵌入向量 + mock.return_value.search.return_value = [] + yield mock.return_value + + +@pytest.fixture +def mock_memory_fragment(): + node1 = MockNode("node1", "Entity", importance_weight=0.8) + node2 = MockNode("node2", "Entity", importance_weight=0.9) + node3 = MockNode("node3", "Entity", importance_weight=0.3) + node4 = MockNode("node4", "Entity", importance_weight=0.7) + + edge1 = MockEdge(node1, node2, "related_to") + edge2 = MockEdge(node3, node4, "related_to") + + class MockMemoryFragment: + async def calculate_top_triplet_importances(self, k): + return [edge1, edge2] + + async def map_vector_distances_to_graph_nodes(self, node_distances): + pass + + async def map_vector_distances_to_graph_edges(self, vector_engine, query_vector, edge_distances): + pass + + return MockMemoryFragment() + + +@pytest.mark.asyncio +async def test_importance_weight_in_scoring(mock_vector_engine, mock_graph_engine, mock_memory_fragment): + with patch('cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment', + return_value=mock_memory_fragment): + query = "query test" + + retriever = GraphCompletionRetriever(top_k=2) + triplets = await retriever.get_triplets(query) + assert len(triplets) == 2 + + first_edge = triplets[0] + second_edge = triplets[1] + + first_avg_weight = (first_edge.node1.attributes["importance_weight"] + + first_edge.node2.attributes["importance_weight"]) / 2 + + assert abs(first_avg_weight - 0.85) < 0.01 + + second_avg_weight = (second_edge.node1.attributes["importance_weight"] + + second_edge.node2.attributes["importance_weight"]) / 2 + assert abs(second_avg_weight - 0.5) < 0.01 + + +@pytest.mark.asyncio +async def test_importance_weight_default_value(): + node1 = MockNode("node1", "Entity") + node2 = MockNode("node2", "Entity") + + node1.attributes.pop("importance_weight", None) + node2.attributes.pop("importance_weight", None) + + edge = MockEdge(node1, node2, "related_to") + + class MockMemoryFragment: + async def calculate_top_triplet_importances(self, k): + return [edge] + + async def map_vector_distances_to_graph_nodes(self, node_distances): + pass + + async def map_vector_distances_to_graph_edges(self, vector_engine, query_vector, edge_distances): + pass + + with patch('cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment', + return_value=MockMemoryFragment()): + retriever = GraphCompletionRetriever() + triplets = await retriever.get_triplets("query test") + + assert len(triplets) == 1 + assert triplets[0] == edge + + assert "importance_weight" not in triplets[0].node1.attributes + assert "importance_weight" not in triplets[0].node2.attributes + + +@pytest.mark.asyncio +async def test_importance_weight_edge_cases(): + node1 = MockNode("node1", "Entity", importance_weight=0.0) + node2 = MockNode("node2", "Entity", importance_weight=1.0) + + edge = MockEdge(node1, node2, "related_to") + + class MockMemoryFragment: + async def calculate_top_triplet_importances(self, k): + return [edge] + + async def map_vector_distances_to_graph_nodes(self, node_distances): + pass + + async def map_vector_distances_to_graph_edges(self, vector_engine, query_vector, edge_distances): + pass + + with patch('cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment', + return_value=MockMemoryFragment()): + retriever = GraphCompletionRetriever() + triplets = await retriever.get_triplets("query test") + + assert len(triplets) == 1 + assert triplets[0].node1.attributes["importance_weight"] == 0.0 + assert triplets[0].node2.attributes["importance_weight"] == 1.0 + + avg_weight = (0.0 + 1.0) / 2 + assert abs(avg_weight - 0.5) < 0.01 + + +@pytest.fixture +def mock_memory_fragment_for_ranking(): + node_a1 = MockNode("A1", "Entity", importance_weight=1.0) + node_a2 = MockNode("A2", "Entity", importance_weight=1.0) + edge_a = MockEdge(node_a1, node_a2, "high_weight", score=0.9) + + node_b1 = MockNode("B1", "Entity", importance_weight=0.1) + node_b2 = MockNode("B2", "Entity", importance_weight=0.1) + edge_b = MockEdge(node_b1, node_b2, "low_weight", score=0.5) + + expected_ranking = [edge_a, edge_b] + + class MockMemoryFragment: + async def calculate_top_triplet_importances(self, k): + return expected_ranking + + async def map_vector_distances_to_graph_nodes(self, node_distances): + pass + + async def map_vector_distances_to_graph_edges(self, vector_engine, query_vector, edge_distances): + pass + + return MockMemoryFragment() + + +@pytest.mark.asyncio +async def test_importance_weight_ranking_override(mock_vector_engine, mock_graph_engine, + mock_memory_fragment_for_ranking): + with patch('cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment', + return_value=mock_memory_fragment_for_ranking): + query = "ranking test" + retriever = GraphCompletionRetriever(top_k=2) + triplets = await retriever.get_triplets(query) + + assert len(triplets) == 2 + + assert triplets[0].node1.attributes["importance_weight"] == 1.0 + assert triplets[0].relationship == "high_weight" + + assert triplets[1].node1.attributes["importance_weight"] == 0.1 + assert triplets[1].relationship == "low_weight" + + assert triplets[0].score > triplets[1].score + assert abs(triplets[0].score - 0.9) < 0.01 + assert abs(triplets[1].score - 0.5) < 0.01 \ No newline at end of file diff --git a/cognee/tests/test_propagate_importance_weights.py b/cognee/tests/test_propagate_importance_weights.py new file mode 100644 index 0000000000..ee69260a6d --- /dev/null +++ b/cognee/tests/test_propagate_importance_weights.py @@ -0,0 +1,78 @@ +import pytest +import asyncio +from typing import List, Dict, Any, Optional +from unittest.mock import MagicMock + +# 导入 CogneeGraph 相关的元素 +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge + +from cognee.tasks.memify.propagate_importance_weights import propagate_importance_weights + +class MockNode(Node): + + def __init__(self, node_id: str, importance_weight: Optional[float] = None): + super().__init__(node_id, dimension=1) + if importance_weight is not None: + self.attributes["importance_weight"] = importance_weight + +@pytest.fixture +def mock_memory_fragment() -> CogneeGraph: + + node_a = MockNode("N_A", importance_weight=1.0) + node_b = MockNode("N_B", importance_weight=0.2) + + node_x = MockNode("N_X") + node_y = MockNode("N_Y") + node_z = MockNode("N_Z") + + graph = CogneeGraph(directed=False) + + for node in [node_a, node_b, node_x, node_y, node_z]: + graph.add_node(node) + + edge_ax = Edge(node_a, node_x, directed=False) + graph.add_edge(edge_ax) + + edge_ay = Edge(node_a, node_y, directed=False) + graph.add_edge(edge_ay) + + edge_by = Edge(node_b, node_y, directed=False) + graph.add_edge(edge_by) + + return graph + +@pytest.mark.asyncio +async def test_weight_propagation_and_fusion(mock_memory_fragment: CogneeGraph): + data = [mock_memory_fragment] + updated_data = await propagate_importance_weights(data) + + updated_graph: CogneeGraph = updated_data[0] + + nodes = {node.id: node for node in updated_graph.nodes.values()} + + n_a = nodes['N_A'] + assert abs(n_a.attributes["importance_weight"] - 1.0) < 1e-4 + + n_b = nodes['N_B'] + assert abs(n_b.attributes["importance_weight"] - 0.2) < 1e-4 + + n_x = nodes['N_X'] + assert abs(n_x.attributes["importance_weight"] - 1.0) < 1e-4 + + n_y = nodes['N_Y'] + assert abs(n_y.attributes["importance_weight"] - 0.6) < 1e-4 + + assert "importance_weight" not in nodes['N_Z'].attributes + + edges = updated_graph.get_edges() + edge_map = {(e.node1.id, e.node2.id): e for e in edges if e.node1.id < e.node2.id} + + edge_ax = edge_map[('N_A', 'N_X')] + assert abs(edge_ax.attributes["importance_weight"] - 1.0) < 1e-4 + + edge_ay = edge_map[('N_A', 'N_Y')] + assert abs(edge_ay.attributes["importance_weight"] - 0.8) < 1e-4 + + edge_by = edge_map[('N_B', 'N_Y')] + assert abs(edge_by.attributes["importance_weight"] - 0.4) < 1e-4 \ No newline at end of file diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py index 44786f79d2..9544ac48c4 100644 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -47,6 +47,7 @@ async def test_chunk_context_simple(self): raw_data_location="somewhere", external_metadata="", mime_type="text/plain", + importance_weight = 0.5 ) chunk1 = DocumentChunk( @@ -56,6 +57,7 @@ async def test_chunk_context_simple(self): cut_type="sentence_end", is_part_of=document, contains=[], + importance_weight=0.5 ) chunk2 = DocumentChunk( text="Mike Broski", @@ -64,6 +66,7 @@ async def test_chunk_context_simple(self): cut_type="sentence_end", is_part_of=document, contains=[], + importance_weight=1 ) chunk3 = DocumentChunk( text="Christina Mayer", @@ -72,6 +75,7 @@ async def test_chunk_context_simple(self): cut_type="sentence_end", is_part_of=document, contains=[], + importance_weight=0.5 ) entities = [chunk1, chunk2, chunk3] @@ -199,3 +203,240 @@ async def test_chunk_context_on_empty_graph(self): context = await retriever.get_context("Christina Mayer") assert len(context) == 0, "Found chunks when none should exist" + + @pytest.mark.asyncio + async def test_importance_weight_default_value(self): + + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_importance_weight_default" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_importance_weight_default" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Test Document", + raw_data_location="test", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Test chunk 1 (Missing Weight)", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + chunk2 = DocumentChunk( + text="Test chunk 2 (Explicit Low Weight)", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + importance_weight=0.1 + ) + + entities = [chunk1, chunk2] + await add_data_points(entities) + + retriever = ChunksRetriever() + + with patch.object(retriever.vector_engine, 'search', new_callable=AsyncMock) as mock_search: + mock_search.return_value = [ + type('ScoredPoint', (), {'payload': chunk1.model_dump(), 'score': 0.9}), + type('ScoredPoint', (), {'payload': chunk2.model_dump(), 'score': 0.6}) + ] + + context = await retriever.get_context("test query") + + args, kwargs = mock_search.call_args + assert 'score_threshold' not in kwargs + assert len(context) == 2 + assert context[0]["text"] == "Test chunk 1 (Missing Weight)" + assert context[1]["text"] == "Test chunk 2 (Explicit Low Weight)" + + @pytest.mark.asyncio + async def test_importance_weight_ranking(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_importance_weight_ranking" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_importance_weight_ranking" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Test Document", + raw_data_location="test", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="High importance, low score", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + importance_weight=1.0 + ) + + chunk2 = DocumentChunk( + text="Low importance, high score", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + importance_weight=0.1 + ) + + entities = [chunk1, chunk2] + await add_data_points(entities) + + retriever = ChunksRetriever() + + with patch.object(retriever.vector_engine, 'search', new_callable=AsyncMock) as mock_search: + mock_search.return_value = [ + type('ScoredPoint', (), {'payload': chunk1.model_dump(), 'score': 0.6}), + type('ScoredPoint', (), {'payload': chunk2.model_dump(), 'score': 0.9}) + ] + + context = await retriever.get_context("test query") + + assert len(context) == 2 + assert context[0]["text"] == "High importance, low score" + assert context[1]["text"] == "Low importance, high score" + + @pytest.mark.asyncio + async def test_importance_weight_boundary_values(self): + + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_importance_weight_boundary_values" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_importance_weight_boundary_values" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Test Document", + raw_data_location="test", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Zero weight chunk", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + importance_weight=0.0 + ) + + chunk2 = DocumentChunk( + text="Full weight chunk", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + importance_weight=1.0 + ) + + entities = [chunk1, chunk2] + await add_data_points(entities) + + retriever = ChunksRetriever() + with patch.object(retriever.vector_engine, 'search', new_callable=AsyncMock) as mock_search: + mock_search.return_value = [ + type('ScoredPoint', (), {'payload': chunk1.model_dump(), 'score': 0.8}), # 原始得分高 (0.8) + type('ScoredPoint', (), {'payload': chunk2.model_dump(), 'score': 0.5}) # 原始得分低 (0.5) + ] + + context = await retriever.get_context("test query") + assert len(context) == 2 + assert context[0]["text"] == "Full weight chunk" + assert context[1]["text"] == "Zero weight chunk" + + @pytest.mark.asyncio + async def test_ranking_stability_on_equal_score(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_ranking_stability" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_ranking_stability" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Test Document", + raw_data_location="test", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Stable Chunk 1 (High Weight, Low Score)", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + importance_weight=1.0 + ) + + chunk2 = DocumentChunk( + text="Stable Chunk 2 (Low Weight, High Score)", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + importance_weight=0.5 + ) + + entities = [chunk1, chunk2] + await add_data_points(entities) + + retriever = ChunksRetriever() + + with patch.object(retriever.vector_engine, 'search', new_callable=AsyncMock) as mock_search: + mock_search.return_value = [ + type('ScoredPoint', (), {'payload': chunk2.model_dump(), 'score': 1.0}), + type('ScoredPoint', (), {'payload': chunk1.model_dump(), 'score': 3.0}) + ] + + context = await retriever.get_context("test query for stability") + + assert len(context) == 2 + assert context[0]["text"] == "Stable Chunk 2 (Low Weight, High Score)" + assert context[1]["text"] == "Stable Chunk 1 (High Weight, Low Score)"