Self-RAG——自我“PUA”的RAG
date
Mar 10, 2025
slug
self-rag
status
Published
tags
Tech
summary
a note on rag paper
type
Newsletter
背景
由于大模型的训练数据具有一定的时间滞后性,使得大模型的“记忆”中不包含近期发生的事情、信息。为了解决这个问题,人们提出了所谓RAG的概念。但是对于大模型来说,并不是每个用户的任务都需要从数据库中获取相关数据,对于这些任务来说,从数据库拉取数据作为上下文交给大模型处理可能会平白无故地导致大模型效果变差。当然这只是一方面,即使有些任务确实需要外部数据,但是从外部获取的数据依然可能会误导大模型(如果拉取外部数据的“姿势”不对的话。。。)
因此,该工作认为,能不能设计一个新的RAG系统,使其能在serving时对自己拉取的外部资料进行反思,根据反思结果不断迭代,直到得到一个相对合理的结果后收敛。Self-RAG应运而生。
Self-RAG
首先Self-RAG的工作基于这样一个认识:
当前的LLM并没有针对RAG场景进行训练过,怎么保证给大模型一段外部材料,LLM就会按照外部资料中事实性信息来作答呢?
因此,Self-RAG决定fine-tuning大模型,使其能更好利用Retrieve的外部信息。在记录如何fine-tuning大模型之前,我们先介绍一下Self-RAG如何处理用户请求的。
Self-RAG Serving
那么如何让在模型在推理的过程中自我反思呢?一个很直观的思路就是在模型的回答中加入一些特殊的token(可以通过扩展模型的词汇表来达到这个目的),其中不同的token对应着模型对于该回答的一个评价结果,基于这样的评价结果,模型继续修改自己的回答,以达到更好的效果。具体的token定义如下:
有了这些基本的token之后,我们简单叙述Self-RAG serving流程:
我们用x表示用户输入,用y(t)表示LLM输出的第t个segment (第t次迭代输出的segment)
- LLM基于x、y(t)预测Retrieve token
- Retrieve token为No或者continue,则不获取外部数据直接生成y(t+1),并根据生成的结果使用LLM生成IsUse token。如果Retrieve token为yes,则从外部数据库中获取relative passages。
- 将relative passages划分成一个个chunk,对于每个chunk并行调用LLM生成IsRel token(input: x + chunk)和y(t+1)(input: x + chunk + y(<t+1))。
- 使用LLM基于x、chunk、y(t+1)预测IsSup和IsUse Token。
- 基于IsRel、IsSup、IsUse对并行生成的y(t+1)进行排序,然后选择Rank最高的作为y(t+1)进入下一轮迭代。
那么如何对y(t+1)排序呢?
论文中提到对生成的多个y(t+1)进行Beam Search,我很是疑惑,Beam Search应该是一个类似图或者树形结构优化算法,对于这种评分都有的情况,难道不是遍历一遍就能找到最好吗?
个人理解:Beam Search需要一个grade function(见下图)
所谓Beam Search会不会就是当IsRel的得分已经很差的时候直接排除对应的这个y(t+1),这样就不需要将y(t+1)的所有Special Token都计算出来了
从上面Beam Search的算法中我们也可以直到,w(i)其实是用户可以调整的超参数,用户可以根据自己对模型输出的要求动态调整对应的超参数,在不重新训练模型的情况下改变LLM处理逻辑。
类似的还有Retrieve Token的设置,也存在类似的超参数设置机制
Self-RAG的推理过程还算简单,但是对LLM能力的要求却不低。那么我们怎么fine-tuning LLM达到这样的效果呢?
Self-RAG Training
我们需要训练两个模型:the critic C和the generator M
训练generator M
M需要运用到以下场景:
- 基于x、y(t)预测Retrieve token
- 基于relative passage和用户输入预测IsRel
- 基于relative passage和用户输入和y(<t+1)预测y(t+1)
- 基于x、chunk、y(t+1)预测IsSup和IsUse Token
针对每个运用场景,该工作构造了合适的prompt使大模型正常处理对应的情况。同时,该工作使用critic模型生成了fine tuning所需要的数据集。形式如下所示:
这些数据的产生流程如下(原先的数据集包含问题和对应答案):
- 首先让critic模型预测Retrieve结果
- 如果Retrieve结果为No,则直接基于output预测IsUse
- 否则从外部数据库中获取对应的passages
- 使用Spacy工具对问题答案进行chunk,得到多个segment
- 对于每个segment,给定用户输入,之前已经出现过的segments以及之前获取的passages,让critic模型预测该segment是否需要Retrieve。
- 如果不需要Retrieve,则不向该segment中插入<paragraph>
- 否则,我们将passages分成多块,对每个passage,我们预测IsRel和IsSup,最后综合得到一个评分,选择评分最高的passage作为这个segment对应的<paragraph>
- 在最后,生成对应的IsUse
上面我们讲了M Model如何训练,那么critic模型怎么训练呢?
训练critic C
这里采用了知识蒸馏的方式训练critic模型。
首先给定一些问答集合构成的训练集,该工作使用GPT-4重新清洗了这些训练数据。该工作使用GPT-4针对这些数据生成对应的Retrieve Token、IsUse Token等等。使用这些数据去训练M。
使用这种训练方式一方面避免了人工标注带来的巨大成本,同时避免了直接使用GPT4带来了的高额的API费用。
总结:工作思路比较清晰,解决的问题我认为也确实是RAG场景中存在的,比如为什么问题和relative passage的embedding在向量空间是接近的?既然这样的假设不一定成立的话,那么使用向量近似度从vector database拉取到的数据为什么一定能否起到加强回答的作用呢?因此本工作给出了新的思路,让LLM具备反思能力不就好了吗?美中不足的是,该工作的测试并不能很好地使我信服(但是并不是说这个方法就是不好的),同时该工作中需要重新训练LLM本身也就让这篇工作的effort开销巨大。但是不可否认的是这确实是我认为的相对solid的工作了。