Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from cognee.shared.logging_utils import get_logger
import math
from typing import List, Optional
import numpy as np
import math
import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
Expand Down Expand Up @@ -74,20 +75,34 @@ async def embed_text(self, text: List[str]) -> List[List[float]]:
return [data["embedding"] for data in response.data]

except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list):
if len(text) == 1:
parts = [text]
else:
parts = [text[0 : math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2) :]]

parts_futures = [self.embed_text(part) for part in parts]
embeddings = await asyncio.gather(*parts_futures)

all_embeddings = []
for embeddings_part in embeddings:
all_embeddings.extend(embeddings_part)
if isinstance(text, list) and len(text) > 1:
mid = math.ceil(len(text) / 2)
left, right = text[:mid], text[mid:]
left_vecs, right_vecs = await asyncio.gather(
self.embed_text(left),
self.embed_text(right),
)
return left_vecs + right_vecs

# If caller passed ONE oversize string split the string itself into
# half so we can process it
if isinstance(text, list) and len(text) == 1:
logger.debug(f"Pooling embeddings of text string with size: {len(text[0])}")
s = text[0]
third = len(s) // 3
# We are using thirds to intentionally have overlap between split parts
# for better embedding calculation
left_part, right_part = s[: third * 2], s[third:]

# Recursively embed the split parts in parallel
(left_vec,), (right_vec,) = await asyncio.gather(
self.embed_text([left_part]),
self.embed_text([right_part]),
)

return all_embeddings
# POOL the two embeddings into one
pooled = (np.array(left_vec) + np.array(right_vec)) / 2
return [pooled.tolist()]

logger.error("Context window exceeded for embedding text: %s", str(error))
raise error
Expand Down
2 changes: 1 addition & 1 deletion cognee/tasks/storage/index_data_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def index_data_points(data_points: list[DataPoint]):
field_name = index_name_and_field[first_occurence + 1 :]
try:
# In case the ammount if indexable points is too large we need to send them in batches
batch_size = 1000
batch_size = 100
for i in range(0, len(indexable_points), batch_size):
batch = indexable_points[i : i + batch_size]
await vector_engine.index_data_points(index_name, field_name, batch)
Expand Down
Loading