forked from starsuzi/Adaptive-RAG
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_gpt_baseline_trivia.sh
More file actions
executable file
·54 lines (47 loc) · 2.17 KB
/
run_gpt_baseline_trivia.sh
File metadata and controls
executable file
·54 lines (47 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#!/usr/bin/env bash
# Generate GPT baseline predictions for TriviaQA (nor_qa, oner_qa, ircot_qa).
# Outputs go to predictions/test/ so evaluate_selfrag_e2e.py --model gpt --dataset trivia
# can compare GPT baseline vs GPT+Self-RAG.
#
# Prerequisites:
# - OPENAI_API_KEY set in the environment
# - .llm_server_address.jsonnet present (used by predict.py; not used when config is GPT)
# - .retriever_address.jsonnet present and retriever running (required for oner_qa and ircot_qa)
# - processed_data/trivia/test_subsampled.jsonl exists
#
# Usage: ./run_gpt_baseline_trivia.sh [LLM_PORT_NUM]
# Default LLM_PORT_NUM=8010
set -e
LLM_PORT="${1:-8010}"
DATASET="trivia"
MODEL="gpt"
if [[ -z "${OPENAI_API_KEY:-}" ]]; then
echo "Error: OPENAI_API_KEY is not set. Export it before running."
exit 1
fi
if [[ ! -f .llm_server_address.jsonnet ]]; then
echo "Error: .llm_server_address.jsonnet not found. Create it (e.g. {\"host\": \"http://localhost\", \"port\": \"8010\"})."
exit 1
fi
if [[ ! -f .retriever_address.jsonnet ]]; then
echo "Error: .retriever_address.jsonnet not found. Retriever is required for oner_qa and ircot_qa."
exit 1
fi
if [[ ! -f "processed_data/${DATASET}/test_subsampled.jsonl" ]]; then
echo "Error: processed_data/${DATASET}/test_subsampled.jsonl not found."
exit 1
fi
echo ">>>> Generating GPT baseline predictions for TriviaQA (LLM port ${LLM_PORT}) <<<<"
for SYSTEM in nor_qa oner_qa ircot_qa; do
echo ""
echo ">>> ${SYSTEM} ${MODEL} ${DATASET}"
echo ">>>> Write configs <<<<"
python runner.py "${SYSTEM}" "${MODEL}" "${DATASET}" write --prompt_set 1 --llm_port_num "${LLM_PORT}"
echo ">>>> Predict on test set <<<<"
python runner.py "${SYSTEM}" "${MODEL}" "${DATASET}" predict --prompt_set 1 --eval_test --llm_port_num "${LLM_PORT}"
done
echo ""
echo "Done. Baseline prediction paths (for evaluate_selfrag_e2e.py --model gpt --dataset trivia):"
echo " predictions/test/nor_qa_gpt_trivia____prompt_set_1/"
echo " predictions/test/oner_qa_gpt_trivia____prompt_set_1___bm25_retrieval_count__15___distractor_count__1/"
echo " predictions/test/ircot_qa_gpt_trivia____prompt_set_1___bm25_retrieval_count__6___distractor_count__1/"