Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/my-website/docs/caching/all_caches.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ response2 = completion(

Install redisvl client
```shell
pip install redisvl==0.4.1
pip install redisvl==0.12.1
```

For the hosted version you can setup your own Redis DB here: https://redis.io/try-free/
Expand All @@ -234,6 +234,9 @@ litellm.cache = Cache(
similarity_threshold=0.8, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
ttl=120,
redis_semantic_cache_embedding_model="text-embedding-ada-002", # this model is passed to litellm.embedding(), any litellm.embedding() model is supported here
embedding_cache_enabled=True, # enable caching of embeddings
embedding_cache_ttl=120, # ttl for embeddings cache
embedding_cache_name="litellm-embeddings-cache", # name for embeddings cache default is litellm_redis_semantic_embeddings_cache
)
response1 = completion(
model="gpt-3.5-turbo",
Expand Down Expand Up @@ -603,6 +606,9 @@ def __init__(
similarity_threshold: Optional[float] = None,
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
redis_semantic_cache_index_name: Optional[str] = None,
embedding_cache_enabled: bool = False,
embedding_cache_ttl: Optional[int] = None,
embedding_cache_name: Optional[str] = None,

# s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None,
Expand Down
9 changes: 9 additions & 0 deletions docs/my-website/docs/proxy/caching.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,11 @@ $ litellm --config /path/to/config.yaml

Caching can be enabled by adding the `cache` key in the `config.yaml`

Install redisvl client if missing in requirements.txt
```shell
pip install redisvl==0.12.1
```

#### Step 1: Add `cache` to the config.yaml
```yaml
model_list:
Expand All @@ -410,6 +415,10 @@ litellm_settings:
type: "redis-semantic"
similarity_threshold: 0.8 # similarity threshold for semantic cache
redis_semantic_cache_embedding_model: azure-embedding-model # set this to a model_name set in model_list
redis_semantic_cache_index_name: litellm-redis-semantic-cache-index # OPTIONAL, default is litellm_semantic_cache_index
embedding_cache_enabled: True # OPTIONAL, default is False
embedding_cache_ttl: 120 # OPTIONAL, default is None
embedding_cache_name: litellm-embeddings-cache # OPTIONAL, default is litellm_redis_semantic_embeddings_cache
```

#### Step 2: Add Redis Credentials to .env
Expand Down
10 changes: 8 additions & 2 deletions litellm/caching/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def __init__(
gcs_path: Optional[str] = None,
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
redis_semantic_cache_index_name: Optional[str] = None,
# Embeddings cache (RedisVL) options for redis-semantic
redis_semantic_embedding_cache_enabled: bool = False,
redis_semantic_embedding_cache_ttl: Optional[int] = None,
redis_semantic_embedding_cache_name: Optional[str] = None,
redis_flush_size: Optional[int] = None,
redis_startup_nodes: Optional[List] = None,
disk_cache_dir: Optional[str] = None,
Expand All @@ -111,8 +115,7 @@ def __init__(
gcp_ssl_ca_certs: Optional[str] = None,
**kwargs,
):
"""
Initializes the cache based on the given type.
"""Initializes the cache based on the given type.

Args:
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
Expand Down Expand Up @@ -195,6 +198,9 @@ def __init__(
similarity_threshold=similarity_threshold,
embedding_model=redis_semantic_cache_embedding_model,
index_name=redis_semantic_cache_index_name,
embedding_cache_enabled=redis_semantic_embedding_cache_enabled,
embedding_cache_ttl=redis_semantic_embedding_cache_ttl,
embedding_cache_name=redis_semantic_embedding_cache_name,
**kwargs,
)
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
Expand Down
115 changes: 93 additions & 22 deletions litellm/caching/redis_semantic_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class RedisSemanticCache(BaseCache):
"""

DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
DEFAULT_REDIS_EMBEDDINGS_CACHE_NAME: str = "litellm_redis_semantic_embeddings_cache"

def __init__(
self,
Expand All @@ -45,10 +46,13 @@ def __init__(
similarity_threshold: Optional[float] = None,
embedding_model: str = "text-embedding-ada-002",
index_name: Optional[str] = None,
# Embeddings cache (RedisVL) configuration
embedding_cache_enabled: bool = False,
embedding_cache_ttl: Optional[int] = None,
embedding_cache_name: Optional[str] = None,
**kwargs,
):
"""
Initialize the Redis Semantic Cache.
"""Initialize the Redis Semantic Cache.

Args:
host: Redis host address
Expand All @@ -59,13 +63,18 @@ def __init__(
where 1.0 requires exact matches and 0.0 accepts any match
embedding_model: Model to use for generating embeddings
index_name: Name for the Redis index
embedding_cache_enabled: Whether to enable RedisVL EmbeddingsCache
embedding_cache_ttl: Default TTL for embeddings cache entries in seconds
embedding_cache_name: Optional name prefix for embeddings cache keys
ttl: Default time-to-live for cache entries in seconds
**kwargs: Additional arguments passed to the Redis client

Raises:
Exception: If similarity_threshold is not provided or required Redis
connection information is missing
"""
# Import RedisVL components lazily to avoid hard dependency when not used
from redisvl.extensions.cache.embeddings import EmbeddingsCache
from redisvl.extensions.llmcache import SemanticCache
from redisvl.utils.vectorize import CustomTextVectorizer

Expand All @@ -87,6 +96,15 @@ def __init__(
self.distance_threshold = 1 - similarity_threshold
self.embedding_model = embedding_model

# Embeddings cache configuration
self.embedding_cache_enabled: bool = embedding_cache_enabled
self.embedding_cache_ttl: Optional[int] = embedding_cache_ttl
if embedding_cache_name is None:
self.embedding_cache_name = self.DEFAULT_REDIS_EMBEDDINGS_CACHE_NAME
else:
self.embedding_cache_name = embedding_cache_name
self._embeddings_cache: Optional[EmbeddingsCache] = None

# Set up Redis connection
if redis_url is None:
try:
Expand All @@ -106,8 +124,26 @@ def __init__(

print_verbose(f"Redis semantic-cache redis_url: {redis_url}")

# Initialize the embeddings cache if enabled
if self.embedding_cache_enabled:
try:
self._embeddings_cache = EmbeddingsCache(
name=self.embedding_cache_name,
redis_url=redis_url,
ttl=self.embedding_cache_ttl,
)
except Exception as e: # pragma: no cover - defensive, treat as non-fatal
print_verbose(
f"Redis semantic-cache: failed to initialize EmbeddingsCache, "
f"disabling embedding cache. Error: {str(e)}"
)
self.embedding_cache_enabled = False
self._embeddings_cache = None

# Initialize the Redis vectorizer and cache
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
cache_vectorizer = CustomTextVectorizer(
self._get_embedding, cache=self._embeddings_cache
)

self.llmcache = SemanticCache(
name=index_name,
Expand All @@ -117,6 +153,14 @@ def __init__(
overwrite=False,
)

@property
def embeddings_cache(self):
"""Expose the underlying EmbeddingsCache instance (if any).

This is primarily for tests and potential reuse by other components.
"""
return self._embeddings_cache

def _get_ttl(self, **kwargs) -> Optional[int]:
"""
Get the TTL (time-to-live) value for cache entries.
Expand All @@ -133,16 +177,16 @@ def _get_ttl(self, **kwargs) -> Optional[int]:
return ttl

def _get_embedding(self, prompt: str) -> List[float]:
"""
Generate an embedding vector for the given prompt using the configured embedding model.
"""Generate an embedding vector for the given prompt.

Args:
prompt: The text to generate an embedding for

Returns:
List[float]: The embedding vector
This is the sync embedding function used by RedisVL's CustomTextVectorizer.
It deliberately bypasses LiteLLM's high-level cache (cache={"no-store": True})
because EmbeddingsCache is responsible for caching at this layer.
"""
# Create an embedding from prompt

# NOTE: EmbeddingsCache is already wired into CustomTextVectorizer via
# the `cache` parameter in __init__, so this method only needs to
# compute the embedding when there is an EmbeddingsCache miss.
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
Expand All @@ -151,8 +195,7 @@ def _get_embedding(self, prompt: str) -> List[float]:
cache={"no-store": True, "no-cache": True},
),
)
embedding = embedding_response["data"][0]["embedding"]
return embedding
return embedding_response["data"][0]["embedding"]

def _get_cache_logic(self, cached_response: Any) -> Any:
"""
Expand Down Expand Up @@ -269,16 +312,28 @@ def get_cache(self, key: str, **kwargs) -> Any:
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")

async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
"""Asynchronously generate an embedding for the given prompt.

This is used by the async semantic cache paths. It first checks the
RedisVL EmbeddingsCache (if enabled) before falling back to the
underlying embedding provider.
"""
Asynchronously generate an embedding for the given prompt.

Args:
prompt: The text to generate an embedding for
**kwargs: Additional arguments that may contain metadata
# Fast path: check embeddings cache if available
if self.embedding_cache_enabled and self._embeddings_cache is not None:
try:
cached = await self._embeddings_cache.aget(
text=prompt,
model_name=self.embedding_model,
)
if cached is not None:
return cached["embedding"] # type: ignore[index]
except Exception as e: # pragma: no cover - defensive
print_verbose(
f"Redis semantic-cache: EmbeddingsCache.aget failed, "
f"falling back to provider. Error: {str(e)}"
)

Returns:
List[float]: The embedding vector
"""
from litellm.proxy.proxy_server import llm_model_list, llm_router

# Route the embedding request through the proxy if appropriate
Expand Down Expand Up @@ -310,8 +365,24 @@ async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
cache={"no-store": True, "no-cache": True},
)

# Extract and return the embedding vector
return embedding_response["data"][0]["embedding"]
embedding_vec = embedding_response["data"][0]["embedding"]

# Store in embeddings cache (best effort)
if self.embedding_cache_enabled and self._embeddings_cache is not None:
try:
await self._embeddings_cache.aset(
text=prompt,
model_name=self.embedding_model,
embedding=embedding_vec,
ttl=self.embedding_cache_ttl,
)
except Exception as e: # pragma: no cover - defensive
print_verbose(
"Redis semantic-cache: EmbeddingsCache.aset failed; "
f"continuing without cache. Error: {str(e)}"
)

return embedding_vec
except Exception as e:
print_verbose(f"Error generating async embedding: {str(e)}")
raise ValueError(f"Failed to generate embedding: {str(e)}") from e
Expand Down
87 changes: 87 additions & 0 deletions tests/test_litellm/caching/test_redis_semantic_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,90 @@ async def test_redis_semantic_cache_async_get_cache(monkeypatch):
# Verify methods were called
redis_semantic_cache._get_async_embedding.assert_called_once()
redis_semantic_cache.llmcache.acheck.assert_called_once()



def test_redis_semantic_cache_embeddings_cache_enabled(monkeypatch):
# Create proper mocks for redisvl modules
mock_embeddings_cache_instance = MagicMock()
mock_embeddings_cache_class = MagicMock(return_value=mock_embeddings_cache_instance)
semantic_cache_mock = MagicMock()
custom_vectorizer_mock = MagicMock()

with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
"redisvl.extensions.cache.embeddings": MagicMock(
EmbeddingsCache=mock_embeddings_cache_class
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache

# Set environment variables
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")

cache = RedisSemanticCache(
similarity_threshold=0.8,
embedding_cache_enabled=True,
embedding_cache_name="test_embed_cache",
embedding_cache_ttl=123,
)

# Ensure embeddings cache is initialized when enabled
assert cache.embedding_cache_enabled is True
assert cache.embeddings_cache is not None
assert cache.embeddings_cache is mock_embeddings_cache_instance

# Verify EmbeddingsCache was called with correct parameters
mock_embeddings_cache_class.assert_called_once_with(
name="test_embed_cache",
redis_url="redis://:test_password@localhost:6379",
ttl=123,
)


@pytest.mark.asyncio
async def test_redis_semantic_cache_async_embedding_uses_cache(monkeypatch):
# Patch redisvl modules and EmbeddingsCache class specifically
embeddings_cache_mock_cls = MagicMock()
embeddings_cache_instance = AsyncMock()
embeddings_cache_mock_cls.return_value = embeddings_cache_instance

with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(),
"redisvl.utils.vectorize": MagicMock(),
"redisvl.extensions.cache.embeddings": MagicMock(
EmbeddingsCache=embeddings_cache_mock_cls
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache

# Set environment variables
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")

# Embeddings cache returns a cached vector
embeddings_cache_instance.aget.return_value = {"embedding": [0.1, 0.2, 0.3]}

cache = RedisSemanticCache(
similarity_threshold=0.8,
embedding_cache_enabled=True,
)

# Call internal async embedding helper
result = await cache._get_async_embedding("hello world")

# Should have used the embeddings cache and not fallen back to provider
embeddings_cache_instance.aget.assert_awaited_once()
assert result == [0.1, 0.2, 0.3]
Loading