Ascend-SACT/PaddleOCR-VL-0.9B
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

1、模型概述及场景

PaddleOCR-VL 作为一款开源多模态模型,凭借其强大的文本检测、识别与图文关联分析能力,在工业质检、智能文档审核、车载多模态OCR 等场景中已取得广泛应用。然而,PaddleOCR-VL 与昇腾服务器存在兼容性问题,导致其无法在昇腾服务器上直接部署与运行。为此,本文重点完成 PaddleOCR-VL 在昇腾A3平台的适配工作,实现模型在昇腾服务器上的稳定部署与推理,全面支持文本检测、文字识别及图文语义关联等核心多模态功能。同时,通过容器化封装,构建可一键启动、跨环境复用的标准化部署方案,显著提升模型在实际业务场景中的可落地性。

2、准备运行环境

组件版本
Atlas 900 A3910C
Ascend HDK25.2.3
操作系统openEuler 22.03 (LTS-SP4)
CANN8.3.RC1

3、模型分析

PaddleOCR-VL其实就是两个模型组成,一个版面分析模型PP-DocLayoutV2和一个VLM模型PaddleOCR-VL-0.9B,核心是PaddleOCR-VL-0.9B。

进一步的,PP-DocLayoutV2是在基础版面检测模型PP-DocLayout_plus-L的基础上级联一个含6层Transformer的指针网络;PaddleOCR-VL-0.9B基于一个动态分辨率视觉编码器与ERNIE-4.5-0.3B语言模型。

4、快速打通流程

环境搭建

  1. 基础环境安装

权重下载:PaddleOCR-VL · 模型库

vllm版本:0.11.1

git clone https://github.com/vllm-project/vllm.git

git checkout v0.11.1

git clone https://github.com/vllm-project/vllm-ascend.git

git reset --hard 6664a4e

  1. paddle配套环境安装

python -m pip install paddlepaddle==3.2.0

python -m pip install -U "paddleocr[doc-parser]"

  1. 其他需要的依赖安装

pip install safetensors

可能缺少opencv组件:

apt-get install -y libgl1 libglib2.0-0

注意:不同的操作系统命令不同,根据当前操作系统进行切换,openEuler操作系统对应的命令是:

yum install -y mesa-libGL yum install -y glib2

适配

  1. 修改‎vllm/model_executor/models/ernie45.py文件

第26行添加

from vllm.compilation.decorators import support_torch_compile

在32行添加

@support_torch_compile(

set dynamic_arg_dims to support mrope

dynamic_arg_dims={

"input_ids": 0,

"positions": -1,

"intermediate_tensors": 0,

"inputs_embeds": 0,

}

)

  1. 修改vllm/model_executor/models/registry.py‎文件

在Ovis2_5下面一行增加注册相关代码

"PaddleOCRVLForConditionalGeneration": (

"paddleocr_vl",

"PaddleOCRVLForConditionalGeneration",

),

  1. 在‎vllm/model_executor/models下新建paddleocr_vl.py

文件见附录

模型推理

  1. 启动vllm服务

vllm serve PaddleOCR-VL-0.9B
--trust-remote-code
--max-num-batched-tokens 16384
--no-enable-prefix-caching
--mm-processor-cache-gb 0

  1. 调用Python推理脚本:

from paddleocr import PaddleOCRVL

pipeline = PaddleOCRVL(

layout_detection_model_dir='[YOUR_PATH]/PaddleOCR-VL-0.9B/PP-DocLayoutV2',

use_doc_orientation_classify=False,

vl_rec_backend="vllm-server", vl_rec_server_url="http://127.0.0.1:8000/v1",

use_doc_unwarping=False)

output = pipeline.predict("./image_path")

for res in output:

res.print()

res.save_to_json(save_path="output")

res.save_to_markdown(save_path="output") 效果展示

附录

# SPDX-License-Identifier: Apache-2.0

# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright (c) 2025 PaddlePaddle PaddlePaddle。保留所有权利。

#

# 根据Apache许可证2.0版("许可证")获得许可;

# 除非遵守许可证,否则您不得使用此文件。

# 您可以在以下网址获取许可证副本:

#

# http://www.apache.org/licenses/LICENSE-2.0

#

# 除非适用法律要求或书面同意,否则根据许可证分发的软件

# 按"原样"分发,不附带任何明示或暗示的担保或条件。

# 有关许可证下权限和限制的具体规定,请参阅许可证。

import math

from collections.abc import Iterable, Mapping, Sequence

from functools import partial

from typing import Annotated, Literal

import numpy as np

import torch

import torch.nn as nn

import torch.nn.functional as F

from einops import rearrange, repeat

from transformers import BatchFeature, PretrainedConfig

from transformers.activations import GELUActivation

from transformers.modeling_outputs import (

BaseModelOutputWithPooling,

)

from transformers.utils import torch_int

from vllm.attention.backends.registry import _Backend

from vllm.attention.layer import (

check_upstream_fa_availability,

maybe_get_vit_flash_attn_backend,

)

from vllm.attention.ops.vit_attn_wrappers import (

vit_flash_attn_wrapper,

vit_xformers_attn_wrapper,

)

from vllm.config import VllmConfig

from vllm.config.multimodal import BaseDummyOptions

from vllm.distributed import parallel_state

from vllm.distributed import utils as dist_utils

from vllm.model_executor.layers.activation import get_act_fn

from vllm.model_executor.layers.linear import (

ColumnParallelLinear,

QKVParallelLinear,

RowParallelLinear,

)

from vllm.model_executor.layers.quantization import QuantizationConfig

