跳到内容

使用多数投票

不确定性路由的思维链1 提示生成多个思维链推理序列(在原始论文中,这要么是 8 个,要么是 32 个)。

然后,只有当同意该答案的思维链比例高于特定阈值时,它才将这些思维链中的多数答案作为最终解决方案。

我们可以使用 instructor 实现这一点,如下所示。

from openai import AsyncOpenAI
from pydantic import BaseModel
import instructor
from textwrap import dedent
from typing import Literal
import asyncio
from collections import Counter

client = instructor.from_openai(AsyncOpenAI())


class ChainOfThoughtResponse(BaseModel):
    chain_of_thought: str
    correct_answer: Literal["A", "B", "C", "D"]


async def generate_response(query: str, options: dict[str, str]):
    formatted_options = "\n".join(
        [f"{key}:{answer}" for key, answer in options.items()]
    )
    return await client.chat.completions.create(
        model="gpt-4o",
        response_model=ChainOfThoughtResponse,
        messages=[
            {
                "role": "system",
                "content": dedent(
                    f"""
                You are a a world class AI who excels at answering
                complex questions. Choose one of the options below
                that best answers the question you are about to be
                asked
                <question>
                {query}
                </question>

                <options>
                {formatted_options}
                </options>
                """
                ),
            }
        ],
    )


async def generate_batch_responses(
    query: str, options: dict[str, str], num_chains: int
) -> list[ChainOfThoughtResponse]:
    coros = [generate_response(query, options) for _ in range(num_chains)]
    return await asyncio.gather(*coros)


if __name__ == "__main__":
    question = """In a population of giraffes, an environmental
    change occurs that favors individuals that are tallest. As a
    result, more of the taller individuals are able to obtain
    nutrients and survive to pass along their genetic information.
    This is an example of"""

    options = {
        "A": "directional selection",
        "B": "stabilizing selection",
        "C": "sexual selection",
        "D": "disruptive selection",
    }

    correct_answer = "A"
    k = 8
    threshold = 0.6

    responses = asyncio.run(generate_batch_responses(question, options, k))
    votes = Counter([response.correct_answer for response in responses])
    print(votes)
    #> Counter({'A': 8})

    majority_vote_element, majority_vote_count = votes.most_common(1)[0]
    print(majority_vote_element, majority_vote_count)
    #> A 8
    majority_threshold = majority_vote_count / k

    if majority_threshold < threshold:
        response = asyncio.run(generate_response(question, options))
        response = response.correct_answer
    else:
        response = majority_vote_element

    print(response)
    #> A

参考文献

1: Gemini:功能强大的多模态模型系列