Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 41 additions & 31 deletions fastchat/serve/monitor/basic_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,60 @@


NUM_SERVERS = 14
LOG_ROOT_DIR = "~/fastchat_logs"


def get_log_files(max_num_files=None):
dates = []
for month in range(4, 12):
for day in range(1, 33):
dates.append(f"2023-{month:02d}-{day:02d}")

log_root = os.path.expanduser(LOG_ROOT_DIR)
filenames = []
for d in dates:
for i in range(NUM_SERVERS):
name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
if os.path.exists(name):
filenames.append(name)
for i in range(NUM_SERVERS):
for filename in os.listdir(f"{log_root}/server{i}"):
if filename.endswith("-conv.json"):
filepath = f"{log_root}/server{i}/{filename}"
name_tstamp_tuple = (filepath, os.path.getmtime(filepath))
filenames.append(name_tstamp_tuple)
# sort by tstamp
filenames = sorted(filenames, key=lambda x: x[1])
filenames = [x[0] for x in filenames]

max_num_files = max_num_files or len(filenames)
filenames = filenames[-max_num_files:]
return filenames


def load_log_files(log_files):
def load_log_files(filename):
data = []
for filename in tqdm(log_files, desc="read files"):
for retry in range(5):
try:
lines = open(filename).readlines()
break
except FileNotFoundError:
time.sleep(2)

for l in lines:
row = json.loads(l)

data.append(
dict(
type=row["type"],
tstamp=row["tstamp"],
model=row.get("model", ""),
models=row.get("models", ["", ""]),
)
for retry in range(5):
try:
lines = open(filename).readlines()
break
except FileNotFoundError:
time.sleep(2)

for l in lines:
row = json.loads(l)
data.append(
dict(
type=row["type"],
tstamp=row["tstamp"],
model=row.get("model", ""),
models=row.get("models", ["", ""]),
)

)
return data


def load_log_files_parallel(log_files, num_threads=16):
data_all = []
from multiprocessing import Pool

with Pool(num_threads) as p:
ret_all = list(tqdm(p.imap(load_log_files, log_files), total=len(log_files)))
for ret in ret_all:
data_all.extend(ret)
return data_all


def get_anony_vote_df(df):
anony_vote_df = df[
df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"])
Expand All @@ -77,7 +87,7 @@ def merge_counts(series, on, names):


def report_basic_stats(log_files):
df_all = load_log_files(log_files)
df_all = load_log_files_parallel(log_files)
df_all = pd.DataFrame(df_all)
now_t = df_all["tstamp"].max()
df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)]
Expand Down
132 changes: 87 additions & 45 deletions fastchat/serve/monitor/clean_battle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"laion",
"chatglm",
"chatgpt",
"gpt-4",
"openai",
"anthropic",
"claude",
Expand All @@ -35,33 +36,26 @@
"lamda",
"google",
"llama",
"qianwan",
"alibaba",
"mistral",
"zhipu",
"KEG lab",
"01.AI",
"AI2",
"Tülu",
"Tulu",
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
"$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.",
"API REQUEST ERROR. Please increase the number of max tokens.",
"**API REQUEST ERROR** Reason: The response was blocked.",
"**API REQUEST ERROR**",
]

for i in range(len(IDENTITY_WORDS)):
IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()


def get_log_files(max_num_files=None):
dates = []
for month in range(4, 13):
for day in range(1, 33):
dates.append(f"2023-{month:02d}-{day:02d}")

filenames = []
for d in dates:
for i in range(NUM_SERVERS):
name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
if os.path.exists(name):
filenames.append(name)
max_num_files = max_num_files or len(filenames)
filenames = filenames[-max_num_files:]
return filenames


def remove_html(raw):
if raw.startswith("<h3>"):
return raw[raw.find(": ") + 2 : -len("</h3>\n")]
Expand All @@ -76,29 +70,54 @@ def to_openai_format(messages):
return ret


def replace_model_name(old_name):
return (
old_name.replace("bard", "palm-2")
.replace("claude-v1", "claude-1")
.replace("claude-instant-v1", "claude-instant-1")
.replace("oasst-sft-1-pythia-12b", "oasst-pythia-12b")
)
def replace_model_name(old_name, tstamp):
replace_dict = {
"bard": "palm-2",
"claude-v1": "claude-1",
"claude-instant-v1": "claude-instant-1",
"oasst-sft-1-pythia-12b": "oasst-pythia-12b",
"claude-2": "claude-2.0",
}
if old_name in ["gpt-4", "gpt-3.5-turbo"]:
if tstamp > 1687849200:
return old_name + "-0613"
else:
return old_name + "-0314"
if old_name in replace_dict:
return replace_dict[old_name]
return old_name


