Skip to content

Commit 4a12a8e

Browse files
Category Hard Data Labeling Script (#3334)
1 parent 34eca62 commit 4a12a8e

File tree

1 file changed

+214
-0
lines changed

1 file changed

+214
-0
lines changed
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import argparse
2+
import json
3+
import pandas as pd
4+
import os
5+
import re
6+
import ast
7+
import time
8+
import concurrent.futures
9+
import tqdm
10+
import random
11+
import threading
12+
13+
LOCK = threading.RLock()
14+
15+
## Configs
16+
SYSTEM_PROMPT = "Your task is to evaluate how well the following input prompts can assess the capabilities of advanced AI assistants.\n\nFor the input prompt, please analyze it based on the following 7 criteria.\n1. Specificity: Does the prompt ask for a specific output, such as code, a mathematical solution, a logical simplification, a problem-solving strategy, or a hardware setup recommendation? This specificity allows the AI to demonstrate its ability to understand and generate precise responses.\n2. Domain Knowledge: Does the prompt cover a specific domain, such as programming, mathematics, logic, problem-solving, or hardware setup? Prompts spanning a range of topics test the AI's breadth of knowledge and its ability to apply that knowledge to different domains.\n3. Complexity: Does the prompt vary in complexity, from straightforward tasks to more complex, multi-step problems? This allows evaluators to assess the AI's capability to handle problems of varying difficulty.\n4. Problem-Solving Skills: Does the prompt directly involves the AI to demonstrate active problem-solving skills, such systemically coming up with a solution for a specific setup instead of regurgitating an existing fact? This tests the AI's ability to apply logical reasoning and provide practical solutions.\n5. Creativity: Does the prompt involve a level of creativity in approaching the problem? This criterion tests the AI's ability to provide tailored solutions that take into account the user's specific needs and limitations.\n6. Technical Accuracy: Does the prompt require technical accuracy in the response? This allows evaluators to assess the AI's precision and correctness in technical fields.\n7. Real-world Application: Does the prompt relate to real-world applications, such as setting up a functional system or writing code for a practical use case? This tests the AI's ability to provide practical and actionable information that could be implemented in real-life scenarios.\n\nYou must list the criteria numbers that the prompt satisfies in the format of a Python array. For example, \"[...]\". Do not explain your choice."
17+
18+
ENDPOINT_INFO = {
19+
"model_name": "META-LLAMA/LLAMA-3-70B-CHAT-HF",
20+
"name": "llama-3-70b-instruct",
21+
"endpoints": [{"api_base": "-", "api_key": "-"}],
22+
"parallel": 8,
23+
"temperature": 0.0,
24+
"max_token": 512,
25+
} # Modify this
26+
27+
TAGS = {
28+
1: "specificity",
29+
2: "domain_knowledge",
30+
3: "complexity",
31+
4: "problem_solving",
32+
5: "creativity",
33+
6: "technical_accuracy",
34+
7: "real_world",
35+
}
36+
37+
# API setting constants
38+
API_MAX_RETRY = 3
39+
API_RETRY_SLEEP = 10
40+
API_ERROR_OUTPUT = "$ERROR$"
41+
42+
43+
def get_endpoint(endpoint_list):
44+
if endpoint_list is None:
45+
return None
46+
assert endpoint_list is not None
47+
# randomly pick one
48+
api_dict = random.choices(endpoint_list)[0]
49+
return api_dict
50+
51+
52+
pattern = re.compile(r"(\[\d(?:\,\s\d)*\])")
53+
54+
55+
def get_score(judgment):
56+
matches = pattern.findall(judgment)
57+
matches = [m for m in matches if m != ""]
58+
if len(set(matches)) == 0:
59+
return []
60+
elif len(set(matches)) == 1:
61+
try:
62+
return ast.literal_eval(matches[0])
63+
except SyntaxError:
64+
print(matches[0])
65+
return []
66+
else:
67+
return []
68+
69+
70+
def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None):
71+
import openai
72+
73+
if api_dict:
74+
client = openai.OpenAI(
75+
base_url=api_dict["api_base"],
76+
api_key=api_dict["api_key"],
77+
)
78+
else:
79+
client = openai.OpenAI()
80+
81+
output = API_ERROR_OUTPUT
82+
for _ in range(API_MAX_RETRY):
83+
try:
84+
# print(messages)
85+
completion = client.chat.completions.create(
86+
model=model,
87+
messages=messages,
88+
temperature=temperature,
89+
max_tokens=max_tokens,
90+
# extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None,
91+
)
92+
output = completion.choices[0].message.content
93+
break
94+
except openai.RateLimitError as e:
95+
print(type(e), e)
96+
time.sleep(API_RETRY_SLEEP)
97+
except openai.BadRequestError as e:
98+
print(messages)
99+
print(type(e), e)
100+
break
101+
except openai.APIConnectionError as e:
102+
print(messages)
103+
print(type(e), e)
104+
time.sleep(API_RETRY_SLEEP)
105+
except openai.InternalServerError as e:
106+
print(messages)
107+
print(type(e), e)
108+
time.sleep(1)
109+
except KeyError:
110+
print(type(e), e)
111+
break
112+
113+
return output
114+
115+
116+
def get_answer(
117+
question: dict,
118+
max_tokens: int,
119+
temperature: float,
120+
answer_file: str,
121+
api_dict: dict,
122+
):
123+
conv = []
124+
conv.append({"role": "system", "content": SYSTEM_PROMPT})
125+
126+
conv.append({"role": "user", "content": question["prompt"]})
127+
output = chat_completion_openai(
128+
model=ENDPOINT_INFO["model_name"],
129+
messages=conv,
130+
temperature=temperature,
131+
max_tokens=max_tokens,
132+
api_dict=api_dict,
133+
)
134+
135+
criteria = get_score(output)
136+
137+
# Dump answers
138+
question["criteria_tag"] = {name: bool(i in criteria) for i, name in TAGS.items()}
139+
question.drop("prompt")
140+
141+
with LOCK:
142+
with open(answer_file, "a") as fout:
143+
fout.write(json.dumps(question.to_dict()) + "\n")
144+
145+
146+
if __name__ == "__main__":
147+
parser = argparse.ArgumentParser()
148+
parser.add_argument("--input-file", type=str, required=True)
149+
parser.add_argument("--cache-file", type=str, default=None)
150+
parser.add_argument("--output-file", type=str, required=True)
151+
parser.add_argument("--convert-to-json", action="store_true")
152+
args = parser.parse_args()
153+
154+
print("loading input data (might take min)")
155+
input_data = pd.read_json(args.input_file)
156+
print(f"{len(input_data)}# of input data just loaded")
157+
if args.cache_file:
158+
print("loading cache data")
159+
cache_data = pd.read_json(args.cache_file)
160+
print(f"{len(cache_data)}# of cache data just loaded")
161+
162+
assert "criteria_tag" in cache_data.columns and len(
163+
cache_data["criteria_tag"].dropna()
164+
) == len(cache_data)
165+
166+
not_labeled = input_data[
167+
~input_data["question_id"].isin(cache_data["question_id"])
168+
].copy()
169+
else:
170+
not_labeled = input_data.copy()
171+
172+
if os.path.isfile(args.output_file):
173+
print("loading existing output")
174+
output_data = pd.read_json(args.output_file, lines=True)
175+
print(f"{len(output_data)}# of existing output just loaded")
176+
177+
assert "criteria_tag" in output_data.columns and len(
178+
output_data["criteria_tag"].dropna()
179+
) == len(output_data)
180+
181+
not_labeled = not_labeled[
182+
~not_labeled["question_id"].isin(output_data["question_id"])
183+
]
184+
185+
print(f"{len(not_labeled)} needs to be labeled")
186+
187+
not_labeled["prompt"] = not_labeled.conversation_a.map(
188+
lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)])
189+
)
190+
191+
with concurrent.futures.ThreadPoolExecutor(
192+
max_workers=ENDPOINT_INFO["parallel"]
193+
) as executor:
194+
futures = []
195+
for index, row in tqdm.tqdm(not_labeled.iterrows()):
196+
future = executor.submit(
197+
get_answer,
198+
row,
199+
ENDPOINT_INFO["max_token"],
200+
ENDPOINT_INFO["temperature"],
201+
args.output_file,
202+
get_endpoint(ENDPOINT_INFO["endpoints"]),
203+
)
204+
futures.append(future)
205+
for future in tqdm.tqdm(
206+
concurrent.futures.as_completed(futures), total=len(futures)
207+
):
208+
future.result()
209+
210+
if args.convert_to_json:
211+
temp = pd.read_json(args.output_file, lines=True)
212+
temp.to_json(
213+
args.output_file[:-1], orient="records", indent=4, force_ascii=False
214+
)

0 commit comments

Comments
 (0)