跳到内容

使用 OpenAI 的 GPT-4 Vision 模型从图像中提取表格

首先,我们定义一个自定义类型 MarkdownDataFrame,用于处理 markdown 格式的 pandas DataFrame。该类型使用 Python 的 AnnotatedInstanceOf 类型,以及装饰器 BeforeValidatorPlainSerializer 来处理和序列化数据。

定义 Table 类

Table 类对于组织提取的数据至关重要。它包含一个标题和一个 dataframe,作为 markdown 表格进行处理。由于大部分复杂性由 MarkdownDataFrame 类型处理,因此 Table 类非常简单直接!

这需要额外的依赖项 pip install pandas tabulate

from openai import OpenAI
from io import StringIO
from typing import Annotated, Any, List
from pydantic import (
    BaseModel,
    BeforeValidator,
    PlainSerializer,
    InstanceOf,
    WithJsonSchema,
)
import instructor
import pandas as pd
from rich.console import Console

console = Console()
client = instructor.from_openai(
    client=OpenAI(),
    mode=instructor.Mode.TOOLS,
)


def md_to_df(data: Any) -> Any:
    if isinstance(data, str):
        return (
            pd.read_csv(
                StringIO(data),  # Get rid of whitespaces
                sep="|",
                index_col=1,
            )
            .dropna(axis=1, how="all")
            .iloc[1:]
            .map(lambda x: x.strip())
        )  # type: ignore
    return data


MarkdownDataFrame = Annotated[
    InstanceOf[pd.DataFrame],
    BeforeValidator(md_to_df),
    PlainSerializer(lambda x: x.to_markdown()),
    WithJsonSchema(
        {
            "type": "string",
            "description": """
                The markdown representation of the table,
                each one should be tidy, do not try to join tables
                that should be seperate""",
        }
    ),
]


class Table(BaseModel):
    caption: str
    dataframe: MarkdownDataFrame


class MultipleTables(BaseModel):
    tables: List[Table]


example = MultipleTables(
    tables=[
        Table(
            caption="This is a caption",
            dataframe=pd.DataFrame(
                {
                    "Chart A": [10, 40],
                    "Chart B": [20, 50],
                    "Chart C": [30, 60],
                }
            ),
        )
    ]
)


def extract(url: str) -> MultipleTables:
    return client.chat.completions.create(
        model="gpt-4-turbo",
        max_tokens=4000,
        response_model=MultipleTables,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {"url": url},
                    },
                    {
                        "type": "text",
                        "text": """
                            First, analyze the image to determine the most appropriate headers for the tables.
                            Generate a descriptive h1 for the overall image, followed by a brief summary of the data it contains.
                            For each identified table, create an informative h2 title and a concise description of its contents.
                            Finally, output the markdown representation of each table.
                            Make sure to escape the markdown table properly, and make sure to include the caption and the dataframe.
                            including escaping all the newlines and quotes. Only return a markdown table in dataframe, nothing else.
                        """,
                    },
                ],
            }
        ],
    )


urls = [
    "https://a.storyblok.com/f/47007/2400x1260/f816b031cb/uk-ireland-in-three-charts_chart_a.png/m/2880x0",
    "https://a.storyblok.com/f/47007/2400x2000/bf383abc3c/231031_uk-ireland-in-three-charts_table_v01_b.png/m/2880x0",
]

for url in urls:
    for table in extract(url).tables:
        console.print(table.caption, "\n", table.dataframe)