diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index b5d780d60f..dce12b318f 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -3,6 +3,7 @@ import math from typing import List, Optional import litellm +import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine litellm.set_verbose = False @@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): api_version: str model: str dimensions: int + mock:bool def __init__( self, @@ -29,6 +31,11 @@ def __init__( self.model = model self.dimensions = dimensions + enable_mocking = os.getenv("MOCK_EMBEDDING", "false") + if isinstance(enable_mocking, bool): + enable_mocking= str(enable_mocking).lower() + self.mock = enable_mocking in ("true", "1", "yes") + MAX_RETRIES = 5 retry_count = 0 @@ -38,17 +45,26 @@ async def exponential_backoff(attempt): await asyncio.sleep(wait_time) try: - response = await litellm.aembedding( - self.model, - input = text, - api_key = self.api_key, - api_base = self.endpoint, - api_version = self.api_version - ) - - self.retry_count = 0 - - return [data["embedding"] for data in response.data] + if self.mock: + response = { + "data": [{"embedding": [0.0] * self.dimensions} for _ in text] + } + + self.retry_count = 0 + + return [data["embedding"] for data in response["data"]] + else: + response = await litellm.aembedding( + self.model, + input = text, + api_key = self.api_key, + api_base = self.endpoint, + api_version = self.api_version + ) + + self.retry_count = 0 + + return [data["embedding"] for data in response.data] except litellm.exceptions.ContextWindowExceededError as error: if isinstance(text, list): diff --git a/cognee/tasks/repo_processor/get_repo_file_dependencies.py b/cognee/tasks/repo_processor/get_repo_file_dependencies.py index 221af6cf68..b828707966 100644 --- a/cognee/tasks/repo_processor/get_repo_file_dependencies.py +++ b/cognee/tasks/repo_processor/get_repo_file_dependencies.py @@ -73,7 +73,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non yield repo - with ProcessPoolExecutor(max_workers = 12) as executor: + with ProcessPoolExecutor() as executor: loop = asyncio.get_event_loop() tasks = [