Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support openai embedding for topic clustering
  • Loading branch information
CodingWithTim committed Nov 24, 2023
commit 12d32a291ce9fe2827ae09fd0e47b7eba26a85dd
9 changes: 9 additions & 0 deletions fastchat/serve/monitor/summarize_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import argparse
import pickle

import pandas as pd

from fastchat.llm_judge.common import (
chat_compeletion_openai,
chat_compeletion_openai_azure,
Expand Down Expand Up @@ -74,3 +76,10 @@ def truncate_string(s, l):
print()
print(f"topics: {topics}")
print(f"percentages: {percentages}")

# save the informations
df = pd.DataFrame()
df["topic"] = topics
df["percentage"] = percentages

df.to_json(f"cluster_summary_{len(df)}.jsonl", lines=True, orient="records")
48 changes: 42 additions & 6 deletions fastchat/serve/monitor/topic_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sklearn.cluster import KMeans, AgglomerativeClustering
import torch
from tqdm import tqdm
from openai import OpenAI

from fastchat.utils import detect_language

Expand Down Expand Up @@ -46,6 +47,8 @@ def read_texts(input_file, min_length, max_length, english_only):
line_texts = [
x["content"] for x in l["conversation"] if x["role"] == "user"
]
elif "turns" in l:
line_texts = l["turns"]
Comment on lines +50 to +51
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry could you explain this a bit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, input json file has different format. For examples, the 54K prompt json file you sent me uses "turns" to store the conversation prompts. I added support for json files that uses "turns".


for text in line_texts:
text = text.strip()
Expand Down Expand Up @@ -89,6 +92,21 @@ def get_embeddings(texts, model_name, batch_size):
return embeddings.cpu()


# Support OpenAI Embedding
def get_openai_embeddings(texts, model_name, batch_size):
client = OpenAI()
texts = texts.tolist()

embeddings = []
for i in tqdm(range(0, len(texts), batch_size)):
text = texts[i : i + batch_size]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how's batch size affect total runtime?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't use batch size and just do a normal loop it will take several hours. With a batch size of 256, 54K prompts takes only around 3-4 minutes.

responses = client.embeddings.create(input=text, model=model_name).data
embeddings.extend([data.embedding for data in responses])

embeddings = torch.nn.functional.normalize(torch.tensor(embeddings), p=2, dim=1)
return embeddings.cpu()


def run_k_means(embeddings, num_clusters):
np.random.seed(42)
clustering_model = KMeans(n_clusters=num_clusters, n_init="auto")
Expand Down Expand Up @@ -218,18 +236,36 @@ def get_cluster_info(texts, labels, topk_indices):
)
parser.add_argument("--show-top-k", type=int, default=200)
parser.add_argument("--show-cut-off", type=int, default=512)
parser.add_argument("--save-embeddings", action="store_true")
parser.add_argument("--embeddings-file", type=str, default=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this embedding cache is awesome!

args = parser.parse_args()

num_clusters = args.num_clusters
show_top_k = args.show_top_k
show_cut_off = args.show_cut_off

texts = read_texts(
args.input_file, args.min_length, args.max_length, args.english_only
)
print(f"#text: {len(texts)}")
if args.embeddings_file is None:
if args.model == "text-embedding-ada-002":
texts = read_texts(
args.input_file, args.min_length, args.max_length, args.english_only
)
print(f"#text: {len(texts)}")
embeddings = get_openai_embeddings(texts, args.model, args.batch_size)
print(f"embeddings shape: {embeddings.shape}")
else:
texts = read_texts(
args.input_file, args.min_length, args.max_length, args.english_only
)
print(f"#text: {len(texts)}")
embeddings = get_embeddings(texts, args.model, args.batch_size)
print(f"embeddings shape: {embeddings.shape}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if args.model == "text-embedding-ada-002":
texts = read_texts(
args.input_file, args.min_length, args.max_length, args.english_only
)
print(f"#text: {len(texts)}")
embeddings = get_openai_embeddings(texts, args.model, args.batch_size)
print(f"embeddings shape: {embeddings.shape}")
else:
texts = read_texts(
args.input_file, args.min_length, args.max_length, args.english_only
)
print(f"#text: {len(texts)}")
embeddings = get_embeddings(texts, args.model, args.batch_size)
print(f"embeddings shape: {embeddings.shape}")
texts = read_texts(
args.input_file, args.min_length, args.max_length, args.english_only
)
print(f"#text: {len(texts)}")
if args.model == "text-embedding-ada-002":
embeddings = get_openai_embeddings(texts, args.model, args.batch_size)
else:
embeddings = get_embeddings(texts, args.model, args.batch_size)
print(f"embeddings shape: {embeddings.shape}")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'd be even cleaner if we can merge get_embeddings and get_openai_embeddings

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I just commited the new changes to merge the 2 embedding functions.


if args.save_embeddings:
# allow saving embedding to save time
torch.save(embeddings, "embeddings.pt")
else:
embeddings = torch.load(args.embeddings_file)

embeddings = get_embeddings(texts, args.model, args.batch_size)
if args.cluster_alg == "kmeans":
centers, labels = run_k_means(embeddings, num_clusters)
elif args.cluster_alg == "aggcls":
Expand All @@ -249,7 +285,7 @@ def get_cluster_info(texts, labels, topk_indices):
with open(filename_prefix + "_topk.txt", "w") as fout:
fout.write(topk_str)

with open(filename_prefix + "_all.txt", "w") as fout:
with open(filename_prefix + "_all.jsonl", "w") as fout:
for i in range(len(centers)):
tmp_indices = labels == i
tmp_embeddings = embeddings[tmp_indices]
Expand Down