跳至内容

API 参考

源代码位于 instructor/client.py
def from_openai(
    client: openai.OpenAI | openai.AsyncOpenAI,
    mode: instructor.Mode = instructor.Mode.TOOLS,
    **kwargs: Any,
) -> Instructor | AsyncInstructor:
    if hasattr(client, "base_url"):
        provider = get_provider(str(client.base_url))
    else:
        provider = Provider.OPENAI

    if not isinstance(client, (openai.OpenAI, openai.AsyncOpenAI)):
        import warnings

        warnings.warn(
            "Client should be an instance of openai.OpenAI or openai.AsyncOpenAI. Unexpected behavior may occur with other client types.",
            stacklevel=2,
        )

    if provider in {Provider.OPENROUTER}:
        assert mode in {
            instructor.Mode.TOOLS,
            instructor.Mode.OPENROUTER_STRUCTURED_OUTPUTS,
            instructor.Mode.JSON,
        }

    if provider in {Provider.ANYSCALE, Provider.TOGETHER}:
        assert mode in {
            instructor.Mode.TOOLS,
            instructor.Mode.JSON,
            instructor.Mode.JSON_SCHEMA,
            instructor.Mode.MD_JSON,
        }

    if provider in {Provider.OPENAI, Provider.DATABRICKS}:
        assert mode in {
            instructor.Mode.TOOLS,
            instructor.Mode.JSON,
            instructor.Mode.FUNCTIONS,
            instructor.Mode.PARALLEL_TOOLS,
            instructor.Mode.MD_JSON,
            instructor.Mode.TOOLS_STRICT,
            instructor.Mode.JSON_O1,
        }

    if isinstance(client, openai.OpenAI):
        return Instructor(
            client=client,
            create=instructor.patch(create=client.chat.completions.create, mode=mode),
            mode=mode,
            provider=provider,
            **kwargs,
        )

    if isinstance(client, openai.AsyncOpenAI):
        return AsyncInstructor(
            client=client,
            create=instructor.patch(create=client.chat.completions.create, mode=mode),
            mode=mode,
            provider=provider,
            **kwargs,
        )

Validator

基类:OpenAISchema

验证属性是否正确,如果不正确,则返回带有错误消息的新值

源代码位于 instructor/dsl/validators.py
class Validator(OpenAISchema):
    """
    Validate if an attribute is correct and if not,
    return a new value with an error message
    """

    is_valid: bool = Field(
        default=True,
        description="Whether the attribute is valid based on the requirements",
    )
    reason: Optional[str] = Field(
        default=None,
        description="The error message if the attribute is not valid, otherwise None",
    )
    fixed_value: Optional[str] = Field(
        default=None,
        description="If the attribute is not valid, suggest a new value for the attribute",
    )

llm_validator(statement, client, allow_override=False, model='gpt-3.5-turbo', temperature=0)

创建一个使用 LLM 验证属性的验证器

用法

from instructor import llm_validator
from pydantic import BaseModel, Field, field_validator

