1616from sklearn .cluster import KMeans , AgglomerativeClustering
1717import torch
1818from tqdm import tqdm
19+ from openai import OpenAI
1920
2021from fastchat .utils import detect_language
2122
@@ -46,6 +47,8 @@ def read_texts(input_file, min_length, max_length, english_only):
4647 line_texts = [
4748 x ["content" ] for x in l ["conversation" ] if x ["role" ] == "user"
4849 ]
50+ elif "turns" in l :
51+ line_texts = l ["turns" ]
4952
5053 for text in line_texts :
5154 text = text .strip ()
@@ -77,14 +80,26 @@ def read_texts(input_file, min_length, max_length, english_only):
7780
7881
7982def get_embeddings (texts , model_name , batch_size ):
80- model = SentenceTransformer (model_name )
81- embeddings = model .encode (
82- texts ,
83- batch_size = batch_size ,
84- show_progress_bar = True ,
85- device = "cuda" ,
86- convert_to_tensor = True ,
87- )
83+ if model_name == "text-embedding-ada-002" :
84+ client = OpenAI ()
85+ texts = texts .tolist ()
86+
87+ embeddings = []
88+ for i in tqdm (range (0 , len (texts ), batch_size )):
89+ text = texts [i : i + batch_size ]
90+ responses = client .embeddings .create (input = text , model = model_name ).data
91+ embeddings .extend ([data .embedding for data in responses ])
92+ embeddings = torch .tensor (embeddings )
93+ else :
94+ model = SentenceTransformer (model_name )
95+ embeddings = model .encode (
96+ texts ,
97+ batch_size = batch_size ,
98+ show_progress_bar = True ,
99+ device = "cuda" ,
100+ convert_to_tensor = True ,
101+ )
102+
88103 embeddings = torch .nn .functional .normalize (embeddings , p = 2 , dim = 1 )
89104 return embeddings .cpu ()
90105
@@ -218,6 +233,8 @@ def get_cluster_info(texts, labels, topk_indices):
218233 )
219234 parser .add_argument ("--show-top-k" , type = int , default = 200 )
220235 parser .add_argument ("--show-cut-off" , type = int , default = 512 )
236+ parser .add_argument ("--save-embeddings" , action = "store_true" )
237+ parser .add_argument ("--embeddings-file" , type = str , default = None )
221238 args = parser .parse_args ()
222239
223240 num_clusters = args .num_clusters
@@ -229,7 +246,15 @@ def get_cluster_info(texts, labels, topk_indices):
229246 )
230247 print (f"#text: { len (texts )} " )
231248
232- embeddings = get_embeddings (texts , args .model , args .batch_size )
249+ if args .embeddings_file is None :
250+ embeddings = get_embeddings (texts , args .model , args .batch_size )
251+ if args .save_embeddings :
252+ # allow saving embedding to save time and money
253+ torch .save (embeddings , "embeddings.pt" )
254+ else :
255+ embeddings = torch .load (args .embeddings_file )
256+ print (f"embeddings shape: { embeddings .shape } " )
257+
233258 if args .cluster_alg == "kmeans" :
234259 centers , labels = run_k_means (embeddings , num_clusters )
235260 elif args .cluster_alg == "aggcls" :
@@ -249,7 +274,7 @@ def get_cluster_info(texts, labels, topk_indices):
249274 with open (filename_prefix + "_topk.txt" , "w" ) as fout :
250275 fout .write (topk_str )
251276
252- with open (filename_prefix + "_all.txt " , "w" ) as fout :
277+ with open (filename_prefix + "_all.jsonl " , "w" ) as fout :
253278 for i in range (len (centers )):
254279 tmp_indices = labels == i
255280 tmp_embeddings = embeddings [tmp_indices ]
0 commit comments