Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
MetadataModel moved to AbstractEmbeddingModel.
Embedding models, that didn't have support for a MetadataModel got an additional constructor with MetadataModel as the second argument.
I did not include it in all the other constructors since this would break code.

AzureOpenAiEmbeddingModel had special handling for empty results, that got lost in the change.
Not sure if that is ok.

TransformersEmbeddingModel has  as default, which differs from the others. Should that be changed to , which seems the most reasonable default to me and is used in the other models?
  • Loading branch information
schauder committed Apr 10, 2025
commit f46708676116fb513f821a792486228619b8c7a9
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {

private final AzureOpenAiEmbeddingOptions defaultOptions;

private final MetadataMode metadataMode;

/**
* Observation registry used for instrumentation.
*/
Expand Down Expand Up @@ -92,30 +90,16 @@ public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode me
public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode,
AzureOpenAiEmbeddingOptions options, ObservationRegistry observationRegistry) {

super(metadataMode);

Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(metadataMode, "Metadata mode must not be null");
Assert.notNull(options, "Options must not be null");
Assert.notNull(observationRegistry, "Observation registry must not be null");
this.azureOpenAiClient = azureOpenAiClient;
this.metadataMode = metadataMode;
this.defaultOptions = options;
this.observationRegistry = observationRegistry;
}

@Override
public float[] embed(Document document) {
logger.debug("Retrieving embeddings");

EmbeddingResponse response = this
.call(new EmbeddingRequest(List.of(document.getFormattedContent(this.metadataMode)), null));
logger.debug("Embeddings retrieved");

if (CollectionUtils.isEmpty(response.getResults())) {
return new float[0];
}
return response.getResults().get(0).getOutput();
}

@Override
public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
logger.debug("Retrieving embeddings");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
Expand Down Expand Up @@ -65,17 +66,21 @@ public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedr

public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi,
BedrockCohereEmbeddingOptions options) {

this(cohereEmbeddingBedrockApi, MetadataMode.EMBED, options);
}

public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi, MetadataMode metadataMode,
BedrockCohereEmbeddingOptions options) {

super(metadataMode);

Assert.notNull(cohereEmbeddingBedrockApi, "CohereEmbeddingBedrockApi must not be null");
Assert.notNull(options, "BedrockCohereEmbeddingOptions must not be null");
this.embeddingApi = cohereEmbeddingBedrockApi;
this.defaultOptions = options;
}

@Override
public float[] embed(Document document) {
return embed(document.getText());
}

@Override
public EmbeddingResponse call(EmbeddingRequest request) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
Expand Down Expand Up @@ -57,6 +58,13 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel {
private InputType inputType = InputType.TEXT;

public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi) {
this(titanEmbeddingBedrockApi, MetadataMode.EMBED);
}

public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi, MetadataMode metadataMode) {

super(metadataMode);

this.embeddingApi = titanEmbeddingBedrockApi;
}

Expand All @@ -69,11 +77,6 @@ public BedrockTitanEmbeddingModel withInputType(InputType inputType) {
return this;
}