from vllm.model_executor.layers.rotary_embedding.common import (

dispatch_rotary_emb_function,

)

from vllm.model_executor.model_loader.weight_utils import (

default_weight_loader,

maybe_remap_kv_scale_name,

)

from vllm.multimodal import MULTIMODAL_REGISTRY

from vllm.multimodal.inputs import (

MultiModalDataDict,

MultiModalFieldConfig,

MultiModalKwargs,

)

from vllm.multimodal.parse import (

ImageProcessorItems,

ImageSize,

MultiModalDataItems,

)

from vllm.multimodal.processing import (

BaseMultiModalProcessor,

BaseProcessingInfo,

PromptReplacement,

PromptUpdate,

)

from vllm.multimodal.profiling import BaseDummyInputsBuilder

from vllm.sequence import IntermediateTensors

from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .ernie45 import Ernie4_5ForCausalLM

from .interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal

from .utils import (

AutoWeightsLoader,

PPMissingLayer,

WeightsMapper,

is_pp_missing_parameter,

maybe_prefix,

)

from .vision import get_vit_attn_backend

def smart_resize(

height: int,

width: int,

factor: int = 28,

min_pixels: int = 28 * 28 * 130,

max_pixels: int = 28 * 28 * 1280,

):

"""智能调整图像尺寸,使其满足以下条件:

1. 高度和宽度两个维度均能被'factor'整除。

2. 像素总数在['min_pixels','max_pixels']范围内。

3. 尽可能保持图像的原始宽高比。

"""

if height < factor:

    width = round((width * factor) / height)

    height = factor

if width < factor:

    height = round((height * factor) / width)

    width = factor

if max(height, width) / min(height, width) > 200:

    raise ValueError(

        f"图像的宽高比绝对值必须小于200,当前为{max(height, width) / min(height, width)}"

    )

h_bar = round(height / factor) * factor

w_bar = round(width / factor) * factor

if h_bar * w_bar > max_pixels:

    beta = math.sqrt((height * width) / max_pixels)

    h_bar = math.floor(height / beta / factor) * factor

    w_bar = math.floor(width / beta / factor) * factor

elif h_bar * w_bar < min_pixels:

    beta = math.sqrt(min_pixels / (height * width))

    h_bar = math.ceil(height * beta / factor) * factor

    w_bar = math.ceil(width * beta / factor) * factor

return h_bar, w_bar

def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:

if not interleaved:

    x1, x2 = x.chunk(2, dim=-1)

    return torch.cat((-x2, x1), dim=-1)

x1, x2 = x[..., ::2], x[..., 1::2]

return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)

def apply_rotary_emb_torch(

x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False

) -> torch.Tensor:

"""

x: (batch_size, seqlen, nheads, headdim)

cos, sin: (seqlen, rotary_dim / 2) 或 (batch_size, seqlen, rotary_dim / 2)

"""

ro_dim = cos.shape[-1] * 2

assert ro_dim <= x.shape[-1]

cos = repeat(

    cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"

)

sin = repeat(

    sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"

)

return torch.cat(

    [

        x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,

        x[..., ro_dim:],

    ],

    dim=-1,

)

def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:

rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)

t_ = t.float()

cos = freqs.cos()

sin = freqs.sin()

output = rotary_emb_function(t_, cos, sin).type_as(t)

return output

class PaddleOCRVLProcessingInfo(BaseProcessingInfo):

def get_hf_config(self):

    return self.ctx.get_hf_config()

def get_hf_processor(self, **kwargs: object):

    return self.ctx.get_hf_processor(** kwargs)

def get_image_processor(self, **kwargs: object):

    return self.get_hf_processor(** kwargs).image_processor

def get_supported_mm_limits(self):

    return {"image": None}

def get_num_image_tokens(

    self,

    *,

    image_width: int,

    image_height: int,

    image_processor,

) -> int:

    if image_processor is None:

        image_processor = self.get_image_processor()

    do_resize = True

    hf_config = self.get_hf_config()

    vision_config = hf_config.vision_config

    patch_size = vision_config.patch_size

    merge_size = vision_config.spatial_merge_size

    if do_resize:

        resized_height, resized_width = smart_resize(

            height=image_height,

            width=image_width,

            factor=patch_size * merge_size,

            min_pixels=image_processor.min_pixels,

            max_pixels=image_processor.max_pixels,

        )

        preprocessed_size = ImageSize(width=resized_width, height=resized_height)

    else:

        preprocessed_size = ImageSize(width=image_width, height=image_height)

    grid_t = 1

    grid_h = preprocessed_size.height // patch_size

    grid_w = preprocessed_size.width // patch_size

    num_patches = grid_t * grid_h * grid_w

    num_image_tokens = num_patches // (merge_size ** 2)

    return num_image_tokens

def get_image_size_with_most_features(self) -> ImageSize:

    hf_config = self.get_hf_config()

    image_size = hf_config.vision_config.image_size

    return ImageSize(height=image_size, width=image_size)

class PaddleOCRVLDummyInputsBuilder(BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]):

def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:

    num_images = mm_counts.get("image", 0)

    processor = self.info.get_hf_processor()

    image_token = processor.image_token

    return image_token * num_images

def get_dummy_mm_data(

    self,

    seq_len: int,

    mm_counts: Mapping[str, int],

    mm_options: Mapping[str, BaseDummyOptions] | None = None,

) -> MultiModalDataDict:

    num_images = mm_counts.get("image", 0)

    max_image_size = self.info.get_image_size_with_most_features()

    image_overrides = mm_options.get("image") if mm_options else None

    return {

        "image": self._get_dummy_images(

            width=max_image_size.width,

            height=max_image_size.height,

            num_images=num_images,

            overrides=image_overrides,

        )

    }