def clean_battle_data(log_files, exclude_model_names):
def read_file(filename):
data = []
for filename in tqdm(log_files, desc="read files"):
for retry in range(5):
try:
lines = open(filename).readlines()
break
except FileNotFoundError:
time.sleep(2)

for l in lines:
row = json.loads(l)
if row["type"] in VOTES:
data.append(row)
for retry in range(5):
try:
# lines = open(filename).readlines()
for l in open(filename):
row = json.loads(l)
if row["type"] in VOTES:
data.append(row)
break
except FileNotFoundError:
time.sleep(2)
return data


def read_file_parallel(log_files, num_threads=16):
data_all = []
from multiprocessing import Pool

with Pool(num_threads) as p:
ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files)))
for ret in ret_all:
data_all.extend(ret)
return data_all


def clean_battle_data(
log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False
):
data = read_file_parallel(log_files, num_threads=16)

convert_type = {
"leftvote": "model_a",
Expand All @@ -112,6 +131,7 @@ def clean_battle_data(log_files, exclude_model_names):
ct_anony = 0
ct_invalid = 0
ct_leaked_identity = 0
ct_banned = 0
battles = []
for row in data:
if row["models"][0] is None or row["models"][1] is None:
Expand Down Expand Up @@ -158,7 +178,9 @@ def clean_battle_data(log_files, exclude_model_names):
messages = ""
for i in range(2):
state = row["states"][i]
for role, msg in state["messages"][state["offset"] :]:
for turn_idx, (role, msg) in enumerate(
state["messages"][state["offset"] :]
):
if msg:
messages += msg.lower()
for word in IDENTITY_WORDS:
Expand All @@ -171,10 +193,9 @@ def clean_battle_data(log_files, exclude_model_names):
continue

# Replace bard with palm
models = [replace_model_name(m) for m in models]

models = [replace_model_name(m, row["tstamp"]) for m in models]
# Exclude certain models
if any(x in exclude_model_names for x in models):
if exclude_model_names and any(x in exclude_model_names for x in models):
ct_invalid += 1
continue

Expand All @@ -188,8 +209,16 @@ def clean_battle_data(log_files, exclude_model_names):

ip = row["ip"]
if ip not in all_ips:
all_ips[ip] = len(all_ips)
user_id = all_ips[ip]
all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": len(all_ips)}
all_ips[ip]["count"] += 1
if sanitize_ip:
user_id = f"arena_user_{all_ips[ip]['sanitized_id']}"
else:
user_id = f"{all_ips[ip]['ip']}"

if ban_ip_list is not None and ip in ban_ip_list:
ct_banned += 1
continue

# Save the results
battles.append(
Expand Down Expand Up @@ -218,12 +247,19 @@ def clean_battle_data(log_files, exclude_model_names):

print(
f"#votes: {len(data)}, #invalid votes: {ct_invalid}, "
f"#leaked_identity: {ct_leaked_identity}"
f"#leaked_identity: {ct_leaked_identity} "
f"#banned: {ct_banned} "
)
print(f"#battles: {len(battles)}, #anony: {ct_anony}")
print(f"#models: {len(all_models)}, {all_models}")
print(f"last-updated: {last_updated_datetime}")

if ban_ip_list is not None:
for ban_ip in ban_ip_list:
if ban_ip in all_ips:
del all_ips[ban_ip]
print("Top 30 IPs:")
print(sorted(all_ips.values(), key=lambda x: x["count"], reverse=True)[:30])
return battles


Expand All @@ -234,10 +270,16 @@ def clean_battle_data(log_files, exclude_model_names):
"--mode", type=str, choices=["simple", "conv_release"], default="simple"
)
parser.add_argument("--exclude-model-names", type=str, nargs="+")
parser.add_argument("--ban-ip-file", type=str)
parser.add_argument("--sanitize-ip", action="store_true", default=False)
args = parser.parse_args()

log_files = get_log_files(args.max_num_files)
battles = clean_battle_data(log_files, args.exclude_model_names or [])
ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None

battles = clean_battle_data(
log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip
)
last_updated_tstamp = battles[-1]["tstamp"]
cutoff_date = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
Expand Down
Loading