diff --git a/ISSUE_REENABLE_PROVIDERS.md b/ISSUE_REENABLE_PROVIDERS.md new file mode 100644 index 0000000000..5df7f13761 --- /dev/null +++ b/ISSUE_REENABLE_PROVIDERS.md @@ -0,0 +1,59 @@ +Title: Finish live integration for google/azure/llm translation providers + +Summary +------- +This PR temporarily unregisters the `google`, `azure`, and `llm` translation +providers to keep the current change small and reviewable. Those providers +require additional live integration testing, dependency pin resolution, and +secrets configuration before they should be merged. This issue tracks the work +needed to fully enable them. + +Files of interest +----------------- +- `cognee/tasks/translation/translation_providers/google_provider.py` +- `cognee/tasks/translation/translation_providers/azure_provider.py` +- `cognee/tasks/translation/translation_providers/llm_provider.py` + +Repro steps (local) +------------------- +1. From project root, activate your venv and install the following (adjust as + necessary to avoid dependency conflicts): + + ```powershell + C:/path/to/venv/Scripts/python.exe -m pip install googletrans==4.0.0rc1 + C:/path/to/venv/Scripts/python.exe -m pip install azure-ai-translation-text + # LLM provider requires LLM API keys; no pip package required specifically. + ``` + +2. Set required environment variables in a `.env` file or in your shell: + - `LLM_API_KEY` for the LLM provider (provider-specific keys may be required) + - `AZURE_TRANSLATE_KEY`, `AZURE_TRANSLATE_ENDPOINT`, `AZURE_TRANSLATE_REGION` (if required) + +3. Run the probe script to verify provider instantiation: + + ```powershell + $env:PYTHONPATH = 'C:\Users\DELL\Desktop\open\cognee' + C:/path/to/venv/Scripts/python.exe scripts/list_translation_providers.py + ``` + +4. If the Google provider fails with an httpcore/httpx compatibility error, + experiment with compatible `httpx`/`httpcore` versions and re-run `pip + install` until `googletrans` can import. + +Acceptance criteria +------------------- +- Each provider instantiates without ImportError. +- For Google and Azure: detection and translation methods succeed when valid + credentials are provided. +- For LLM: smoke test `scripts/smoke_gemini_test.py` runs successfully with a + valid `LLM_API_KEY` (set as a GitHub secret for CI), and does not leak secrets + in logs. + +Notes +----- +- When re-enabling providers, prefer adding integration tests that run only in + CI with secrets (via GitHub Actions secrets) and are gated behind a + `RUN_LIVE_PROVIDER_TESTS` flag or similar to avoid accidental local runs. +- If the Google dependency resolution proves fragile, consider wrapping the + google provider imports in a lightweight adapter that falls back to a + stable REST-based translation API or external microservice. diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 1292d243af..14fda62d54 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -1,23 +1,31 @@ -import asyncio from pydantic import BaseModel -from typing import Union, Optional +from typing import Union, Optional, Type from uuid import UUID +import os + + + from cognee.modules.ontology.ontology_env_config import get_ontology_env_config from cognee.shared.logging_utils import get_logger + from cognee.shared.data_models import KnowledgeGraph -from cognee.infrastructure.llm import get_max_chunk_tokens +from cognee.infrastructure.llm.utils import get_max_chunk_tokens +from cognee.shared.logging_utils import get_logger -from cognee.modules.pipelines import run_pipeline +from cognee.modules.pipelines.operations.pipeline import run_pipeline from cognee.modules.pipelines.tasks.task import Task from cognee.modules.chunking.TextChunker import TextChunker from cognee.modules.ontology.ontology_config import Config +from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver from cognee.modules.ontology.get_default_ontology_resolver import ( get_default_ontology_resolver, get_ontology_resolver_from_env, ) from cognee.modules.users.models import User +logger = get_logger() + from cognee.tasks.documents import ( check_permissions_on_dataset, classify_documents, @@ -26,31 +34,68 @@ from cognee.tasks.graph import extract_graph_from_data from cognee.tasks.storage import add_data_points from cognee.tasks.summarization import summarize_text +from cognee.tasks.translation import translate_content, get_available_providers, validate_provider, get_available_detectors from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor -from cognee.tasks.temporal_graph.extract_events_and_entities import extract_events_and_timestamps -from cognee.tasks.temporal_graph.extract_knowledge_graph_from_events import ( - extract_knowledge_graph_from_events, -) -logger = get_logger("cognify") +class TranslationProviderError(ValueError): + """Error related to translation provider initialization.""" + pass -update_status_lock = asyncio.Lock() +class UnknownTranslationProviderError(TranslationProviderError): + """Unknown translation provider name.""" +class ProviderInitializationError(TranslationProviderError): + """Provider failed to initialize (likely missing dependency or bad config).""" -async def cognify( - datasets: Union[str, list[str], list[UUID]] = None, - user: User = None, - graph_model: BaseModel = KnowledgeGraph, + +_WARNED_ENV_VARS: set[str] = set() + +def _parse_batch_env(var: str, default: int = 10) -> int: + """ + Parse an environment variable as a positive integer (minimum 1), falling back to a default. + + If the environment variable named `var` is unset, the provided `default` is returned. + If the variable is set but cannot be parsed as an integer, `default` is returned and a + one-time warning is logged for that variable (the variable name is recorded in + `_WARNED_ENV_VARS` to avoid repeated warnings). + + Parameters: + var: Name of the environment variable to read. + default: Fallback integer value returned when the variable is missing or invalid. + + Returns: + An integer >= 1 representing the parsed value or the fallback `default`. + """ + raw = os.getenv(var) + if raw is None: + return default + try: + return max(1, int(raw)) + except (TypeError, ValueError): + if var not in _WARNED_ENV_VARS: + logger.warning("Invalid int for %s=%r; using default=%d", var, raw, default) + _WARNED_ENV_VARS.add(var) + return default + +# Constants for batch processing +DEFAULT_BATCH_SIZE = _parse_batch_env("COGNEE_DEFAULT_BATCH_SIZE", 10) + +async def cognify( # pylint: disable=too-many-arguments,too-many-positional-arguments + datasets: Optional[Union[str, UUID, list[str], list[UUID]]] = None, + user: Optional[User] = None, + graph_model: Type[BaseModel] = KnowledgeGraph, chunker=TextChunker, - chunk_size: int = None, - config: Config = None, - vector_db_config: dict = None, - graph_db_config: dict = None, + + chunk_size: Optional[int] = None, + ontology_file_path: Optional[str] = None, + vector_db_config: Optional[dict] = None, + graph_db_config: Optional[dict] = None, + config: Optional[Config] = None, + run_in_background: bool = False, incremental_loading: bool = True, custom_prompt: Optional[str] = None, - temporal_cognify: bool = False, ): """ Transform ingested data into a structured knowledge graph. @@ -83,6 +128,9 @@ async def cognify( 6. **Graph Construction**: Builds semantic knowledge graph with embeddings 7. **Content Summarization**: Creates hierarchical summaries for navigation + Note: To include a Translation step after chunking, use + `get_default_tasks_with_translation(...)`. + Graph Model Customization: The `graph_model` parameter allows custom knowledge structures: - **Default**: General-purpose KnowledgeGraph for any domain @@ -116,6 +164,7 @@ async def cognify( knowledge graph extraction. The prompt should guide the LLM on how to extract entities and relationships from the text content. + Returns: Union[dict, list[PipelineRunInfo]]: - **Blocking mode**: Dictionary mapping dataset_id -> PipelineRunInfo with: @@ -132,7 +181,7 @@ async def cognify( ```python import cognee - from cognee import SearchType + from cognee.api.v1.search import SearchType # Process your data into knowledge graph await cognee.cognify() @@ -190,7 +239,15 @@ class ScientificPaper(DataPoint): - LLM_PROVIDER, LLM_MODEL, VECTOR_DB_PROVIDER, GRAPH_DATABASE_PROVIDER - LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False) - LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60) + + New in this version: + - COGNEE_DEFAULT_BATCH_SIZE: Default batch size for processing (default: 10) """ + + tasks = get_default_tasks( + user, graph_model, chunker, chunk_size, ontology_file_path, custom_prompt + ) + if config is None: ontology_config = get_ontology_env_config() if ( @@ -208,12 +265,10 @@ class ScientificPaper(DataPoint): "ontology_config": {"ontology_resolver": get_default_ontology_resolver()} } - if temporal_cognify: - tasks = await get_temporal_tasks(user, chunker, chunk_size) - else: - tasks = await get_default_tasks( - user, graph_model, chunker, chunk_size, config, custom_prompt - ) + tasks = await get_default_tasks( + user, graph_model, chunker, chunk_size, config, custom_prompt + ) + # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background) @@ -231,14 +286,45 @@ class ScientificPaper(DataPoint): ) -async def get_default_tasks( # TODO: Find out a better way to do this (Boris's comment) - user: User = None, - graph_model: BaseModel = KnowledgeGraph, +def get_default_tasks( # pylint: disable=too-many-arguments,too-many-positional-arguments + user: Optional[User] = None, + graph_model: Type[BaseModel] = KnowledgeGraph, chunker=TextChunker, - chunk_size: int = None, - config: Config = None, + + chunk_size: Optional[int] = None, + ontology_file_path: Optional[str] = None, custom_prompt: Optional[str] = None, ) -> list[Task]: + """ + Return the standard, non-translation Task list used by the cognify pipeline. + + This builds the default processing pipeline (no automatic translation) and returns + a list of Task objects in execution order: + 1. classify_documents + 2. check_permissions_on_dataset (enforces write permission for `user`) + 3. extract_chunks_from_documents (uses `chunker` and `chunk_size`) + 4. extract_graph_from_data (uses `graph_model`, optional `ontology_file_path`, and `custom_prompt`) + 5. summarize_text + 6. add_data_points + + Notes: + - Batch sizes for downstream tasks use the module-level DEFAULT_BATCH_SIZE. + - If `chunk_size` is not provided, the token limit from get_max_chunk_tokens() is used. + + Parameters: + user: Optional user context used for the permission check. + graph_model: Model class used to construct knowledge graph instances. + chunker: Chunking strategy or class used to split documents into chunks. + chunk_size: Optional max tokens per chunk; if omitted, defaults to get_max_chunk_tokens(). + ontology_file_path: Optional path to an ontology file passed to the extractor. + custom_prompt: Optional custom prompt applied during graph extraction. + + Returns: + List[Task]: Ordered list of Task objects for the cognify pipeline (no translation). + """ + # Precompute max_chunk_size for stability + max_chunk = chunk_size or get_max_chunk_tokens() + if config is None: ontology_config = get_ontology_env_config() if ( @@ -246,13 +332,13 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's and ontology_config.ontology_resolver and ontology_config.matching_strategy ): - config: Config = { + config = { "ontology_config": { "ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict()) } } else: - config: Config = { + config = { "ontology_config": {"ontology_resolver": get_default_ontology_resolver()} } @@ -261,7 +347,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's Task(check_permissions_on_dataset, user=user, permissions=["write"]), Task( extract_chunks_from_documents, - max_chunk_size=chunk_size or get_max_chunk_tokens(), + max_chunk_size=max_chunk, chunker=chunker, ), # Extract text chunks based on the document type. Task( @@ -269,51 +355,102 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's graph_model=graph_model, config=config, custom_prompt=custom_prompt, - task_config={"batch_size": 10}, + task_config={"batch_size": DEFAULT_BATCH_SIZE}, ), # Generate knowledge graphs from the document chunks. Task( summarize_text, - task_config={"batch_size": 10}, + task_config={"batch_size": DEFAULT_BATCH_SIZE}, ), - Task(add_data_points, task_config={"batch_size": 10}), + Task(add_data_points, task_config={"batch_size": DEFAULT_BATCH_SIZE}), ] return default_tasks -async def get_temporal_tasks( - user: User = None, chunker=TextChunker, chunk_size: int = None +def get_default_tasks_with_translation( # pylint: disable=too-many-arguments,too-many-positional-arguments + user: Optional[User] = None, + graph_model: Type[BaseModel] = KnowledgeGraph, + chunker=TextChunker, + chunk_size: Optional[int] = None, + ontology_file_path: Optional[str] = None, + custom_prompt: Optional[str] = None, + translation_provider: str = "noop", + detection_provider: str = "langdetect", + target_language: str = "en", ) -> list[Task]: """ - Builds and returns a list of temporal processing tasks to be executed in sequence. - - The pipeline includes: - 1. Document classification. - 2. Dataset permission checks (requires "write" access). - 3. Document chunking with a specified or default chunk size. - 4. Event and timestamp extraction from chunks. - 5. Knowledge graph extraction from events. - 6. Batched insertion of data points. - - Args: - user (User, optional): The user requesting task execution, used for permission checks. - chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker. - chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default. - + Return the default Cognify pipeline task list with an added translation step. + + Constructs the standard processing pipeline (classify -> permission check -> chunk extraction -> translate -> graph extraction -> summarize -> add data points), + validates and initializes the named translation provider, and applies module DEFAULT_BATCH_SIZE to downstream batchable tasks. + + Parameters: + translation_provider (str): Name of a registered translation provider (case-insensitive). Defaults to `"noop"` which is a no-op provider. + Returns: - list[Task]: A list of Task objects representing the temporal processing pipeline. + list[Task]: Ordered Task objects ready to be executed by the pipeline executor. + + Raises: + UnknownTranslationProviderError: If the given provider name is not in get_available_providers(). + ProviderInitializationError: If the provider fails to initialize or validate via validate_provider(). """ - temporal_tasks = [ + # Fail fast on unknown providers (keeps errors close to the API surface) + translation_provider = (translation_provider or "noop").strip().lower() + detection_provider = (detection_provider or "langdetect").strip().lower() + # Validate provider using public API + if translation_provider not in get_available_providers(): + available = ", ".join(get_available_providers()) + logger.error("Unknown provider '%s'. Available: %s", translation_provider, available) + raise UnknownTranslationProviderError( + f"Unknown provider '{translation_provider}'. Available: {available}" + ) + # Validate detection provider is a known detector + if detection_provider not in get_available_detectors(): + available_detectors = ", ".join(get_available_detectors()) + raise ValueError(f"Invalid detection provider '{detection_provider}'. Available detectors: {available_detectors}") + # Instantiate to validate dependencies; include provider-specific config errors + try: + validate_provider(translation_provider) + except Exception as e: # we want to convert provider init errors + available = ", ".join(get_available_providers()) + logger.error( + "Provider '%s' failed to initialize (available: %s).", + translation_provider, + available, + exc_info=True, + ) + raise ProviderInitializationError() from e + + # Precompute max_chunk_size for stability + max_chunk = chunk_size or get_max_chunk_tokens() + + default_tasks = [ Task(classify_documents), Task(check_permissions_on_dataset, user=user, permissions=["write"]), Task( extract_chunks_from_documents, - max_chunk_size=chunk_size or get_max_chunk_tokens(), + max_chunk_size=max_chunk, chunker=chunker, + ), # Extract text chunks based on the document type. + Task( + translate_content, + target_language=target_language, + translation_provider=translation_provider, + detection_provider=detection_provider, + task_config={"batch_size": DEFAULT_BATCH_SIZE}, + ), # Auto-translate non-English content and attach metadata + Task( + extract_graph_from_data, + graph_model=graph_model, + ontology_adapter=RDFLibOntologyResolver(ontology_file=ontology_file_path), + custom_prompt=custom_prompt, + task_config={"batch_size": DEFAULT_BATCH_SIZE}, + ), # Generate knowledge graphs from the document chunks. + Task( + summarize_text, + task_config={"batch_size": DEFAULT_BATCH_SIZE}, ), - Task(extract_events_and_timestamps, task_config={"chunk_size": 10}), - Task(extract_knowledge_graph_from_events), - Task(add_data_points, task_config={"batch_size": 10}), + Task(add_data_points, task_config={"batch_size": DEFAULT_BATCH_SIZE}), ] - return temporal_tasks + return default_tasks diff --git a/cognee/base_config.py b/cognee/base_config.py index a2ad06249f..3df6bbc6ec 100644 --- a/cognee/base_config.py +++ b/cognee/base_config.py @@ -10,9 +10,13 @@ class BaseConfig(BaseSettings): data_root_directory: str = get_absolute_path(".data_storage") system_root_directory: str = get_absolute_path(".cognee_system") + + monitoring_tool: object = Observer.LLMLITE + cache_root_directory: str = get_absolute_path(".cognee_cache") monitoring_tool: object = Observer.NONE + @pydantic.model_validator(mode="after") def validate_paths(self): # Adding this here temporarily to ensure that the cache root directory is set correctly for S3 storage automatically diff --git a/cognee/modules/observability/get_observe.py b/cognee/modules/observability/get_observe.py index 9ee44e46a2..c03bbff3f0 100644 --- a/cognee/modules/observability/get_observe.py +++ b/cognee/modules/observability/get_observe.py @@ -23,3 +23,4 @@ def decorator(func): return decorator return no_op_decorator + diff --git a/cognee/tasks/translation/__init__.py b/cognee/tasks/translation/__init__.py new file mode 100644 index 0000000000..85b18710d7 --- /dev/null +++ b/cognee/tasks/translation/__init__.py @@ -0,0 +1,22 @@ +from .translate_content import ( + translate_content, + register_translation_provider, + TranslationProvider, + validate_provider, +) +from .models import TranslatedContent, LanguageMetadata +from .translation_registry import get_available_providers, get_available_detectors +# Backwards-compatible alias expected by tests and older code +get_available_translators = get_available_providers + +__all__ = ( + "get_available_providers", + "get_available_detectors", + "get_available_translators", + "LanguageMetadata", + "register_translation_provider", + "translate_content", + "TranslatedContent", + "TranslationProvider", + "validate_provider", +) diff --git a/cognee/tasks/translation/models.py b/cognee/tasks/translation/models.py new file mode 100644 index 0000000000..298d281e32 --- /dev/null +++ b/cognee/tasks/translation/models.py @@ -0,0 +1,47 @@ + +from __future__ import annotations +# Translation response model for structured output +from pydantic import BaseModel + +class TranslationResponse(BaseModel): + """Response model for LLM-based translation.""" + translated_text: str + +from datetime import datetime, timezone +from typing import Dict, Any, Optional +from pydantic import Field + +from cognee.infrastructure.engine import DataPoint + + +class TranslatedContent(DataPoint): + """Represents translated content with quality metrics. + + Stores the original and translated text, provider used, a confidence + score and a timestamp. Intended to be stored as metadata on the + originating DocumentChunk so the original and translation live + together. + """ + original_chunk_id: str + original_text: str + translated_text: str + source_language: str = Field(..., pattern=r"^[a-z]{2}(-[A-Z]{2})?$|^(unknown|und)$") + target_language: str = Field("en", pattern=r"^[a-z]{2}(-[A-Z]{2})?$") + translation_provider: str = "noop" + confidence_score: float = Field(0.0, ge=0.0, le=1.0) + translation_timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + # Inherit `metadata` from DataPoint to keep typing and defaults consistent. + + +class LanguageMetadata(DataPoint): + """Language information for content. + + Records the detected language, detection confidence, whether the + chunk requires translation and a simple character count. + """ + content_id: str + detected_language: str = Field(..., pattern=r"^[a-z]{2}(-[A-Z]{2})?$|^(unknown|und)$") + language_confidence: float = Field(0.0, ge=0.0, le=1.0) + requires_translation: bool = False + character_count: int = Field(0, ge=0) + # Inherit `metadata` from DataPoint to keep typing and defaults consistent. \ No newline at end of file diff --git a/cognee/tasks/translation/test_translation.py b/cognee/tasks/translation/test_translation.py new file mode 100644 index 0000000000..d76d616ad6 --- /dev/null +++ b/cognee/tasks/translation/test_translation.py @@ -0,0 +1,492 @@ +""" +Unit tests for translation functionality. + +Tests cover: +- Translation provider registry and discovery +- Language detection across providers +- Translation functionality +- Error handling and fallbacks +- Model validation and serialization +""" + +import pytest # type: ignore[import-untyped] +from typing import Tuple, Optional, Dict +from pydantic import ValidationError +import cognee.tasks.translation.translate_content as translate_module + +from cognee.tasks.translation.translate_content import ( + translate_content, + register_translation_provider, + get_available_providers, + TranslationProvider, + _get_provider, + snapshot_registry, + restore_registry, +) +from cognee.tasks.translation.translation_providers.noop_provider import NoopProvider +from cognee.tasks.translation.models import TranslatedContent, LanguageMetadata + + +class TestDetectionError(Exception): # pylint: disable=too-few-public-methods + """Test exception for detection failures.""" + + +class TestTranslationError(Exception): # pylint: disable=too-few-public-methods + """Test exception for translation failures.""" + + +# Ensure registry isolation across tests using public helpers +@pytest.fixture(autouse=True) +def _restore_registry(): + """ + Pytest fixture that snapshots the translation provider registry and restores it after the test. + + Use to isolate tests that register or modify providers: the current registry state is captured before the test runs, and always restored when the fixture completes (including on exceptions). + """ + snapshot = snapshot_registry() + try: + yield + finally: + restore_registry(snapshot) + + +class MockDocumentChunk: # pylint: disable=too-few-public-methods + """Mock document chunk for testing.""" + + def __init__(self, text: str, chunk_id: str = "test_chunk", metadata: Optional[Dict] = None): + """ + Initialize a mock document chunk used in tests. + + Parameters: + text (str): Chunk text content. + chunk_id (str): Identifier for the chunk; also used as chunk_index for tests. Defaults to "test_chunk". + metadata (Optional[Dict]): Optional mapping of metadata values; defaults to an empty dict. + """ + self.text = text + self.id = chunk_id + self.chunk_index = chunk_id + self.metadata = metadata or {} + + +class MockTranslationProvider: + """Mock provider for testing custom provider registration.""" + + async def detect_language(self, text: str) -> Tuple[str, float]: + """ + Detect the language of the given text and return an ISO 639-1 language code with a confidence score. + + This mock implementation uses simple keyword heuristics: returns ("es", 0.95) if the text contains "hola", + ("fr", 0.90) if it contains "bonjour", and ("en", 0.85) otherwise. + + Parameters: + text (str): Input text to analyze. + + Returns: + Tuple[str, float]: A tuple of (language_code, confidence) where language_code is an ISO 639-1 code and + confidence is a float between 0.0 and 1.0 indicating detection confidence. + """ + if "hola" in text.lower(): + return "es", 0.95 + if "bonjour" in text.lower(): + return "fr", 0.90 + return "en", 0.85 + + async def translate(self, text: str, target_language: str) -> Tuple[str, float]: + """ + Simulate translating `text` into `target_language` and return a mock translated string with a confidence score. + + If `target_language` is "en", returns the input prefixed with "[MOCK TRANSLATED]" and a confidence of 0.88. For any other target language, returns the original `text` and a confidence of 0.0. + + Parameters: + text (str): The text to translate. + target_language (str): The target language code (e.g., "en"). + + Returns: + Tuple[str, float]: A pair of (translated_text, confidence) where confidence is in [0.0, 1.0]. + """ + if target_language == "en": + return f"[MOCK TRANSLATED] {text}", 0.88 + return text, 0.0 + + +class TestProviderRegistry: + """Test translation provider registration and discovery.""" + + def test_get_available_providers_includes_builtin(self): + """Test that built-in providers are included in available list.""" + providers = get_available_providers() + assert "noop" in providers + assert "langdetect" in providers + + def test_register_custom_provider(self): + """Test custom provider registration.""" + register_translation_provider("mock", MockTranslationProvider) + providers = get_available_providers() + assert "mock" in providers + + # Test provider can be retrieved + provider = _get_provider("mock") + assert isinstance(provider, MockTranslationProvider) + + def test_provider_name_normalization(self): + """Test provider names are normalized to lowercase.""" + register_translation_provider("CUSTOM_PROVIDER", MockTranslationProvider) + providers = get_available_providers() + assert "custom_provider" in providers + + # Should be retrievable with different casing + provider1 = _get_provider("CUSTOM_PROVIDER") + provider2 = _get_provider("custom_provider") + assert provider1.__class__ is provider2.__class__ + + def test_unknown_provider_raises(self): + """Test unknown providers raise ValueError.""" + with pytest.raises(ValueError): + _get_provider("nonexistent_provider") + + +class TestNoopProvider: + """Test NoOp provider functionality.""" + + @pytest.mark.parametrize("text", ["Hello world", "Hëllo wörld"]) + @pytest.mark.asyncio + async def test_detect_language_noop(self, text): + """NoOp returns (None, 0.0) for any text.""" + provider = NoopProvider() + lang, conf = await provider.detect_language(text) + assert lang is None + assert conf == 0.0 + + @pytest.mark.asyncio + async def test_translate_returns_original(self): + """Test translation returns original text with zero confidence.""" + provider = NoopProvider() + text = "Test text" + translated, conf = await provider.translate(text, "es") + assert translated == text + assert conf == 0.0 + + +class TestTranslationModels: + """Test Pydantic models for translation data.""" + + def test_translated_content_validation(self): + """Test TranslatedContent model validation.""" + content = TranslatedContent( + original_chunk_id="chunk_1", + original_text="Hello", + translated_text="Hola", + source_language="en", + target_language="es", + translation_provider="test", + confidence_score=0.9 + ) + assert content.original_chunk_id == "chunk_1" + assert content.confidence_score == 0.9 + + def test_translated_content_confidence_validation(self): + """Test confidence score validation bounds.""" + # Valid confidence scores + TranslatedContent( + original_chunk_id="test", + original_text="test", + translated_text="test", + source_language="en", + confidence_score=0.0 + ) + TranslatedContent( + original_chunk_id="test", + original_text="test", + translated_text="test", + source_language="en", + confidence_score=1.0 + ) + + # Invalid confidence scores should raise validation error + with pytest.raises(ValidationError): + TranslatedContent( + original_chunk_id="test", + original_text="test", + translated_text="test", + source_language="en", + confidence_score=-0.1 + ) + + with pytest.raises(ValidationError): + TranslatedContent( + original_chunk_id="test", + original_text="test", + translated_text="test", + source_language="en", + confidence_score=1.1 + ) + + def test_language_metadata_validation(self): + """Test LanguageMetadata model validation.""" + metadata = LanguageMetadata( + content_id="chunk_1", + detected_language="es", + language_confidence=0.95, + requires_translation=True, + character_count=100 + ) + assert metadata.content_id == "chunk_1" + assert metadata.requires_translation is True + assert metadata.character_count == 100 + + def test_language_metadata_character_count_validation(self): + """Test character count cannot be negative.""" + with pytest.raises(ValidationError): + LanguageMetadata( + content_id="test", + detected_language="en", + character_count=-1 + ) + + +class TestTranslateContentFunction: + """Test main translate_content function.""" + + @pytest.mark.asyncio + async def test_noop_provider_processing(self): + """Test processing with noop provider.""" + chunks = [ + MockDocumentChunk("Hello world", "chunk_1"), + MockDocumentChunk("Test content", "chunk_2") + ] + + result = await translate_content( + chunks, + target_language="en", + translation_provider="noop", + confidence_threshold=0.8 + ) + + assert len(result) == 2 + for chunk in result: + assert "language" in chunk.metadata + assert chunk.metadata["language"]["detected_language"] == "unknown" + # No translation should occur with noop provider + assert "translation" not in chunk.metadata + + @pytest.mark.asyncio + async def test_translation_with_custom_provider(self): + """Test translation with custom registered provider.""" + # Register mock provider + register_translation_provider("test_provider", MockTranslationProvider) + + chunks = [MockDocumentChunk("Hola mundo", "chunk_1")] + + result = await translate_content( + chunks, + target_language="en", + translation_provider="test_provider", + confidence_threshold=0.8 + ) + + chunk = result[0] + assert "language" in chunk.metadata + assert "translation" in chunk.metadata + + # Check language metadata + lang_meta = chunk.metadata["language"] + assert lang_meta["detected_language"] == "es" + assert lang_meta["requires_translation"] is True + + # Check translation metadata + trans_meta = chunk.metadata["translation"] + assert trans_meta["original_text"] == "Hola mundo" + assert "[MOCK TRANSLATED]" in trans_meta["translated_text"] + assert trans_meta["source_language"] == "es" + assert trans_meta["target_language"] == "en" + assert trans_meta["translation_provider"] == "test_provider" + + # Check chunk text was updated + assert "[MOCK TRANSLATED]" in chunk.text + + @pytest.mark.asyncio + async def test_low_confidence_no_translation(self): + """Test that low confidence detection doesn't trigger translation.""" + register_translation_provider("low_conf", MockTranslationProvider) + + chunks = [MockDocumentChunk("Hello world", "chunk_1")] # English text + + result = await translate_content( + chunks, + target_language="en", + translation_provider="low_conf", + confidence_threshold=0.9 # High threshold + ) + + chunk = result[0] + assert "language" in chunk.metadata + # Should not translate due to high threshold and English detection + assert "translation" not in chunk.metadata + + @pytest.mark.asyncio + async def test_error_handling_in_detection(self): + """Test graceful error handling in language detection.""" + class FailingProvider: + async def detect_language(self, _text: str) -> Tuple[str, float]: + """ + Simulate a language detection failure by always raising TestDetectionError. + + This async method is used in tests to emulate a provider that fails during language detection. It accepts a text string but does not return; it always raises TestDetectionError. + """ + raise TestDetectionError() + + async def translate(self, text: str, _target_language: str) -> Tuple[str, float]: + """ + Return the input text unchanged and a translation confidence of 0.0. + + This no-op translator performs no translation; the supplied target language is ignored. + + Parameters: + text (str): Source text to "translate". + _target_language (str): Target language (ignored). + + Returns: + Tuple[str, float]: A tuple containing the original text and a confidence score (always 0.0). + """ + return text, 0.0 + + register_translation_provider("failing", FailingProvider) + + chunks = [MockDocumentChunk("Test text", "chunk_1")] + + # Disable 'langdetect' fallback to force unknown + ld = translate_module._provider_registry.pop("langdetect", None) + try: + result = await translate_content(chunks, translation_provider="failing") + finally: + if ld is not None: + translate_module._provider_registry["langdetect"] = ld + + chunk = result[0] + assert "language" in chunk.metadata + # Should have unknown language due to detection failure + lang_meta = chunk.metadata["language"] + assert lang_meta["detected_language"] == "unknown" + assert lang_meta["language_confidence"] == 0.0 + + @pytest.mark.asyncio + async def test_error_handling_in_translation(self): + """Test graceful error handling in translation.""" + class PartialProvider: + async def detect_language(self, _text: str) -> Tuple[str, float]: + """ + Mock language detection used in tests. + + Parameters: + _text (str): Input text (ignored by this mock). + + Returns: + Tuple[str, float]: A fixed detected language code ("es") and confidence (0.9). + """ + return "es", 0.9 + + async def translate(self, _text: str, _target_language: str) -> Tuple[str, float]: + """ + Simulate a failing translation by always raising TestTranslationError. + + This async method ignores its inputs and is used in tests to emulate a provider-side failure during translation. + + Parameters: + _text (str): Unused input text. + _target_language (str): Unused target language code. + + Raises: + TestTranslationError: Always raised to simulate a translation failure. + """ + raise TestTranslationError() + + register_translation_provider("partial", PartialProvider) + + chunks = [MockDocumentChunk("Hola", "chunk_1")] + + result = await translate_content( + chunks, + translation_provider="partial", + confidence_threshold=0.8 + ) + + chunk = result[0] + # Should have detected Spanish but failed translation + assert chunk.metadata["language"]["detected_language"] == "es" + # Should still create translation metadata with original text + assert "translation" in chunk.metadata + trans_meta = chunk.metadata["translation"] + assert trans_meta["translated_text"] == "Hola" # Original text due to failure + assert trans_meta["confidence_score"] == 0.0 + + @pytest.mark.asyncio + async def test_no_translation_when_same_language(self): + """Test no translation occurs when source equals target language.""" + register_translation_provider("same_lang", MockTranslationProvider) + + chunks = [MockDocumentChunk("Hello world", "chunk_1")] + + result = await translate_content( + chunks, + target_language="en", # Same as detected language + translation_provider="same_lang" + ) + + chunk = result[0] + assert "language" in chunk.metadata + # No translation should occur for same language + assert "translation" not in chunk.metadata + + @pytest.mark.asyncio + async def test_metadata_serialization(self): + """Test that metadata is properly serialized to dicts.""" + register_translation_provider("serialize_test", MockTranslationProvider) + + chunks = [MockDocumentChunk("Hola", "chunk_1")] + + result = await translate_content( + chunks, + translation_provider="serialize_test", + confidence_threshold=0.8 + ) + + chunk = result[0] + + # Metadata should be plain dicts, not Pydantic models + assert isinstance(chunk.metadata["language"], dict) + if "translation" in chunk.metadata: + assert isinstance(chunk.metadata["translation"], dict) + + def test_model_serialization_compatibility(self): + """ + Verify that a TranslatedContent instance can be dumped to a JSON-serializable dict. + + Creates a TranslatedContent with sample fields, calls model_dump(), and asserts: + - the result is a dict, + - required fields like `original_chunk_id`, `translation_timestamp`, and `metadata` are present and preserved, + - the dict can be round-tripped through json.dumps/json.loads without losing `original_chunk_id`. + """ + content = TranslatedContent( + original_chunk_id="test", + original_text="Hello", + translated_text="Hola", + source_language="en", + target_language="es" + ) + + # Should serialize to dict + data = content.model_dump() + assert isinstance(data, dict) + assert data["original_chunk_id"] == "test" + assert "translation_timestamp" in data + assert "metadata" in data + + # Should be JSON serializable + import json + json_str = json.dumps(data) + parsed = json.loads(json_str) + assert parsed["original_chunk_id"] == "test" + + + + diff --git a/cognee/tasks/translation/translate_content.py b/cognee/tasks/translation/translate_content.py new file mode 100644 index 0000000000..8fccdd5b3d --- /dev/null +++ b/cognee/tasks/translation/translate_content.py @@ -0,0 +1,353 @@ +# pylint: disable=R0903, W0221 +"""This module provides content translation capabilities for the Cognee framework.""" +import asyncio +import math +import os +from dataclasses import dataclass, field +from typing import Any, Dict, Tuple, Optional +from cognee.shared.logging_utils import get_logger +from .models import TranslatedContent, LanguageMetadata +from .translation_providers_enum import TranslationProviderEnum, TranslationProvider +from .translation_registry import ( + register_translation_provider, + get_available_providers, + get_provider_class, + snapshot_registry, + restore_registry, + validate_provider, +) +from .translation_providers.langdetect_provider import LangDetectProvider +from .translation_providers.noop_provider import NoopProvider + +logger = get_logger(__name__) + + +# Environment variables for configuration +TARGET_LANGUAGE = os.getenv("COGNEE_TRANSLATION_TARGET_LANGUAGE", "en") +try: + CONFIDENCE_THRESHOLD = float(os.getenv("COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD", "0.80")) +except (TypeError, ValueError): + logger.warning( + "Invalid float for COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD=%r; defaulting to 0.80", + os.getenv("COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD"), + ) + CONFIDENCE_THRESHOLD = 0.80 + + +def _normalize_confidence(confidence: Any) -> float: + """Normalize confidence value to float in [0.0, 1.0] range.""" + try: + confidence = float(confidence) + except (TypeError, ValueError): + return 0.0 + if math.isnan(confidence): + return 0.0 + return max(0.0, min(1.0, confidence)) + + +@dataclass +class TranslationContext: # pylint: disable=too-many-instance-attributes + """A context object to hold data for a single translation operation.""" + provider: TranslationProvider + chunk: Any + text: str + target_language: str + confidence_threshold: float + content_id: str = field(init=False) + detected_language: str = "unknown" + detection_confidence: float = 0.0 + requires_translation: bool = False + + def __post_init__(self): + # Try to set content_id from chunk attributes + for attr in ("id", "content_id", "uuid", "pk"): + if hasattr(self.chunk, attr): + self.content_id = str(getattr(self.chunk, attr)) + break + else: + self.content_id = "unknown" + +class TranslationProviderError(ValueError): + """Error related to translation provider initialization.""" + pass + +class UnknownTranslationProviderError(TranslationProviderError): + """Unknown translation provider name.""" + +class ProviderInitializationError(TranslationProviderError): + """Provider failed to initialize (likely missing dependency or bad config).""" + +def _get_provider(translation_provider: str) -> TranslationProvider: + """Resolve and instantiate a registered translation provider by name.""" + provider_cls = get_provider_class(translation_provider) + return provider_cls() + +def _normalize_lang_code(code: Optional[str]) -> str: + """Normalize a language code to a canonical form or return 'unknown'.""" + if not isinstance(code, str) or not code.strip(): + return "unknown" + c = code.strip().replace("_", "-") + parts = c.split("-") + lang = parts[0] + if len(lang) == 2 and lang.isalpha(): + if len(parts) >= 2: + region = parts[1] + if len(region) == 2 and region.isalpha(): + return f"{lang.lower()}-{region.upper()}" + return lang.lower() + +async def _detect_language_with_fallback(provider: TranslationProvider, text: str, content_id: str) -> Tuple[str, float]: + try: + detection = await provider.detect_language(text) + except Exception: + detection = None + if detection is None: + try: + fallback_cls = get_provider_class("langdetect") + fallback_provider = fallback_cls() + if not isinstance(provider, fallback_cls): + detection = await fallback_provider.detect_language(text) + except Exception: + detection = None + if detection is None: + return "unknown", 0.0 + lang_code, confidence = detection + detected_language = _normalize_lang_code(lang_code) + confidence = _normalize_confidence(confidence) + return detected_language, confidence + +def _decide_if_translation_is_required(translation_context: TranslationContext) -> None: + """Decide whether a translation should be performed and update translation_context.requires_translation.""" + target_language = _normalize_lang_code(translation_context.target_language) + + if translation_context.detected_language == "unknown": + # Decide purely from input; provider capability is handled via fallbacks. + translation_context.requires_translation = bool(translation_context.text.strip()) + else: + same_base = translation_context.detected_language.split("-")[0] == target_language.split("-")[0] + translation_context.requires_translation = ( + (not same_base) and translation_context.detection_confidence >= translation_context.confidence_threshold + ) + +def _attach_language_metadata(translation_context: TranslationContext) -> None: + """Attach language detection and translation decision metadata to the context's chunk.""" + translation_context.chunk.metadata = getattr(translation_context.chunk, "metadata", {}) or {} + lang_meta = LanguageMetadata( + content_id=str(translation_context.content_id), + detected_language=translation_context.detected_language, + language_confidence=translation_context.detection_confidence, + requires_translation=translation_context.requires_translation, + character_count=len(translation_context.text), + ) + translation_context.chunk.metadata["language"] = lang_meta.model_dump() + +def _build_provider_plan(translation_provider_name, fallback_input): + primary_key = (translation_provider_name or "noop").lower() + raw = fallback_input or [] + fallback_providers = [] + seen = {primary_key} + invalid_providers = [] + available_providers = set(get_available_providers()) + for p in raw: + if isinstance(p, str) and p.strip(): + key = p.strip().lower() + if key in available_providers and key not in seen: + fallback_providers.append(key) + seen.add(key) + else: + invalid_providers.append(p) + if invalid_providers: + logger.warning("Ignoring unknown fallback providers: %s", invalid_providers) + return primary_key, fallback_providers + + +async def _translate_and_update(translation_context: TranslationContext) -> None: + """Translate the text in translation_context and update the chunk's metadata if successful.""" + try: + result = await translation_context.provider.translate( + translation_context.text, translation_context.target_language + ) + except Exception as e: + logger.exception("Translation failed: %s", e) + return + if result is not None: + translated_text, confidence = result + if translated_text and translated_text != translation_context.text: + translation_context.chunk.text = translated_text + translation_context.chunk.metadata = getattr(translation_context.chunk, "metadata", {}) or {} + translation_context.chunk.metadata["translation"] = { + "provider": type(translation_context.provider).__name__, + "confidence": confidence, + "translated_text": translated_text, + } + +async def _process_chunk(chunk, plan, provider_cache): + # Unpack plan: (target_language, primary_key, fallback_providers, confidence_threshold, detection_provider_name) + target_language, primary_key, fallback_providers, confidence_threshold, detection_provider_name = plan + try: + provider = provider_cache.get(primary_key) + if provider is None: + provider = _get_provider(primary_key) + provider = provider_cache.setdefault(primary_key, provider) + except asyncio.CancelledError: + raise + except (ImportError, ValueError) as e: + logger.error("Provider import/value error for %s: %s", primary_key, e) + return chunk + except Exception as e: + logger.exception("Failed to initialize translation provider: %s", primary_key) + return chunk + + text_to_translate = getattr(chunk, "text", "") + if not isinstance(text_to_translate, str) or not text_to_translate.strip(): + return chunk + + translation_context = TranslationContext( + provider=provider, + chunk=chunk, + text=text_to_translate, + target_language=target_language, + confidence_threshold=confidence_threshold, + ) + + # Attempt detection using the requested detection provider; fall back to the provider's detection or langdetect + detection = None + try: + detector_cls = get_provider_class(detection_provider_name) + detector = detector_cls() + if hasattr(detector, "detect_language"): + detection = await detector.detect_language(text_to_translate) + except Exception: + detection = None + + if detection is None: + # Fallback to original detection-with-fallback semantics + translation_context.detected_language, translation_context.detection_confidence = await _detect_language_with_fallback( + provider, text_to_translate, str(translation_context.content_id) + ) + else: + lang_code, confidence = detection + translation_context.detected_language = _normalize_lang_code(lang_code) + translation_context.detection_confidence = _normalize_confidence(confidence) + + _decide_if_translation_is_required(translation_context) + _attach_language_metadata(translation_context) + + if translation_context.requires_translation: + # Short-circuit: primary provider cannot translate and no fallbacks provided + if primary_key == "noop" and not fallback_providers: + return translation_context.chunk + await _translate_and_update(translation_context) + # If no translation metadata was produced, try fallbacks in order + if "translation" not in getattr(translation_context.chunk, "metadata", {}): + for alternative_provider_name in fallback_providers: + try: + alternative_provider = provider_cache.get(alternative_provider_name) + if alternative_provider is None: + alternative_provider = _get_provider(alternative_provider_name) + alternative_provider = provider_cache.setdefault(alternative_provider_name, alternative_provider) + except asyncio.CancelledError: + raise + except (ImportError, ValueError) as e: + logger.error("Fallback provider import/value error for %s: %s", alternative_provider_name, e) + continue + except Exception as e: + logger.exception("Failed to initialize fallback translation provider: %s", alternative_provider_name) + continue + translation_context.provider = alternative_provider + await _translate_and_update(translation_context) + if "translation" in getattr(translation_context.chunk, "metadata", {}): + break + + return translation_context.chunk + + +async def translate_content(*chunks: Any, **kwargs) -> Any: + """ + Translate the content of a chunk if necessary. + + This function detects the language of the chunk's text, decides if translation is needed, + and if so, translates the text to the target language using the specified provider. + It updates the chunk with the translated text and adds metadata about the translation process. + + Args: + *chunks: The chunk(s) of content to be processed. Each chunk must have a 'text' attribute. + Can be called as: + - translate_content(chunk) - single chunk + - translate_content(chunk1, chunk2, ...) - multiple chunks + - translate_content([chunk1, chunk2, ...]) - list of chunks + **kwargs: Additional arguments: + target_language (str): Target language code (default from COGNEE_TRANSLATION_TARGET_LANGUAGE). + translation_provider (str): Primary provider key (e.g., "openai", "google", "azure", "noop"). + Defaults to "noop". + fallback_providers (List[str]): Ordered list of provider keys to try if the primary + fails or returns unchanged text. Defaults to empty list. + confidence_threshold (float): Minimum confidence threshold for language detection + (default from COGNEE_TRANSLATION_CONFIDENCE_THRESHOLD). + + Returns: + Any: For single chunk input - returns the processed chunk directly. + For multiple chunks input - returns List[Any] of processed chunks. + Each returned chunk may have its text translated and metadata updated with: + - language: detected language and confidence + - translation: translated text and provider information (if translation occurred) + """ + # Always work with a list internally for consistency + if len(chunks) == 1 and isinstance(chunks[0], list): + # Single list argument: translate_content([chunk1, chunk2, ...]) + batch = chunks[0] + return_single = False + elif len(chunks) == 1: + # Single chunk argument: translate_content(chunk) + batch = list(chunks) + return_single = True + else: + # Multiple chunk arguments: translate_content(chunk1, chunk2, ...) + batch = list(chunks) + return_single = False + + target_language = kwargs.get("target_language", TARGET_LANGUAGE) + translation_provider_name = kwargs.get("translation_provider", "noop") + primary_key, fallback_providers = _build_provider_plan( + translation_provider_name, kwargs.get("fallback_providers", []) + ) + detection_provider_name = kwargs.get("detection_provider", "langdetect") + confidence_threshold = kwargs.get("confidence_threshold", CONFIDENCE_THRESHOLD) + + # Provider cache for this batch to reduce instantiation overhead + provider_cache: Dict[str, Any] = {} + + # Bundle plan parameters to reduce argument count + plan = (target_language, primary_key, fallback_providers, confidence_threshold, detection_provider_name) + + # Parse concurrency with error handling + try: + max_concurrency = int(os.getenv("COGNEE_TRANSLATION_MAX_CONCURRENCY", "8")) + except (TypeError, ValueError): + logger.warning("Invalid COGNEE_TRANSLATION_MAX_CONCURRENCY; defaulting to 8") + max_concurrency = 8 + if max_concurrency < 1: + logger.warning("COGNEE_TRANSLATION_MAX_CONCURRENCY < 1; clamping to 1") + max_concurrency = 1 + + sem = asyncio.Semaphore(max_concurrency) + async def _wrapped(c): + async with sem: + return await _process_chunk(c, plan, provider_cache) + results = await asyncio.gather(*(_wrapped(c) for c in batch)) + + return results[0] if return_single else results + + +# Initialize providers +register_translation_provider("noop", NoopProvider) +register_translation_provider("langdetect", LangDetectProvider) +# The following providers are temporarily unregistered in this PR per maintainer +# feedback: they require live integration testing, extra dependencies, or +# credentials before they should be merged. Keep the implementation files in +# place (so they can be re-enabled later) and re-enable by uncommenting the +# matching register_translation_provider(...) calls once the provider is +# validated in CI or a local dev environment. +# register_translation_provider("llm", LLMProvider) +# register_translation_provider("google", GoogleTranslateProvider) +# register_translation_provider("azure", AzureTranslateProvider) \ No newline at end of file diff --git a/cognee/tasks/translation/translation_errors.py b/cognee/tasks/translation/translation_errors.py new file mode 100644 index 0000000000..cb7f21fae9 --- /dev/null +++ b/cognee/tasks/translation/translation_errors.py @@ -0,0 +1,31 @@ +class TranslationDependencyError(ImportError): + """Raised when a required translation dependency is missing.""" + +class LangDetectError(TranslationDependencyError): + """LangDetect library required.""" + def __init__(self, message="langdetect is not installed. Please install it with `pip install langdetect`"): + super().__init__(message) + +class GoogleTranslateError(TranslationDependencyError): + """GoogleTrans library required.""" + def __init__(self, message="googletrans is not installed. Please install it with `pip install googletrans==4.0.0-rc1`"): + super().__init__(message) + +class AzureTranslateError(TranslationDependencyError): + """Azure Translate library required.""" + def __init__(self, message="azure-ai-translation-text is not installed. Please install it with `pip install azure-ai-translation-text`"): + super().__init__(message) + +class AzureConfigError(ValueError): + """Azure configuration error.""" + def __init__(self, message="Azure Translate key (AZURE_TRANSLATE_KEY) is required."): + super().__init__(message) + +class UnknownProviderError(ValueError): + """Unknown translation provider error.""" + def __init__(self, provider_name=None): + if provider_name: + message = f"Unknown translation provider: {provider_name}." + else: + message = "Unknown translation provider." + super().__init__(message) diff --git a/cognee/tasks/translation/translation_providers/langdetect_provider.py b/cognee/tasks/translation/translation_providers/langdetect_provider.py new file mode 100644 index 0000000000..f3e218ffcd --- /dev/null +++ b/cognee/tasks/translation/translation_providers/langdetect_provider.py @@ -0,0 +1,35 @@ +from typing import Optional, Tuple, Any +from ..translation_providers_enum import TranslationProvider +from ..translation_errors import LangDetectError +import logging + +logger = logging.getLogger(__name__) + +class LangDetectProvider: + """A provider that uses the 'langdetect' library for offline language detection. This provider does not support translation.""" + _detector: Any = None + + def __init__(self): + if self._detector is None: + try: + from langdetect import DetectorFactory, detect_langs + from langdetect.lang_detect_exception import LangDetectException + DetectorFactory.seed = 0 + self._detector = (detect_langs, LangDetectException) + except ImportError as e: + raise LangDetectError() from e + + async def detect_language(self, text: str) -> Optional[Tuple[str, float]]: + """Detect the language of the provided text using the langdetect library.""" + detect_langs, LangDetectException = self._detector + try: + langs = detect_langs(text) + if langs: + return langs[0].lang, langs[0].prob + except LangDetectException: + logger.debug("Langdetect failed (text_len=%d)", len(text) if isinstance(text, str) else -1) + return None + + async def translate(self, text: str, _target_language: str) -> Optional[Tuple[str, float]]: + """This provider does not support translation. It returns the original text.""" + return text, 0.0 diff --git a/cognee/tasks/translation/translation_providers/noop_provider.py b/cognee/tasks/translation/translation_providers/noop_provider.py new file mode 100644 index 0000000000..bb826b1b19 --- /dev/null +++ b/cognee/tasks/translation/translation_providers/noop_provider.py @@ -0,0 +1,12 @@ +from typing import Optional, Tuple +from ..translation_providers_enum import TranslationProvider + +class NoopProvider: + """A no-op translation provider that does not perform detection or translation.""" + async def detect_language(self, _text: str) -> Optional[Tuple[str, float]]: + """No-op language detection: intentionally performs no detection and always returns None.""" + return None + + async def translate(self, text: str, _target_language: str) -> Optional[Tuple[str, float]]: + """Return the input text unchanged and a confidence score of 0.0.""" + return text, 0.0 diff --git a/cognee/tasks/translation/translation_providers_enum.py b/cognee/tasks/translation/translation_providers_enum.py new file mode 100644 index 0000000000..106e661c71 --- /dev/null +++ b/cognee/tasks/translation/translation_providers_enum.py @@ -0,0 +1,15 @@ +from enum import Enum +from typing import Protocol, Optional, Tuple + +class TranslationProviderEnum(Enum): + LLM = "llm" + GOOGLE = "google" + AZURE = "azure" + LANGDETECT = "langdetect" + NOOP = "noop" + +class TranslationProvider(Protocol): + async def detect_language(self, text: str) -> Optional[Tuple[str, float]]: + ... + async def translate(self, text: str, target_language: str) -> Optional[Tuple[str, float]]: + ... diff --git a/cognee/tasks/translation/translation_registry.py b/cognee/tasks/translation/translation_registry.py new file mode 100644 index 0000000000..f016bd755c --- /dev/null +++ b/cognee/tasks/translation/translation_registry.py @@ -0,0 +1,41 @@ +from typing import Dict, Type +from .translation_providers_enum import TranslationProvider + +_provider_registry: Dict[str, Type[TranslationProvider]] = {} + +def register_translation_provider(name: str, provider_cls: Type[TranslationProvider]) -> None: + """Register a translation provider under a canonical lowercase key.""" + _provider_registry[name.lower()] = provider_cls + +def get_available_providers() -> list: + """Return a sorted list of available provider keys.""" + return sorted(_provider_registry.keys()) + + +def get_available_detectors() -> list: + """Return a sorted list of registered providers that implement language detection. + + This inspects the registered provider classes and returns those whose class + defines a `detect_language` attribute/method. Detectors are a subset of + providers and may be used independently for detection tasks. + """ + detectors = [name for name, cls in _provider_registry.items() if hasattr(cls, "detect_language")] + return sorted(detectors) + +def get_provider_class(name: str) -> Type[TranslationProvider]: + """Get a provider class by name, or raise KeyError if not found.""" + return _provider_registry[name.lower()] + +def snapshot_registry() -> Dict[str, Type[TranslationProvider]]: + """Return a shallow copy snapshot of the provider registry (for tests).""" + return dict(_provider_registry) + +def restore_registry(snapshot: Dict[str, Type[TranslationProvider]]) -> None: + """Restore the global translation provider registry from a previously captured snapshot.""" + _provider_registry.clear() + _provider_registry.update(snapshot) + +def validate_provider(name: str) -> None: + """Ensure a provider is registered or raise ValueError.""" + if name.lower() not in _provider_registry: + raise ValueError(f"Unknown provider: {name}") diff --git a/cognee/tests/test_translation_providers.py b/cognee/tests/test_translation_providers.py new file mode 100644 index 0000000000..175a56b7cc --- /dev/null +++ b/cognee/tests/test_translation_providers.py @@ -0,0 +1,32 @@ +""" +Basic tests for translation providers and detectors in Cognee. +These tests ensure that all registered providers and detectors can be instantiated and used for simple detection/translation tasks. +""" +import pytest +import asyncio +from cognee.tasks.translation import translate_content, get_available_translators, get_available_detectors + +@pytest.mark.asyncio +@pytest.mark.parametrize("provider", get_available_translators()) +async def test_translation_provider_basic(provider): + # Noop should not translate, others may require API keys + chunk = type("Chunk", (), {"text": "Hello world!", "metadata": {}})() + try: + result = await translate_content(chunk, translation_provider=provider, target_language="fr") + assert hasattr(result, "text") + assert hasattr(result, "metadata") + except Exception as e: + # If provider requires config/API key, skip + pytest.skip(f"Provider '{provider}' not fully configured: {e}") + +@pytest.mark.asyncio +@pytest.mark.parametrize("detector", get_available_detectors()) +async def test_language_detector_basic(detector): + chunk = type("Chunk", (), {"text": "Hello world!", "metadata": {}})() + try: + result = await translate_content(chunk, translation_provider="noop", detection_provider=detector) + assert hasattr(result, "text") + assert hasattr(result, "metadata") + assert "language" in result.metadata + except Exception as e: + pytest.skip(f"Detector '{detector}' not fully configured: {e}") diff --git a/evals/README.md b/evals/README.md index 2ad9d64c2f..61ae885104 100644 --- a/evals/README.md +++ b/evals/README.md @@ -58,9 +58,6 @@ The following charts visualize the benchmark results and performance comparisons A comprehensive comparison of all evaluated systems across multiple metrics, showing Cognee's performance relative to Mem0, Graphiti, and LightRAG. - - - #### Optimized Cognee Configurations ![Optimized Cognee Configurations](optimized_cognee_configurations.png) diff --git a/examples/python/translation_example.py b/examples/python/translation_example.py new file mode 100644 index 0000000000..9fa9e59178 --- /dev/null +++ b/examples/python/translation_example.py @@ -0,0 +1,91 @@ +import asyncio +import os +import cognee +from cognee.api.v1.search import SearchType +from cognee.api.v1.cognify.cognify import get_default_tasks_with_translation +from cognee.modules.pipelines.operations.pipeline import run_pipeline +from cognee.tasks.translation import get_available_providers + +# Prerequisites: +# 1. Set up your environment with API keys for your chosen translation provider. +# - For OpenAI: OPENAI_API_KEY +# - For Azure: AZURE_TRANSLATE_KEY, AZURE_TRANSLATE_ENDPOINT +# 2. Specify the translation provider via an environment variable (optional, defaults to "noop"): +# COGNEE_TRANSLATION_PROVIDER="openai" # Or "google", "azure", "langdetect" +# 3. Install any required libraries for your provider: +# - pip install openai langdetect googletrans==4.0.0-rc1 azure-ai-translation-text + +async def main(): + """ + Demonstrates an end-to-end translation-enabled Cognify workflow using the Cognee SDK. + + Performs three main steps: + 1. Resets the demo workspace by pruning stored data and system metadata. + 2. Seeds three multilingual documents, builds translation-enabled Cognify tasks using the + provider specified by the COGNEE_TRANSLATION_PROVIDER environment variable (defaults to "noop"), + and executes the pipeline to translate and process the documents. + - If the selected provider is missing or invalid, the function prints the error and returns early. + 3. Issues an English search query (using SearchType.INSIGHTS) against the processed index and + prints any returned result texts. + + Side effects: + - Mutates persistent Cognee state (prune, add, cognify pipeline execution). + - Prints status and result messages to stdout. + + Notes: + - No return value. + - Exceptions ValueError and ImportError are caught and handled by printing an error and exiting the function. + """ + # 1. Set up cognee and add multilingual content + print("Setting up demo environment...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + multilingual_texts = [ + "El procesamiento de lenguaje natural (PLN) es un subcampo de la IA.", + "Le traitement automatique du langage naturel (TALN) est un sous-domaine de l'IA.", + "Natural language processing (NLP) is a subfield of AI.", + ] + + print("Adding multilingual texts...") + for text in multilingual_texts: + await cognee.add(text) + print("Texts added successfully.\n") + + # 2. Run the cognify pipeline with translation enabled + provider = os.getenv('COGNEE_TRANSLATION_PROVIDER', 'noop').lower() + print(f"Running cognify with translation provider: {provider}") + + if provider not in get_available_providers(): + print(f"Unknown provider: {provider}. Available: {', '.join(get_available_providers())}") + return + + try: + # Build translation-enabled tasks and execute the pipeline + translation_enabled_tasks = get_default_tasks_with_translation( + translation_provider=provider + ) + async for _ in run_pipeline(tasks=translation_enabled_tasks): + pass + print("Cognify pipeline with translation completed successfully.") + except (ValueError, ImportError, RuntimeError) as e: + print(f"Error during cognify: {e}") + print("Please ensure the selected provider is installed and configured correctly.") + return + + # 3. Search for content in English + query_text = "Tell me about NLP" + print(f"\nSearching for: '{query_text}'") + + # The search should now return results from all documents, as they have been translated. + search_results = await cognee.search(query_text, query_type=SearchType.INSIGHTS) + + print("\nSearch Results:") + if search_results: + for result in search_results: + print(f"- {result.text}") + else: + print("No results found.") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index bbe3732e92..668f0d83a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,10 @@ dependencies = [ "langfuse>=2.32.0,<3", "filetype>=1.2.0,<2.0.0", "aiohttp>=3.11.14,<4.0.0", + # httpx/httpcore pinned for compatibility with googletrans (prevents + # "module 'httpcore' has no attribute 'SyncHTTPTransport'" errors) + "httpx==0.23.3", + "httpcore==0.13.7", "aiofiles>=23.2.1,<24.0.0", "rdflib>=7.1.4,<7.2.0", "pypdf>=4.1.0,<7.0.0", @@ -121,6 +125,11 @@ gui = [ "qasync>=0.27.1,<0.28", ] graphiti = ["graphiti-core>=0.7.0,<0.8"] +translation = [ + "langdetect>=1.0.9,<2.0.0", + "googletrans==4.0.0rc1", + "azure-ai-translation-text>=1.0.0,<2.0.0", +] # Note: New s3fs and boto3 versions don't work well together # Always use comaptible fixed versions of these two dependencies aws = ["s3fs[boto3]==2025.3.2"]