class PaddleOCRVLMultiModalProcessor(

BaseMultiModalProcessor[PaddleOCRVLProcessingInfo]

):

def _call_hf_processor(

    self,

    prompt: str,

    mm_data: Mapping[str, object],

    mm_kwargs: Mapping[str, object],

    tok_kwargs: Mapping[str, object],

) -> BatchFeature:

    if mm_data:

        processed_outputs = self.info.ctx.call_hf_processor(

            self.info.get_hf_processor(**mm_kwargs),

            dict(text=prompt, ** mm_data),

            dict(**mm_kwargs, **tok_kwargs),

        )

        num_patches_per_image = processed_outputs["image_grid_thw"].prod(-1)

        processed_outputs["pixel_values"] = processed_outputs["pixel_values"].split(

            num_patches_per_image.tolist()

        )

    else:

        tokenizer = self.info.get_tokenizer()

        processed_outputs = tokenizer(

            prompt, add_special_tokens=True, return_tensors="pt"

        )

    return processed_outputs

def _get_mm_fields_config(

    self,

    hf_inputs: BatchFeature,

    hf_processor_mm_kwargs: Mapping[str, object],

) -> Mapping[str, MultiModalFieldConfig]:

    return dict(

        pixel_values=MultiModalFieldConfig.batched("image"),

        image_grid_thw=MultiModalFieldConfig.batched("image"),

    )

def _get_prompt_updates(

    self,

    mm_items: MultiModalDataItems,

    hf_processor_mm_kwargs: Mapping[str, object],

    out_mm_kwargs: MultiModalKwargs,

) -> Sequence[PromptUpdate]:

    image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)

    hf_config = self.info.get_hf_config()

    image_token_id = hf_config.image_token_id

    def get_replacement(item_idx: int, image_processor):

        images = mm_items.get_items("image", ImageProcessorItems)

        image_size = images.get_image_size(item_idx)

        num_image_tokens = self.info.get_num_image_tokens(

            image_width=image_size.width,

            image_height=image_size.height,

            image_processor=image_processor,

        )

        return [image_token_id] * num_image_tokens

    return [

        PromptReplacement(

            modality="image",

            target=[image_token_id],

            replacement=partial(get_replacement, image_processor=image_processor),

        ),

    ]

class Projector(nn.Module):

def __init__(

    self,

    text_config: PretrainedConfig,

    vision_config: PretrainedConfig,

    prefix: str = "",

):

    super().__init__()

    self.text_config = text_config

    self.vision_config = vision_config

    self.merge_kernel_size = (2, 2)

    self.hidden_size = (

        self.vision_config.hidden_size

        * self.merge_kernel_size[0]

        * self.merge_kernel_size[1]

    )

    self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05)

    self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)

    self.act = GELUActivation()

    self.linear_2 = nn.Linear(

        self.hidden_size, self.text_config.hidden_size, bias=True

    )

def forward(

    self,

    image_features: torch.Tensor,

    image_grid_thw: torch.Tensor,

) -> torch.Tensor:

    m1, m2 = self.merge_kernel_size

    if isinstance(image_features, (list, tuple)):

        processed_features = list()

        for image_feature, image_grid in zip(image_features, image_grid_thw):

            image_feature = self.pre_norm(image_feature)

            t, h, w = image_grid

            image_feature = rearrange(

                image_feature,

                "(t h p1 w p2) d -> (t h w) (p1 p2 d)",

                t=t,

                h=h // m1,

                p1=m1,

                w=w // m2,

                p2=m2,

            )

            hidden_states = self.linear_1(image_feature)

            hidden_states = self.act(hidden_states)

            hidden_states = self.linear_2(hidden_states)

            processed_features.append(hidden_states)

        return processed_features

    dims = image_features.shape[:-1]

    dim = image_features.shape[-1]

    image_features = image_features.view(np.prod(dims), dim)

    hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)

    hidden_states = self.linear_1(hidden_states)

    hidden_states = self.act(hidden_states)

    hidden_states = self.linear_2(hidden_states)

    return hidden_states.view(*dims, -1)

class PaddleOCRImagePixelInputs(TensorSchema):

type: Literal["pixel_values"]

pixel_values: Annotated[

    torch.Tensor,

    TensorShape("bn", "p", 3, "patch_size", "patch_size", dynamic_dims={"p"}),

]

image_grid_thw: Annotated[

    torch.Tensor,

    TensorShape("bn", 3),

]

class SiglipVisionEmbeddings(nn.Module):

def __init__(self, config: PretrainedConfig):

    super().__init__()

    self.config = config

    self.embed_dim = config.hidden_size

    self.image_size = config.image_size

    self.patch_size = config.patch_size

    self.patch_embedding = nn.Conv2d(

        in_channels=config.num_channels,

        out_channels=self.embed_dim,

        kernel_size=self.patch_size,

        stride=self.patch_size,

        padding="valid",

    )

    self.num_patches = (self.image_size // self.patch_size) ** 2

    self.num_positions = self.num_patches

    self.cache_position_embedding = dict()

    self.cache_position_count = dict()

    self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)

    self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)

    self.register_buffer(

        "position_ids",

        torch.arange(self.num_positions).expand((1, -1)),

        persistent=False,

    )

