Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,6 +1702,8 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model.config.max_sequence_length = min(
model.config.max_position_embeddings, tokenizer.model_max_length
)
model.use_cls_pooling = True
model.eval()
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
Expand Down
51 changes: 45 additions & 6 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
data = model_output.hidden_states[-1].transpose(0, 1)
else:
data = model_output.hidden_states[-1]
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)

if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling:
sum_embeddings = data[:, 0]
else:
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
token_num = torch.sum(attention_mask).item()

return sum_embeddings, token_num
Expand Down Expand Up @@ -211,10 +215,14 @@ def get_embeddings(self, params):
base64_encode = params.get("encoding_format", None)

if self.embed_in_truncate:
chunk_embeddings, token_num = self.__process_embed_chunk(
embedding, token_num = self.__process_embed_chunk(
input_ids, attention_mask, **model_type_dict
)
embedding = chunk_embeddings / token_num
if (
not hasattr(self.model, "use_cls_pooling")
or not self.model.use_cls_pooling
):
embedding = embedding / token_num
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret["token_num"] = token_num
else:
Expand All @@ -224,10 +232,41 @@ def get_embeddings(self, params):
chunk_input_ids = input_ids[:, i : i + self.context_len]
chunk_attention_mask = attention_mask[:, i : i + self.context_len]

# add cls token and mask to get cls embedding
if (
hasattr(self.model, "use_cls_pooling")
and self.model.use_cls_pooling
):
cls_tokens = (
torch.zeros(
(chunk_input_ids.size(0), 1),
dtype=chunk_input_ids.dtype,
device=chunk_input_ids.device,
)
+ tokenizer.cls_token_id
)
chunk_input_ids = torch.cat(
[cls_tokens, chunk_input_ids], dim=-1
)
mask = torch.ones(
(chunk_attention_mask.size(0), 1),
dtype=chunk_attention_mask.dtype,
device=chunk_attention_mask.device,
)
chunk_attention_mask = torch.cat(
[mask, chunk_attention_mask], dim=-1
)

chunk_embeddings, token_num = self.__process_embed_chunk(
chunk_input_ids, chunk_attention_mask, **model_type_dict
)
all_embeddings.append(chunk_embeddings)
if (
hasattr(self.model, "use_cls_pooling")
and self.model.use_cls_pooling
):
all_embeddings.append(chunk_embeddings * token_num)
else:
all_embeddings.append(chunk_embeddings)
all_token_num += token_num

all_embeddings_tensor = torch.stack(all_embeddings)
Expand Down