跳到内容

使用 LLM 结合不同的响应

通用自我一致性1 旨在通过使用第二个 LLM 模型来判断个体响应的质量,从而扩展自我一致性。因此,我们不再根据每个推理链中最常出现的值来选择最终答案,而是提示模型为我们选择相对于 Prompt 最一致的答案。

这使我们能够支持更多不同格式的响应和答案,从而带来更多样化的输出,并提高准确性。

我们可以在 instructor 中实现这一点,如下所示。

from openai import AsyncOpenAI
from pydantic import BaseModel, Field, ValidationInfo, field_validator
import instructor
from textwrap import dedent
import asyncio

client = instructor.from_openai(AsyncOpenAI())


class Response(BaseModel):
    chain_of_thought: str
    answer: str


class SelectedResponse(BaseModel):
    most_consistent_response_id: int = Field(
        description="""The ID of the most consistent response that
        was provided"""
    )

    @field_validator("most_consistent_response_id")
    @classmethod
    def validate_id(cls, v: int, info: ValidationInfo):
        context = info.context
        number_responses = context.get("number_responses", float("inf"))

        if v > number_responses:
            raise ValueError(
                f"""Most consistent response ID {v} is greater than the
                number of responses {number_responses}. Please return a
                valid id between 0 and {number_responses-1}"""
            )
        return v


async def generate_response(query: str) -> Response:
    return await client.chat.completions.create(
        model="gpt-4o",
        response_model=Response,
        messages=[{"role": "user", "content": query}],
    )


async def generate_batch_responses(query: str, no_responses: int):
    coros = [generate_response(query) for _ in range(no_responses)]
    return await asyncio.gather(*coros)


async def select_consistent_response(responses: list[Response], query: str):
    formatted_responses = "\n".join(
        [
            f"Response {idx}: {response.chain_of_thought}. {response.answer}"
            for idx, response in enumerate(responses)
        ]
    )

    return await client.chat.completions.create(
        model="gpt-4o",
        response_model=SelectedResponse,
        messages=[
            {
                "role": "user",
                "content": dedent(
                    f"""
                <user query>
                {query}
                </user query>

                {formatted_responses}

                Evaluate these responses.
                Select the most consistent response based on majority
                consensus
                """
                ),
            }
        ],
        validation_context={"number_responses": len(responses)},
    )


if __name__ == "__main__":
    query = """The three-digit number 'ab5' is divisible by 3. How many different
     three-digit numbers can 'ab5' represent?"""
    responses = asyncio.run(generate_batch_responses(query, 3))

    for response in responses:
        print(response.model_dump_json(indent=2))
        """
        {
          "chain_of_thought": "A number is divisible by 3 if
          the sum of its digits is divisible by 3. Given the
          number 'ab5', we need to check how many different
          values of 'a' and 'b', where both are digits (0-9)
          can make the sum divisible by 3.\n\nThe sum of the
          digits is a + b + 5.\n\nWe need to find pairs (a, b)
          such that (a + b + 5) % 3 == 0.",
          "answer": "30"
        }
        """
        """
        {
          "chain_of_thought": "A number is divisible by 3 if
          the sum of its digits is divisible by 3. Let's
          denote the digits a and b. The number 'ab5' has
          digits a, b, and 5. Therefore, the sum of the
          digits is a + b + 5. Since the number is divisible
          by 3, a + b + 5 must be divisible by 3.\n\nNow,
          since a and b are single digits (0-9), we need to
          find pairs (a, b) such that a + b + 5 is divisible
          by 3. We will evaluate all possible combinations of
          values for a and b to count how many valid pairs
          (a, b) exist.\n\nLet's start by considering b's
          values:\n1. If b = 0, then a + 5 must be divisible
          by 3.\n2. If b = 1, then a + 6 must be divisible by
          3.\n3. If b = 2, then a + 7 must be divisible by
          3.\n4. If b = 3, then a + 8 must be divisible by
          3.\n5. If b = 4, then a + 9 must be divisible by
          3.\n6. If b = 5, then a + 10 must be divisible by
          3.\n7. If b = 6, then a + 11 must be divisible by
          3.\n8. If b = 7, then a + 12 must be divisible by
          3.\n9. If b = 8, then a + 13 must be divisible by
          3.\n10. If b = 9, then a + 14 must be divisible by
          3.\n\nWe will find all corresponding a values for
          each b and count the valid combinations.\n",
          "answer": "There are 30 different three-digit
          numbers that 'ab5' can represent."
        }
        """
        """
        {
          "chain_of_thought": "A number is divisible by 3 if
          the sum of its digits is divisible by 3. The given
          number is in the form 'ab5', where 'a' and 'b' are
          digits from 0 to 9. To find the total number of
          different three-digit numbers that 'ab5' can
          represent, we need to determine all possible digit
          combinations for 'a' and 'b' such that 'a + b + 5'
          is divisible by 3.",
          "answer": "30"
        }
        """

    selected_response = asyncio.run(select_consistent_response(responses, query))
    print(selected_response.model_dump_json(indent=2))
    """
    {
      "most_consistent_response_id": 0
    }
    """

    print(
        responses[selected_response.most_consistent_response_id].model_dump_json(
            indent=2
        )
    )
    """
    {
      "chain_of_thought": "A number is divisible by 3 if the sum of its digits is divisible by 3. Given the number 'ab5', we need to
      check how many different values of 'a' and 'b', where both are digits (0-9) can make the sum divisible by 3.\n\nThe sum of the
      digits is a + b + 5.\n\nWe need to find pairs (a, b) such that (a + b + 5) % 3 == 0.",
      "answer": "30"
    }
    """

参考文献

1: 用于大型语言模型生成的通用自我一致性