diff --git a/docs/my-website/docs/caching/all_caches.md b/docs/my-website/docs/caching/all_caches.md index 0548c331f805..3889e7369dac 100644 --- a/docs/my-website/docs/caching/all_caches.md +++ b/docs/my-website/docs/caching/all_caches.md @@ -211,7 +211,7 @@ response2 = completion( Install redisvl client ```shell -pip install redisvl==0.4.1 +pip install redisvl==0.12.1 ``` For the hosted version you can setup your own Redis DB here: https://redis.io/try-free/ @@ -234,6 +234,9 @@ litellm.cache = Cache( similarity_threshold=0.8, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity ttl=120, redis_semantic_cache_embedding_model="text-embedding-ada-002", # this model is passed to litellm.embedding(), any litellm.embedding() model is supported here + embedding_cache_enabled=True, # enable caching of embeddings + embedding_cache_ttl=120, # ttl for embeddings cache + embedding_cache_name="litellm-embeddings-cache", # name for embeddings cache default is litellm_redis_semantic_embeddings_cache ) response1 = completion( model="gpt-3.5-turbo", @@ -603,6 +606,9 @@ def __init__( similarity_threshold: Optional[float] = None, redis_semantic_cache_embedding_model: str = "text-embedding-ada-002", redis_semantic_cache_index_name: Optional[str] = None, + embedding_cache_enabled: bool = False, + embedding_cache_ttl: Optional[int] = None, + embedding_cache_name: Optional[str] = None, # s3 Bucket, boto3 configuration s3_bucket_name: Optional[str] = None, diff --git a/docs/my-website/docs/proxy/caching.md b/docs/my-website/docs/proxy/caching.md index 6da977c8b05c..67e35abac0ca 100644 --- a/docs/my-website/docs/proxy/caching.md +++ b/docs/my-website/docs/proxy/caching.md @@ -390,6 +390,11 @@ $ litellm --config /path/to/config.yaml Caching can be enabled by adding the `cache` key in the `config.yaml` +Install redisvl client if missing in requirements.txt +```shell +pip install redisvl==0.12.1 +``` + #### Step 1: Add `cache` to the config.yaml ```yaml model_list: @@ -410,6 +415,10 @@ litellm_settings: type: "redis-semantic" similarity_threshold: 0.8 # similarity threshold for semantic cache redis_semantic_cache_embedding_model: azure-embedding-model # set this to a model_name set in model_list + redis_semantic_cache_index_name: litellm-redis-semantic-cache-index # OPTIONAL, default is litellm_semantic_cache_index + embedding_cache_enabled: True # OPTIONAL, default is False + embedding_cache_ttl: 120 # OPTIONAL, default is None + embedding_cache_name: litellm-embeddings-cache # OPTIONAL, default is litellm_redis_semantic_embeddings_cache ``` #### Step 2: Add Redis Credentials to .env diff --git a/litellm/caching/caching.py b/litellm/caching/caching.py index 82fc37e0cb48..3679a5818130 100644 --- a/litellm/caching/caching.py +++ b/litellm/caching/caching.py @@ -98,6 +98,10 @@ def __init__( gcs_path: Optional[str] = None, redis_semantic_cache_embedding_model: str = "text-embedding-ada-002", redis_semantic_cache_index_name: Optional[str] = None, + # Embeddings cache (RedisVL) options for redis-semantic + redis_semantic_embedding_cache_enabled: bool = False, + redis_semantic_embedding_cache_ttl: Optional[int] = None, + redis_semantic_embedding_cache_name: Optional[str] = None, redis_flush_size: Optional[int] = None, redis_startup_nodes: Optional[List] = None, disk_cache_dir: Optional[str] = None, @@ -111,8 +115,7 @@ def __init__( gcp_ssl_ca_certs: Optional[str] = None, **kwargs, ): - """ - Initializes the cache based on the given type. + """Initializes the cache based on the given type. Args: type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local". @@ -195,6 +198,9 @@ def __init__( similarity_threshold=similarity_threshold, embedding_model=redis_semantic_cache_embedding_model, index_name=redis_semantic_cache_index_name, + embedding_cache_enabled=redis_semantic_embedding_cache_enabled, + embedding_cache_ttl=redis_semantic_embedding_cache_ttl, + embedding_cache_name=redis_semantic_embedding_cache_name, **kwargs, ) elif type == LiteLLMCacheType.QDRANT_SEMANTIC: diff --git a/litellm/caching/redis_semantic_cache.py b/litellm/caching/redis_semantic_cache.py index c76f27377d8e..e697c8c52697 100644 --- a/litellm/caching/redis_semantic_cache.py +++ b/litellm/caching/redis_semantic_cache.py @@ -35,6 +35,7 @@ class RedisSemanticCache(BaseCache): """ DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index" + DEFAULT_REDIS_EMBEDDINGS_CACHE_NAME: str = "litellm_redis_semantic_embeddings_cache" def __init__( self, @@ -45,10 +46,13 @@ def __init__( similarity_threshold: Optional[float] = None, embedding_model: str = "text-embedding-ada-002", index_name: Optional[str] = None, + # Embeddings cache (RedisVL) configuration + embedding_cache_enabled: bool = False, + embedding_cache_ttl: Optional[int] = None, + embedding_cache_name: Optional[str] = None, **kwargs, ): - """ - Initialize the Redis Semantic Cache. + """Initialize the Redis Semantic Cache. Args: host: Redis host address @@ -59,6 +63,9 @@ def __init__( where 1.0 requires exact matches and 0.0 accepts any match embedding_model: Model to use for generating embeddings index_name: Name for the Redis index + embedding_cache_enabled: Whether to enable RedisVL EmbeddingsCache + embedding_cache_ttl: Default TTL for embeddings cache entries in seconds + embedding_cache_name: Optional name prefix for embeddings cache keys ttl: Default time-to-live for cache entries in seconds **kwargs: Additional arguments passed to the Redis client @@ -66,6 +73,8 @@ def __init__( Exception: If similarity_threshold is not provided or required Redis connection information is missing """ + # Import RedisVL components lazily to avoid hard dependency when not used + from redisvl.extensions.cache.embeddings import EmbeddingsCache from redisvl.extensions.llmcache import SemanticCache from redisvl.utils.vectorize import CustomTextVectorizer @@ -87,6 +96,15 @@ def __init__( self.distance_threshold = 1 - similarity_threshold self.embedding_model = embedding_model + # Embeddings cache configuration + self.embedding_cache_enabled: bool = embedding_cache_enabled + self.embedding_cache_ttl: Optional[int] = embedding_cache_ttl + if embedding_cache_name is None: + self.embedding_cache_name = self.DEFAULT_REDIS_EMBEDDINGS_CACHE_NAME + else: + self.embedding_cache_name = embedding_cache_name + self._embeddings_cache: Optional[EmbeddingsCache] = None + # Set up Redis connection if redis_url is None: try: @@ -106,8 +124,26 @@ def __init__( print_verbose(f"Redis semantic-cache redis_url: {redis_url}") + # Initialize the embeddings cache if enabled + if self.embedding_cache_enabled: + try: + self._embeddings_cache = EmbeddingsCache( + name=self.embedding_cache_name, + redis_url=redis_url, + ttl=self.embedding_cache_ttl, + ) + except Exception as e: # pragma: no cover - defensive, treat as non-fatal + print_verbose( + f"Redis semantic-cache: failed to initialize EmbeddingsCache, " + f"disabling embedding cache. Error: {str(e)}" + ) + self.embedding_cache_enabled = False + self._embeddings_cache = None + # Initialize the Redis vectorizer and cache - cache_vectorizer = CustomTextVectorizer(self._get_embedding) + cache_vectorizer = CustomTextVectorizer( + self._get_embedding, cache=self._embeddings_cache + ) self.llmcache = SemanticCache( name=index_name, @@ -117,6 +153,14 @@ def __init__( overwrite=False, ) + @property + def embeddings_cache(self): + """Expose the underlying EmbeddingsCache instance (if any). + + This is primarily for tests and potential reuse by other components. + """ + return self._embeddings_cache + def _get_ttl(self, **kwargs) -> Optional[int]: """ Get the TTL (time-to-live) value for cache entries. @@ -133,16 +177,16 @@ def _get_ttl(self, **kwargs) -> Optional[int]: return ttl def _get_embedding(self, prompt: str) -> List[float]: - """ - Generate an embedding vector for the given prompt using the configured embedding model. + """Generate an embedding vector for the given prompt. - Args: - prompt: The text to generate an embedding for - - Returns: - List[float]: The embedding vector + This is the sync embedding function used by RedisVL's CustomTextVectorizer. + It deliberately bypasses LiteLLM's high-level cache (cache={"no-store": True}) + because EmbeddingsCache is responsible for caching at this layer. """ - # Create an embedding from prompt + + # NOTE: EmbeddingsCache is already wired into CustomTextVectorizer via + # the `cache` parameter in __init__, so this method only needs to + # compute the embedding when there is an EmbeddingsCache miss. embedding_response = cast( EmbeddingResponse, litellm.embedding( @@ -151,8 +195,7 @@ def _get_embedding(self, prompt: str) -> List[float]: cache={"no-store": True, "no-cache": True}, ), ) - embedding = embedding_response["data"][0]["embedding"] - return embedding + return embedding_response["data"][0]["embedding"] def _get_cache_logic(self, cached_response: Any) -> Any: """ @@ -269,16 +312,28 @@ def get_cache(self, key: str, **kwargs) -> Any: print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}") async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]: + """Asynchronously generate an embedding for the given prompt. + + This is used by the async semantic cache paths. It first checks the + RedisVL EmbeddingsCache (if enabled) before falling back to the + underlying embedding provider. """ - Asynchronously generate an embedding for the given prompt. - Args: - prompt: The text to generate an embedding for - **kwargs: Additional arguments that may contain metadata + # Fast path: check embeddings cache if available + if self.embedding_cache_enabled and self._embeddings_cache is not None: + try: + cached = await self._embeddings_cache.aget( + text=prompt, + model_name=self.embedding_model, + ) + if cached is not None: + return cached["embedding"] # type: ignore[index] + except Exception as e: # pragma: no cover - defensive + print_verbose( + f"Redis semantic-cache: EmbeddingsCache.aget failed, " + f"falling back to provider. Error: {str(e)}" + ) - Returns: - List[float]: The embedding vector - """ from litellm.proxy.proxy_server import llm_model_list, llm_router # Route the embedding request through the proxy if appropriate @@ -310,8 +365,24 @@ async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]: cache={"no-store": True, "no-cache": True}, ) - # Extract and return the embedding vector - return embedding_response["data"][0]["embedding"] + embedding_vec = embedding_response["data"][0]["embedding"] + + # Store in embeddings cache (best effort) + if self.embedding_cache_enabled and self._embeddings_cache is not None: + try: + await self._embeddings_cache.aset( + text=prompt, + model_name=self.embedding_model, + embedding=embedding_vec, + ttl=self.embedding_cache_ttl, + ) + except Exception as e: # pragma: no cover - defensive + print_verbose( + "Redis semantic-cache: EmbeddingsCache.aset failed; " + f"continuing without cache. Error: {str(e)}" + ) + + return embedding_vec except Exception as e: print_verbose(f"Error generating async embedding: {str(e)}") raise ValueError(f"Failed to generate embedding: {str(e)}") from e diff --git a/tests/test_litellm/caching/test_redis_semantic_cache.py b/tests/test_litellm/caching/test_redis_semantic_cache.py index f9946e266fed..c226bd9047fd 100644 --- a/tests/test_litellm/caching/test_redis_semantic_cache.py +++ b/tests/test_litellm/caching/test_redis_semantic_cache.py @@ -144,3 +144,90 @@ async def test_redis_semantic_cache_async_get_cache(monkeypatch): # Verify methods were called redis_semantic_cache._get_async_embedding.assert_called_once() redis_semantic_cache.llmcache.acheck.assert_called_once() + + + +def test_redis_semantic_cache_embeddings_cache_enabled(monkeypatch): + # Create proper mocks for redisvl modules + mock_embeddings_cache_instance = MagicMock() + mock_embeddings_cache_class = MagicMock(return_value=mock_embeddings_cache_instance) + semantic_cache_mock = MagicMock() + custom_vectorizer_mock = MagicMock() + + with patch.dict( + "sys.modules", + { + "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock), + "redisvl.utils.vectorize": MagicMock( + CustomTextVectorizer=custom_vectorizer_mock + ), + "redisvl.extensions.cache.embeddings": MagicMock( + EmbeddingsCache=mock_embeddings_cache_class + ), + }, + ): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + # Set environment variables + monkeypatch.setenv("REDIS_HOST", "localhost") + monkeypatch.setenv("REDIS_PORT", "6379") + monkeypatch.setenv("REDIS_PASSWORD", "test_password") + + cache = RedisSemanticCache( + similarity_threshold=0.8, + embedding_cache_enabled=True, + embedding_cache_name="test_embed_cache", + embedding_cache_ttl=123, + ) + + # Ensure embeddings cache is initialized when enabled + assert cache.embedding_cache_enabled is True + assert cache.embeddings_cache is not None + assert cache.embeddings_cache is mock_embeddings_cache_instance + + # Verify EmbeddingsCache was called with correct parameters + mock_embeddings_cache_class.assert_called_once_with( + name="test_embed_cache", + redis_url="redis://:test_password@localhost:6379", + ttl=123, + ) + + +@pytest.mark.asyncio +async def test_redis_semantic_cache_async_embedding_uses_cache(monkeypatch): + # Patch redisvl modules and EmbeddingsCache class specifically + embeddings_cache_mock_cls = MagicMock() + embeddings_cache_instance = AsyncMock() + embeddings_cache_mock_cls.return_value = embeddings_cache_instance + + with patch.dict( + "sys.modules", + { + "redisvl.extensions.llmcache": MagicMock(), + "redisvl.utils.vectorize": MagicMock(), + "redisvl.extensions.cache.embeddings": MagicMock( + EmbeddingsCache=embeddings_cache_mock_cls + ), + }, + ): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + # Set environment variables + monkeypatch.setenv("REDIS_HOST", "localhost") + monkeypatch.setenv("REDIS_PORT", "6379") + monkeypatch.setenv("REDIS_PASSWORD", "test_password") + + # Embeddings cache returns a cached vector + embeddings_cache_instance.aget.return_value = {"embedding": [0.1, 0.2, 0.3]} + + cache = RedisSemanticCache( + similarity_threshold=0.8, + embedding_cache_enabled=True, + ) + + # Call internal async embedding helper + result = await cache._get_async_embedding("hello world") + + # Should have used the embeddings cache and not fallen back to provider + embeddings_cache_instance.aget.assert_awaited_once() + assert result == [0.1, 0.2, 0.3]