使用 OpenAI 和 Pydantic 进行文本分类¶
本教程演示了如何使用 OpenAI API 和 Pydantic 模型实现文本分类任务——特别是单标签分类和多标签分类。有关完整示例,请查看我们使用指南中的单标签分类和多标签分类示例。
动机
文本分类是许多自然语言处理(NLP)应用中的常见问题,例如垃圾邮件检测或支持工单分类。目标是提供一种系统化的方法,结合使用 OpenAI 的 GPT 模型与 Python 数据结构来处理这些情况。
单标签分类¶
定义结构¶
对于单标签分类,我们定义一个 Pydantic 模型,使用 Literal 字段来表示可能的标签。
Literals 与 Enums
对于分类标签,我们更倾向于使用 Literal
类型而非 enum
。Literals 提供了更好的类型检查,并且与 Pydantic 模型一起使用时更直接。
思维链
使用思维链已被证明可以提高预测质量约 10%。
from pydantic import BaseModel, Field
from typing import Literal
from openai import OpenAI
import instructor
# Apply the patch to the OpenAI client
# enables response_model keyword
client = instructor.from_openai(OpenAI())
class ClassificationResponse(BaseModel):
"""
A few-shot example of text classification:
Examples:
- "Buy cheap watches now!": SPAM
- "Meeting at 3 PM in the conference room": NOT_SPAM
- "You've won a free iPhone! Click here": SPAM
- "Can you pick up some milk on your way home?": NOT_SPAM
- "Increase your followers by 10000 overnight!": SPAM
"""
chain_of_thought: str = Field(
...,
description="The chain of thought that led to the prediction.",
)
label: Literal["SPAM", "NOT_SPAM"] = Field(
...,
description="The predicted class label.",
)
文本分类¶
函数 classify
将执行单标签分类。
def classify(data: str) -> ClassificationResponse:
"""Perform single-label classification on the input text."""
return client.chat.completions.create(
model="gpt-4o-mini",
response_model=ClassificationResponse,
messages=[
{
"role": "user",
"content": f"Classify the following text: <text>{data}</text>",
},
],
)
测试与评估¶
我们来运行示例,看看它是否能正确识别垃圾邮件和非垃圾邮件消息。
if __name__ == "__main__":
for text, label in [
("Hey Jason! You're awesome", "NOT_SPAM"),
("I am a nigerian prince and I need your help.", "SPAM"),
]:
prediction = classify(text)
assert prediction.label == label
print(f"Text: {text}, Predicted Label: {prediction.label}")
#> Text: Hey Jason! You're awesome, Predicted Label: NOT_SPAM
#> Text: I am a nigerian prince and I need your help., Predicted Label: SPAM
多标签分类¶
定义结构¶
对于多标签分类,我们将更新方法,使用 Literals 而非 enums,并在模型的 docstring 中包含少样本示例。
from typing import List
from pydantic import BaseModel, Field
from typing import Literal
class MultiClassPrediction(BaseModel):
"""
Class for a multi-class label prediction.
Examples:
- "My account is locked": ["TECH_ISSUE"]
- "I can't access my billing info": ["TECH_ISSUE", "BILLING"]
- "When do you close for holidays?": ["GENERAL_QUERY"]
- "My payment didn't go through and now I can't log in": ["BILLING", "TECH_ISSUE"]
"""
chain_of_thought: str = Field(
...,
description="The chain of thought that led to the prediction.",
)
class_labels: List[Literal["TECH_ISSUE", "BILLING", "GENERAL_QUERY"]] = Field(
...,
description="The predicted class labels for the support ticket.",
)
文本分类¶
函数 multi_classify
负责多标签分类。
import instructor
from openai import OpenAI
client = instructor.from_openai(OpenAI())
def multi_classify(data: str) -> MultiClassPrediction:
"""Perform multi-label classification on the input text."""
return client.chat.completions.create(
model="gpt-4o-mini",
response_model=MultiClassPrediction,
messages=[
{
"role": "user",
"content": f"Classify the following support ticket: <ticket>{data}</ticket>",
},
],
)
测试与评估¶
最后,我们使用一个示例支持工单测试多标签分类函数。
# Test multi-label classification
ticket = "My account is locked and I can't access my billing info."
prediction = multi_classify(ticket)
assert "TECH_ISSUE" in prediction.class_labels
assert "BILLING" in prediction.class_labels
print(f"Ticket: {ticket}")
#> Ticket: My account is locked and I can't access my billing info.
print(f"Predicted Labels: {prediction.class_labels}")
#> Predicted Labels: ['TECH_ISSUE', 'BILLING']
通过使用 Literals 和包含少样本示例,我们改进了单标签和多标签分类的实现。这些更改增强了类型安全性,并为 AI 模型提供了更好的指导,可能带来更准确的分类结果。