Skip to content

ResearchRAG/aipm-self-rag

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

45 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

自我反思检索增强生成模型(SELF-RAG):通过自我反思学习检索、生成和批评

这包括了自我反思检索增强生成模型(SELF-RAG):通过自我反思学习检索、生成和批评(ICLR 2024,口头报告前1%)的原始实现,作者为Akari Asai、Zeqiu Wu、Yizhong Wang、Avirup Sil和Hannaneh Hajishirzi。

网站 | 7B模型 | 13B模型 | 论文 | 训练数据 | 推特摘要 | 更新

自我反思检索增强生成模型(Self-RAG)(图右)是一个新框架,用于训练任意大型语言模型(LM)学习检索、生成和批评,以提高生成事实性和质量,同时不损害大型语言模型(LLM)的多功能性。

与传统的检索增强生成(RAG;图左)方法不同,自我反思检索增强生成模型根据多样化的查询按需检索(例如,可以多次检索或完全跳过检索),并通过预测反思标记从多个细粒度方面批评自己的生成,将反思标记作为生成的一个组成部分。

我们进行逐段的束搜索,以选择最大化多样化偏好的效用的输出。

如果你发现我们的代码、数据、模型或论文有用,请引用论文:

@inproceedings{
asai2024selfrag,
author={Asai, Akari and Wu, Zeqiu and Wang, Yizhong and Sil, Avirup and Hajishirzi, Hannaneh},
title={Self-{RAG}: Learning to Retrieve, Generate, and Critique through Self-Reflection},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=hSyW5go0v8} 
}

更新

  • 2023.10:代码、模型和论文的初始发布。

内容

  1. 安装
  2. 快速开始
  3. 检索器设置
  4. 训练
  5. 推理
  6. 基线
  7. 常见问题解答
  8. 联系方式

安装

通过运行下面的命令安装依赖的Python库。

pip install -r requirements.txt

请使用最新版本的vllm,因为旧版本可能无法通过SamplingParam设置skip_special_tokens,这是通过(this PR)添加的。

你也可以通过运行下面的命令创建一个conda环境。

conda env create -f environment.yml

快速开始

你可以从HuggingFace Hub下载Self-RAG。对于推理,我们建议使用vllm,因为它显著加快了推理速度。

from vllm import LLM, SamplingParams

model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

def format_prompt(input, paragraph=None):
  prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
  if paragraph is not None:
    prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
  return prompt

query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]

# 对于不需要检索的查询
preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
  print("Model prediction: {0}".format(pred.outputs[0].text))

输出:

Model prediction: Twitter, Instagram, and WhatsApp are all social media platforms. [No Retrieval]WhatsApp is the odd one out because it is a messaging app, while Twitter and # Instagram are primarily used for sharing photos and videos.[Utility:5]</s>
Model prediction: Sure![Retrieval]<paragraph><paragraph>

如你所见,当第一个查询不需要检索时,Self-RAG开始生成不包含检索的响应。另一方面,对于第二个查询,Self-RAG输出了[Retrieve]标记,因为这个问题需要更细粒度的事实基础。

对于需要事实基础的查询,你可以插入一个段落。Self-RAG可以在生成过程中随时检索和插入段落,并在它们被上下文标记特殊标记<paragraph></paragraph>包围时识别它们。

# 对于需要事实基础的查询
prompt = format_prompt("Can you tell me the difference between llamas and alpacas?", "The alpaca (Lama pacos) is a species of South American camelid mammal. It is similar to, and often confused with, the llama. Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.")
preds = model.generate([prompt], sampling_params)
print([pred.outputs[0].text for pred in preds])
# ['[Relevant]Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.[Fully supported][Utility:5]</s>']

Self-RAG找到了相关插入的文档,并生成了完全由证据支持的答案。

请注意,这个演示使用了较小的语料库和具有完整推理算法的Self-RAG。对于完整评估,你需要设置检索器或下载我们检索的结果。请按照推理中的说明操作。

检索器设置

默认情况下,我们使用Contriever作为我们的检索组件。

下载数据

下载DPR使用的预处理段落数据。

cd retrieval_lm
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz 

然后,下载生成的段落。我们使用Contriever-MSMARCO

wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar 

运行检索器

你可以通过运行下面的命令来运行段落检索。

cd retrieval_lm
python passage_retrieval.py \
    --model_name_or_path facebook/contriever-msmarco --passages psgs_w100.tsv \
    --passages_embeddings "wikipedia_embeddings/*" \
    --data YOUR_INPUT_FILE  \
    --output_dir YOUR_OUTPUT_FILE \
    --n_docs 20

你的输入文件应该是一个jsonjsonl。每个实例必须包含questioninstruction,这些将在检索期间作为查询使用。

为你自己的数据生成嵌入