def interpolate_pos_encoding(

    self,

    embeddings: torch.Tensor,

    height: int,

    width: int,

    is_after_patchify: bool = False,

) -> torch.Tensor:

    num_positions = self.position_embedding.weight.shape[0]

    patch_pos_embed = self.position_embedding.weight.unsqueeze(0)

    dim = embeddings.shape[-1]

    if is_after_patchify:

        new_height = height

        new_width = width

    else:

        new_height = height // self.patch_size

        new_width = width // self.patch_size

    sqrt_num_positions = torch_int(num_positions ** 0.5)

    patch_pos_embed = patch_pos_embed.reshape(

        1, sqrt_num_positions, sqrt_num_positions, dim

    )

    patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

    patch_pos_embed = nn.functional.interpolate(

        patch_pos_embed,

        size=(new_height, new_width),

        mode="bilinear",

        align_corners=False,

    )

    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

    return patch_pos_embed

def fetch_position_embedding_lfu_cache(

    self, embeddings: torch.Tensor, h: int, w: int, max_cache: int = 20

):

    grid = (h, w)

    if grid in self.cache_position_embedding:

        self.cache_position_count[grid] += 1

        return self.cache_position_embedding[grid]

    if len(self.cache_position_embedding) >= max_cache:

        min_hit_grid = min(

            self.cache_position_count,

            key=self.cache_position_count.get,

        )

        self.cache_position_count.pop(min_hit_grid)

        self.cache_position_embedding.pop(min_hit_grid)

    position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)

    self.cache_position_count[grid] = 1

    self.cache_position_embedding[grid] = position_embedding

    return position_embedding

def forward(

    self,

    pixel_values: torch.FloatTensor,

    position_ids: torch.Tensor | None = None,

    image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]

    | None = None,

    interpolate_pos_encoding=False,

) -> torch.Tensor:

    if pixel_values.dim() == 4:

        pixel_values = pixel_values.unsqueeze(0)

    if pixel_values.dim() == 5:

        if position_ids is None:

            raise ValueError(

                "当pixel_values的维度为5时,position_ids不能为None。"

            )

        (

            batch_size,

            squence_len,

            channel,

            height,

            width,

        ) = pixel_values.shape

        target_dtype = self.patch_embedding.weight.dtype

        pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")

        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))

        embeddings = patch_embeds.flatten(-2).squeeze(-1)

        if interpolate_pos_encoding and image_grid_thw is not None:

            start = 0

            tmp_embeddings = list()

            for image_grid in image_grid_thw:

                t, h, w = image_grid

                end = start + t * h * w

                image_embeddings = embeddings[start:end, :]

                position_embedding = (

                    self.interpolate_pos_encoding(image_embeddings, h, w, True)

                    .squeeze(0)

                    .repeat(t, 1)

                )

                image_embeddings = image_embeddings + position_embedding

                tmp_embeddings.append(image_embeddings)

                start = end

            embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)

        else:

            embeddings = embeddings + self.packing_position_embedding(position_ids)

        return embeddings

    else:

        raise ValueError(

            "不支持的pixel_values维度:"

            f"{pixel_values.dim()}。预期维度为4或5。"

        )

def all_gather_interleave(local_tensor: torch.Tensor, hidden_size: int, tp_size: int):

"""跨模型并行组对输入张量进行交错式全收集。"""

import torch.distributed as dist

gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]

dist.all_gather(

    gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group

)

gathered_tensors_split = [

    torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors

]

ordered_tensors = [

    tensor for pair in zip(*gathered_tensors_split) for tensor in pair

]

result_tensor = torch.cat(ordered_tensors, dim=-1)

return result_tensor

class SiglipAttention(nn.Module):

"""基于Qwen2.5-VisionAttention改编的SigLIP视觉注意力模块。"""

def __init__(

    self,

    *,

    embed_dim: int,

    num_heads: int,

    projection_size: int,

    quant_config: QuantizationConfig | None = None,

    prefix: str = "",

    attn_backend: _Backend = _Backend.TORCH_SDPA,

    attn_backend_override: _Backend | None = None,

    use_upstream_fa: bool = False,

) -> None:

    super().__init__()

    self.tp_size = parallel_state.get_tensor_model_parallel_world_size()

    self.tp_rank = parallel_state.get_tensor_model_parallel_rank()

    self.hidden_size_per_attention_head = dist_utils.divide(

        projection_size, num_heads

    )

    self.num_attention_heads_per_partition = dist_utils.divide(

        num_heads, self.tp_size

    )

    self.qkv_proj = QKVParallelLinear(

        hidden_size=embed_dim,

        head_size=self.hidden_size_per_attention_head,

        total_num_heads=num_heads,

        total_num_kv_heads=num_heads,

        bias=True,

        quant_config=quant_config,

        prefix=f"{prefix}.qkv_proj",

    )

    self.out_proj = RowParallelLinear(

        input_size=projection_size,

        output_size=embed_dim,

        quant_config=quant_config,

        prefix=f"{prefix}.out_proj",

    )

    self.attn_backend = attn_backend

    self.use_upstream_fa = use_upstream_fa

    self.attn_backend, self.flash_attn_varlen_func = (

        maybe_get_vit_flash_attn_backend(

            self.attn_backend,

            self.use_upstream_fa,

            attn_backend_override=attn_backend_override,

        )

    )

    self.is_flash_attn_backend = self.attn_backend in {

        _Backend.FLASH_ATTN,

        _Backend.ROCM_AITER_FA,

    }

