使用多数投票
不确定性路由的思维链1 提示生成多个思维链推理序列(在原始论文中,这要么是 8 个,要么是 32 个)。
然后,只有当同意该答案的思维链比例高于特定阈值时,它才将这些思维链中的多数答案作为最终解决方案。
我们可以使用 instructor
实现这一点,如下所示。
from openai import AsyncOpenAI
from pydantic import BaseModel
import instructor
from textwrap import dedent
from typing import Literal
import asyncio
from collections import Counter
client = instructor.from_openai(AsyncOpenAI())
class ChainOfThoughtResponse(BaseModel):
chain_of_thought: str
correct_answer: Literal["A", "B", "C", "D"]
async def generate_response(query: str, options: dict[str, str]):
formatted_options = "\n".join(
[f"{key}:{answer}" for key, answer in options.items()]
)
return await client.chat.completions.create(
model="gpt-4o",
response_model=ChainOfThoughtResponse,
messages=[
{
"role": "system",
"content": dedent(
f"""
You are a a world class AI who excels at answering
complex questions. Choose one of the options below
that best answers the question you are about to be
asked
<question>
{query}
</question>
<options>
{formatted_options}
</options>
"""
),
}
],
)
async def generate_batch_responses(
query: str, options: dict[str, str], num_chains: int
) -> list[ChainOfThoughtResponse]:
coros = [generate_response(query, options) for _ in range(num_chains)]
return await asyncio.gather(*coros)
if __name__ == "__main__":
question = """In a population of giraffes, an environmental
change occurs that favors individuals that are tallest. As a
result, more of the taller individuals are able to obtain
nutrients and survive to pass along their genetic information.
This is an example of"""
options = {
"A": "directional selection",
"B": "stabilizing selection",
"C": "sexual selection",
"D": "disruptive selection",
}
correct_answer = "A"
k = 8
threshold = 0.6
responses = asyncio.run(generate_batch_responses(question, options, k))
votes = Counter([response.correct_answer for response in responses])
print(votes)
#> Counter({'A': 8})
majority_vote_element, majority_vote_count = votes.most_common(1)[0]
print(majority_vote_element, majority_vote_count)
#> A 8
majority_threshold = majority_vote_count / k
if majority_threshold < threshold:
response = asyncio.run(generate_response(question, options))
response = response.correct_answer
else:
response = majority_vote_element
print(response)
#> A