Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions fastchat/serve/gradio_block_arena_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"""

import os
import json
import numpy as np

import gradio as gr

Expand All @@ -31,14 +33,24 @@
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")


def get_vqa_sample(example_json: gr.components.Textbox):
with open(example_json, "r") as f:
vqa_samples = json.load(f)
random_sample = np.random.choice(vqa_samples)
question, path = random_sample["question"], random_sample["path"]
return question, path


def clear_history_example(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history_example. ip: {ip}")
state = None
return (state, []) + (disable_btn,) * 5


def build_single_vision_language_model_ui(models, add_promotion_links=False):
def build_single_vision_language_model_ui(
models, add_promotion_links=False, random_questions=None
):
promotion = (
"""
| [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
Expand Down Expand Up @@ -140,9 +152,8 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False):
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)

if add_promotion_links:
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
if random_questions:
random_btn = gr.Button(value="🎲 Random", interactive=True)

# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
Expand Down Expand Up @@ -194,4 +205,12 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False):
[state, chatbot] + btn_list,
)

if random_questions:
questions_textbox = gr.Textbox(value=random_questions, visible=False)
random_btn.click(
get_vqa_sample, # First, get the VQA sample
questions_textbox, # Pass the path to the VQA samples
[textbox, imagebox], # Outputs are textbox and imagebox
)

return [state, model_selector]
7 changes: 6 additions & 1 deletion fastchat/serve/gradio_web_server_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
with gr.Tab("Vision Direct Chat", id=3, visible=args.multimodal):
single_vision_language_model_list = (
build_single_vision_language_model_ui(
vl_models, add_promotion_links=True
vl_models,
add_promotion_links=True,
random_questions=args.random_questions,
)
)

Expand Down Expand Up @@ -202,6 +204,9 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
parser.add_argument(
"--multimodal", action="store_true", help="Show multi modal tabs."
)
parser.add_argument(
"--random-questions", type=str, help="Load random questions from a JSON file"
)
parser.add_argument(
"--register-api-endpoint-file",
type=str,
Expand Down
1 change: 1 addition & 0 deletions vqav2_questions.json

Large diffs are not rendered by default.