Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added support for multimodal embeddings, with tests
  • Loading branch information
jamesbraza committed Oct 9, 2025
commit 1fc5a0a0ce51511e0998339e5afa779fd9e700fe
46 changes: 36 additions & 10 deletions packages/lmi/src/lmi/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,37 @@
from lmi.cost_tracker import track_costs
from lmi.llms import PassThroughRouter
from lmi.rate_limiter import GLOBAL_LIMITER
from lmi.utils import get_litellm_retrying_config
from lmi.utils import get_litellm_retrying_config, is_encoded_image

URL_ENCODED_IMAGE_TOKEN_ESTIMATE = 85 # tokens


def estimate_tokens(
document: str
| list[str]
| list[litellm.ChatCompletionImageObject]
| list[litellm.types.llms.vertex_ai.PartType],
) -> float:
"""Estimate token count for rate limiting purposes."""
if isinstance(document, str): # Text or a data URL
return (
URL_ENCODED_IMAGE_TOKEN_ESTIMATE
if is_encoded_image(document)
else len(document) / CHARACTERS_PER_TOKEN_ASSUMPTION
)
# For multimodal content, estimate based on text parts and add fixed cost for images
token_count = 0.0
for part in document:
if isinstance(part, str): # Part of a batch of text or data URLs
token_count += estimate_tokens(part)
# Handle different multimodal formats
elif part.get("type") == "image_url": # OpenAI format
token_count += URL_ENCODED_IMAGE_TOKEN_ESTIMATE
elif ( # Gemini text format -- https://ai.google.dev/api#text-only-prompt
"text" in part
):
token_count += len(part["text"]) / CHARACTERS_PER_TOKEN_ASSUMPTION # type: ignore[typeddict-item]
return token_count


class EmbeddingModes(StrEnum):
Expand All @@ -39,7 +69,7 @@ def set_mode(self, mode: EmbeddingModes) -> None:

@abstractmethod
async def embed_documents(self, texts: list[str]) -> list[list[float]]:
pass
"""Embed a list of documents."""

async def embed_document(self, text: str) -> list[float]:
return (await self.embed_documents([text]))[0]
Expand Down Expand Up @@ -138,7 +168,7 @@ def _truncate_if_large(self, texts: list[str]) -> list[str]:
# heuristic about ratio of tokens to characters
conservative_char_token_ratio = 3
maybe_too_large = max_tokens * conservative_char_token_ratio
if any(len(t) > maybe_too_large for t in texts):
if any(len(t) > maybe_too_large for t in texts if not is_encoded_image(t)):
try:
enct = tiktoken.encoding_for_model("cl100k_base")
enc_batch = enct.encode_ordinary_batch(texts)
Expand All @@ -154,16 +184,12 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]:
N = len(texts)
embeddings = []
for i in range(0, N, batch_size):
await self.check_rate_limit(
sum(
len(t) / CHARACTERS_PER_TOKEN_ASSUMPTION
for t in texts[i : i + batch_size]
)
)
batch = texts[i : i + batch_size]
await self.check_rate_limit(sum(estimate_tokens(t) for t in batch))

response = await track_costs(self.router.aembedding)(
model=self.name,
input=texts[i : i + batch_size],
input=batch,
dimensions=self.ndim,
**self.config.get("kwargs", {}),
)
Expand Down
76 changes: 75 additions & 1 deletion packages/lmi/tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import litellm
import pytest
import tiktoken
from litellm.caching import Cache, InMemoryCache
from pytest_subtests import SubTests

Expand All @@ -15,8 +16,34 @@
SentenceTransformerEmbeddingModel,
SparseEmbeddingModel,
embedding_model_factory,
estimate_tokens,
)
from lmi.utils import VCR_DEFAULT_MATCH_ON
from lmi.utils import VCR_DEFAULT_MATCH_ON, encode_image_as_url


def test_estimate_tokens(subtests: SubTests, png_image: bytes) -> None:
with subtests.test(msg="text only"):
text_only = "Hello world"
text_only_estimated_token_count = estimate_tokens(text_only)
assert text_only_estimated_token_count == 2.75, (
"Expected a reasonable token estimate"
)
text_only_actual_token_count = len(
tiktoken.get_encoding("cl100k_base").encode(text_only)
)
assert text_only_estimated_token_count == pytest.approx(
text_only_actual_token_count, abs=1
), "Estimation should be within one token of what tiktoken"

# Test multimodal (text + image)
with subtests.test(msg="multimodal"): # Text + image
multimodal = [
"What is in this image?",
encode_image_as_url(image_type="png", image_data=png_image),
]
assert estimate_tokens(multimodal) == 90.5, (
"Expected a reasonable token estimate"
)


class TestLiteLLMEmbeddingModel:
Expand Down Expand Up @@ -231,6 +258,53 @@ async def test_router_usage(
# Confirm use of the sentinel timeout in the Router's model_list or pass through
assert mock_aembedding.call_args.kwargs["timeout"] == self.SENTINEL_TIMEOUT

@pytest.mark.asyncio
async def test_multimodal_embedding(
self, subtests: SubTests, png_image_gcs: str
) -> None:
multimodal_model = LiteLLMEmbeddingModel(
name=f"{litellm.LlmProviders.VERTEX_AI.value}/multimodalembedding@001"
)

with subtests.test(msg="text or image only"):
embedding_text_only = await multimodal_model.embed_document("Some text")
assert len(embedding_text_only) == 1408
assert all(isinstance(x, float) for x in embedding_text_only)

embedding_image_only = await multimodal_model.embed_document(png_image_gcs)
assert len(embedding_image_only) == 1408
assert all(isinstance(x, float) for x in embedding_image_only)

assert embedding_image_only != embedding_text_only

with subtests.test(msg="text and image mixing"):
(embedding_image_text,) = await multimodal_model.embed_documents([
"What is in this image?",
png_image_gcs,
])
assert len(embedding_image_text) == 1408
assert all(isinstance(x, float) for x in embedding_image_text)

(embedding_two_images,) = await multimodal_model.embed_documents([
png_image_gcs,
png_image_gcs,
])
assert len(embedding_two_images) == 1408
assert all(isinstance(x, float) for x in embedding_two_images)

assert embedding_image_text != embedding_two_images

with subtests.test(msg="batching"):
multimodal_model.config["batch_size"] = 1
embeddings = await multimodal_model.embed_documents([
"Some text",
png_image_gcs,
])
assert len(embeddings) == 2
for embedding in embeddings:
assert len(embedding) == 1408
assert all(isinstance(x, float) for x in embedding)


@pytest.mark.asyncio
async def test_sparse_embedding_model(subtests: SubTests):
Expand Down