-
Notifications
You must be signed in to change notification settings - Fork 4.8k
support openai embedding for topic clustering #2729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| from sklearn.cluster import KMeans, AgglomerativeClustering | ||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||
| from tqdm import tqdm | ||||||||||||||||||||||||||||||||||||||||||||||||
| from openai import OpenAI | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| from fastchat.utils import detect_language | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -46,6 +47,8 @@ def read_texts(input_file, min_length, max_length, english_only): | |||||||||||||||||||||||||||||||||||||||||||||||
| line_texts = [ | ||||||||||||||||||||||||||||||||||||||||||||||||
| x["content"] for x in l["conversation"] if x["role"] == "user" | ||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||
| elif "turns" in l: | ||||||||||||||||||||||||||||||||||||||||||||||||
| line_texts = l["turns"] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| for text in line_texts: | ||||||||||||||||||||||||||||||||||||||||||||||||
| text = text.strip() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -89,6 +92,21 @@ def get_embeddings(texts, model_name, batch_size): | |||||||||||||||||||||||||||||||||||||||||||||||
| return embeddings.cpu() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Support OpenAI Embedding | ||||||||||||||||||||||||||||||||||||||||||||||||
| def get_openai_embeddings(texts, model_name, batch_size): | ||||||||||||||||||||||||||||||||||||||||||||||||
| client = OpenAI() | ||||||||||||||||||||||||||||||||||||||||||||||||
| texts = texts.tolist() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| embeddings = [] | ||||||||||||||||||||||||||||||||||||||||||||||||
| for i in tqdm(range(0, len(texts), batch_size)): | ||||||||||||||||||||||||||||||||||||||||||||||||
| text = texts[i : i + batch_size] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
| responses = client.embeddings.create(input=text, model=model_name).data | ||||||||||||||||||||||||||||||||||||||||||||||||
| embeddings.extend([data.embedding for data in responses]) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| embeddings = torch.nn.functional.normalize(torch.tensor(embeddings), p=2, dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return embeddings.cpu() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def run_k_means(embeddings, num_clusters): | ||||||||||||||||||||||||||||||||||||||||||||||||
| np.random.seed(42) | ||||||||||||||||||||||||||||||||||||||||||||||||
| clustering_model = KMeans(n_clusters=num_clusters, n_init="auto") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -218,18 +236,36 @@ def get_cluster_info(texts, labels, topk_indices): | |||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--show-top-k", type=int, default=200) | ||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--show-cut-off", type=int, default=512) | ||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--save-embeddings", action="store_true") | ||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--embeddings-file", type=str, default=None) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this embedding cache is awesome! |
||||||||||||||||||||||||||||||||||||||||||||||||
| args = parser.parse_args() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| num_clusters = args.num_clusters | ||||||||||||||||||||||||||||||||||||||||||||||||
| show_top_k = args.show_top_k | ||||||||||||||||||||||||||||||||||||||||||||||||
| show_cut_off = args.show_cut_off | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| texts = read_texts( | ||||||||||||||||||||||||||||||||||||||||||||||||
| args.input_file, args.min_length, args.max_length, args.english_only | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"#text: {len(texts)}") | ||||||||||||||||||||||||||||||||||||||||||||||||
| if args.embeddings_file is None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| if args.model == "text-embedding-ada-002": | ||||||||||||||||||||||||||||||||||||||||||||||||
| texts = read_texts( | ||||||||||||||||||||||||||||||||||||||||||||||||
| args.input_file, args.min_length, args.max_length, args.english_only | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"#text: {len(texts)}") | ||||||||||||||||||||||||||||||||||||||||||||||||
| embeddings = get_openai_embeddings(texts, args.model, args.batch_size) | ||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"embeddings shape: {embeddings.shape}") | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| texts = read_texts( | ||||||||||||||||||||||||||||||||||||||||||||||||
| args.input_file, args.min_length, args.max_length, args.english_only | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"#text: {len(texts)}") | ||||||||||||||||||||||||||||||||||||||||||||||||
| embeddings = get_embeddings(texts, args.model, args.batch_size) | ||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"embeddings shape: {embeddings.shape}") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if args.model == "text-embedding-ada-002": | |
| texts = read_texts( | |
| args.input_file, args.min_length, args.max_length, args.english_only | |
| ) | |
| print(f"#text: {len(texts)}") | |
| embeddings = get_openai_embeddings(texts, args.model, args.batch_size) | |
| print(f"embeddings shape: {embeddings.shape}") | |
| else: | |
| texts = read_texts( | |
| args.input_file, args.min_length, args.max_length, args.english_only | |
| ) | |
| print(f"#text: {len(texts)}") | |
| embeddings = get_embeddings(texts, args.model, args.batch_size) | |
| print(f"embeddings shape: {embeddings.shape}") | |
| texts = read_texts( | |
| args.input_file, args.min_length, args.max_length, args.english_only | |
| ) | |
| print(f"#text: {len(texts)}") | |
| if args.model == "text-embedding-ada-002": | |
| embeddings = get_openai_embeddings(texts, args.model, args.batch_size) | |
| else: | |
| embeddings = get_embeddings(texts, args.model, args.batch_size) | |
| print(f"embeddings shape: {embeddings.shape}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it'd be even cleaner if we can merge get_embeddings and get_openai_embeddings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I just commited the new changes to merge the 2 embedding functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry could you explain this a bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, input json file has different format. For examples, the 54K prompt json file you sent me uses "turns" to store the conversation prompts. I added support for json files that uses "turns".