Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
from typing import List, Optional
import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine

litellm.set_verbose = False
Expand All @@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_version: str
model: str
dimensions: int
mock:bool

def __init__(
self,
Expand All @@ -28,6 +30,7 @@ def __init__(
self.api_version = api_version
self.model = model
self.dimensions = dimensions
self.mock = os.getenv("MOCK_EMBEDDING", False).lower() in ("true", "1", "yes")

MAX_RETRIES = 5
retry_count = 0
Expand All @@ -38,17 +41,26 @@ async def exponential_backoff(attempt):
await asyncio.sleep(wait_time)

try:
response = await litellm.aembedding(
self.model,
input = text,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)

self.retry_count = 0

return [data["embedding"] for data in response.data]
if self.mock:
response = {
"data": [{"embedding": [0.0] * self.dimensions} for _ in text]
}

self.retry_count = 0

return [data["embedding"] for data in response["data"]]
else:
response = await litellm.aembedding(
self.model,
input = text,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)

self.retry_count = 0

return [data["embedding"] for data in response.data]

except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list):
Expand Down
2 changes: 1 addition & 1 deletion cognee/tasks/repo_processor/get_repo_file_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non

yield repo

with ProcessPoolExecutor(max_workers = 12) as executor:
with ProcessPoolExecutor() as executor:
loop = asyncio.get_event_loop()

tasks = [
Expand Down
Loading