Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,24 @@
# ChatServer
chatbot server


## Train Alpaca with SkyPilot
1. Install skypilot and setup the credentials locally following the instructions [here](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html)
2. Launch the training job with the following line (will be launched on a single node with 4 A100-80GB GPUs)
```
# WANDB API KEY is required for logging. We use the key in your local environment.
sky launch -c alpaca -s scripts/train.yaml --env WANDB_API_KEY
```
Or use spot (not managed).
```
sky launch -c alpaca-spot -s --use-spot scripts/train.yaml --env WANDB_API_KEY
```
**The following still does not work at the moment as Alpaca code does not support multiple nodes.**
We can also launch the training job with multiple nodes and different number of GPUs. We will automatically adapt the
gradient accumulation steps to the setting (Supported max number of #nodes * #GPUs per node = 32)
```
sky launch -c alpaca-2 -s --num-nodes 2 --gpus A100-80GB:8 scripts/train.yaml --env WANDB_API_KEY
```
Managed spot version TO BE ADDED.


99 changes: 99 additions & 0 deletions scripts/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
resources:
accelerators: A100-80GB:4

num_nodes: 1

file_mounts:
/artifacts:
name: skypilot-chatbot # Change to your own bucket
store: gcs
mode: MOUNT
/data:
name: skypilot-chatbot-data # Change to your own bucket
store: gcs
mode: MOUNT
~/llama: gs://llama-7b/llama-ckpt # Change to your own bucket with the LLaMA weights

setup: |
# Download the model weights if not exits
# mkdir -p $HOME/llama
# if [ ! -f /artifacts/llama-hf/llama-7B/complete ]; then
# if [ ! -f $HOME/llama/complete ]; then
# bash download.sh $LLAMA_URL 7B $HOME/llama &
# fi
# fi

# Download the data
if [ ! -f /data/alpaca-data.json ]; then
wget https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json -O /data/alpaca-data.json
fi

# Setup the environment
conda create -n chatbot python=3.10 -y
conda activate chatbot

# Install pytorch
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116

# Install huggingface with the LLaMA commit
git clone https://github.com/huggingface/transformers.git
cd transformers
git checkout 60d51ef # pin to latest commit
pip install .
cd -

# Install alpaca
git clone https://github.com/tatsu-lab/stanford_alpaca.git
cd stanford_alpaca
git checkout eb5b171 # pin to latest commit
pip install -r requirements.txt
cd -

# wait
# touch $HOME/llama/complete

mkdir -p /artifacts/llama-hf/llama-7B
if [ ! -f /artifacts/llama-hf/llama-7B/complete ]; then
cd transformers
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir $HOME/llama \
--model_size 7B \
--output_dir ~/hf-output || exit 1
mv ~/hf-output/llama-7b/* ~/hf-output/
mv ~/hf-output/tokenizer/* ~/hf-output/
cp -r ~/hf-output/* /artifacts/llama-hf/llama-7B
touch /artifacts/llama-hf/llama-7B/complete
fi

run: |
cd stanford_alpaca
conda activate chatbot
NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
torchrun \
--nnodes=$NUM_NODES \
--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
--master_port=12355 \
train.py \
--model_name_or_path /artifacts/llama-hf/llama-7B \
--data_path /data/alpaca-data.json \
--bf16 True \
--output_dir /artifacts/chatbot/7b/ckpt \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps $((32 / NUM_NODES / SKYPILOT_NUM_GPUS_PER_NODE)) \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True