Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
from_metadata()
  • Loading branch information
jingyizhu99 committed Jan 30, 2024
commit e7b2c170b0b0ceefd92cf8c308d40dfa3a70b7d5
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import OrderedDict
from typing import Callable, List, Optional, Union

import cloudpickle
from azure.core.credentials import TokenCredential
from azure.ai.resources._index._embeddings.openai import OpenAIEmbedder
from azure.ai.resources._index._langchain.vendor.embeddings.base import Embeddings as Embedder
Expand Down Expand Up @@ -223,4 +224,24 @@ def from_uri(uri: str, credential: Optional[TokenCredential] = None, **kwargs) -
"""Create an embeddings object from a URI."""
config = parse_model_uri(uri, **kwargs)
kwargs["credential"] = credential
return EmbeddingsContainer(**{**config, **kwargs})
return EmbeddingsContainer(**{**config, **kwargs})

@staticmethod
def from_metadata(metadata: dict) -> "EmbeddingsContainer":
"""Create an embeddings object from metadata."""
schema_version = metadata.get("schema_version", "1")
if schema_version == "1":
embeddings = EmbeddingsContainer(metadata["kind"], **metadata["arguments"])
return embeddings
elif schema_version == "2":
kind = metadata["kind"]
del metadata["kind"]
if kind == "custom":
metadata["embedding_fn"] = cloudpickle.loads(
gzip.decompress(metadata["pickled_embedding_fn"]))
del metadata["pickled_embedding_fn"]

embeddings = EmbeddingsContainer(kind, **metadata)
return embeddings
else:
raise ValueError(f"Schema version {schema_version} is not supported")