def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:

    seq_len, bs, _ = qkv.shape

    if self.tp_size > 1:

        qkv = all_gather_interleave(qkv, self.qkv_proj.hidden_size, self.tp_size)

    q, k, v = qkv.chunk(3, dim=2)

    if self.tp_size > 1:

        splitter = partial(

            dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size

        )

        q = splitter(q)[self.tp_rank]

        k = splitter(k)[self.tp_rank]

        v = splitter(v)[self.tp_rank]

    new_shape = (

        seq_len,

        bs,

        self.num_attention_heads_per_partition,

        self.hidden_size_per_attention_head,

    )

    q, k, v = (x.view(*new_shape) for x in (q, k, v))

    return q, k, v

def forward(

    self,

    hidden_states: torch.Tensor,

    *,

    cu_seqlens: torch.Tensor,

    rotary_pos_emb: torch.Tensor | None,

    max_seqlen: torch.Tensor | None,

    seqlens: torch.Tensor | None,

) -> torch.Tensor:

    batch_size, _, _ = hidden_states.shape

    x = rearrange(hidden_states, "b s d -> s b d")

    x, _ = self.qkv_proj(x)

    q, k, v = self.split_qkv(x)

    q, k, v = (rearrange(t, "s b h d -> b s h d") for t in (q, k, v))

    if rotary_pos_emb is not None:

        qk_concat = torch.cat([q, k], dim=0)

        qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)

        q, k = torch.chunk(qk_rotated, 2, dim=0)

    if self.is_flash_attn_backend:

        if max_seqlen is None:

            raise ValueError("Flash attention后端需要max_seqlen参数。")

        context_layer = vit_flash_attn_wrapper(

            q,

            k,

            v,

            cu_seqlens,

            max_seqlen,

            batch_size,

            self.attn_backend == _Backend.ROCM_AITER_FA,

            self.use_upstream_fa,

        )

    elif self.attn_backend == _Backend.TORCH_SDPA:

        outputs = []

        for i in range(1, len(cu_seqlens)):

            start_idx = cu_seqlens[i - 1]

            end_idx = cu_seqlens[i]

            q_i = q[:, start_idx:end_idx]

            k_i = k[:, start_idx:end_idx]

            v_i = v[:, start_idx:end_idx]

            q_i, k_i, v_i = (

                rearrange(tensor, "b s h d -> b h s d")

                for tensor in (q_i, k_i, v_i)

            )

output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)

output_i = rearrange(output_i, "b h s d -> b s h d")

outputs.append(output_i)

context_layer = torch.cat(outputs, dim=1)

context_layer = rearrange(

context_layer, "b s h d -> s b (h d)"

).contiguous()

elif self.attn_backend == _Backend.XFORMERS:

if seqlens is None:

raise ValueError("xFormers 注意力后端需要 seqlens 张量。")

context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)

else:

raise RuntimeError(

f"PaddleOCR-VL 目前不支持 {self.attn_backend} 后端。"

)

output, _ = self.out_proj(context_layer)

output = rearrange(output, "s b d -> b s d")

return output

class SigLIPRotaryEmbedding(nn.Module):

def init(self, dim: int, theta: float = 10000.0) -> None:

super().init()

self.dim = dim

self.theta = theta

self.rope_init()

def rope_init(self):

inv_freq = 1.0 / (

self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)

)

self.register_buffer("inv_freq", inv_freq, persistent=False)

def forward(self, seqlen: int) -> torch.Tensor:

seq = torch.arange(

seqlen,

device=self.inv_freq.device,

dtype=self.inv_freq.dtype,

)

freqs = torch.outer(seq, self.inv_freq)

return freqs

class SiglipMLP(nn.Module):

def init(

self,

config: PretrainedConfig,

quant_config: QuantizationConfig | None = None,

prefix: str = "",

) -> None:

super().init()

self.config = config

self.activation_fn = get_act_fn(config.hidden_act)

针对 BNB 和 torchao 量化的特殊处理

if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:

quantizable = True

else:

对于其他量化方式,我们要求隐藏层大小是 64 的倍数

quantizable = (

config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0

)

self.fc1 = ColumnParallelLinear(

config.hidden_size,

config.intermediate_size,

quant_config=quant_config if quantizable else None,

prefix=f"{prefix}.fc1",

)

self.fc2 = RowParallelLinear(

config.intermediate_size,

config.hidden_size,

quant_config=quant_config if quantizable else None,

prefix=f"{prefix}.fc2",

)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

hidden_states, _ = self.fc1(hidden_states)

hidden_states = self.activation_fn(hidden_states)

hidden_states, _ = self.fc2(hidden_states)

return hidden_states

class SiglipEncoderLayer(nn.Module):

def init(

self,

config: PretrainedConfig,

quant_config: QuantizationConfig | None = None,

prefix: str = "",

*,

attn_backend: _Backend = _Backend.TORCH_SDPA,

attn_backend_override: _Backend | None = None,

use_upstream_fa: bool = False,

):

super().init()

self.embed_dim = config.hidden_size

self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

self.self_attn = SiglipAttention(

embed_dim=config.hidden_size,

num_heads=config.num_attention_heads,

projection_size=config.hidden_size,

quant_config=quant_config,

prefix=f"{prefix}.self_attn",

attn_backend=attn_backend,

attn_backend_override=attn_backend_override,

use_upstream_fa=use_upstream_fa,

)

self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

self.mlp = SiglipMLP(

config,

quant_config=quant_config,

prefix=f"{prefix}.mlp",

)

def forward(

self,

hidden_states: torch.Tensor,

*,

cu_seqlens: torch.Tensor,

rotary_pos_emb: torch.Tensor | None,

max_seqlen: torch.Tensor | None,

seqlens: torch.Tensor | None,

) -> torch.Tensor:

residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)