你可以通过运行以下命令为你自己的数据生成嵌入。(该脚本改编自Contriever仓库。)请注意,从大规模语料库(>10M文档)生成嵌入可能需要时间,我们建议在多个GPU上运行。

cd retrieval_lm
for i in {0..3}; do
  export CUDA_VISIBLE_DEVICES=${i}
  python generate_passage_embeddings.py  --model_name_or_path facebook/contriever-msmarco \
  --output_dir YOUR_OUTPUT_DIR \
  --passages YOUR_PASSAGE_DATA --shard_id ${i}  --num_shards 4 > ./log/nohup.my_embeddings.${i} 2>&1 &

训练

自我反思检索增强生成模型训练两个模型,CriticGenerator,这两个模型都通过标准下一个标记预测目标扩展标记词汇表。

或者,你可以下载我们包含150K实例的训练数据这里

收集反思标记

我们从GPT-4收集训练数据。调用GPT-4的脚本每种特殊标记类型都在data_creation/critic

或者,你可以在这里下载我们的训练数据。

Critic训练

一旦你创建或下载了训练数据,运行下面的命令对Llama2-7B进行Critic训练。

cd data_creation
torchrun --nproc_per_node=2 \
  --master_port=2568 train_special_tokens.py \
  --model_name_or_path meta-llama/Llama-2-7b-hf \
  --data_path PATH_TO_TRAIN_DATA_FILE \
  --bf16  True \
  --output_dir PATH_TO_CRITIC_MODEL \
  --num_train_epochs 3  \
  --per_device_train_batch_size 1 --per_device_eval_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --evaluation_strategy "no" \
  --save_strategy "steps" \
  --save_steps 300 \
  --save_total_limit 1 \
  --learning_rate 2e-5 \
  --weight_decay 0. \
  --warmup_ratio 0.01 \
  --lr_scheduler_type "cosine" \
  --logging_steps 10 \
  --fsdp "full_shard auto_wrap"

生成器数据创建

生成器训练数据的代码在data_creation/generator下。请参阅README.md中的说明。

或者,你可以在HuggingFace dataset这里下载我们的训练数据

生成器训练

对于生成器训练,我们使用DeepSpeed使训练更高效。设置训练数据路径后,通过运行下面的脚本来运行训练。

cd retrieval_lm
bash script_finetune_7b.sh

对于13B模型训练,使用training_13b。我们使用8 A100与40 GRAM进行7B模型训练,使用4 a100与80 GB GRAM进行13B训练。7B应该适合1-2 A100,尽管训练可能会很慢。

推理

对于论文中进行的任务评估,请下载数据这里

每个文件已经附带了由Contriever离线检索的文档,所以如果你不想在推理过程中运行检索器,你可以直接在contexts加载检索的文档。

下面,我们描述了Self-RAG和基线。

  • 短文本:运行短文本生成的评估。
  • 长文本:运行长文本生成的评估。

短文本(PubHealth、ARC-Challenge、TriviaQA、PopQA)

由于我们通常只对短文本生成任务检索一次,我们提供了一个易于运行的评估脚本,利用Contriever预先检索的文档。请参阅下面的个别命令。

问答

python run_short_form.py \
--model_name selfrag/selfrag_llama2_7b \
--input_file eval_data/popqa_longtail_w_gs.jsonl \
--mode MODE --max_new_tokens 100 \
--threshold 0.2 \
--output_file YOUR_OUTPUT_FILE \
--metric match --ndocs 10 --use_groundness --use_utility --use_seqscore \
--dtype half

mode指定了推理时间模式中的['adaptive_retrieval', 'no_retrieval', 'always_retrieve']

  • adaptive_retrieval根据threshold或Self-RAG预测进行检索
  • no_retrieval在推理时禁用检索
  • always_retrieve总是检索。

对于13B,如果你使用单个GPU与24 GRAM,可能会有OOM问题。你可以通过设置--world_size在多个GPU上运行推理。

ARC Challenge

python run_short_form.py \
  --model_name selfrag/selfrag_llama2_7b \
  --input_file eval_data/arc_challenge_processed.jsonl \
  --max_new_tokens 50 --threshold 0.2 \
  --output_file OUTPUT_FILE_NAME \
  --metric match --ndocs 5 --use_groundness --use_utility --use_seqscore \
  --task arc_c

PubHealth

python run_short_form.py \
  --model_name selfrag/selfrag_llama2_7b \
  --input_file eval_data/health_claims_processed.jsonl \
  --max_new_tokens 50 \
  --threshold 0.2 --output_file OUTPUT_FILE_NAME \
  --metric match --ndocs 5 \
  --use_groundness --use_utility --use_seqscore \
  --task fever

长文本(ASQA、FactScore)

对于长文本QA,你可以使用检索模型进行评估,也可以使用预先给定的段落进行评估。 目前,我们正在努力减少运行时内存需求(DPR / Contriever使用整个英文维基百科嵌入需要100 GB RAM)加速长文本生成,并发布使用小批量初始检索文档的推理代码(~20)。