class User(BaseModel):
    name: str = Annotated[str, llm_validator("The name must be a full name all lowercase")
    age: int = Field(description="The age of the person")

try:
    user = User(name="Jason Liu", age=20)
except ValidationError as e:
    print(e)
1 validation error for User
name
    The name is valid but not all lowercase (type=value_error.llm_validator)

请注意,这里的错误消息由 LLM 编写,错误类型为 value_error.llm_validator

参数

名称 类型 描述 默认值
statement str

要验证的语句

必需
model str

用于验证的 LLM (默认值: "gpt-3.5-turbo")

'gpt-3.5-turbo'
temperature float

LLM 使用的温度 (默认值: 0)

0
openai_client OpenAI

要使用的 OpenAI 客户端 (默认值: None)

必需
源代码位于 instructor/dsl/validators.py
def llm_validator(
    statement: str,
    client: Instructor,
    allow_override: bool = False,
    model: str = "gpt-3.5-turbo",
    temperature: float = 0,
) -> Callable[[str], str]:
    """
    Create a validator that uses the LLM to validate an attribute

    ## Usage

    ```python
    from instructor import llm_validator
    from pydantic import BaseModel, Field, field_validator

    class User(BaseModel):
        name: str = Annotated[str, llm_validator("The name must be a full name all lowercase")
        age: int = Field(description="The age of the person")

    try:
        user = User(name="Jason Liu", age=20)
    except ValidationError as e:
        print(e)
    ```

    ```
    1 validation error for User
    name
        The name is valid but not all lowercase (type=value_error.llm_validator)
    ```

    Note that there, the error message is written by the LLM, and the error type is `value_error.llm_validator`.

    Parameters:
        statement (str): The statement to validate
        model (str): The LLM to use for validation (default: "gpt-3.5-turbo-0613")
        temperature (float): The temperature to use for the LLM (default: 0)
        openai_client (OpenAI): The OpenAI client to use (default: None)
    """

    def llm(v: str) -> str:
        resp = client.chat.completions.create(
            response_model=Validator,
            messages=[
                {
                    "role": "system",
                    "content": "You are a world class validation model. Capable to determine if the following value is valid for the statement, if it is not, explain why and suggest a new value.",
                },
                {
                    "role": "user",
                    "content": f"Does `{v}` follow the rules: {statement}",
                },
            ],
            model=model,
            temperature=temperature,
        )

        # If the response is  not valid, return the reason, this could be used in
        # the future to generate a better response, via reasking mechanism.
        assert resp.is_valid, resp.reason

        if allow_override and not resp.is_valid and resp.fixed_value is not None:
            # If the value is not valid, but we allow override, return the fixed value
            return resp.fixed_value
        return v

    return llm

openai_moderation(client)

使用 OpenAI 审查模型验证消息。

仅应用于监控 OpenAI API 的输入和输出。其他用例根据 https://platform.openai.com/docs/guides/moderation/overview 的规定不允许使用。

示例

from instructor import OpenAIModeration

class Response(BaseModel):
    message: Annotated[str, AfterValidator(OpenAIModeration(openai_client=client))]

Response(message="I hate you")

 ValidationError: 1 validation error for Response
 message
Value error, `I hate you.` was flagged for ['harassment'] [type=value_error, input_value='I hate you.', input_type=str]

client (OpenAI):要使用的 OpenAI 客户端,必须是同步的 (默认值: None)

源代码位于 instructor/dsl/validators.py
def openai_moderation(client: OpenAI) -> Callable[[str], str]:
    """
    Validates a message using OpenAI moderation model.

    Should only be used for monitoring inputs and outputs of OpenAI APIs
    Other use cases are disallowed as per:
    https://platform.openai.com/docs/guides/moderation/overview

    Example:
    ```python
    from instructor import OpenAIModeration

    class Response(BaseModel):
        message: Annotated[str, AfterValidator(OpenAIModeration(openai_client=client))]

    Response(message="I hate you")
    ```

    ```
     ValidationError: 1 validation error for Response
     message
    Value error, `I hate you.` was flagged for ['harassment'] [type=value_error, input_value='I hate you.', input_type=str]
    ```

    client (OpenAI): The OpenAI client to use, must be sync (default: None)
    """

    def validate_message_with_openai_mod(v: str) -> str:
        response = client.moderations.create(input=v)
        out = response.results[0]
        cats = out.categories.model_dump()
        if out.flagged:
            raise ValueError(
                f"`{v}` was flagged for {', '.join(cat for cat in cats if cats[cat])}"
            )

        return v

    return validate_message_with_openai_mod

IterableBase

源代码位于 instructor/dsl/iterable.py
class IterableBase:
    task_type: ClassVar[Optional[type[BaseModel]]] = None

    @classmethod
    def from_streaming_response(
        cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
    ) -> Generator[BaseModel, None, None]:  # noqa: ARG003
        json_chunks = cls.extract_json(completion, mode)

        if mode in {Mode.MD_JSON, Mode.GEMINI_TOOLS}:
            json_chunks = extract_json_from_stream(json_chunks)

        if mode in {Mode.VERTEXAI_TOOLS, Mode.MISTRAL_TOOLS}:
            response = next(json_chunks)
            if not response:
                return

            json_response = json.loads(response)
            if not json_response["tasks"]:
                return

            for item in json_response["tasks"]:
                yield cls.extract_cls_task_type(json.dumps(item), **kwargs)

        yield from cls.tasks_from_chunks(json_chunks, **kwargs)

    @classmethod
    async def from_streaming_response_async(
        cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
    ) -> AsyncGenerator[BaseModel, None]:
        json_chunks = cls.extract_json_async(completion, mode)

        if mode == Mode.MD_JSON:
            json_chunks = extract_json_from_stream_async(json_chunks)

        if mode in {Mode.MISTRAL_TOOLS, Mode.VERTEXAI_TOOLS}:
            return cls.tasks_from_mistral_chunks(json_chunks, **kwargs)

        return cls.tasks_from_chunks_async(json_chunks, **kwargs)

    @classmethod
    async def tasks_from_mistral_chunks(
        cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
    ) -> AsyncGenerator[BaseModel, None]:
        """Process streaming chunks from Mistral and VertexAI.

        Handles the specific JSON format used by these providers when streaming."""

        async for chunk in json_chunks:
            if not chunk:
                continue
            json_response = json.loads(chunk)
            if not json_response["tasks"]:
                continue

            for item in json_response["tasks"]:
                obj = cls.extract_cls_task_type(json.dumps(item), **kwargs)
                yield obj

    @classmethod
    def tasks_from_chunks(
        cls, json_chunks: Iterable[str], **kwargs: Any
    ) -> Generator[BaseModel, None, None]:
        started = False
        potential_object = ""
        for chunk in json_chunks:
            potential_object += chunk
            if not started:
                if "[" in chunk:
                    started = True
                    potential_object = chunk[chunk.find("[") + 1 :]

            while True:
                task_json, potential_object = cls.get_object(potential_object, 0)
                if task_json:
                    assert cls.task_type is not None
                    obj = cls.extract_cls_task_type(task_json, **kwargs)
                    yield obj
                else:
                    break

    @classmethod
    async def tasks_from_chunks_async(
        cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
    ) -> AsyncGenerator[BaseModel, None]:
        started = False
        potential_object = ""
        async for chunk in json_chunks:
            potential_object += chunk
            if not started:
                if "[" in chunk:
                    started = True
                    potential_object = chunk[chunk.find("[") + 1 :]

            while True:
                task_json, potential_object = cls.get_object(potential_object, 0)
                if task_json:
                    assert cls.task_type is not None
                    obj = cls.extract_cls_task_type(task_json, **kwargs)
                    yield obj
                else:
                    break

    @classmethod
    def extract_cls_task_type(
        cls,
        task_json: str,
        **kwargs: Any,
    ):
        assert cls.task_type is not None
        if get_origin(cls.task_type) is Union:
            union_members = get_args(cls.task_type)
            for member in union_members:
                try:
                    obj = member.model_validate_json(task_json, **kwargs)
                    return obj
                except Exception:
                    pass
        else:
            return cls.task_type.model_validate_json(task_json, **kwargs)
        raise ValueError(
            f"Failed to extract task type with {task_json} for {cls.task_type}"
        )

    @staticmethod
    def extract_json(
        completion: Iterable[Any], mode: Mode
    ) -> Generator[str, None, None]:
        for chunk in completion:
            try:
                if mode == Mode.ANTHROPIC_JSON:
                    if json_chunk := chunk.delta.text:
                        yield json_chunk
                if mode == Mode.ANTHROPIC_TOOLS:
                    yield chunk.delta.partial_json
                if mode == Mode.GEMINI_JSON:
                    yield chunk.text
                if mode == Mode.VERTEXAI_JSON:
                    yield chunk.candidates[0].content.parts[0].text
                if mode == Mode.VERTEXAI_TOOLS:
                    yield json.dumps(
                        chunk.candidates[0].content.parts[0].function_call.args
                    )
                if mode == Mode.MISTRAL_STRUCTURED_OUTPUTS:
                    yield chunk.data.choices[0].delta.content
                if mode == Mode.MISTRAL_TOOLS:
                    if not chunk.data.choices[0].delta.tool_calls:
                        continue
                    yield chunk.data.choices[0].delta.tool_calls[0].function.arguments

                if mode in {Mode.GENAI_TOOLS}:
                    yield json.dumps(
                        chunk.candidates[0].content.parts[0].function_call.args
                    )
                if mode in {Mode.GENAI_STRUCTURED_OUTPUTS}:
                    yield chunk.candidates[0].content.parts[0].text

                if mode in {Mode.GEMINI_TOOLS}:
                    resp = chunk.candidates[0].content.parts[0].function_call
                    resp_dict = type(resp).to_dict(resp)  # type:ignore

                    if "args" in resp_dict:
                        yield json.dumps(resp_dict["args"])
                elif chunk.choices:
                    if mode == Mode.FUNCTIONS:
                        Mode.warn_mode_functions_deprecation()
                        if json_chunk := chunk.choices[0].delta.function_call.arguments:
                            yield json_chunk
                    elif mode in {
                        Mode.JSON,
                        Mode.MD_JSON,
                        Mode.JSON_SCHEMA,
                        Mode.CEREBRAS_JSON,
                        Mode.FIREWORKS_JSON,
                        Mode.PERPLEXITY_JSON,
                    }:
                        if json_chunk := chunk.choices[0].delta.content:
                            yield json_chunk
                    elif mode in {
                        Mode.TOOLS,
                        Mode.TOOLS_STRICT,
                        Mode.FIREWORKS_TOOLS,
                        Mode.WRITER_TOOLS,
                    }:
                        if json_chunk := chunk.choices[0].delta.tool_calls:
                            if json_chunk[0].function.arguments is not None:
                                yield json_chunk[0].function.arguments
                    else:
                        raise NotImplementedError(
                            f"Mode {mode} is not supported for MultiTask streaming"
                        )
            except AttributeError:
                pass

    @staticmethod
    async def extract_json_async(
        completion: AsyncGenerator[Any, None], mode: Mode
    ) -> AsyncGenerator[str, None]:
        async for chunk in completion:
            try:
                if mode == Mode.ANTHROPIC_JSON:
                    if json_chunk := chunk.delta.text:
                        yield json_chunk
                if mode == Mode.ANTHROPIC_TOOLS:
                    yield chunk.delta.partial_json
                if mode == Mode.VERTEXAI_JSON:
                    yield chunk.candidates[0].content.parts[0].text
                if mode == Mode.VERTEXAI_TOOLS:
                    yield json.dumps(
                        chunk.candidates[0].content.parts[0].function_call.args
                    )
                if mode == Mode.MISTRAL_STRUCTURED_OUTPUTS:
                    yield chunk.data.choices[0].delta.content
                if mode == Mode.MISTRAL_TOOLS:
                    if not chunk.data.choices[0].delta.tool_calls:
                        continue
                    yield chunk.data.choices[0].delta.tool_calls[0].function.arguments
                if mode == Mode.GENAI_STRUCTURED_OUTPUTS:
                    yield chunk.text
                if mode in {Mode.GENAI_TOOLS}:
                    yield json.dumps(
                        chunk.candidates[0].content.parts[0].function_call.args
                    )
                elif chunk.choices:
                    if mode == Mode.FUNCTIONS:
                        Mode.warn_mode_functions_deprecation()
                        if json_chunk := chunk.choices[0].delta.function_call.arguments:
                            yield json_chunk
                    elif mode in {
                        Mode.JSON,
                        Mode.MD_JSON,
                        Mode.JSON_SCHEMA,
                        Mode.CEREBRAS_JSON,
                        Mode.FIREWORKS_JSON,
                        Mode.PERPLEXITY_JSON,
                    }:
                        if json_chunk := chunk.choices[0].delta.content:
                            yield json_chunk
                    elif mode in {
                        Mode.TOOLS,
                        Mode.TOOLS_STRICT,
                        Mode.FIREWORKS_TOOLS,
                        Mode.WRITER_TOOLS,
                    }:
                        if json_chunk := chunk.choices[0].delta.tool_calls:
                            if json_chunk[0].function.arguments is not None:
                                yield json_chunk[0].function.arguments
                    else:
                        raise NotImplementedError(
                            f"Mode {mode} is not supported for MultiTask streaming"
                        )
            except AttributeError:
                pass

    @staticmethod
    def get_object(s: str, stack: int) -> tuple[Optional[str], str]:
        start_index = s.find("{")
        for i, c in enumerate(s):
            if c == "{":
                stack += 1
            if c == "}":
                stack -= 1
                if stack == 0:
                    return s[start_index : i + 1], s[i + 2 :]
        return None, s

tasks_from_mistral_chunks(json_chunks, **kwargs) async classmethod

处理来自 Mistral 和 VertexAI 的流式分块。

处理这些提供商在流式传输时使用的特定 JSON 格式。

源代码位于 instructor/dsl/iterable.py
@classmethod
async def tasks_from_mistral_chunks(
    cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
) -> AsyncGenerator[BaseModel, None]:
    """Process streaming chunks from Mistral and VertexAI.

    Handles the specific JSON format used by these providers when streaming."""

    async for chunk in json_chunks:
        if not chunk:
            continue
        json_response = json.loads(chunk)
        if not json_response["tasks"]:
            continue

        for item in json_response["tasks"]:
            obj = cls.extract_cls_task_type(json.dumps(item), **kwargs)
            yield obj

IterableModel(subtask_class, name=None, description=None)

动态创建一个 IterableModel OpenAISchema,可用于根据给定基类划分多个任务。这将创建一个类,可用于为特定任务创建工具包,名称和描述会自动生成。但是,它们可以被覆盖。

用法

from pydantic import BaseModel, Field
from instructor import IterableModel

class User(BaseModel):
    name: str = Field(description="The name of the person")
    age: int = Field(description="The age of the person")
    role: str = Field(description="The role of the person")

MultiUser = IterableModel(User)

结果

class MultiUser(OpenAISchema, MultiTaskBase):
    tasks: List[User] = Field(
        default_factory=list,
        repr=False,
        description="Correctly segmented list of `User` tasks",
    )

    @classmethod
    def from_streaming_response(cls, completion) -> Generator[User]:
        '''
        Parse the streaming response from OpenAI and yield a `User` object
        for each task in the response
        '''
        json_chunks = cls.extract_json(completion)
        yield from cls.tasks_from_chunks(json_chunks)

参数

名称 类型 描述 默认值
subtask_class 类型[OpenAISchema]

用于 MultiTask 的基类

必需
name 可选[str]

MultiTask 类的名称,如果为 None,则使用子任务类的名称作为 Multi{subtask_class.__name__}

None
description 可选[str]

MultiTask 类的描述,如果为 None,则描述设置为 Correct segmentation of{subtask_class.name}tasks

None

返回值

名称 类型 描述
schema OpenAISchema

可用于划分多个任务的新类

源代码位于 instructor/dsl/iterable.py
def IterableModel(
    subtask_class: type[BaseModel],
    name: Optional[str] = None,
    description: Optional[str] = None,
) -> type[BaseModel]:
    """
    Dynamically create a IterableModel OpenAISchema that can be used to segment multiple
    tasks given a base class. This creates class that can be used to create a toolkit
    for a specific task, names and descriptions are automatically generated. However
    they can be overridden.

    ## Usage

    ```python
    from pydantic import BaseModel, Field
    from instructor import IterableModel

    class User(BaseModel):
        name: str = Field(description="The name of the person")
        age: int = Field(description="The age of the person")
        role: str = Field(description="The role of the person")

    MultiUser = IterableModel(User)
    ```

    ## Result

    ```python
    class MultiUser(OpenAISchema, MultiTaskBase):
        tasks: List[User] = Field(
            default_factory=list,
            repr=False,
            description="Correctly segmented list of `User` tasks",
        )

        @classmethod
        def from_streaming_response(cls, completion) -> Generator[User]:
            '''
            Parse the streaming response from OpenAI and yield a `User` object
            for each task in the response
            '''
            json_chunks = cls.extract_json(completion)
            yield from cls.tasks_from_chunks(json_chunks)
    ```

    Parameters:
        subtask_class (Type[OpenAISchema]): The base class to use for the MultiTask
        name (Optional[str]): The name of the MultiTask class, if None then the name
            of the subtask class is used as `Multi{subtask_class.__name__}`
        description (Optional[str]): The description of the MultiTask class, if None
            then the description is set to `Correct segmentation of `{subtask_class.__name__}` tasks`

    Returns:
        schema (OpenAISchema): A new class that can be used to segment multiple tasks
    """
    task_name = subtask_class.__name__ if name is None else name

    name = f"Iterable{task_name}"

    list_tasks = (
        list[subtask_class],
        Field(
            default_factory=list,
            repr=False,
            description=f"Correctly segmented list of `{task_name}` tasks",
        ),
    )

    base_models = cast(tuple[type[BaseModel], ...], (OpenAISchema, IterableBase))
    new_cls = create_model(
        name,
        tasks=list_tasks,
        __base__=base_models,
    )
    new_cls = cast(type[IterableBase], new_cls)

    # set the class constructor BaseModel
    new_cls.task_type = subtask_class

    new_cls.__doc__ = (
        f"Correct segmentation of `{task_name}` tasks"
        if description is None
        else description
    )
    assert issubclass(
        new_cls, OpenAISchema
    ), "The new class should be a subclass of OpenAISchema"
    return new_cls

Partial

基类:Generic[T_Model]

生成一个以 PartialBase 作为基类的新类。

注意

这将启用模型在流式传输时的部分验证。

示例

Partial[SomeModel]

源代码位于 instructor/dsl/partial.py
class Partial(Generic[T_Model]):
    """Generate a new class which has PartialBase as a base class.

    Notes:
        This will enable partial validation of the model while streaming.

    Example:
        Partial[SomeModel]
    """

    def __new__(
        cls,
        *args: object,  # noqa
        **kwargs: object,  # noqa
    ) -> Partial[T_Model]:
        """Cannot instantiate.

        Raises:
            TypeError: Direct instantiation not allowed.
        """
        raise TypeError("Cannot instantiate abstract Partial class.")

    def __init_subclass__(
        cls,
        *args: object,
        **kwargs: object,
    ) -> NoReturn:
        """Cannot subclass.

        Raises:
           TypeError: Subclassing not allowed.
        """
        raise TypeError(f"Cannot subclass {cls.__module__}.Partial")

    def __class_getitem__(
        cls,
        wrapped_class: type[T_Model] | tuple[type[T_Model], type[MakeFieldsOptional]],
    ) -> type[T_Model]:
        """Convert model to one that inherits from PartialBase.

        We don't make the fields optional at this point, we just wrap them with `Partial` so the names of the nested models will be
        `Partial{ModelName}`. We want the output of `model_json_schema()` to
        reflect the name change, but everything else should be the same as the
        original model. During validation, we'll generate a true partial model
        to support partially defined fields.

        """

        make_fields_optional = None
        if isinstance(wrapped_class, tuple):
            wrapped_class, make_fields_optional = wrapped_class

        def _wrap_models(field: FieldInfo) -> tuple[object, FieldInfo]:
            tmp_field = deepcopy(field)

            annotation = field.annotation

            # Handle generics (like List, Dict, etc.)
            if get_origin(annotation) is not None:
                # Get the generic base (like List, Dict) and its arguments (like User in List[User])
                generic_base = get_origin(annotation)
                generic_args = get_args(annotation)

                modified_args = tuple(_process_generic_arg(arg) for arg in generic_args)

                # Reconstruct the generic type with modified arguments
                tmp_field.annotation = (
                    generic_base[modified_args] if generic_base else None
                )
            # If the field is a BaseModel, then recursively convert it's
            # attributes to optionals.
            elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
                tmp_field.annotation = Partial[annotation]
            return tmp_field.annotation, tmp_field

        model_name = (
            wrapped_class.__name__
            if wrapped_class.__name__.startswith("Partial")
            else f"Partial{wrapped_class.__name__}"
        )

        return create_model(
            model_name,
            __base__=(wrapped_class, PartialBase),  # type: ignore
            __module__=wrapped_class.__module__,
            **{
                field_name: (
                    _make_field_optional(field_info)
                    if make_fields_optional is not None
                    else _wrap_models(field_info)
                )
                for field_name, field_info in wrapped_class.model_fields.items()
            },  # type: ignore
        )

__class_getitem__(wrapped_class)

将模型转换为继承自 PartialBase 的模型。

此时我们不将字段设为可选,我们只是用 Partial 包装它们,因此嵌套模型的名称将是 Partial{ModelName}。我们希望 model_json_schema() 的输出反映名称更改,但其他一切应与原始模型相同。在验证期间,我们将生成一个真正的部分模型以支持部分定义的字段。

源代码位于 instructor/dsl/partial.py
def __class_getitem__(
    cls,
    wrapped_class: type[T_Model] | tuple[type[T_Model], type[MakeFieldsOptional]],
) -> type[T_Model]:
    """Convert model to one that inherits from PartialBase.

    We don't make the fields optional at this point, we just wrap them with `Partial` so the names of the nested models will be
    `Partial{ModelName}`. We want the output of `model_json_schema()` to
    reflect the name change, but everything else should be the same as the
    original model. During validation, we'll generate a true partial model
    to support partially defined fields.

    """

    make_fields_optional = None
    if isinstance(wrapped_class, tuple):
        wrapped_class, make_fields_optional = wrapped_class

    def _wrap_models(field: FieldInfo) -> tuple[object, FieldInfo]:
        tmp_field = deepcopy(field)

        annotation = field.annotation

        # Handle generics (like List, Dict, etc.)
        if get_origin(annotation) is not None:
            # Get the generic base (like List, Dict) and its arguments (like User in List[User])
            generic_base = get_origin(annotation)
            generic_args = get_args(annotation)

            modified_args = tuple(_process_generic_arg(arg) for arg in generic_args)

            # Reconstruct the generic type with modified arguments
            tmp_field.annotation = (
                generic_base[modified_args] if generic_base else None
            )
        # If the field is a BaseModel, then recursively convert it's
        # attributes to optionals.
        elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
            tmp_field.annotation = Partial[annotation]
        return tmp_field.annotation, tmp_field

    model_name = (
        wrapped_class.__name__
        if wrapped_class.__name__.startswith("Partial")
        else f"Partial{wrapped_class.__name__}"
    )

    return create_model(
        model_name,
        __base__=(wrapped_class, PartialBase),  # type: ignore
        __module__=wrapped_class.__module__,
        **{
            field_name: (
                _make_field_optional(field_info)
                if make_fields_optional is not None
                else _wrap_models(field_info)
            )
            for field_name, field_info in wrapped_class.model_fields.items()
        },  # type: ignore
    )

__init_subclass__(*args, **kwargs)

无法继承。

引发

类型 描述
TypeError

不允许继承。

源代码位于 instructor/dsl/partial.py
def __init_subclass__(
    cls,
    *args: object,
    **kwargs: object,
) -> NoReturn:
    """Cannot subclass.

    Raises:
       TypeError: Subclassing not allowed.
    """
    raise TypeError(f"Cannot subclass {cls.__module__}.Partial")

__new__(*args, **kwargs)

无法实例化。

引发

类型 描述
TypeError

不允许直接实例化。

源代码位于 instructor/dsl/partial.py
def __new__(
    cls,
    *args: object,  # noqa
    **kwargs: object,  # noqa
) -> Partial[T_Model]:
    """Cannot instantiate.

    Raises:
        TypeError: Direct instantiation not allowed.
    """
    raise TypeError("Cannot instantiate abstract Partial class.")

PartialBase

基类:Generic[T_Model]

源代码位于 instructor/dsl/partial.py
class PartialBase(Generic[T_Model]):
    @classmethod
    @cache
    def get_partial_model(cls) -> type[T_Model]:
        """Return a partial model we can use to validate partial results."""
        assert issubclass(cls, BaseModel), (
            f"{cls.__name__} must be a subclass of BaseModel"
        )

        model_name = (
            cls.__name__
            if cls.__name__.startswith("Partial")
            else f"Partial{cls.__name__}"
        )

        return create_model(
            model_name,
            __base__=cls,
            __module__=cls.__module__,
            **{
                field_name: _make_field_optional(field_info)
                for field_name, field_info in cls.model_fields.items()
            },  # type: ignore[all]
        )

    @classmethod
    def from_streaming_response(
        cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
    ) -> Generator[T_Model, None, None]:
        json_chunks = cls.extract_json(completion, mode)

        if mode in {Mode.MD_JSON, Mode.GEMINI_TOOLS}:
            json_chunks = extract_json_from_stream(json_chunks)

        if mode == Mode.WRITER_TOOLS:
            yield from cls.writer_model_from_chunks(json_chunks, **kwargs)
        else:
            yield from cls.model_from_chunks(json_chunks, **kwargs)

    @classmethod
    async def from_streaming_response_async(
        cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
    ) -> AsyncGenerator[T_Model, None]:
        json_chunks = cls.extract_json_async(completion, mode)

        if mode == Mode.MD_JSON:
            json_chunks = extract_json_from_stream_async(json_chunks)
        elif mode == Mode.WRITER_TOOLS:
            return cls.writer_model_from_chunks_async(json_chunks, **kwargs)

        return cls.model_from_chunks_async(json_chunks, **kwargs)

    @classmethod
    def writer_model_from_chunks(
        cls, json_chunks: Iterable[Any], **kwargs: Any
    ) -> Generator[T_Model, None, None]:
        potential_object = ""
        partial_model = cls.get_partial_model()
        partial_mode = (
            "on" if issubclass(cls, PartialLiteralMixin) else "trailing-strings"
        )
        for chunk in json_chunks:
            if len(chunk) > len(potential_object):
                potential_object = chunk
            else:
                potential_object += chunk
            obj = from_json(
                (potential_object.strip() or "{}").encode(), partial_mode=partial_mode
            )
            obj = partial_model.model_validate(obj, strict=None, **kwargs)
            yield obj

    @classmethod
    async def writer_model_from_chunks_async(
        cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
    ) -> AsyncGenerator[T_Model, None]:
        potential_object = ""
        partial_model = cls.get_partial_model()
        partial_mode = (
            "on" if issubclass(cls, PartialLiteralMixin) else "trailing-strings"
        )
        async for chunk in json_chunks:
            if len(chunk) > len(potential_object):
                potential_object = chunk
            else:
                potential_object += chunk
            obj = from_json(
                (potential_object.strip() or "{}").encode(), partial_mode=partial_mode
            )
            obj = partial_model.model_validate(obj, strict=None, **kwargs)
            yield obj

    @classmethod
    def model_from_chunks(
        cls, json_chunks: Iterable[Any], **kwargs: Any
    ) -> Generator[T_Model, None, None]:
        potential_object = ""
        partial_model = cls.get_partial_model()
        partial_mode = (
            "on" if issubclass(cls, PartialLiteralMixin) else "trailing-strings"
        )
        chunk_buffer = []
        for chunk in json_chunks:
            chunk_buffer += chunk
            if len(chunk_buffer) < 2:
                continue
            potential_object += remove_control_chars("".join(chunk_buffer))
            chunk_buffer = []
            obj = process_potential_object(
                potential_object, partial_mode, partial_model, **kwargs
            )
            yield obj
        if chunk_buffer:
            potential_object += remove_control_chars(chunk_buffer[0])
            obj = process_potential_object(
                potential_object, partial_mode, partial_model, **kwargs
            )
            yield obj

    @classmethod
    async def model_from_chunks_async(
        cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
    ) -> AsyncGenerator[T_Model, None]:
        potential_object = ""
        partial_model = cls.get_partial_model()
        partial_mode = (
            "on" if issubclass(cls, PartialLiteralMixin) else "trailing-strings"
        )
        async for chunk in json_chunks:
            potential_object += chunk
            obj = from_json(
                (potential_object.strip() or "{}").encode(), partial_mode=partial_mode
            )
            obj = partial_model.model_validate(obj, strict=None, **kwargs)
            yield obj

    @staticmethod
    def extract_json(
        completion: Iterable[Any], mode: Mode
    ) -> Generator[str, None, None]:
        """Extract JSON chunks from various LLM provider streaming responses.

        Each provider has a different structure for streaming responses that needs
        specific handling to extract the relevant JSON data."""
        for chunk in completion:
            try:
                if mode == Mode.MISTRAL_STRUCTURED_OUTPUTS:
                    yield chunk.data.choices[0].delta.content
                if mode == Mode.MISTRAL_TOOLS:
                    if not chunk.data.choices[0].delta.tool_calls:
                        continue
                    yield chunk.data.choices[0].delta.tool_calls[0].function.arguments
                if mode == Mode.ANTHROPIC_JSON:
                    if json_chunk := chunk.delta.text:
                        yield json_chunk
                if mode == Mode.ANTHROPIC_TOOLS:
                    yield chunk.delta.partial_json
                if mode == Mode.VERTEXAI_JSON:
                    yield chunk.candidates[0].content.parts[0].text
                if mode == Mode.VERTEXAI_TOOLS:
                    yield json.dumps(
                        chunk.candidates[0].content.parts[0].function_call.args
                    )
                if mode == Mode.GENAI_STRUCTURED_OUTPUTS:
                    yield chunk.text
                if mode == Mode.GENAI_TOOLS:
                    fc = chunk.candidates[0].content.parts[0].function_call.args
                    yield json.dumps(fc)
                if mode == Mode.GEMINI_JSON:
                    yield chunk.text
                if mode == Mode.GEMINI_TOOLS:
                    resp = chunk.candidates[0].content.parts[0].function_call
                    resp_dict = type(resp).to_dict(resp)  # type:ignore
                    if "args" in resp_dict:
                        yield json.dumps(resp_dict["args"])
                elif chunk.choices:
                    if mode == Mode.FUNCTIONS:
                        Mode.warn_mode_functions_deprecation()
                        if json_chunk := chunk.choices[0].delta.function_call.arguments:
                            yield json_chunk
                    elif mode in {
                        Mode.JSON,
                        Mode.MD_JSON,
                        Mode.JSON_SCHEMA,
                        Mode.CEREBRAS_JSON,
                        Mode.FIREWORKS_JSON,
                        Mode.PERPLEXITY_JSON,
                    }:
                        if json_chunk := chunk.choices[0].delta.content:
                            yield json_chunk
                    elif mode in {
                        Mode.TOOLS,
                        Mode.TOOLS_STRICT,
                        Mode.FIREWORKS_TOOLS,
                        Mode.WRITER_TOOLS,
                    }:
                        if json_chunk := chunk.choices[0].delta.tool_calls:
                            if json_chunk[0].function.arguments:
                                yield json_chunk[0].function.arguments
                    else:
                        raise NotImplementedError(
                            f"Mode {mode} is not supported for MultiTask streaming"
                        )
            except AttributeError:
                pass

    @staticmethod
    async def extract_json_async(
        completion: AsyncGenerator[Any, None], mode: Mode
    ) -> AsyncGenerator[str, None]:
        async for chunk in completion:
            try:
                if mode == Mode.ANTHROPIC_JSON:
                    if json_chunk := chunk.delta.text:
                        yield json_chunk
                if mode == Mode.ANTHROPIC_TOOLS:
                    yield chunk.delta.partial_json
                if mode == Mode.MISTRAL_STRUCTURED_OUTPUTS:
                    yield chunk.data.choices[0].delta.content
                if mode == Mode.MISTRAL_TOOLS:
                    if not chunk.data.choices[0].delta.tool_calls:
                        continue
                    yield chunk.data.choices[0].delta.tool_calls[0].function.arguments
                if mode == Mode.VERTEXAI_JSON:
                    yield chunk.candidates[0].content.parts[0].text
                if mode == Mode.VERTEXAI_TOOLS:
                    yield json.dumps(
                        chunk.candidates[0].content.parts[0].function_call.args
                    )
                if mode == Mode.GENAI_STRUCTURED_OUTPUTS:
                    yield chunk.text
                if mode == Mode.GENAI_TOOLS:
                    fc = chunk.candidates[0].content.parts[0].function_call.args
                    yield json.dumps(fc)
                if mode == Mode.GEMINI_JSON:
                    yield chunk.text
                if mode == Mode.GEMINI_TOOLS:
                    resp = chunk.candidates[0].content.parts[0].function_call
                    resp_dict = type(resp).to_dict(resp)  # type:ignore
                    if "args" in resp_dict:
                        yield json.dumps(resp_dict["args"])
                elif chunk.choices:
                    if mode == Mode.FUNCTIONS:
                        Mode.warn_mode_functions_deprecation()
                        if json_chunk := chunk.choices[0].delta.function_call.arguments:
                            yield json_chunk
                    elif mode in {
                        Mode.JSON,
                        Mode.MD_JSON,
                        Mode.JSON_SCHEMA,
                        Mode.CEREBRAS_JSON,
                        Mode.FIREWORKS_JSON,
                        Mode.PERPLEXITY_JSON,
                    }:
                        if json_chunk := chunk.choices[0].delta.content:
                            yield json_chunk
                    elif mode in {
                        Mode.TOOLS,
                        Mode.TOOLS_STRICT,
                        Mode.FIREWORKS_TOOLS,
                        Mode.WRITER_TOOLS,
                    }:
                        if json_chunk := chunk.choices[0].delta.tool_calls:
                            if json_chunk[0].function.arguments:
                                yield json_chunk[0].function.arguments
                    else:
                        raise NotImplementedError(
                            f"Mode {mode} is not supported for MultiTask streaming"
                        )
            except AttributeError:
                pass

extract_json(completion, mode) staticmethod

从各种 LLM 提供商的流式响应中提取 JSON 分块。

每个提供商的流式响应结构不同,需要特殊处理来提取相关的 JSON 数据。

源代码位于 instructor/dsl/partial.py
@staticmethod
def extract_json(
    completion: Iterable[Any], mode: Mode
) -> Generator[str, None, None]:
    """Extract JSON chunks from various LLM provider streaming responses.

    Each provider has a different structure for streaming responses that needs
    specific handling to extract the relevant JSON data."""
    for chunk in completion:
        try:
            if mode == Mode.MISTRAL_STRUCTURED_OUTPUTS:
                yield chunk.data.choices[0].delta.content
            if mode == Mode.MISTRAL_TOOLS:
                if not chunk.data.choices[0].delta.tool_calls:
                    continue
                yield chunk.data.choices[0].delta.tool_calls[0].function.arguments
            if mode == Mode.ANTHROPIC_JSON:
                if json_chunk := chunk.delta.text:
                    yield json_chunk
            if mode == Mode.ANTHROPIC_TOOLS:
                yield chunk.delta.partial_json
            if mode == Mode.VERTEXAI_JSON:
                yield chunk.candidates[0].content.parts[0].text
            if mode == Mode.VERTEXAI_TOOLS:
                yield json.dumps(
                    chunk.candidates[0].content.parts[0].function_call.args
                )
            if mode == Mode.GENAI_STRUCTURED_OUTPUTS:
                yield chunk.text
            if mode == Mode.GENAI_TOOLS:
                fc = chunk.candidates[0].content.parts[0].function_call.args
                yield json.dumps(fc)
            if mode == Mode.GEMINI_JSON:
                yield chunk.text
            if mode == Mode.GEMINI_TOOLS:
                resp = chunk.candidates[0].content.parts[0].function_call
                resp_dict = type(resp).to_dict(resp)  # type:ignore
                if "args" in resp_dict:
                    yield json.dumps(resp_dict["args"])
            elif chunk.choices:
                if mode == Mode.FUNCTIONS:
                    Mode.warn_mode_functions_deprecation()
                    if json_chunk := chunk.choices[0].delta.function_call.arguments:
                        yield json_chunk
                elif mode in {
                    Mode.JSON,
                    Mode.MD_JSON,
                    Mode.JSON_SCHEMA,
                    Mode.CEREBRAS_JSON,
                    Mode.FIREWORKS_JSON,
                    Mode.PERPLEXITY_JSON,
                }:
                    if json_chunk := chunk.choices[0].delta.content:
                        yield json_chunk
                elif mode in {
                    Mode.TOOLS,
                    Mode.TOOLS_STRICT,
                    Mode.FIREWORKS_TOOLS,
                    Mode.WRITER_TOOLS,
                }:
                    if json_chunk := chunk.choices[0].delta.tool_calls:
                        if json_chunk[0].function.arguments:
                            yield json_chunk[0].function.arguments
                else:
                    raise NotImplementedError(
                        f"Mode {mode} is not supported for MultiTask streaming"
                    )
        except AttributeError:
            pass

get_partial_model() cached classmethod

返回一个可用于验证部分结果的部分模型。

源代码位于 instructor/dsl/partial.py
@classmethod
@cache
def get_partial_model(cls) -> type[T_Model]:
    """Return a partial model we can use to validate partial results."""
    assert issubclass(cls, BaseModel), (
        f"{cls.__name__} must be a subclass of BaseModel"
    )

    model_name = (
        cls.__name__
        if cls.__name__.startswith("Partial")
        else f"Partial{cls.__name__}"
    )

    return create_model(
        model_name,
        __base__=cls,
        __module__=cls.__module__,
        **{
            field_name: _make_field_optional(field_info)
            for field_name, field_info in cls.model_fields.items()
        },  # type: ignore[all]
    )

MaybeBase

基类:BaseModel, Generic[T]

从模型中提取结果(如果存在),否则设置错误和消息字段。

源代码位于 instructor/dsl/maybe.py
class MaybeBase(BaseModel, Generic[T]):
    """
    Extract a result from a model, if any, otherwise set the error and message fields.
    """

    result: Optional[T]
    error: bool = Field(default=False)
    message: Optional[str]

    def __bool__(self) -> bool:
        return self.result is not None

Maybe(model)

为给定的 Pydantic 模型创建一个 Maybe 模型。这允许您返回一个包含 resulterrormessage 字段的模型,用于数据可能不存在于上下文中的情况。

用法

from pydantic import BaseModel, Field
from instructor import Maybe

class User(BaseModel):
    name: str = Field(description="The name of the person")
    age: int = Field(description="The age of the person")
    role: str = Field(description="The role of the person")

MaybeUser = Maybe(User)

结果

class MaybeUser(BaseModel):
    result: Optional[User]
    error: bool = Field(default=False)
    message: Optional[str]

    def __bool__(self):
        return self.result is not None

参数

名称 类型 描述 默认值
model 类型[BaseModel]

用 Maybe 包装的 Pydantic 模型。

必需

返回值

名称 类型 描述
MaybeModel 类型[BaseModel]

包含 resulterrormessage 字段的新 Pydantic 模型。

源代码位于 instructor/dsl/maybe.py
def Maybe(model: type[T]) -> type[MaybeBase[T]]:
    """
    Create a Maybe model for a given Pydantic model. This allows you to return a model that includes fields for `result`, `error`, and `message` for sitatations where the data may not be present in the context.

    ## Usage

    ```python
    from pydantic import BaseModel, Field
    from instructor import Maybe

    class User(BaseModel):
        name: str = Field(description="The name of the person")
        age: int = Field(description="The age of the person")
        role: str = Field(description="The role of the person")

    MaybeUser = Maybe(User)
    ```

    ## Result

    ```python
    class MaybeUser(BaseModel):
        result: Optional[User]
        error: bool = Field(default=False)
        message: Optional[str]

        def __bool__(self):
            return self.result is not None
    ```

    Parameters:
        model (Type[BaseModel]): The Pydantic model to wrap with Maybe.

    Returns:
        MaybeModel (Type[BaseModel]): A new Pydantic model that includes fields for `result`, `error`, and `message`.
    """
    return create_model(
        f"Maybe{model.__name__}",
        __base__=MaybeBase,
        result=(
            Optional[model],
            Field(
                default=None,
                description="Correctly extracted result from the model, if any, otherwise None",
            ),
        ),
        error=(bool, Field(default=False)),
        message=(
            Optional[str],
            Field(
                default=None,
                description="Error message if no result was found, should be short and concise",
            ),
        ),
    )

OpenAISchema

基类:BaseModel

源代码位于 instructor/function_calls.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
class OpenAISchema(BaseModel):
    # Ignore classproperty, since Pydantic doesn't understand it like it would a normal property.
    model_config = ConfigDict(ignored_types=(classproperty,))

    @classproperty
    def openai_schema(cls) -> dict[str, Any]:
        """
        Return the schema in the format of OpenAI's schema as jsonschema

        Note:
            Its important to add a docstring to describe how to best use this class, it will be included in the description attribute and be part of the prompt.

        Returns:
            model_json_schema (dict): A dictionary in the format of OpenAI's schema as jsonschema
        """
        schema = cls.model_json_schema()
        docstring = parse(cls.__doc__ or "")
        parameters = {
            k: v for k, v in schema.items() if k not in ("title", "description")
        }
        for param in docstring.params:
            if (name := param.arg_name) in parameters["properties"] and (
                description := param.description
            ):
                if "description" not in parameters["properties"][name]:
                    parameters["properties"][name]["description"] = description

        parameters["required"] = sorted(
            k for k, v in parameters["properties"].items() if "default" not in v
        )

        if "description" not in schema:
            if docstring.short_description:
                schema["description"] = docstring.short_description
            else:
                schema["description"] = (
                    f"Correctly extracted `{cls.__name__}` with all "
                    f"the required parameters with correct types"
                )

        return {
            "name": schema["title"],
            "description": schema["description"],
            "parameters": parameters,
        }

    @classproperty
    def anthropic_schema(cls) -> dict[str, Any]:
        # Generate the Anthropic schema based on the OpenAI schema to avoid redundant schema generation
        openai_schema = cls.openai_schema
        return {
            "name": openai_schema["name"],
            "description": openai_schema["description"],
            "input_schema": cls.model_json_schema(),
        }

    @classproperty
    def gemini_schema(cls) -> Any:
        import google.generativeai.types as genai_types

        # Use OpenAI schema
        openai_schema = cls.openai_schema

        # Transform to Gemini format
        function = genai_types.FunctionDeclaration(
            name=openai_schema["name"],
            description=openai_schema["description"],
            parameters=map_to_gemini_function_schema(openai_schema["parameters"]),
        )

        return function

    @classmethod
    def from_response(
        cls,
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
        mode: Mode = Mode.TOOLS,
    ) -> BaseModel:
        """Execute the function from the response of an openai chat completion

        Parameters:
            completion (openai.ChatCompletion): The response from an openai chat completion
            throw_error (bool): Whether to throw an error if the function call is not detected
            context (dict): The context to use for validating the response
            strict (bool): Whether to use strict json parsing
            mode (Mode): The openai completion mode

        Returns:
            cls (OpenAISchema): An instance of the class
        """

        if mode == Mode.ANTHROPIC_TOOLS:
            return cls.parse_anthropic_tools(completion, validation_context, strict)

        if mode == Mode.ANTHROPIC_TOOLS or mode == Mode.ANTHROPIC_REASONING_TOOLS:
            return cls.parse_anthropic_tools(completion, validation_context, strict)

        if mode == Mode.ANTHROPIC_JSON:
            return cls.parse_anthropic_json(completion, validation_context, strict)

        if mode == Mode.BEDROCK_JSON:
            return cls.parse_bedrock_json(completion, validation_context, strict)

        if mode in {Mode.VERTEXAI_TOOLS, Mode.GEMINI_TOOLS}:
            return cls.parse_vertexai_tools(completion, validation_context)

        if mode == Mode.VERTEXAI_JSON:
            return cls.parse_vertexai_json(completion, validation_context, strict)

        if mode == Mode.COHERE_TOOLS:
            return cls.parse_cohere_tools(completion, validation_context, strict)

        if mode == Mode.GEMINI_JSON:
            return cls.parse_gemini_json(completion, validation_context, strict)

        if mode == Mode.GENAI_STRUCTURED_OUTPUTS:
            return cls.parse_genai_structured_outputs(
                completion, validation_context, strict
            )

        if mode == Mode.GEMINI_TOOLS:
            return cls.parse_gemini_tools(completion, validation_context, strict)

        if mode == Mode.GENAI_TOOLS:
            return cls.parse_genai_tools(completion, validation_context, strict)

        if mode == Mode.COHERE_JSON_SCHEMA:
            return cls.parse_cohere_json_schema(completion, validation_context, strict)

        if mode == Mode.WRITER_TOOLS:
            return cls.parse_writer_tools(completion, validation_context, strict)

        if not completion.choices:
            # This helps catch errors from OpenRouter
            if hasattr(completion, "error"):
                raise ValueError(completion.error)

            raise ValueError("No completion choices found")

        if completion.choices[0].finish_reason == "length":
            raise IncompleteOutputException(last_completion=completion)

        if mode == Mode.FUNCTIONS:
            Mode.warn_mode_functions_deprecation()
            return cls.parse_functions(completion, validation_context, strict)

        if mode == Mode.MISTRAL_STRUCTURED_OUTPUTS:
            return cls.parse_mistral_structured_outputs(
                completion, validation_context, strict
            )

        if mode in {
            Mode.TOOLS,
            Mode.MISTRAL_TOOLS,
            Mode.TOOLS_STRICT,
            Mode.CEREBRAS_TOOLS,
            Mode.FIREWORKS_TOOLS,
        }:
            return cls.parse_tools(completion, validation_context, strict)

        if mode in {
            Mode.JSON,
            Mode.JSON_SCHEMA,
            Mode.MD_JSON,
            Mode.JSON_O1,
            Mode.CEREBRAS_JSON,
            Mode.FIREWORKS_JSON,
            Mode.PERPLEXITY_JSON,
            Mode.OPENROUTER_STRUCTURED_OUTPUTS,
        }:
            return cls.parse_json(completion, validation_context, strict)

        raise ValueError(f"Invalid patch mode: {mode}")

    @classmethod
    def parse_genai_structured_outputs(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        return cls.model_validate_json(
            completion.text, context=validation_context, strict=strict
        )

    @classmethod
    def parse_genai_tools(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        from google.genai import types

        assert isinstance(completion, types.GenerateContentResponse)
        assert len(completion.candidates) == 1

        assert len(completion.candidates[0].content.parts) == 1, (
            f"Instructor does not support multiple function calls, use List[Model] instead"
        )
        function_call = completion.candidates[0].content.parts[0].function_call
        assert function_call is not None, (
            f"Please return your response as a function call with the schema {cls.openai_schema} and the name {cls.openai_schema['name']}"
        )

        assert function_call.name == cls.openai_schema["name"]
        return cls.model_validate(
            obj=function_call.args, context=validation_context, strict=strict
        )

    @classmethod
    def parse_cohere_json_schema(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ):
        assert hasattr(completion, "text"), (
            "Completion is not of type NonStreamedChatResponse"
        )
        return cls.model_validate_json(
            completion.text, context=validation_context, strict=strict
        )

    @classmethod
    def parse_anthropic_tools(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        from anthropic.types import Message

        if isinstance(completion, Message) and completion.stop_reason == "max_tokens":
            raise IncompleteOutputException(last_completion=completion)

        # Anthropic returns arguments as a dict, dump to json for model validation below
        tool_calls = [
            json.dumps(c.input) for c in completion.content if c.type == "tool_use"
        ]  # TODO update with anthropic specific types

        tool_calls_validator = TypeAdapter(
            Annotated[list[Any], Field(min_length=1, max_length=1)]
        )
        tool_call = tool_calls_validator.validate_python(tool_calls)[0]

        return cls.model_validate_json(
            tool_call, context=validation_context, strict=strict
        )

    @classmethod
    def parse_anthropic_json(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        from anthropic.types import Message

        last_block = None

        if hasattr(completion, "choices"):
            completion = completion.choices[0]
            if completion.finish_reason == "length":
                raise IncompleteOutputException(last_completion=completion)
            text = completion.message.content
        else:
            assert isinstance(completion, Message)
            if completion.stop_reason == "max_tokens":
                raise IncompleteOutputException(last_completion=completion)
            # Find the last text block in the completion
            # this is because the completion is a list of blocks
            # and the last block is the one that contains the text ideally
            # this could happen due to things like multiple tool calls
            # read: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#response
            text_blocks = [c for c in completion.content if c.type == "text"]
            last_block = text_blocks[-1]
            # strip (\u0000-\u001F) control characters
            text = re.sub(r"[\u0000-\u001F]", "", last_block.text)

        extra_text = extract_json_from_codeblock(text)

        if strict:
            model = cls.model_validate_json(
                extra_text, context=validation_context, strict=True
            )
        else:
            # Allow control characters.
            parsed = json.loads(extra_text, strict=False)
            # Pydantic non-strict: https://docs.pydantic.org.cn/latest/concepts/strict_mode/
            model = cls.model_validate(parsed, context=validation_context, strict=False)

        return model

    @classmethod
    def parse_bedrock_json(
        cls: type[BaseModel],
        completion: Any,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        if isinstance(completion, dict):
            text = completion.get("output").get("message").get("content")[0].get("text")

            match = re.search(r"```?json(.*?)```?", text, re.DOTALL)
            if match:
                text = match.group(1).strip()

            text = re.sub(r"```?json|\\n", "", text).strip()
        else:
            text = completion.text
        return cls.model_validate_json(text, context=validation_context, strict=strict)

    @classmethod
    def parse_gemini_json(
        cls: type[BaseModel],
        completion: Any,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        try:
            text = completion.text
        except ValueError:
            logger.debug(
                f"Error response: {completion.result.candidates[0].finish_reason}\n\n{completion.result.candidates[0].safety_ratings}"
            )

        try:
            extra_text = extract_json_from_codeblock(text)  # type: ignore
        except UnboundLocalError:
            raise ValueError("Unable to extract JSON from completion text") from None

        if strict:
            return cls.model_validate_json(
                extra_text, context=validation_context, strict=True
            )
        else:
            # Allow control characters.
            parsed = json.loads(extra_text, strict=False)
            # Pydantic non-strict: https://docs.pydantic.org.cn/latest/concepts/strict_mode/
            return cls.model_validate(parsed, context=validation_context, strict=False)

    @classmethod
    def parse_vertexai_tools(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
    ) -> BaseModel:
        tool_call = completion.candidates[0].content.parts[0].function_call.args  # type: ignore
        model = {}
        for field in tool_call:  # type: ignore
            model[field] = tool_call[field]
        # We enable strict=False because the conversion from protobuf -> dict often results in types like ints being cast to floats, as a result in order for model.validate to work we need to disable strict mode.
        return cls.model_validate(model, context=validation_context, strict=False)

    @classmethod
    def parse_vertexai_json(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        return cls.model_validate_json(
            completion.text, context=validation_context, strict=strict
        )

    @classmethod
    def parse_cohere_tools(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        text = cast(str, completion.text)  # type: ignore - TODO update with cohere specific types
        extra_text = extract_json_from_codeblock(text)
        return cls.model_validate_json(
            extra_text, context=validation_context, strict=strict
        )

    @classmethod
    def parse_writer_tools(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        message = completion.choices[0].message
        tool_calls = message.tool_calls
        assert len(tool_calls) == 1, (
            "Instructor does not support multiple tool calls, use List[Model] instead"
        )
        assert tool_calls[0].function.name == cls.openai_schema["name"], (
            "Tool name does not match"
        )
        return cls.model_validate_json(
            tool_calls[0].function.arguments,
            context=validation_context,
            strict=strict,
        )

    @classmethod
    def parse_functions(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        message = completion.choices[0].message
        assert (
            message.function_call.name == cls.openai_schema["name"]  # type: ignore[index]
        ), "Function name does not match"
        return cls.model_validate_json(
            message.function_call.arguments,  # type: ignore[attr-defined]
            context=validation_context,
            strict=strict,
        )

    @classmethod
    def parse_tools(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        message = completion.choices[0].message
        # this field seems to be missing when using instructor with some other tools (e.g. litellm)
        # trying to fix this by adding a check

        if hasattr(message, "refusal"):
            assert message.refusal is None, (
                f"Unable to generate a response due to {message.refusal}"
            )
        assert len(message.tool_calls or []) == 1, (
            f"Instructor does not support multiple tool calls, use List[Model] instead"
        )
        tool_call = message.tool_calls[0]  # type: ignore
        assert (
            tool_call.function.name == cls.openai_schema["name"]  # type: ignore[index]
        ), "Tool name does not match"
        return cls.model_validate_json(
            tool_call.function.arguments,  # type: ignore
            context=validation_context,
            strict=strict,
        )

    @classmethod
    def parse_mistral_structured_outputs(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        if not completion.choices or len(completion.choices) > 1:
            raise ValueError(
                "Instructor does not support multiple tool calls, use list[Model] instead"
            )

        message = completion.choices[0].message

        return cls.model_validate_json(
            message.content, context=validation_context, strict=strict
        )

    @classmethod
    def parse_json(
        cls: type[BaseModel],
        completion: ChatCompletion,
        validation_context: Optional[dict[str, Any]] = None,
        strict: Optional[bool] = None,
    ) -> BaseModel:
        """Parse JSON mode responses using the optimized extraction and validation."""
        # Check for incomplete output
        _handle_incomplete_output(completion)

        # Extract text from the response
        message = _extract_text_content(completion)
        if not message:
            # Fallback for OpenAI format if _extract_text_content doesn't handle it
            message = completion.choices[0].message.content or ""

        # Extract JSON from the text
        json_content = extract_json_from_codeblock(message)

        # Validate the model from the JSON
        return _validate_model_from_json(cls, json_content, validation_context, strict)

from_response(completion, validation_context=None, strict=None, mode=Mode.TOOLS) classmethod

根据 openai 聊天补全的响应执行函数

参数

名称 类型 描述 默认值
completion ChatCompletion

来自 openai 聊天补全的响应

必需
throw_error bool

如果未检测到函数调用是否抛出错误

必需
context dict

用于验证响应的上下文

必需
strict bool

是否使用严格的 json 解析

None
mode Mode

openai 补全模式

TOOLS

返回值

名称 类型 描述
cls OpenAISchema

类的实例

源代码位于 instructor/function_calls.py
@classmethod
def from_response(
    cls,
    completion: ChatCompletion,
    validation_context: Optional[dict[str, Any]] = None,
    strict: Optional[bool] = None,
    mode: Mode = Mode.TOOLS,
) -> BaseModel:
    """Execute the function from the response of an openai chat completion

    Parameters:
        completion (openai.ChatCompletion): The response from an openai chat completion
        throw_error (bool): Whether to throw an error if the function call is not detected
        context (dict): The context to use for validating the response
        strict (bool): Whether to use strict json parsing
        mode (Mode): The openai completion mode

    Returns:
        cls (OpenAISchema): An instance of the class
    """

    if mode == Mode.ANTHROPIC_TOOLS:
        return cls.parse_anthropic_tools(completion, validation_context, strict)

    if mode == Mode.ANTHROPIC_TOOLS or mode == Mode.ANTHROPIC_REASONING_TOOLS:
        return cls.parse_anthropic_tools(completion, validation_context, strict)

    if mode == Mode.ANTHROPIC_JSON:
        return cls.parse_anthropic_json(completion, validation_context, strict)

    if mode == Mode.BEDROCK_JSON:
        return cls.parse_bedrock_json(completion, validation_context, strict)

    if mode in {Mode.VERTEXAI_TOOLS, Mode.GEMINI_TOOLS}:
        return cls.parse_vertexai_tools(completion, validation_context)

    if mode == Mode.VERTEXAI_JSON:
        return cls.parse_vertexai_json(completion, validation_context, strict)

    if mode == Mode.COHERE_TOOLS:
        return cls.parse_cohere_tools(completion, validation_context, strict)

    if mode == Mode.GEMINI_JSON:
        return cls.parse_gemini_json(completion, validation_context, strict)

    if mode == Mode.GENAI_STRUCTURED_OUTPUTS:
        return cls.parse_genai_structured_outputs(
            completion, validation_context, strict
        )

    if mode == Mode.GEMINI_TOOLS:
        return cls.parse_gemini_tools(completion, validation_context, strict)

    if mode == Mode.GENAI_TOOLS:
        return cls.parse_genai_tools(completion, validation_context, strict)

    if mode == Mode.COHERE_JSON_SCHEMA:
        return cls.parse_cohere_json_schema(completion, validation_context, strict)

    if mode == Mode.WRITER_TOOLS:
        return cls.parse_writer_tools(completion, validation_context, strict)

    if not completion.choices:
        # This helps catch errors from OpenRouter
        if hasattr(completion, "error"):
            raise ValueError(completion.error)

        raise ValueError("No completion choices found")

    if completion.choices[0].finish_reason == "length":
        raise IncompleteOutputException(last_completion=completion)

    if mode == Mode.FUNCTIONS:
        Mode.warn_mode_functions_deprecation()
        return cls.parse_functions(completion, validation_context, strict)

    if mode == Mode.MISTRAL_STRUCTURED_OUTPUTS:
        return cls.parse_mistral_structured_outputs(
            completion, validation_context, strict
        )

    if mode in {
        Mode.TOOLS,
        Mode.MISTRAL_TOOLS,
        Mode.TOOLS_STRICT,
        Mode.CEREBRAS_TOOLS,
        Mode.FIREWORKS_TOOLS,
    }:
        return cls.parse_tools(completion, validation_context, strict)

    if mode in {
        Mode.JSON,
        Mode.JSON_SCHEMA,
        Mode.MD_JSON,
        Mode.JSON_O1,
        Mode.CEREBRAS_JSON,
        Mode.FIREWORKS_JSON,
        Mode.PERPLEXITY_JSON,
        Mode.OPENROUTER_STRUCTURED_OUTPUTS,
    }:
        return cls.parse_json(completion, validation_context, strict)

    raise ValueError(f"Invalid patch mode: {mode}")

openai_schema()

以 OpenAI 的 jsonschema 格式返回 schema

注意

重要的是添加一个 docstring 来描述如何最好地使用此类,它将包含在描述属性中并成为提示的一部分。

返回值

名称 类型 描述
model_json_schema dict

以 OpenAI 的 jsonschema 格式表示的字典

源代码位于 instructor/function_calls.py
@classproperty
def openai_schema(cls) -> dict[str, Any]:
    """
    Return the schema in the format of OpenAI's schema as jsonschema

    Note:
        Its important to add a docstring to describe how to best use this class, it will be included in the description attribute and be part of the prompt.

    Returns:
        model_json_schema (dict): A dictionary in the format of OpenAI's schema as jsonschema
    """
    schema = cls.model_json_schema()
    docstring = parse(cls.__doc__ or "")
    parameters = {
        k: v for k, v in schema.items() if k not in ("title", "description")
    }
    for param in docstring.params:
        if (name := param.arg_name) in parameters["properties"] and (
            description := param.description
        ):
            if "description" not in parameters["properties"][name]:
                parameters["properties"][name]["description"] = description

    parameters["required"] = sorted(
        k for k, v in parameters["properties"].items() if "default" not in v
    )

    if "description" not in schema:
        if docstring.short_description:
            schema["description"] = docstring.short_description
        else:
            schema["description"] = (
                f"Correctly extracted `{cls.__name__}` with all "
                f"the required parameters with correct types"
            )

    return {
        "name": schema["title"],
        "description": schema["description"],
        "parameters": parameters,
    }

parse_json(completion, validation_context=None, strict=None) classmethod

使用优化的提取和验证解析 JSON 模式响应。

源代码位于 instructor/function_calls.py
@classmethod
def parse_json(
    cls: type[BaseModel],
    completion: ChatCompletion,
    validation_context: Optional[dict[str, Any]] = None,
    strict: Optional[bool] = None,
) -> BaseModel:
    """Parse JSON mode responses using the optimized extraction and validation."""
    # Check for incomplete output
    _handle_incomplete_output(completion)

    # Extract text from the response
    message = _extract_text_content(completion)
    if not message:
        # Fallback for OpenAI format if _extract_text_content doesn't handle it
        message = completion.choices[0].message.content or ""

    # Extract JSON from the text
    json_content = extract_json_from_codeblock(message)

    # Validate the model from the JSON
    return _validate_model_from_json(cls, json_content, validation_context, strict)

openai_schema(cls)

包装 Pydantic 模型类以添加 OpenAISchema 功能。

源代码位于 instructor/function_calls.py
def openai_schema(cls: type[BaseModel]) -> OpenAISchema:
    """
    Wrap a Pydantic model class to add OpenAISchema functionality.
    """
    if not issubclass(cls, BaseModel):
        raise TypeError("Class must be a subclass of pydantic.BaseModel")

    # Create the wrapped model
    schema = wraps(cls, updated=())(
        create_model(
            cls.__name__ if hasattr(cls, "__name__") else str(cls),
            __base__=(cls, OpenAISchema),
        )
    )

    return cast(OpenAISchema, schema)