hidden_states = self.self_attn(

hidden_states=hidden_states,

cu_seqlens=cu_seqlens,

rotary_pos_emb=rotary_pos_emb,

max_seqlen=max_seqlen,

seqlens=seqlens,

)

hidden_states = residual + hidden_states

residual = hidden_states

hidden_states = self.layer_norm2(hidden_states)

hidden_states = self.mlp(hidden_states)

hidden_states = residual + hidden_states

return hidden_states

class SiglipEncoder(nn.Module):

def init(

self,

config: PretrainedConfig,

quant_config: QuantizationConfig | None = None,

prefix: str = "",

attn_backend_override: _Backend | None = None,

):

super().init()

self.config = config

embed_dim = config.hidden_size

num_heads = config.num_attention_heads

head_dim = embed_dim // num_heads

self.attn_backend = get_vit_attn_backend(

head_size=head_dim,

dtype=torch.get_default_dtype(),

attn_backend_override=attn_backend_override,

)

self.use_upstream_fa = False

if self.attn_backend not in {

_Backend.FLASH_ATTN,

_Backend.ROCM_AITER_FA,

} and check_upstream_fa_availability(torch.get_default_dtype()):

self.attn_backend = _Backend.FLASH_ATTN

self.use_upstream_fa = True

if self.attn_backend not in {

_Backend.FLASH_ATTN,

_Backend.TORCH_SDPA,

_Backend.XFORMERS,

_Backend.ROCM_AITER_FA,

}:

raise RuntimeError(

f"PaddleOCR-VL 目前不支持 {self.attn_backend} 后端。"

)

self.layers = nn.ModuleList(

[

SiglipEncoderLayer(

config,

quant_config=quant_config,

prefix=f"{prefix}.layers.{layer_idx}",

attn_backend=self.attn_backend,

attn_backend_override=attn_backend_override,

use_upstream_fa=self.use_upstream_fa,

)

for layer_idx in range(config.num_hidden_layers)

]

)

self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)

@staticmethod

def flatten_list(image_grid_thw):

tmp_image_grid_thw = list()

for image_grid in image_grid_thw:

if isinstance(image_grid, list):

tmp_image_grid_thw.extend(image_grid)

else:

tmp_image_grid_thw.append(image_grid)

return tmp_image_grid_thw

def forward(

self,

inputs_embeds,

cu_seqlens: torch.Tensor | None = None,

image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]

| None = None,

height_position_ids: torch.Tensor | None = None,

width_position_ids: torch.Tensor | None = None,

) -> torch.Tensor:

device = inputs_embeds.device

hidden_states = inputs_embeds

flatten_image_grid_thw = self.flatten_list(image_grid_thw)

if width_position_ids is None or height_position_ids is None:

split_hids = list()

split_wids = list()

for t, h, w in flatten_image_grid_thw:

image_pids = torch.arange(t * h * w, device=device) % (h * w)

sample_hids = image_pids // w

sample_wids = image_pids % w

split_hids.append(sample_hids)

split_wids.append(sample_wids)

width_position_ids = torch.concat(split_wids, dim=0)

height_position_ids = torch.concat(split_hids, dim=0)

pids = torch.stack(

[height_position_ids, width_position_ids],

dim=-1,

)

max_grid_size = pids.max() + 1

rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)

rotary_pos_emb = rope_emb_max_grid[pids].flatten(1)

if cu_seqlens is None:

raise ValueError("SiglipEncoder 的 cu_seqlens 不能为 None。")

if not isinstance(cu_seqlens, torch.Tensor):

cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)

else:

cu_seqlens = cu_seqlens.to(device=device)

max_seqlen = None

seqlens = None

if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

elif self.attn_backend == _Backend.XFORMERS:

seqlens = cu_seqlens[1:] - cu_seqlens[:-1]

hidden_states = inputs_embeds

for encoder_layer in self.layers:

hidden_states = encoder_layer(

hidden_states,

cu_seqlens=cu_seqlens,

rotary_pos_emb=rotary_pos_emb,

max_seqlen=max_seqlen,

seqlens=seqlens,

)

return hidden_states

class SiglipVisionTransformer(nn.Module):

def init(

self,

config: PretrainedConfig,

quant_config: QuantizationConfig | None = None,

prefix: str = "",

attn_backend_override: _Backend | None = None,

):

super().init()

self.config = config

embed_dim = config.hidden_size

self.embeddings = SiglipVisionEmbeddings(config)

self.encoder = SiglipEncoder(

config,

quant_config=quant_config,

prefix=f"{prefix}.encoder",

attn_backend_override=attn_backend_override,

)

self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

def forward(

self,

pixel_values: torch.Tensor,

interpolate_pos_encoding: bool | None = False,

position_ids: torch.Tensor | None = None,

height_position_ids: torch.Tensor | None = None,

width_position_ids: torch.Tensor | None = None,

cu_seqlens: torch.Tensor | None = None,

image_grid_thw: torch.Tensor | None = None,

) -> torch.Tensor:

hidden_states = self.embeddings(

pixel_values,

interpolate_pos_encoding=interpolate_pos_encoding,

position_ids=position_ids,

image_grid_thw=image_grid_thw,

)

last_hidden_state = self.encoder(

inputs_embeds=hidden_states,

cu_seqlens=cu_seqlens,

image_grid_thw=image_grid_thw,

height_position_ids=height_position_ids,

width_position_ids=width_position_ids,

)

last_hidden_state = self.post_layernorm(last_hidden_state)

return last_hidden_state

class SiglipVisionModel(nn.Module):

