diff --git a/fastchat/serve/example_images/city.jpeg b/fastchat/serve/example_images/city.jpeg deleted file mode 100644 index d6f12601f..000000000 Binary files a/fastchat/serve/example_images/city.jpeg and /dev/null differ diff --git a/fastchat/serve/example_images/distracted.jpg b/fastchat/serve/example_images/distracted.jpg new file mode 100644 index 000000000..382c888a0 Binary files /dev/null and b/fastchat/serve/example_images/distracted.jpg differ diff --git a/fastchat/serve/example_images/fridge.jpeg b/fastchat/serve/example_images/fridge.jpeg deleted file mode 100644 index 88f5b370e..000000000 Binary files a/fastchat/serve/example_images/fridge.jpeg and /dev/null differ diff --git a/fastchat/serve/example_images/fridge.jpg b/fastchat/serve/example_images/fridge.jpg new file mode 100644 index 000000000..8ed943e8b Binary files /dev/null and b/fastchat/serve/example_images/fridge.jpg differ diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index b5c60b3c3..5ddf138e8 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -7,9 +7,11 @@ python3 -m fastchat.serve.gradio_web_server_multi --share --multimodal """ +import json import os import gradio as gr +import numpy as np from fastchat.serve.gradio_web_server import ( upvote_last_response, @@ -31,6 +33,12 @@ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") +def get_vqa_sample(): + random_sample = np.random.choice(vqa_samples) + question, path = random_sample["question"], random_sample["path"] + return question, path + + def clear_history_example(request: gr.Request): ip = get_ip(request) logger.info(f"clear_history_example. ip: {ip}") @@ -38,7 +46,9 @@ def clear_history_example(request: gr.Request): return (state, []) + (disable_btn,) * 5 -def build_single_vision_language_model_ui(models, add_promotion_links=False): +def build_single_vision_language_model_ui( + models, add_promotion_links=False, random_questions=None +): promotion = ( """ | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | @@ -103,8 +113,8 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False): ) max_output_tokens = gr.Slider( minimum=0, - maximum=1024, - value=512, + maximum=2048, + value=1024, step=64, interactive=True, label="Max output tokens", @@ -113,17 +123,23 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False): examples = gr.Examples( examples=[ [ - f"{cur_dir}/example_images/city.jpeg", - "What is unusual about this image?", + f"{cur_dir}/example_images/fridge.jpg", + "How can I prepare a delicious meal using these ingredients?", ], [ - f"{cur_dir}/example_images/fridge.jpeg", - "What is in this fridge?", + f"{cur_dir}/example_images/distracted.jpg", + "What might the woman on the right be thinking about?", ], ], inputs=[imagebox, textbox], ) + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🎲 Random Example", interactive=True) + with gr.Column(scale=8): chatbot = gr.Chatbot( elem_id="chatbot", label="Scroll down and start chatting", height=550 @@ -134,6 +150,7 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False): textbox.render() with gr.Column(scale=1, min_width=50): send_btn = gr.Button(value="Send", variant="primary") + with gr.Row(elem_id="buttons"): upvote_btn = gr.Button(value="👍 Upvote", interactive=False) downvote_btn = gr.Button(value="👎 Downvote", interactive=False) @@ -169,11 +186,12 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False): [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) - examples.dataset.click(clear_history_example, None, [state, chatbot] + btn_list) model_selector.change( clear_history, None, [state, chatbot, textbox, imagebox] + btn_list ) + imagebox.upload(clear_history_example, None, [state, chatbot] + btn_list) + examples.dataset.click(clear_history_example, None, [state, chatbot] + btn_list) textbox.submit( add_text, @@ -194,4 +212,11 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False): [state, chatbot] + btn_list, ) + if random_questions: + random_btn.click( + get_vqa_sample, # First, get the VQA sample + [], # Pass the path to the VQA samples + [textbox, imagebox], # Outputs are textbox and imagebox + ) + return [state, model_selector] diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py index 72d8aef75..884bcb7df 100644 --- a/fastchat/serve/gradio_web_server_multi.py +++ b/fastchat/serve/gradio_web_server_multi.py @@ -131,7 +131,9 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): with gr.Tab("Vision Direct Chat", id=3, visible=args.multimodal): single_vision_language_model_list = ( build_single_vision_language_model_ui( - vl_models, add_promotion_links=True + vl_models, + add_promotion_links=True, + random_questions=args.random_questions, ) ) @@ -202,6 +204,9 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): parser.add_argument( "--multimodal", action="store_true", help="Show multi modal tabs." ) + parser.add_argument( + "--random-questions", type=str, help="Load random questions from a JSON file" + ) parser.add_argument( "--register-api-endpoint-file", type=str,