@Override
public float[] embed(Document document) {
return embed(document.getText());
}

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ public class MiniMaxEmbeddingModel extends AbstractEmbeddingModel {

private final MiniMaxApi miniMaxApi;

private final MetadataMode metadataMode;

/**
* Observation registry used for instrumentation.
*/
Expand Down Expand Up @@ -128,25 +126,20 @@ public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode,
*/
public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode, MiniMaxEmbeddingOptions options,
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {

super(metadataMode);

Assert.notNull(miniMaxApi, "MiniMaxApi must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");

this.miniMaxApi = miniMaxApi;
this.metadataMode = metadataMode;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
}

@Override
public float[] embed(Document document) {
Assert.notNull(document, "Document must not be null");
return this.embed(document.getFormattedContent(this.metadataMode));
}

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
MiniMaxEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {

private final MistralAiEmbeddingOptions defaultOptions;

private final MetadataMode metadataMode;

private final MistralAiApi mistralAiApi;

private final RetryTemplate retryTemplate;
Expand Down Expand Up @@ -94,14 +92,16 @@ public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataM

public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode,
MistralAiEmbeddingOptions options, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {

super(metadataMode);

Assert.notNull(mistralAiApi, "mistralAiApi must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");

this.mistralAiApi = mistralAiApi;
this.metadataMode = metadataMode;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
Expand Down Expand Up @@ -174,12 +174,6 @@ private MistralAiApi.EmbeddingRequest<List<String>> createRequest(EmbeddingReque
requestOptions.getEncodingFormat());
}

@Override
public float[] embed(Document document) {
Assert.notNull(document, "Document must not be null");
return this.embed(document.getFormattedContent(this.metadataMode));
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
Expand Down Expand Up @@ -72,6 +73,14 @@ public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions option

public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions options,
ObservationRegistry observationRegistry) {
this(genAi, MetadataMode.EMBED, options, observationRegistry);
}

public OCIEmbeddingModel(GenerativeAiInference genAi, MetadataMode metadataMode, OCIEmbeddingOptions options,
ObservationRegistry observationRegistry) {

super(metadataMode);

Assert.notNull(genAi, "com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Expand All @@ -98,11 +107,6 @@ public EmbeddingResponse call(EmbeddingRequest request) {
.observe(() -> embedAllWithContext(embedTextRequests, context));
}

@Override
public float[] embed(Document document) {
return embed(document.getText());
}

private EmbeddingResponse embedAllWithContext(List<EmbedTextRequest> embedTextRequests,
EmbeddingModelObservationContext context) {
String modelId = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingModel;
Expand Down Expand Up @@ -78,6 +79,14 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel {

public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
this(ollamaApi, MetadataMode.EMBED, defaultOptions, observationRegistry, modelManagementOptions);
}

public OllamaEmbeddingModel(OllamaApi ollamaApi, MetadataMode metadataMode, OllamaOptions defaultOptions,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {

super(metadataMode);

Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "options must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Expand All @@ -95,11 +104,6 @@ public static Builder builder() {
return new Builder();
}

@Override
public float[] embed(Document document) {
return embed(document.getText());
}

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ public class OpenAiEmbeddingModel extends AbstractEmbeddingModel {

private final OpenAiApi openAiApi;

private final MetadataMode metadataMode;

/**
* Observation registry used for instrumentation.
*/
Expand Down Expand Up @@ -128,25 +126,20 @@ public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode, Open
*/
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options,
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {

super(metadataMode);

Assert.notNull(openAiApi, "openAiApi must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");

this.openAiApi = openAiApi;
this.metadataMode = metadataMode;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
}

@Override
public float[] embed(Document document) {
Assert.notNull(document, "Document must not be null");
return this.embed(document.getFormattedContent(this.metadataMode));
}

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
// Before moving any further, build the final request EmbeddingRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
Expand Down Expand Up @@ -75,6 +76,21 @@ public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOp
*/
public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOptions options,
boolean createExtension) {
this(jdbcTemplate, MetadataMode.EMBED, options, createExtension);
}

/**
* a PostgresMlEmbeddingModel constructor
* @param jdbcTemplate JdbcTemplate to use to interact with the database.
* @param metadataMode MetadataMode describing what metadata values are included in
* the embedding.
* @param options PostgresMlEmbeddingOptions to configure the client.
*/
public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, MetadataMode metadataMode,
PostgresMlEmbeddingOptions options, boolean createExtension) {

super(metadataMode);

Assert.notNull(jdbcTemplate, "jdbc template must not be null.");
Assert.notNull(options, "options must not be null.");
Assert.notNull(options.getTransformer(), "transformer must not be null.");
Expand All @@ -96,11 +112,6 @@ public float[] embed(String text) {
ModelOptionsUtils.toJsonString(this.defaultOptions.getKwargs()));
}

@Override
public float[] embed(Document document) {
return this.embed(document.getFormattedContent(this.defaultOptions.getMetadataMode()));
}

@SuppressWarnings("null")
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ public class QianFanEmbeddingModel extends AbstractEmbeddingModel {

private final QianFanApi qianFanApi;

private final MetadataMode metadataMode;

/**
* Observation registry used for instrumentation.
*/
Expand Down Expand Up @@ -126,25 +124,20 @@ public QianFanEmbeddingModel(QianFanApi qianFanApi, MetadataMode metadataMode,
*/
public QianFanEmbeddingModel(QianFanApi qianFanApi, MetadataMode metadataMode, QianFanEmbeddingOptions options,
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {

super(metadataMode);

Assert.notNull(qianFanApi, "QianFanApi must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");

this.qianFanApi = qianFanApi;
this.metadataMode = metadataMode;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
}

@Override
public float[] embed(Document document) {
Assert.notNull(document, "Document must not be null");
return this.embed(document.getFormattedContent(this.metadataMode));
}

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
QianFanEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions);
Expand Down
Loading