def init(

self,

config,

quant_config: QuantizationConfig | None = None,

prefix: str = "",

attn_backend_override: _Backend | None = None,

):

super().init()

self.vision_model = SiglipVisionTransformer(

config,

quant_config=quant_config,

prefix=f"{prefix}.vision_model",

attn_backend_override=attn_backend_override,

)

self.quant_config = quant_config

@property

def dtype(self) -> torch.dtype:

return self.vision_model.embeddings.patch_embedding.weight.dtype

@property

def device(self) -> torch.device:

return self.vision_model.embeddings.patch_embedding.weight.device

def get_input_embeddings(self) -> nn.Module:

return self.vision_model.embeddings.patch_embedding

def forward(

self,

pixel_values,

interpolate_pos_encoding: bool = False,

position_ids: torch.Tensor | None = None,

image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]

| None = None,

cu_seqlens: torch.Tensor | None = None,

) -> BaseModelOutputWithPooling:

return self.vision_model(

pixel_values=pixel_values,

interpolate_pos_encoding=interpolate_pos_encoding,

position_ids=position_ids,

image_grid_thw=image_grid_thw,

cu_seqlens=cu_seqlens,

)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

stacked_params_mapping = [

("qkv_proj", "q_proj", "q"),

("qkv_proj", "k_proj", "k"),

("qkv_proj", "v_proj", "v"),

]

params_dict = dict(self.named_parameters(remove_duplicate=False))

loaded_params: set[str] = set()

for name, loaded_weight in weights:

if "rotary_emb.inv_freq" in name:

continue

if "head.attention" in name or "head.layernorm" in name:

continue

if "head.mlp" in name or "head.probe" in name:

continue

if self.quant_config is not None and (

scale_name := self.quant_config.get_cache_scale(name)

):

param = params_dict[scale_name]

weight_loader = getattr(

param,

"weight_loader",

default_weight_loader,

)

loaded_weight = (

loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]

)

weight_loader(param, loaded_weight)

loaded_params.add(scale_name)

continue

for (

param_name,

weight_name,

shard_id,

) in stacked_params_mapping:

if weight_name not in name:

continue

name = name.replace(weight_name, param_name)

if name.endswith(".bias") and name not in params_dict:

continue

if is_pp_missing_parameter(name, self):

continue

param = params_dict[name]

weight_loader = param.weight_loader

weight_loader(param, loaded_weight, shard_id)

break

else:

if name.endswith(".bias") and name not in params_dict:

continue

name = maybe_remap_kv_scale_name(name, params_dict)

if name is None:

continue

if is_pp_missing_parameter(name, self):

continue

param = params_dict[name]

weight_loader = getattr(

param,

"weight_loader",

default_weight_loader,

)

weight_loader(param, loaded_weight)

loaded_params.add(name)

return loaded_params

@MULTIMODAL_REGISTRY.register_processor(

PaddleOCRVLMultiModalProcessor,

info=PaddleOCRVLProcessingInfo,

dummy_inputs=PaddleOCRVLDummyInputsBuilder,

)

class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsMRoPE):

merge_by_field_config = True

hf_to_vllm_mapper = WeightsMapper(

orig_to_new_prefix={

"model.": "language_model.model.",

"lm_head.": "language_model.lm_head.",

}

)

def init(self, *, vllm_config: VllmConfig, prefix: str = ""):

super().init()

config = vllm_config.model_config.hf_config

quant_config = vllm_config.quant_config

multimodal_config = vllm_config.model_config.multimodal_config

self.config = config

self.multimodal_config = multimodal_config

attn_backend_override = (

multimodal_config.mm_encoder_attn_backend

if multimodal_config is not None

else None

)

self.visual = SiglipVisionModel(

config=config.vision_config,

quant_config=quant_config,

prefix=maybe_prefix(prefix, "visual"),

attn_backend_override=attn_backend_override,

)

self.mlp_AR = Projector(config, config.vision_config)

self.language_model = Ernie4_5ForCausalLM(

vllm_config=vllm_config,

prefix=maybe_prefix(prefix, "language_model"),

)

for layer in self.language_model.model.layers:

if not isinstance(layer, PPMissingLayer):

layer.self_attn.rotary_emb.is_neox_style = True

self.make_empty_intermediate_tensors = (

self.language_model.make_empty_intermediate_tensors

)

def compute_logits(

self,

hidden_states: torch.Tensor,

) -> torch.Tensor | None:

return self.language_model.compute_logits(hidden_states)

def get_mrope_input_positions(

self,

input_tokens: list[int],

hf_config: PretrainedConfig,

image_grid_thw: list[list[int]] | torch.Tensor,

video_grid_thw: list[list[int]] | torch.Tensor,

second_per_grid_ts: list[float],

context_len: int = 0,

seq_len: int | None = None,

audio_feature_lengths: torch.Tensor | None = None,

use_audio_in_video: bool = False,

) -> tuple[torch.Tensor, int]:

"""获取 mrope 输入位置和 delta 值。"""

image_token_id = hf_config.image_token_id

video_token_id = hf_config.video_token_id

vision_start_token_id = hf_config.vision_start_token_id

spatial_merge_size = hf_config.vision_config.spatial_merge_size

tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)

input_tokens_tensor = torch.tensor(input_tokens)

vision_start_indices = torch.argwhere(

input_tokens_tensor == vision_start_token_id

).squeeze(1)

vision_tokens = input_tokens_tensor[vision_start_indices + 1]

image_nums = (vision_tokens == image_token_id).sum()

video_nums = (vision_tokens == video_token_id).sum()

llm_pos_ids_list: list = []

st = 0

