为您的RAG管道构建基于LLM的重排序器¶
您的检索增强生成(RAG)管道是否正在为不相关的搜索结果而烦恼?
想象一下,拥有一个强大的工具,能够智能地重新评估和重新排序您的搜索结果,显著提高它们与用户查询的相关性。
在这篇博客文章中,我们将向您展示如何使用 Instructor 和 Pydantic 创建一个基于 LLM 的重排序器。这种方法将
- 提升搜索结果的准确性
- 利用大型语言模型(LLMs)的力量
- 利用结构化输出进行精确信息检索
通过本教程的学习,您将能够实现一个 LLM 重排序器,用于标记您的合成数据,以便微调传统的重排序器,或构建您的 RAG 系统评估管道。让我们深入了解吧!
设置环境¶
首先,让我们设置环境并导入必要的库
我们正在使用 instructor
库,它与 OpenAI 的 API 和 Pydantic 无缝集成,以实现结构化输出。
定义重排序模型¶
我们将使用 Pydantic 定义我们的 Label
和 RerankedResults
模型,它们用于结构化 LLM 的输出
请注意,我不仅在 Label 类中引用了 chunk_id,还要求语言模型使用思维链。这对于使用 4o Mini 或 Claude 等模型非常有用,但如果我们计划使用 o1-mini
和 o1-preview
模型,则不一定需要。
class Label(BaseModel):
chunk_id: int = Field(description="The unique identifier of the text chunk")
chain_of_thought: str = Field(
description="The reasoning process used to evaluate the relevance"
)
relevancy: int = Field(
description="Relevancy score from 0 to 10, where 10 is most relevant",
ge=0,
le=10,
)
class RerankedResults(BaseModel):
labels: list[Label] = Field(description="List of labeled and ranked chunks")
@field_validator("labels")
@classmethod
def model_validate(cls, v: list[Label]) -> list[Label]:
return sorted(v, key=lambda x: x.relevancy, reverse=True)
这些模型确保我们的 LLM 输出是结构化的,并包含一个带有相关性分数的已标记块列表。RerankedResults
模型包含一个验证器,该验证器会自动按相关性分数降序排列标签。
创建重排序器函数¶
接下来,我们将创建一个函数,该函数使用我们的 LLM 根据文本块与查询的相关性对其进行重排序
def rerank_results(query: str, chunks: list[dict]) -> RerankedResults:
return client.chat.completions.create(
model="gpt-4o-mini",
response_model=RerankedResults,
messages=[
{
"role": "system",
"content": """
You are an expert search result ranker. Your task is to evaluate the relevance of each text chunk to the given query and assign a relevancy score.
For each chunk:
1. Analyze its content in relation to the query.
2. Provide a chain of thought explaining your reasoning.
3. Assign a relevancy score from 0 to 10, where 10 is most relevant.
Be objective and consistent in your evaluations.
""",
},
{
"role": "user",
"content": """
<query>{{ query }}</query>
<chunks_to_rank>
{% for chunk in chunks %}
<chunk id="{{ chunk.id }}">
{{ chunk.text }}
</chunk>
{% endfor %}
</chunks_to_rank>
Please provide a RerankedResults object with a Label for each chunk.
""",
},
],
context={"query": query, "chunks": chunks},
)
该函数接收查询和文本块列表作为输入,将其与预定义提示一起发送到 LLM,并返回结构化的 RerankedResults
对象。多亏了 instructor,我们可以使用 jinja 模板将查询和文本块注入到提示中,只需传入 context
参数即可。
测试重排序器¶
为了测试我们的基于 LLM 的重排序器,我们可以创建一个示例查询和文本块列表。以下是使用重排序器的一个示例
def main():
query = "What are the health benefits of regular exercise?"
chunks = [
{
"id": 0,
"text": "Regular exercise can improve cardiovascular health and reduce the risk of heart disease.",
},
{
"id": 1,
"text": "The price of gym memberships varies widely depending on location and facilities.",
},
{
"id": 2,
"text": "Exercise has been shown to boost mood and reduce symptoms of depression and anxiety.",
},
{
"id": 3,
"text": "Proper nutrition is essential for maintaining a healthy lifestyle.",
},
{
"id": 4,
"text": "Strength training can increase muscle mass and improve bone density, especially important as we age.",
},
]
results = rerank_results(query, chunks)
print("Reranked results:")
for label in results.labels:
print(f"Chunk {label.chunk_id} (Relevancy: {label.relevancy}):")
print(f"Text: {chunks[label.chunk_id]['text']}")
print(f"Reasoning: {label.chain_of_thought}")
print()
if __name__ == "__main__":
main()
此测试演示了重排序器如何根据文本块与查询的相关性来评估和排序它们。完整实现可以在 examples/reranker/run.py
文件中找到。
如果您想扩展此示例,可以使用 rerank_results
函数来标记合成数据,以便微调传统的重排序器,或构建您的 RAG 系统评估管道。
此外,我们还可以在 Label.chunk_id
字段中添加验证器,以确保 chunk_id 存在于 chunks
列表中。如果标签是 uuids
或复杂字符串,并且我们想确保 chunk_id 是 chunks 列表的有效索引,这将非常有用。
这是一个示例
class Label(BaseModel):
chunk_id: int = Field(description="The unique identifier of the text chunk")
...
@field_validator("chunk_id")
@classmethod
def validate_chunk_id(cls, v: int, info: ValidationInfo) -> int:
context = info.context
chunks = context["chunks"]
if v not in [chunk["id"] for chunk in chunks]:
raise ValueError(
f"Chunk with id {v} not found, must be one of {[chunk['id'] for chunk in chunks]}"
)
return v
这将自动检查 chunk_id
是否存在于 chunks
列表中,如果不存在,则会引发 ValueError
错误。其中,context
是我们传递给 rerank_results
函数的上下文字典。