请注意:我们当前的实现专门针对目标任务数据集的评估。我们计划更新我们的代码库,使界面更简单、更易于使用。我们会在发布另一个版本时宣布。

使用预先检索的段落运行推理

对于ASQA,请运行以下命令,

python run_long_form_static.py \
  --model_name selfrag/selfrag_llama2_7b \
  --ndocs 5 --max_new_tokens 300 --threshold 0.2 \
  --use_grounding --use_utility --use_seqscore \
  --task asqa --input_file eval_data/asqa_eval_gtr_top100.json \
  --output_file YOUR_OUTPUT_FILE_NAME --max_depth 7 --mode always_retrieve \

对于FactScore,

python run_long_form_static.py \
  --model_name selfrag/selfrag_llama2_7b \
  --ndocs 5 --max_new_tokens 300 --threshold 0.2 \
  --use_grounding --use_utility --use_seqscore \
  --task factscore --input_file eval_data/factscore_unlabeled_alpaca_13b_retrieval.jsonl \
  --output_file YOUR_OUTPUT_FILE_NAME --max_depth 7 \

长文本生成的关键参数

有几个关键参数与Self-RAG的推理有关。

  • w_rel(默认1.0):w_rel在束搜索中控制对isRel(关于检索段落是否相关的批评标记)标记概率的强调。
  • w_sup(默认1.0):w_sup在束搜索中控制对isSup(关于生成是否由文档支持的批评标记)标记概率的强调。
  • w_use(默认0.5):w_use在束搜索中控制对isUse(关于整体质量的批评标记)标记概率的强调。
  • threshold(默认0.2):这个阈值控制自适应检索的频率。
  • max_depth(默认6):这对应于论文中的T,它定义了搜索的最大深度。
  • beam_width(默认2):这控制段级束搜索中的束大小。

有关更多详细信息,请参阅我们论文中的详细说明(第3节)和分析(第5节)。

运行评估

对于长文本评估,设置外部库或存储库以运行评估。

  • factscore==v0.1.5(生物) 请按照FactScore官方存储库的说明设置你的环境。
python -m factscore.factscorer --data_path YOUR_OUTPUT_FILE  --model_name retrieval+ChatGPT --cache_dir YOUR_CACHE_DIR --openai_key YOUR_OPEN_AI_KEY --verbose

ALCE为长文本QA提供了使用多种不同指标的综合评估。对于你的第一次评估,安装ALCE存储库并下载数据。

git clone https://github.com/princeton-nlp/ALCE.git 
python3 -m alce_env
cd ALCE
bash download_data.sh

对于ASQA,你可以按如下方式运行评估。请注意,ASQA评估需要T5-XXL(11B)基于NLI模块。

python eval.py --f YOUR_OUTPUT_FILE --citations --qa --mauve

基线

重新运行基线的代码可在run_baseline_lm.py。 要运行检索增强基线,请确保下载带有检索段落的任务输入文件。

纯LM基线

  • Huggingface模型
python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
 --max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa --prompt_name "prompt_no_input"

例如,PubHealth

python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/health_claims_processed.jsonl \
--max_new_tokens 20 \
--metric accuracy \
--result_fp llama2_7b_pubhealth_results.json \
--task fever

注意:对于PubHealth和ARC,请传递任务名称(ARC = arc_c 和 PubHealth = fever)以正确设置指令。

  • OpenAI API

对于OpenAI API模型,你还需要在这里设置组织密钥这里。 你还需要有一个包含你的OpenAI API密钥的txt文件。

python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
 --task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--prompt_name "prompt_no_input"

检索增强基线

  • Huggingface模型
python run_baseline_refactor.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
 --max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"
  • OpenAI API
python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
 --task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"

常见问题解答

Q1: 我如何使用Self-RAG方案训练一个新的预训练LM? -- 如果你使用hugging face transformers,你可以简单地在训练脚本中更改model_name_or_pathtokenizer_namescript_finetune_7b.sh。如果你想使用你自己的微调脚本,请确保添加特殊标记并遮蔽段落上下文,如这个问题中讨论的

Q2: 你们计划发布基于Mirstral-7B的Self-RAG吗? -- 现在,我有限的带宽无法做到这一点,但有一个社区训练版本的Self-RAG SciPhi-Self-RAG-Mistral-7B-32k 在Mistral-7B之上。如果我们能在Mistral-7B上训练Self-RAG并发布检查点,我们会宣布。

联系方式

如果你有问题,请提出一个提及@AkariAsai的问题,或发送电子邮件至akari[at]cs.washington.edu。

About

This includes the original implementation of SELF-RAG: Learning to Retrieve, Generate and Critique through self-reflection

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 99.1%
  • Shell 0.9%