remain_images, remain_videos = image_nums, video_nums

image_index, video_index = 0, 0

for _ in range(image_nums + video_nums):

video_second_per_grid_t = 0.0

if remain_images > 0:

try:

ed_image = input_tokens.index(image_token_id, st)

except ValueError:

ed_image = len(input_tokens) + 1

else:

ed_image = len(input_tokens) + 1

if remain_videos > 0:

try:

ed_video = input_tokens.index(video_token_id, st)

except ValueError:

ed_video = len(input_tokens) + 1

else:

ed_video = len(input_tokens) + 1

if ed_image < ed_video:

t, h, w = (

image_grid_thw[image_index][0],

image_grid_thw[image_index][1],

image_grid_thw[image_index][2],

)

image_index += 1

remain_images -= 1

ed = ed_image

else:

t, h, w = (

video_grid_thw[video_index][0],

video_grid_thw[video_index][1],

video_grid_thw[video_index][2],

)

video_second_per_grid_t = 1.0

if second_per_grid_ts:

video_second_per_grid_t = second_per_grid_ts[video_index]

video_index += 1

remain_videos -= 1

ed = ed_video

llm_grid_t, llm_grid_h, llm_grid_w = (

t,

h // spatial_merge_size,

w // spatial_merge_size,

)

text_len = ed - st

st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0

llm_pos_ids_list.append(

torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx

)

t_index = (

(

torch.arange(llm_grid_t)

.view(-1, 1)

.expand(-1, llm_grid_h * llm_grid_w)

  • video_second_per_grid_t

  • tokens_per_second

)

.long()

.flatten()

)

h_index = (

torch.arange(llm_grid_h)

.view(1, -1, 1)

.expand(llm_grid_t, -1, llm_grid_w)

.flatten()

)

w_index = (

torch.arange(llm_grid_w)

.view(1, 1, -1)

.expand(llm_grid_t, llm_grid_h, -1)

.flatten()

)

llm_pos_ids_list.append(

torch.stack([t_index, h_index, w_index]) + text_len + st_idx

)

st = ed + llm_grid_t * llm_grid_h * llm_grid_w

if st < len(input_tokens):

st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0

text_len = len(input_tokens) - st

llm_pos_ids_list.append(

torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx

)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)

mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()

llm_positions = llm_positions[:, context_len:seq_len]

return llm_positions, mrope_position_delta

def get_language_model(self) -> nn.Module:

return self.language_model

def _parse_and_validate_image_input(

self, **kwargs: object

) -> PaddleOCRImagePixelInputs | None:

pixel_values = kwargs.pop("pixel_values", None)

image_grid_thw = kwargs.pop("image_grid_thw", None)

if pixel_values is None:

return None

return PaddleOCRImagePixelInputs(

type="pixel_values",

pixel_values=pixel_values,

image_grid_thw=image_grid_thw,

)

def forward(

self,

input_ids: torch.Tensor,

positions: torch.Tensor,

intermediate_tensors: IntermediateTensors | None = None,

inputs_embeds: torch.Tensor | None = None,

**kwargs,

):

if intermediate_tensors is not None:

inputs_embeds = None

elif inputs_embeds is None:

vision_embeddings = self.get_multimodal_embeddings(** kwargs)

is_multimodal = kwargs.pop("is_multimodal", None)

handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False)

inputs_embeds = self.get_input_embeddings(

input_ids,

vision_embeddings,

is_multimodal=is_multimodal,

handle_oov_mm_token=handle_oov_mm_token,

)

input_ids = None

return self.language_model(

input_ids, positions, intermediate_tensors, inputs_embeds

)

@classmethod

def get_placeholder_str(cls, modality: str, i: int) -> str | None:

if modality.startswith("image"):

return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"

raise ValueError("仅支持图像模态")

def encode_image(

self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor

) -> torch.Tensor:

pixel_values = pixel_values.type(self.visual.dtype)

siglip_position_ids = list()

image_grid_hws = list()

cu_seqlens = [0]

thw_tuple = tuple(image_grid_thw.tolist())

numel = np.prod(thw_tuple)

image_grid_hws.append(thw_tuple)

image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])

siglip_position_ids.append(image_position_ids)

cu_seqlens.append(cu_seqlens[-1] + numel)

siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(

pixel_values.device

)

cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(pixel_values.device)

vision_outputs = self.visual(

pixel_values=pixel_values,

image_grid_thw=image_grid_hws,

position_ids=siglip_position_ids,

interpolate_pos_encoding=True,

cu_seqlens=cu_seqlens,

)

return vision_outputs

def _process_image_input(

self, image_input: PaddleOCRImagePixelInputs

) -> MultiModalEmbeddings:

pixel_values = image_input.pixel_values

image_grid_thw = image_input.image_grid_thw

vision_outputs = tuple(

self.encode_image(pixel, grid).squeeze(0)

for pixel, grid in zip(pixel_values, image_grid_thw)

)

image_embeds = self.mlp_AR(vision_outputs, image_grid_thw)

return image_embeds

def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:

image_input = self._parse_and_validate_image_input(** kwargs)

if image_input is None:

return ()

multimodal_embeddings: tuple[torch.Tensor, ...] = ()

image_embeds = self._process_image_input(image_input)

multimodal_embeddings += tuple(image_embeds)

return multimodal_embeddings

def load_weights(self, weights: Iterable

tuple\[str,torch.Tensortuple\[str, torch.Tensortuple\[str,torch.Tensor

]) -> set

strstrstr

:

loader = AutoWeightsLoader(self)

autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

return autoloaded_weights