Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
zhipu ai embeddings model support dimensions parameters
spring.ai.zhipuai.embedding.options.dimensions=1024
  • Loading branch information
yp committed Dec 6, 2024
commit 010c5f9276ed5654d909e151b46668f86512901b
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,13 @@ private ZhiPuAiEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeO

return ZhiPuAiEmbeddingOptions.builder()
.withModel(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
.withDimensions(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getDimensions(),
defaultOptions.getDimensions()))
.build();
}

private ZhiPuAiApi.EmbeddingRequest<String> createEmbeddingRequest(String text, EmbeddingOptions requestOptions) {
return new ZhiPuAiApi.EmbeddingRequest<>(text, requestOptions.getModel());
return new ZhiPuAiApi.EmbeddingRequest<>(text, requestOptions.getModel(), requestOptions.getDimensions());
}

public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.ai.zhipuai;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
Expand All @@ -33,6 +32,12 @@
@JsonInclude(Include.NON_NULL)
public class ZhiPuAiEmbeddingOptions implements EmbeddingOptions {

/**
* The default vector dimension is 2048. The model supports custom vector dimensions,
* and it is recommended to choose dimensions of 256, 512, 1024, or 2048.
*/
private @JsonProperty("dimensions") Integer dimensions;

// @formatter:off
/**
* ID of the model to use.
Expand All @@ -54,9 +59,12 @@ public void setModel(String model) {
}

@Override
@JsonIgnore
public Integer getDimensions() {
return null;
return this.dimensions;
}

public void setDimensions(Integer dimensions) {
this.dimensions = dimensions;
}

public static class Builder {
Expand All @@ -72,6 +80,11 @@ public Builder withModel(String model) {
return this;
}

public Builder withDimensions(Integer dimensions) {
this.options.setDimensions(dimensions);
return this;
}

public ZhiPuAiEmbeddingOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,10 @@

package org.springframework.ai.zhipuai.api;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Predicate;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
Expand All @@ -43,6 +32,16 @@
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Predicate;

// @formatter:off
/**
Expand Down Expand Up @@ -956,15 +955,18 @@ public String toString() {
@JsonInclude(Include.NON_NULL)
public record EmbeddingRequest<T>(
@JsonProperty("input") T input,
@JsonProperty("model") String model) {

@JsonProperty("model") String model,

@JsonProperty("dimensions") Integer dimensions) {


/**
* Create an embedding request with the given input. Encoding model is set to 'embedding-2'.
* @param input Input text to embed.
*/
public EmbeddingRequest(T input) {
this(input, DEFAULT_EMBEDDING_MODEL);
public EmbeddingRequest(T input, Integer dimensions) {
this(input,DEFAULT_EMBEDDING_MODEL,dimensions);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void chatCompletionStream() {
@Test
void embeddings() {
ResponseEntity<EmbeddingList<Embedding>> response = this.zhiPuAiApi
.embeddings(new ZhiPuAiApi.EmbeddingRequest<>("Hello world"));
.embeddings(new ZhiPuAiApi.EmbeddingRequest<>("Hello world", ZhiPuAiApi.DEFAULT_EMBEDDING_MODEL, 1024));

assertThat(response).isNotNull();
assertThat(Objects.requireNonNull(response.getBody()).data()).hasSize(1);
Expand Down